Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| from datasets import load_from_disk | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| DataCollatorForSeq2Seq, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| ) | |
| from .common import ( | |
| ARTIFACT_DIR, | |
| DEFAULT_SUMMARY_COLUMN, | |
| DEFAULT_TARGET_MAX_LENGTH, | |
| DEFAULT_TEXT_COLUMN, | |
| build_compute_metrics, | |
| ensure_project_dirs, | |
| existing_default_checkpoint, | |
| load_tokenizer, | |
| maybe_limit_split, | |
| resolve_mixed_precision, | |
| resolve_model_reference, | |
| validate_model_dir, | |
| write_json, | |
| write_jsonl, | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Evaluate a fine-tuned summarization checkpoint." | |
| ) | |
| parser.add_argument( | |
| "--dataset-dir", required=True, help="Path produced by mlplo.data_cleaning." | |
| ) | |
| parser.add_argument("--model-path", default=existing_default_checkpoint()) | |
| parser.add_argument( | |
| "--split", default="test", choices=["train", "validation", "test"] | |
| ) | |
| parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN) | |
| parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN) | |
| parser.add_argument("--per-device-eval-batch-size", type=int, default=2) | |
| parser.add_argument( | |
| "--generation-max-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH | |
| ) | |
| parser.add_argument("--generation-num-beams", type=int, default=4) | |
| parser.add_argument("--max-samples", type=int, default=None) | |
| parser.add_argument("--preview-rows", type=int, default=5) | |
| parser.add_argument( | |
| "--include-bertscore", | |
| action="store_true", | |
| help=( | |
| "Compute BERTScore F1 in addition to ROUGE. " | |
| "Downloads a ~400 MB model on first use." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output-file", default=str(ARTIFACT_DIR / "eval_metrics.json") | |
| ) | |
| parser.add_argument( | |
| "--predictions-file", default=str(ARTIFACT_DIR / "sample_predictions.jsonl") | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| args = parse_args() | |
| ensure_project_dirs() | |
| if not args.model_path: | |
| raise ValueError( | |
| "No model path provided and no default checkpoint exists yet. " | |
| "Train a model first with mlplo.train." | |
| ) | |
| # ββ Validate dataset path βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| dataset_path = Path(args.dataset_dir) | |
| if not dataset_path.exists(): | |
| raise FileNotFoundError(f"Prepared dataset not found: {dataset_path}") | |
| # ββ Validate model directory ββββββββββββββββββββββββββββββββββββββββββββββ | |
| model_reference = resolve_model_reference(args.model_path) | |
| validate_model_dir(model_reference) | |
| LOGGER.info("Loading dataset from %s", dataset_path) | |
| tokenized_dataset = load_from_disk(str(dataset_path)) | |
| if args.split not in tokenized_dataset: | |
| available = list(tokenized_dataset.keys()) | |
| raise KeyError( | |
| f"Split '{args.split}' not found in dataset. Available: {available}" | |
| ) | |
| evaluation_split = maybe_limit_split( | |
| tokenized_dataset[args.split], args.max_samples | |
| ) | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LOGGER.info("Loading model from %s", model_reference) | |
| tokenizer = load_tokenizer(model_reference) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_reference) | |
| data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) | |
| precision = resolve_mixed_precision() | |
| if args.include_bertscore: | |
| LOGGER.info( | |
| "BERTScore enabled. A ~400 MB model will be downloaded on first use." | |
| ) | |
| compute_metrics = build_compute_metrics( | |
| tokenizer, include_bertscore=args.include_bertscore | |
| ) | |
| temp_output_dir = ARTIFACT_DIR / "eval_tmp" | |
| evaluation_args = Seq2SeqTrainingArguments( | |
| output_dir=str(temp_output_dir), | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, | |
| predict_with_generate=True, | |
| generation_max_length=args.generation_max_length, | |
| generation_num_beams=args.generation_num_beams, | |
| fp16=precision["fp16"], | |
| bf16=precision["bf16"], | |
| report_to="none", | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=evaluation_args, | |
| processing_class=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| LOGGER.info("Running evaluation on split '%s'β¦", args.split) | |
| prediction_output = trainer.predict(evaluation_split, metric_key_prefix=args.split) | |
| metrics = prediction_output.metrics | |
| write_json(args.output_file, metrics) | |
| LOGGER.info("Metrics: %s", metrics) | |
| # ββ Decode predictions and write sample file ββββββββββββββββββββββββββββββ | |
| generated_token_ids = prediction_output.predictions | |
| if isinstance(generated_token_ids, tuple): | |
| generated_token_ids = generated_token_ids[0] | |
| generated_token_ids = np.asarray(generated_token_ids) | |
| generated_token_ids = np.where( | |
| generated_token_ids < 0, tokenizer.pad_token_id, generated_token_ids | |
| ) | |
| decoded_predictions = tokenizer.batch_decode( | |
| generated_token_ids, skip_special_tokens=True | |
| ) | |
| # Guard against preview_rows exceeding available samples | |
| n_preview = min(args.preview_rows, len(decoded_predictions), len(evaluation_split)) | |
| preview_rows = [] | |
| for index in range(n_preview): | |
| row = evaluation_split[index] | |
| prediction = decoded_predictions[index].strip() | |
| record: dict = { | |
| "source": row.get(args.text_column, ""), | |
| "reference": row.get(args.summary_column, ""), | |
| "prediction": prediction, | |
| } | |
| if not prediction: | |
| record["empty_prediction"] = True | |
| LOGGER.warning("Empty prediction at index %d.", index) | |
| preview_rows.append(record) | |
| write_jsonl(args.predictions_file, preview_rows) | |
| LOGGER.info( | |
| "Evaluation complete. Metrics β %s | Predictions β %s", | |
| args.output_file, | |
| args.predictions_file, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |