SplatAtlas / scripts /compute_offline_physics.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
5.18 kB
import os
import json
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'item'):
return obj.item()
return super().default(obj)
import argparse
import torch
import numpy as np
from scipy.stats import entropy
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core.registry import METHOD_REGISTRY
import methods
def compute_gini(array):
array = array.flatten()
array = array[array > 0]
if len(array) == 0: return 0.0
array = np.sort(array)
index = np.arange(1, array.shape[0] + 1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))
def generate_standard_grid(cameras, resolution=64):
"""基于相机分布生成标准查询网格 (默认 64^3 约 26 万点,兼顾精度与显存)"""
if not cameras:
min_bounds = torch.tensor([-5.0, -5.0, -5.0])
max_bounds = torch.tensor([5.0, 5.0, 5.0])
else:
cam_centers = torch.stack([c.camera_center for c in cameras])
min_bounds, _ = torch.min(cam_centers, dim=0)
max_bounds, _ = torch.max(cam_centers, dim=0)
extent = max_bounds - min_bounds
min_bounds -= extent * 0.1
max_bounds += extent * 0.1
X = torch.linspace(min_bounds[0], max_bounds[0], resolution)
Y = torch.linspace(min_bounds[1], max_bounds[1], resolution)
Z = torch.linspace(min_bounds[2], max_bounds[2], resolution)
grid_x, grid_y, grid_z = torch.meshgrid(X, Y, Z, indexing='ij')
query_points = torch.stack([grid_x.flatten(), grid_y.flatten(), grid_z.flatten()], dim=-1)
return query_points.cuda()
def analyze_physics(model, cameras):
stats = {}
# 1. 变体自述物理指标 (特异性指标,如畸变度、Billboard偏差等)
try:
specific_metrics = model.compute_physical_metrics(cameras=cameras)
stats.update(specific_metrics)
except NotImplementedError:
print(" ⚠️ [Physics] Wrapper 未实现 compute_physical_metrics,跳过特异性指标。")
except Exception as e:
print(f" ⚠️ [Physics] 特异性指标计算发生错误: {e}")
# 2. 宏观标量场评估 (Apple-to-Apple 的绝对公平空间测试)
try:
query_points = generate_standard_grid(cameras, resolution=64)
with torch.no_grad():
scalar_field = model.evaluate_spatial_field(query_points, cameras=cameras)
field_np = scalar_field.cpu().numpy()
stats['spatial_gini'] = compute_gini(field_np)
p = field_np / (field_np.sum() + 1e-7)
stats['density_entropy'] = entropy(p[p > 0])
stats["Field_Total_Volume"] = float(np.sum(field_np))
except NotImplementedError:
print(" ⚠️ [Physics] Wrapper 未实现 evaluate_spatial_field,跳过空间分布诊断。")
except Exception as e:
print(f" ⚠️ [Physics] 空间标量场评估发生异常: {e}")
return {k: round(v, 4) if isinstance(v, float) else v for k, v in stats.items()}
def load_cameras_from_json(cam_path):
"""为了兼容旧流水线,简单构建具有 camera_center 属性的 Dummy Cam"""
if not os.path.exists(cam_path): return []
with open(cam_path, 'r') as f: cams_data = json.load(f)
class DummyCam:
def __init__(self, center):
self.camera_center = torch.tensor(center, dtype=torch.float32)
return [DummyCam(c['center']) for c in cams_data if 'center' in c]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, required=True, help="模型方法名称")
parser.add_argument("--source_path", type=str, default="", help="原始数据集路径")
parser.add_argument("--model_path", type=str, required=True, help="模型输出目录")
parser.add_argument("--iteration", type=int, default=30000)
args = parser.parse_args()
model_dir = os.path.abspath(args.model_path)
if not os.path.exists(model_dir):
print(f"⚠️ [Physics] 跳过: 模型目录不存在 {model_dir}")
return
folder_name = os.path.basename(model_dir)
print(f"🔬 [Physics] 正在精准解剖场景: {folder_name} ...")
# 实例化对应的 Wrapper 并加载资产
dataset_config = {"source_path": args.source_path, "model_path": args.model_path, "resolution": 1}
try:
model_class = methods.load_method(args.method)
model = model_class(dataset_config, hyperparams={})
model.load(args.model_path, args.iteration)
except Exception as e:
print(f"❌ [Physics] 实例化或加载模型失败: {e}")
return
cameras = model.scene.getTrainCameras() if hasattr(model, 'scene') else []
stats = analyze_physics(model, cameras)
if stats:
stats['experiment'] = folder_name
with open(os.path.join(model_dir, f"offline_physics_{args.iteration}.json"), 'w') as f:
json.dump(stats, f, indent=4, cls=NpEncoder)
print(f" ✨ [Physics] {folder_name} 分析完成!提取指标数: {len(stats)}")
if __name__ == "__main__":
main()