File size: 2,953 Bytes
d14d04e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
uv run plot_over_time.py --input ./phonikud_enhanced_out/ ./phonikud_vocalized_out/ ./unvocalized_mock_out/ ./vocalized_mock_out/ --metrics wer
"""
import argparse
import json
from pathlib import Path
import re

import matplotlib.pyplot as plt


def extract_epoch_step(name: str):
    # example key: "epoch=4679-step=3107342.onnx"
    epoch_match = re.search(r'epoch=(\d+)', name)
    step_match = re.search(r'step=(\d+)', name)
    epoch = int(epoch_match.group(1)) if epoch_match else -1
    step = int(step_match.group(1)) if step_match else -1
    return epoch, step

def load_reports_from_folder(folder: Path):
    overview_path = folder / "overview.json"
    if not overview_path.exists():
        print(f"No overview.json found in {folder}")
        return []

    with open(overview_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    reports = []
    for key, values in data.items():
        if key == "overall":
            continue
        epoch, step = extract_epoch_step(key)
        reports.append({
            "epoch": epoch,
            "step": step,
            "wer": values.get("mean_wer", 0),
            "cer": values.get("mean_cer", 0)
        })

    print(f"Found {len(reports)} reports in folder '{folder.name}'")

    # sort by epoch or step
    reports.sort(key=lambda x: (x["epoch"], x["step"]))
    return reports




def plot_all_models(folder_paths, output_path: Path, metrics: list):
    plt.figure(figsize=(12, 6))

    for folder in folder_paths:
        folder = Path(folder)
        label = folder.name
        reports = load_reports_from_folder(folder)
        if not reports:
            print(f"No reports found in {label}")
            continue

        # You can choose to plot by epoch or step, here by epoch:
        x = [r["epoch"] for r in reports]
        wers = [r["wer"] for r in reports]
        cers = [r["cer"] for r in reports]

        if 'wer' in metrics:
            plt.plot(x, wers, marker='o', label=f"{label} WER")
        if 'cer' in metrics:
            plt.plot(x, cers, marker='x', label=f"{label} CER")

    plt.title("Model WER/CER Progress Over Time")
    plt.xlabel("Epoch")
    plt.ylabel("Error Rate")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_path)
    print(f"Plot saved to {output_path}")



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', nargs='+', required=True, help='One or more folders with *_report.json files')
    parser.add_argument('--output', default='progress.png', help='Path to output image')
    parser.add_argument(
        '--metrics',
        nargs='+',
        choices=['cer', 'wer'],
        default=['cer', 'wer'],
        help='Which metrics to plot: cer, wer, or both (default: both)'
    )
    args = parser.parse_args()

    output_path = Path(args.output)
    plot_all_models(args.input, output_path, args.metrics)


if __name__ == "__main__":
    main()