SplatAtlas / scripts /extract_dynamics.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
6.7 kB
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tensorboard.backend.event_processing import event_accumulator
# === 论文级图表样式设置 ===
sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams.update({
'font.size': 12,
'axes.titlesize': 14,
'axes.labelsize': 12,
'figure.autolayout': True,
'figure.dpi': 300
})
def extract_tb_data(log_dir):
"""从 TensorBoard 目录中提取所有 Scalar 数据为 Pandas DataFrame"""
event_files = glob.glob(os.path.join(log_dir, 'events.out.tfevents.*'))
if not event_files:
return None
# size_guidance=0 保证我们不会截断任何一个 step 的数据
ea = event_accumulator.EventAccumulator(log_dir, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
extracted_data = {}
if 'scalars' not in ea.Tags():
return None
for tag in ea.Tags()['scalars']:
events = ea.Scalars(tag)
# 将每个指标转换为 DataFrame
extracted_data[tag] = pd.DataFrame(
[(e.step, e.value) for e in events],
columns=['step', tag]
).set_index('step')
if not extracted_data:
return None
# 以 step 为基准,横向拼接所有监控指标
df = pd.concat(extracted_data.values(), axis=1).reset_index()
return df
def main():
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs")
plots_dir = os.path.join(os.path.dirname(__file__), "..", "plots")
os.makedirs(plots_dir, exist_ok=True)
all_data = []
print("🔍 [Harvester] 正在扫描 TensorBoard 日志...")
for folder in os.listdir(outputs_dir):
log_dir = os.path.join(outputs_dir, folder)
# 忽略非目录和病理学隔离区(可选,这里为了看全盘先全扫)
if not os.path.isdir(log_dir) or folder == "pathology_archive":
continue
df = extract_tb_data(log_dir)
if df is not None:
# 假设文件夹命名规范为 method_scene
parts = folder.split('_', 1)
method = parts[0] if len(parts) > 1 else folder
df['method'] = method
df['experiment'] = folder
all_data.append(df)
print(f" ✅ 成功提取: {folder} ({len(df)} 条时间序列记录)")
if not all_data:
print("⚠️ 未找到任何有效的 TensorBoard 数据!")
return
master_df = pd.concat(all_data, ignore_index=True)
# 动态匹配我们埋下的探针列名
time_col = 'hardware/cumulative_time_sec' if 'hardware/cumulative_time_sec' in master_df.columns else 'step'
n_col = 'geometry/N_gaussians'
loss_col = 'train/loss'
cos_col = 'dynamics_decoupled/cos_similarity'
gamma_col = 'pathology/gamma_median'
print(f"\n📊 [Harvester] 正在使用 Matplotlib 绘制动力学大图,目标目录: {plots_dir}/")
# =========================================================
# 图表 1: 边际收益崩塌与致密化 S 曲线 (The Marginal Return Curve)
# =========================================================
if n_col in master_df.columns and loss_col in master_df.columns:
plt.figure(figsize=(10, 6))
ax1 = plt.gca()
ax2 = ax1.twinx()
colors = sns.color_palette("husl", len(master_df['experiment'].unique()))
for idx, exp in enumerate(master_df['experiment'].unique()):
exp_df = master_df[master_df['experiment'] == exp].dropna(subset=[time_col, n_col, loss_col])
if exp_df.empty: continue
color = colors[idx]
# 绘制 N 数量的 S 曲线 (实线)
ax1.plot(exp_df[time_col], exp_df[n_col], linestyle='-', linewidth=2.5, color=color, label=f"{exp} (N)")
# 绘制 Loss 下降曲线 (虚线)
ax2.plot(exp_df[time_col], exp_df[loss_col], linestyle='--', linewidth=2, color=color, alpha=0.6)
xlabel = 'Cumulative GPU Time (Seconds)' if time_col != 'step' else 'Training Steps'
ax1.set_xlabel(xlabel, fontweight='bold')
ax1.set_ylabel('Number of Gaussians (N)', color='black', fontweight='bold')
ax2.set_ylabel('Training Loss (L1 + SSIM)', color='gray', fontweight='bold')
plt.title('Densification Dynamics & Loss Convergence', pad=20, fontweight='bold')
ax1.legend(loc='upper left', bbox_to_anchor=(1.15, 1))
out_path1 = os.path.join(plots_dir, 'fig1_marginal_return.png')
plt.savefig(out_path1, bbox_inches='tight')
plt.close()
print(f" 📸 论文大图 1 已生成: {out_path1}")
# =========================================================
# 图表 2: 光度-几何代偿动力学散点图 (The Compensation Scatter)
# =========================================================
if cos_col in master_df.columns and gamma_col in master_df.columns:
plot_df = master_df.dropna(subset=[cos_col, gamma_col]).copy()
if not plot_df.empty:
plt.figure(figsize=(9, 7))
# 使用散点图,横坐标是算子夹角,纵坐标是畸变率,点的大小代表训练进度(Step)
scatter = sns.scatterplot(
data=plot_df,
x=cos_col,
y=gamma_col,
hue='method',
size='step',
sizes=(20, 200),
alpha=0.8,
palette='Set1',
edgecolor='black'
)
# 增加物理意义的警戒线
plt.axvline(x=0, color='gray', linestyle='--', alpha=0.5, label='Orthogonal Gradients (0)')
plt.axhline(y=18, color='red', linestyle=':', linewidth=2, alpha=0.7, label=r'Pathology Threshold ($\gamma=18$)')
plt.xlabel(r'Gradient Cosine Similarity $\cos(\theta)$ (Target vs Parasitic)', fontweight='bold')
plt.ylabel(r'Anisotropy Distortion $\gamma$ (Max/Min Scale)', fontweight='bold')
plt.title('Photometric-Geometric Compensation Phase Space', pad=20, fontweight='bold')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
out_path2 = os.path.join(plots_dir, 'fig2_compensation_scatter.png')
plt.savefig(out_path2, bbox_inches='tight')
plt.close()
print(f" 📸 论文大图 2 已生成: {out_path2}")
print("✅ [Harvester] 您的核心科研图表已就绪!")
if __name__ == "__main__":
main()