thewh1teagle
Fix CSV path in unvocalized mock phonemes command and sort DataFrame by file_id
d14d04e
| """ | |
| uv run plot_over_time.py --input ./phonikud_enhanced_out/ ./phonikud_vocalized_out/ ./unvocalized_mock_out/ ./vocalized_mock_out/ --metrics wer | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import re | |
| import matplotlib.pyplot as plt | |
| def extract_epoch_step(name: str): | |
| # example key: "epoch=4679-step=3107342.onnx" | |
| epoch_match = re.search(r'epoch=(\d+)', name) | |
| step_match = re.search(r'step=(\d+)', name) | |
| epoch = int(epoch_match.group(1)) if epoch_match else -1 | |
| step = int(step_match.group(1)) if step_match else -1 | |
| return epoch, step | |
| def load_reports_from_folder(folder: Path): | |
| overview_path = folder / "overview.json" | |
| if not overview_path.exists(): | |
| print(f"No overview.json found in {folder}") | |
| return [] | |
| with open(overview_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| reports = [] | |
| for key, values in data.items(): | |
| if key == "overall": | |
| continue | |
| epoch, step = extract_epoch_step(key) | |
| reports.append({ | |
| "epoch": epoch, | |
| "step": step, | |
| "wer": values.get("mean_wer", 0), | |
| "cer": values.get("mean_cer", 0) | |
| }) | |
| print(f"Found {len(reports)} reports in folder '{folder.name}'") | |
| # sort by epoch or step | |
| reports.sort(key=lambda x: (x["epoch"], x["step"])) | |
| return reports | |
| def plot_all_models(folder_paths, output_path: Path, metrics: list): | |
| plt.figure(figsize=(12, 6)) | |
| for folder in folder_paths: | |
| folder = Path(folder) | |
| label = folder.name | |
| reports = load_reports_from_folder(folder) | |
| if not reports: | |
| print(f"No reports found in {label}") | |
| continue | |
| # You can choose to plot by epoch or step, here by epoch: | |
| x = [r["epoch"] for r in reports] | |
| wers = [r["wer"] for r in reports] | |
| cers = [r["cer"] for r in reports] | |
| if 'wer' in metrics: | |
| plt.plot(x, wers, marker='o', label=f"{label} WER") | |
| if 'cer' in metrics: | |
| plt.plot(x, cers, marker='x', label=f"{label} CER") | |
| plt.title("Model WER/CER Progress Over Time") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Error Rate") | |
| plt.grid(True) | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(output_path) | |
| print(f"Plot saved to {output_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', nargs='+', required=True, help='One or more folders with *_report.json files') | |
| parser.add_argument('--output', default='progress.png', help='Path to output image') | |
| parser.add_argument( | |
| '--metrics', | |
| nargs='+', | |
| choices=['cer', 'wer'], | |
| default=['cer', 'wer'], | |
| help='Which metrics to plot: cer, wer, or both (default: both)' | |
| ) | |
| args = parser.parse_args() | |
| output_path = Path(args.output) | |
| plot_all_models(args.input, output_path, args.metrics) | |
| if __name__ == "__main__": | |
| main() | |