Adive01 commited on
Commit
8cf0774
·
verified ·
1 Parent(s): fdcc442

Upload mlplo/report.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/report.py +179 -0
mlplo/report.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ from .common import ARTIFACT_DIR, existing_default_checkpoint
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ parser = argparse.ArgumentParser(description="Generate an HTML evaluation report.")
15
+ parser.add_argument(
16
+ "--checkpoint-dir",
17
+ default=existing_default_checkpoint(),
18
+ help="Path to the trained model checkpoint directory containing metrics.",
19
+ )
20
+ parser.add_argument(
21
+ "--output-file",
22
+ default=str(ARTIFACT_DIR / "eval_report.html"),
23
+ help="Output HTML file path.",
24
+ )
25
+ return parser.parse_args()
26
+
27
+
28
+ def load_metrics(checkpoint_dir: Path) -> dict[str, dict[str, float]]:
29
+ metrics = {}
30
+ metrics_dir = checkpoint_dir / "metrics"
31
+ if not metrics_dir.exists():
32
+ return metrics
33
+
34
+ for split in ["train", "validation", "test"]:
35
+ file_path = metrics_dir / f"{split}_metrics.json"
36
+ if file_path.exists():
37
+ try:
38
+ metrics[split] = json.loads(file_path.read_text(encoding="utf-8"))
39
+ except Exception as e:
40
+ LOGGER.warning(f"Failed to load {file_path}: {e}")
41
+ return metrics
42
+
43
+
44
+ def load_predictions(checkpoint_dir: Path) -> list[dict]:
45
+ # We look for the predictions file in the artifact directory,
46
+ # since eval.py writes it there by default.
47
+ pred_file = ARTIFACT_DIR / "sample_predictions.jsonl"
48
+ preds = []
49
+ if pred_file.exists():
50
+ try:
51
+ for line in pred_file.read_text(encoding="utf-8").splitlines():
52
+ if line.strip():
53
+ preds.append(json.loads(line))
54
+ except Exception as e:
55
+ LOGGER.warning(f"Failed to load predictions from {pred_file}: {e}")
56
+ return preds
57
+
58
+
59
+ def generate_html(checkpoint_name: str, metrics: dict, predictions: list) -> str:
60
+ html = f"""
61
+ <!DOCTYPE html>
62
+ <html>
63
+ <head>
64
+ <title>Evaluation Report - {checkpoint_name}</title>
65
+ <style>
66
+ body {{ font-family: sans-serif; margin: 40px; color: #333; }}
67
+ h1 {{ color: #2c3e50; border-bottom: 2px solid #eee; padding-bottom: 10px; }}
68
+ h2 {{ color: #34495e; margin-top: 30px; }}
69
+ table {{ border-collapse: collapse; width: 100%; margin-bottom: 30px; }}
70
+ th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
71
+ th {{ background-color: #f8f9fa; font-weight: bold; }}
72
+ tr:nth-child(even) {{ background-color: #fcfcfc; }}
73
+ .metric-val {{ font-family: monospace; font-size: 1.1em; }}
74
+ .pred-box {{ background: #f8f9fa; padding: 15px; border-radius: 5px; margin-bottom: 20px; border-left: 4px solid #3498db; }}
75
+ .pred-source {{ font-size: 0.9em; color: #666; margin-bottom: 10px; }}
76
+ .pred-ref {{ font-weight: bold; color: #27ae60; margin-bottom: 5px; }}
77
+ .pred-out {{ font-weight: bold; color: #8e44ad; }}
78
+ .empty-warn {{ color: #e74c3c; font-weight: bold; }}
79
+ </style>
80
+ </head>
81
+ <body>
82
+ <h1>Model Evaluation Report</h1>
83
+ <p><strong>Checkpoint:</strong> <code>{checkpoint_name}</code></p>
84
+
85
+ <h2>Overall Metrics</h2>
86
+ <table>
87
+ <tr>
88
+ <th>Split</th>
89
+ <th>Loss</th>
90
+ <th>ROUGE-1</th>
91
+ <th>ROUGE-2</th>
92
+ <th>ROUGE-L</th>
93
+ <th>BERTScore F1</th>
94
+ <th>Avg Gen Length</th>
95
+ </tr>
96
+ """
97
+
98
+ for split in ["train", "validation", "test"]:
99
+ m = metrics.get(split, {})
100
+ if not m:
101
+ continue
102
+
103
+ prefix = split + "_" if split != "train" else ""
104
+
105
+ loss = m.get(f"{prefix}loss", m.get("train_loss", "-"))
106
+ r1 = m.get(f"{prefix}rouge1", "-")
107
+ r2 = m.get(f"{prefix}rouge2", "-")
108
+ rl = m.get(f"{prefix}rougeL", "-")
109
+ bf1 = m.get(f"{prefix}bertscore_f1", "-")
110
+ glen = m.get(f"{prefix}gen_len", "-")
111
+
112
+ def fmt(v):
113
+ return f"{v:.4f}" if isinstance(v, float) else str(v)
114
+
115
+ html += f"""
116
+ <tr>
117
+ <td><strong>{split.title()}</strong></td>
118
+ <td class="metric-val">{fmt(loss)}</td>
119
+ <td class="metric-val">{fmt(r1)}</td>
120
+ <td class="metric-val">{fmt(r2)}</td>
121
+ <td class="metric-val">{fmt(rl)}</td>
122
+ <td class="metric-val">{fmt(bf1)}</td>
123
+ <td class="metric-val">{fmt(glen)}</td>
124
+ </tr>
125
+ """
126
+
127
+ html += """
128
+ </table>
129
+
130
+ <h2>Sample Predictions</h2>
131
+ """
132
+
133
+ if not predictions:
134
+ html += "<p>No predictions found.</p>"
135
+ else:
136
+ for i, p in enumerate(predictions):
137
+ empty_tag = " <span class='empty-warn'>(EMPTY PREDICTION)</span>" if p.get("empty_prediction") else ""
138
+ html += f"""
139
+ <div class="pred-box">
140
+ <div class="pred-source"><strong>Source:</strong> {p.get("source", "")}</div>
141
+ <div class="pred-ref">Target: {p.get("reference", "")}</div>
142
+ <div class="pred-out">Model:{empty_tag} {p.get("prediction", "")}</div>
143
+ </div>
144
+ """
145
+
146
+ html += """
147
+ </body>
148
+ </html>
149
+ """
150
+ return html
151
+
152
+
153
+ def main() -> None:
154
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
155
+ args = parse_args()
156
+
157
+ if not args.checkpoint_dir:
158
+ LOGGER.error("No checkpoint directory provided or found.")
159
+ return
160
+
161
+ checkpoint_path = Path(args.checkpoint_dir)
162
+ if not checkpoint_path.exists():
163
+ LOGGER.error(f"Checkpoint directory not found: {checkpoint_path}")
164
+ return
165
+
166
+ metrics = load_metrics(checkpoint_path)
167
+ predictions = load_predictions(checkpoint_path)
168
+
169
+ html_content = generate_html(checkpoint_path.name, metrics, predictions)
170
+
171
+ out_file = Path(args.output_file)
172
+ out_file.parent.mkdir(parents=True, exist_ok=True)
173
+ out_file.write_text(html_content, encoding="utf-8")
174
+
175
+ LOGGER.info(f"Evaluation report generated at: {out_file.absolute()}")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()