Adive01 commited on
Commit
fdcc442
Β·
verified Β·
1 Parent(s): 8beebbb

Upload mlplo/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/eval.py +189 -0
mlplo/eval.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from datasets import load_from_disk
9
+ from transformers import (
10
+ AutoModelForSeq2SeqLM,
11
+ DataCollatorForSeq2Seq,
12
+ Seq2SeqTrainer,
13
+ Seq2SeqTrainingArguments,
14
+ )
15
+
16
+ from .common import (
17
+ ARTIFACT_DIR,
18
+ DEFAULT_SUMMARY_COLUMN,
19
+ DEFAULT_TARGET_MAX_LENGTH,
20
+ DEFAULT_TEXT_COLUMN,
21
+ build_compute_metrics,
22
+ ensure_project_dirs,
23
+ existing_default_checkpoint,
24
+ load_tokenizer,
25
+ maybe_limit_split,
26
+ resolve_mixed_precision,
27
+ resolve_model_reference,
28
+ validate_model_dir,
29
+ write_json,
30
+ write_jsonl,
31
+ )
32
+
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ def parse_args() -> argparse.Namespace:
37
+ parser = argparse.ArgumentParser(
38
+ description="Evaluate a fine-tuned summarization checkpoint."
39
+ )
40
+ parser.add_argument(
41
+ "--dataset-dir", required=True, help="Path produced by mlplo.data_cleaning."
42
+ )
43
+ parser.add_argument("--model-path", default=existing_default_checkpoint())
44
+ parser.add_argument(
45
+ "--split", default="test", choices=["train", "validation", "test"]
46
+ )
47
+ parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN)
48
+ parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN)
49
+ parser.add_argument("--per-device-eval-batch-size", type=int, default=2)
50
+ parser.add_argument(
51
+ "--generation-max-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH
52
+ )
53
+ parser.add_argument("--generation-num-beams", type=int, default=4)
54
+ parser.add_argument("--max-samples", type=int, default=None)
55
+ parser.add_argument("--preview-rows", type=int, default=5)
56
+ parser.add_argument(
57
+ "--include-bertscore",
58
+ action="store_true",
59
+ help=(
60
+ "Compute BERTScore F1 in addition to ROUGE. "
61
+ "Downloads a ~400 MB model on first use."
62
+ ),
63
+ )
64
+ parser.add_argument(
65
+ "--output-file", default=str(ARTIFACT_DIR / "eval_metrics.json")
66
+ )
67
+ parser.add_argument(
68
+ "--predictions-file", default=str(ARTIFACT_DIR / "sample_predictions.jsonl")
69
+ )
70
+ return parser.parse_args()
71
+
72
+
73
+ def main() -> None:
74
+ logging.basicConfig(
75
+ level=logging.INFO,
76
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
77
+ )
78
+ args = parse_args()
79
+ ensure_project_dirs()
80
+
81
+ if not args.model_path:
82
+ raise ValueError(
83
+ "No model path provided and no default checkpoint exists yet. "
84
+ "Train a model first with mlplo.train."
85
+ )
86
+
87
+ # ── Validate dataset path ─────────────────────────────────────────────────
88
+ dataset_path = Path(args.dataset_dir)
89
+ if not dataset_path.exists():
90
+ raise FileNotFoundError(f"Prepared dataset not found: {dataset_path}")
91
+
92
+ # ── Validate model directory ──────────────────────────────────────────────
93
+ model_reference = resolve_model_reference(args.model_path)
94
+ validate_model_dir(model_reference)
95
+
96
+ LOGGER.info("Loading dataset from %s", dataset_path)
97
+ tokenized_dataset = load_from_disk(str(dataset_path))
98
+
99
+ if args.split not in tokenized_dataset:
100
+ available = list(tokenized_dataset.keys())
101
+ raise KeyError(
102
+ f"Split '{args.split}' not found in dataset. Available: {available}"
103
+ )
104
+
105
+ evaluation_split = maybe_limit_split(
106
+ tokenized_dataset[args.split], args.max_samples
107
+ )
108
+
109
+ # ── Load model ────────────────────────────────────────────────────────────
110
+ LOGGER.info("Loading model from %s", model_reference)
111
+ tokenizer = load_tokenizer(model_reference)
112
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_reference)
113
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
114
+ precision = resolve_mixed_precision()
115
+
116
+ if args.include_bertscore:
117
+ LOGGER.info(
118
+ "BERTScore enabled. A ~400 MB model will be downloaded on first use."
119
+ )
120
+
121
+ compute_metrics = build_compute_metrics(
122
+ tokenizer, include_bertscore=args.include_bertscore
123
+ )
124
+
125
+ temp_output_dir = ARTIFACT_DIR / "eval_tmp"
126
+ evaluation_args = Seq2SeqTrainingArguments(
127
+ output_dir=str(temp_output_dir),
128
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
129
+ predict_with_generate=True,
130
+ generation_max_length=args.generation_max_length,
131
+ generation_num_beams=args.generation_num_beams,
132
+ fp16=precision["fp16"],
133
+ bf16=precision["bf16"],
134
+ report_to="none",
135
+ )
136
+
137
+ trainer = Seq2SeqTrainer(
138
+ model=model,
139
+ args=evaluation_args,
140
+ processing_class=tokenizer,
141
+ data_collator=data_collator,
142
+ compute_metrics=compute_metrics,
143
+ )
144
+
145
+ LOGGER.info("Running evaluation on split '%s'…", args.split)
146
+ prediction_output = trainer.predict(evaluation_split, metric_key_prefix=args.split)
147
+ metrics = prediction_output.metrics
148
+ write_json(args.output_file, metrics)
149
+ LOGGER.info("Metrics: %s", metrics)
150
+
151
+ # ── Decode predictions and write sample file ──────────────────────────────
152
+ generated_token_ids = prediction_output.predictions
153
+ if isinstance(generated_token_ids, tuple):
154
+ generated_token_ids = generated_token_ids[0]
155
+
156
+ generated_token_ids = np.asarray(generated_token_ids)
157
+ generated_token_ids = np.where(
158
+ generated_token_ids < 0, tokenizer.pad_token_id, generated_token_ids
159
+ )
160
+ decoded_predictions = tokenizer.batch_decode(
161
+ generated_token_ids, skip_special_tokens=True
162
+ )
163
+
164
+ # Guard against preview_rows exceeding available samples
165
+ n_preview = min(args.preview_rows, len(decoded_predictions), len(evaluation_split))
166
+ preview_rows = []
167
+ for index in range(n_preview):
168
+ row = evaluation_split[index]
169
+ prediction = decoded_predictions[index].strip()
170
+ record: dict = {
171
+ "source": row.get(args.text_column, ""),
172
+ "reference": row.get(args.summary_column, ""),
173
+ "prediction": prediction,
174
+ }
175
+ if not prediction:
176
+ record["empty_prediction"] = True
177
+ LOGGER.warning("Empty prediction at index %d.", index)
178
+ preview_rows.append(record)
179
+
180
+ write_jsonl(args.predictions_file, preview_rows)
181
+ LOGGER.info(
182
+ "Evaluation complete. Metrics β†’ %s | Predictions β†’ %s",
183
+ args.output_file,
184
+ args.predictions_file,
185
+ )
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()