Spaces:
Running on Zero
Running on Zero
YingMusic-SingerGPU / src /third_party /MusicSourceSeparationTraining /scripts /watch_all_metrics.py
| """ | |
| Script to print metrics for checkpoint file of new format | |
| """ | |
| import argparse | |
| import torch | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--start_check_point", | |
| type=str, | |
| default="", | |
| help="Initial checkpoint to start training", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| checkpoint = torch.load( | |
| args.start_check_point, weights_only=False, map_location="cpu" | |
| ) | |
| all_metrics = checkpoint["all_metrics"] | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| def _fmt_list(xs: List[float], n: int, p: int) -> str: | |
| if not xs: | |
| return "[]" | |
| xs = [float(x) for x in xs] | |
| head = ", ".join(f"{v:.{p}f}" for v in xs[:n]) | |
| if len(xs) > n: | |
| return f"[{head}, … {len(xs) - n} more]" | |
| return f"[{head}]" | |
| def pretty_metrics( | |
| all_time_all_metrics: Dict[str, Any], *, precision: int = 4, show_items: int = 6 | |
| ) -> str: | |
| lines = [] | |
| for epoch_key in sorted( | |
| all_time_all_metrics.keys(), key=lambda k: int(k.split("_")[-1]) | |
| ): | |
| m = all_time_all_metrics[epoch_key] | |
| lines.append(f"\n=== {epoch_key} ===") | |
| for metric_name, per_instr in m.items(): | |
| lines.append(f"{metric_name}:") | |
| for instr, values in per_instr.items(): | |
| arr = np.array(values, dtype=float) | |
| mean = np.mean(arr) if arr.size else float("nan") | |
| std = np.std(arr) if arr.size else float("nan") | |
| preview = _fmt_list(values, show_items, precision) | |
| lines.append( | |
| f" {instr:>10s}: mean={mean:.{precision}f} std={std:.{precision}f} " | |
| f"n={arr.size} values={preview}" | |
| ) | |
| return "\n".join(lines) | |
| print(pretty_metrics(all_metrics, precision=4, show_items=6)) | |
| if __name__ == "__main__": | |
| main() | |