""" uv run plot_over_time.py --input ./phonikud_enhanced_out/ ./phonikud_vocalized_out/ ./unvocalized_mock_out/ ./vocalized_mock_out/ --metrics wer """ import argparse import json from pathlib import Path import re import matplotlib.pyplot as plt def extract_epoch_step(name: str): # example key: "epoch=4679-step=3107342.onnx" epoch_match = re.search(r'epoch=(\d+)', name) step_match = re.search(r'step=(\d+)', name) epoch = int(epoch_match.group(1)) if epoch_match else -1 step = int(step_match.group(1)) if step_match else -1 return epoch, step def load_reports_from_folder(folder: Path): overview_path = folder / "overview.json" if not overview_path.exists(): print(f"No overview.json found in {folder}") return [] with open(overview_path, "r", encoding="utf-8") as f: data = json.load(f) reports = [] for key, values in data.items(): if key == "overall": continue epoch, step = extract_epoch_step(key) reports.append({ "epoch": epoch, "step": step, "wer": values.get("mean_wer", 0), "cer": values.get("mean_cer", 0) }) print(f"Found {len(reports)} reports in folder '{folder.name}'") # sort by epoch or step reports.sort(key=lambda x: (x["epoch"], x["step"])) return reports def plot_all_models(folder_paths, output_path: Path, metrics: list): plt.figure(figsize=(12, 6)) for folder in folder_paths: folder = Path(folder) label = folder.name reports = load_reports_from_folder(folder) if not reports: print(f"No reports found in {label}") continue # You can choose to plot by epoch or step, here by epoch: x = [r["epoch"] for r in reports] wers = [r["wer"] for r in reports] cers = [r["cer"] for r in reports] if 'wer' in metrics: plt.plot(x, wers, marker='o', label=f"{label} WER") if 'cer' in metrics: plt.plot(x, cers, marker='x', label=f"{label} CER") plt.title("Model WER/CER Progress Over Time") plt.xlabel("Epoch") plt.ylabel("Error Rate") plt.grid(True) plt.legend() plt.tight_layout() plt.savefig(output_path) print(f"Plot saved to {output_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument('--input', nargs='+', required=True, help='One or more folders with *_report.json files') parser.add_argument('--output', default='progress.png', help='Path to output image') parser.add_argument( '--metrics', nargs='+', choices=['cer', 'wer'], default=['cer', 'wer'], help='Which metrics to plot: cer, wer, or both (default: both)' ) args = parser.parse_args() output_path = Path(args.output) plot_all_models(args.input, output_path, args.metrics) if __name__ == "__main__": main()