phonikud-experiments / ablation /plot_over_time.py
thewh1teagle
Fix CSV path in unvocalized mock phonemes command and sort DataFrame by file_id
d14d04e
"""
uv run plot_over_time.py --input ./phonikud_enhanced_out/ ./phonikud_vocalized_out/ ./unvocalized_mock_out/ ./vocalized_mock_out/ --metrics wer
"""
import argparse
import json
from pathlib import Path
import re
import matplotlib.pyplot as plt
def extract_epoch_step(name: str):
# example key: "epoch=4679-step=3107342.onnx"
epoch_match = re.search(r'epoch=(\d+)', name)
step_match = re.search(r'step=(\d+)', name)
epoch = int(epoch_match.group(1)) if epoch_match else -1
step = int(step_match.group(1)) if step_match else -1
return epoch, step
def load_reports_from_folder(folder: Path):
overview_path = folder / "overview.json"
if not overview_path.exists():
print(f"No overview.json found in {folder}")
return []
with open(overview_path, "r", encoding="utf-8") as f:
data = json.load(f)
reports = []
for key, values in data.items():
if key == "overall":
continue
epoch, step = extract_epoch_step(key)
reports.append({
"epoch": epoch,
"step": step,
"wer": values.get("mean_wer", 0),
"cer": values.get("mean_cer", 0)
})
print(f"Found {len(reports)} reports in folder '{folder.name}'")
# sort by epoch or step
reports.sort(key=lambda x: (x["epoch"], x["step"]))
return reports
def plot_all_models(folder_paths, output_path: Path, metrics: list):
plt.figure(figsize=(12, 6))
for folder in folder_paths:
folder = Path(folder)
label = folder.name
reports = load_reports_from_folder(folder)
if not reports:
print(f"No reports found in {label}")
continue
# You can choose to plot by epoch or step, here by epoch:
x = [r["epoch"] for r in reports]
wers = [r["wer"] for r in reports]
cers = [r["cer"] for r in reports]
if 'wer' in metrics:
plt.plot(x, wers, marker='o', label=f"{label} WER")
if 'cer' in metrics:
plt.plot(x, cers, marker='x', label=f"{label} CER")
plt.title("Model WER/CER Progress Over Time")
plt.xlabel("Epoch")
plt.ylabel("Error Rate")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(output_path)
print(f"Plot saved to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', nargs='+', required=True, help='One or more folders with *_report.json files')
parser.add_argument('--output', default='progress.png', help='Path to output image')
parser.add_argument(
'--metrics',
nargs='+',
choices=['cer', 'wer'],
default=['cer', 'wer'],
help='Which metrics to plot: cer, wer, or both (default: both)'
)
args = parser.parse_args()
output_path = Path(args.output)
plot_all_models(args.input, output_path, args.metrics)
if __name__ == "__main__":
main()