Adive01 commited on
Commit
fecd2b2
·
verified ·
1 Parent(s): 6aef09e

Upload mlplo/compare.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/compare.py +222 -0
mlplo/compare.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import evaluate
8
+ import numpy as np
9
+ import torch
10
+ from datasets import load_from_disk
11
+ from transformers import AutoModelForSeq2SeqLM
12
+
13
+ from .common import (
14
+ ARTIFACT_DIR,
15
+ DEFAULT_SUMMARY_COLUMN,
16
+ DEFAULT_TEXT_COLUMN,
17
+ ensure_project_dirs,
18
+ load_tokenizer,
19
+ maybe_limit_split,
20
+ resolve_model_reference,
21
+ validate_model_dir,
22
+ )
23
+
24
+ LOGGER = logging.getLogger(__name__)
25
+
26
+
27
+ def parse_args() -> argparse.Namespace:
28
+ parser = argparse.ArgumentParser(
29
+ description="Compare two models side-by-side on a test set."
30
+ )
31
+ parser.add_argument("--model-a", required=True, help="Path to Model A checkpoint.")
32
+ parser.add_argument("--model-b", required=True, help="Path to Model B checkpoint.")
33
+ parser.add_argument(
34
+ "--dataset-dir", required=True, help="Prepared dataset directory."
35
+ )
36
+ parser.add_argument("--split", default="test")
37
+ parser.add_argument("--max-samples", type=int, default=20)
38
+ parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN)
39
+ parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN)
40
+ parser.add_argument(
41
+ "--output-file", default=str(ARTIFACT_DIR / "comparison.html")
42
+ )
43
+ return parser.parse_args()
44
+
45
+
46
+ @torch.inference_mode()
47
+ def generate_summaries(
48
+ model_path: str, dataset, text_col: str, device: torch.device
49
+ ) -> list[str]:
50
+ ref = resolve_model_reference(model_path)
51
+ validate_model_dir(ref)
52
+
53
+ LOGGER.info(f"Loading {ref}...")
54
+ tokenizer = load_tokenizer(ref)
55
+ model = AutoModelForSeq2SeqLM.from_pretrained(ref).to(device)
56
+ model.eval()
57
+
58
+ predictions = []
59
+ for item in dataset:
60
+ text = item[text_col]
61
+ inputs = tokenizer(
62
+ text, return_tensors="pt", truncation=True, max_length=512
63
+ ).to(device)
64
+ out = model.generate(**inputs, max_length=128, num_beams=4)
65
+ pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
66
+ predictions.append(pred)
67
+
68
+ del model
69
+ torch.cuda.empty_cache()
70
+ return predictions
71
+
72
+
73
+ def score_predictions(predictions: list[str], references: list[str]) -> dict:
74
+ rouge = evaluate.load("rouge")
75
+ r_res = rouge.compute(
76
+ predictions=predictions, references=references, use_stemmer=True
77
+ )
78
+
79
+ from bert_score import score as bert_score_fn
80
+ safe_preds = [p if p.strip() else "..." for p in predictions]
81
+ safe_refs = [r if r.strip() else "..." for r in references]
82
+
83
+ LOGGER.info("Computing BERTScore...")
84
+ _, _, f1 = bert_score_fn(safe_preds, safe_refs, lang="en", verbose=False)
85
+
86
+ return {
87
+ "rouge1": r_res["rouge1"],
88
+ "rouge2": r_res["rouge2"],
89
+ "rougeL": r_res["rougeL"],
90
+ "bertscore": float(f1.mean().item()),
91
+ }
92
+
93
+
94
+ def generate_html(
95
+ model_a_name: str,
96
+ model_b_name: str,
97
+ scores_a: dict,
98
+ scores_b: dict,
99
+ dataset,
100
+ preds_a: list[str],
101
+ preds_b: list[str],
102
+ text_col: str,
103
+ sum_col: str,
104
+ ) -> str:
105
+ html = f"""
106
+ <!DOCTYPE html>
107
+ <html>
108
+ <head>
109
+ <title>Model Comparison</title>
110
+ <style>
111
+ body {{ font-family: sans-serif; margin: 40px; color: #333; }}
112
+ table {{ border-collapse: collapse; width: 100%; margin-bottom: 30px; }}
113
+ th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; vertical-align: top; }}
114
+ th {{ background-color: #f8f9fa; font-weight: bold; }}
115
+ .better {{ background-color: #e8f5e9; font-weight: bold; color: #2e7d32; }}
116
+ .source-col {{ width: 30%; font-size: 0.9em; color: #555; }}
117
+ .ref-col {{ width: 20%; font-size: 0.9em; background: #fafafa; }}
118
+ .pred-col {{ width: 25%; }}
119
+ </style>
120
+ </head>
121
+ <body>
122
+ <h1>Model Comparison</h1>
123
+
124
+ <h2>Aggregate Scores</h2>
125
+ <table>
126
+ <tr>
127
+ <th>Metric</th>
128
+ <th>Model A: {model_a_name}</th>
129
+ <th>Model B: {model_b_name}</th>
130
+ </tr>
131
+ """
132
+
133
+ for k in ["rouge1", "rouge2", "rougeL", "bertscore"]:
134
+ va = scores_a[k]
135
+ vb = scores_b[k]
136
+ ca = "better" if va >= vb else ""
137
+ cb = "better" if vb > va else ""
138
+ html += f"""
139
+ <tr>
140
+ <td><strong>{k.upper()}</strong></td>
141
+ <td class="{ca}">{va:.4f}</td>
142
+ <td class="{cb}">{vb:.4f}</td>
143
+ </tr>
144
+ """
145
+
146
+ html += """
147
+ </table>
148
+
149
+ <h2>Side-by-Side Predictions</h2>
150
+ <table>
151
+ <tr>
152
+ <th>Source</th>
153
+ <th>Reference</th>
154
+ <th>Model A</th>
155
+ <th>Model B</th>
156
+ </tr>
157
+ """
158
+
159
+ for i, item in enumerate(dataset):
160
+ html += f"""
161
+ <tr>
162
+ <td class="source-col">{item[text_col]}</td>
163
+ <td class="ref-col">{item[sum_col]}</td>
164
+ <td class="pred-col">{preds_a[i]}</td>
165
+ <td class="pred-col">{preds_b[i]}</td>
166
+ </tr>
167
+ """
168
+
169
+ html += """
170
+ </table>
171
+ </body>
172
+ </html>
173
+ """
174
+ return html
175
+
176
+
177
+ def main() -> None:
178
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
179
+ args = parse_args()
180
+ ensure_project_dirs()
181
+
182
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
183
+
184
+ LOGGER.info(f"Loading dataset {args.dataset_dir} (split: {args.split})...")
185
+ dataset = load_from_disk(args.dataset_dir)[args.split]
186
+ dataset = maybe_limit_split(dataset, args.max_samples)
187
+
188
+ refs = [item[args.summary_column] for item in dataset]
189
+
190
+ LOGGER.info("--- Processing Model A ---")
191
+ preds_a = generate_summaries(args.model_a, dataset, args.text_column, device)
192
+ scores_a = score_predictions(preds_a, refs)
193
+
194
+ LOGGER.info("--- Processing Model B ---")
195
+ preds_b = generate_summaries(args.model_b, dataset, args.text_column, device)
196
+ scores_b = score_predictions(preds_b, refs)
197
+
198
+ name_a = Path(args.model_a).name
199
+ name_b = Path(args.model_b).name
200
+
201
+ LOGGER.info("Generating HTML report...")
202
+ html = generate_html(
203
+ name_a,
204
+ name_b,
205
+ scores_a,
206
+ scores_b,
207
+ dataset,
208
+ preds_a,
209
+ preds_b,
210
+ args.text_column,
211
+ args.summary_column,
212
+ )
213
+
214
+ out_file = Path(args.output_file)
215
+ out_file.parent.mkdir(parents=True, exist_ok=True)
216
+ out_file.write_text(html, encoding="utf-8")
217
+
218
+ LOGGER.info(f"Comparison report written to {out_file.absolute()}")
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main()