SummaryGenerator / mlplo /report.py
Adive01's picture
Upload mlplo/report.py with huggingface_hub
8cf0774 verified
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from .common import ARTIFACT_DIR, existing_default_checkpoint
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate an HTML evaluation report.")
parser.add_argument(
"--checkpoint-dir",
default=existing_default_checkpoint(),
help="Path to the trained model checkpoint directory containing metrics.",
)
parser.add_argument(
"--output-file",
default=str(ARTIFACT_DIR / "eval_report.html"),
help="Output HTML file path.",
)
return parser.parse_args()
def load_metrics(checkpoint_dir: Path) -> dict[str, dict[str, float]]:
metrics = {}
metrics_dir = checkpoint_dir / "metrics"
if not metrics_dir.exists():
return metrics
for split in ["train", "validation", "test"]:
file_path = metrics_dir / f"{split}_metrics.json"
if file_path.exists():
try:
metrics[split] = json.loads(file_path.read_text(encoding="utf-8"))
except Exception as e:
LOGGER.warning(f"Failed to load {file_path}: {e}")
return metrics
def load_predictions(checkpoint_dir: Path) -> list[dict]:
# We look for the predictions file in the artifact directory,
# since eval.py writes it there by default.
pred_file = ARTIFACT_DIR / "sample_predictions.jsonl"
preds = []
if pred_file.exists():
try:
for line in pred_file.read_text(encoding="utf-8").splitlines():
if line.strip():
preds.append(json.loads(line))
except Exception as e:
LOGGER.warning(f"Failed to load predictions from {pred_file}: {e}")
return preds
def generate_html(checkpoint_name: str, metrics: dict, predictions: list) -> str:
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Evaluation Report - {checkpoint_name}</title>
<style>
body {{ font-family: sans-serif; margin: 40px; color: #333; }}
h1 {{ color: #2c3e50; border-bottom: 2px solid #eee; padding-bottom: 10px; }}
h2 {{ color: #34495e; margin-top: 30px; }}
table {{ border-collapse: collapse; width: 100%; margin-bottom: 30px; }}
th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
th {{ background-color: #f8f9fa; font-weight: bold; }}
tr:nth-child(even) {{ background-color: #fcfcfc; }}
.metric-val {{ font-family: monospace; font-size: 1.1em; }}
.pred-box {{ background: #f8f9fa; padding: 15px; border-radius: 5px; margin-bottom: 20px; border-left: 4px solid #3498db; }}
.pred-source {{ font-size: 0.9em; color: #666; margin-bottom: 10px; }}
.pred-ref {{ font-weight: bold; color: #27ae60; margin-bottom: 5px; }}
.pred-out {{ font-weight: bold; color: #8e44ad; }}
.empty-warn {{ color: #e74c3c; font-weight: bold; }}
</style>
</head>
<body>
<h1>Model Evaluation Report</h1>
<p><strong>Checkpoint:</strong> <code>{checkpoint_name}</code></p>
<h2>Overall Metrics</h2>
<table>
<tr>
<th>Split</th>
<th>Loss</th>
<th>ROUGE-1</th>
<th>ROUGE-2</th>
<th>ROUGE-L</th>
<th>BERTScore F1</th>
<th>Avg Gen Length</th>
</tr>
"""
for split in ["train", "validation", "test"]:
m = metrics.get(split, {})
if not m:
continue
prefix = split + "_" if split != "train" else ""
loss = m.get(f"{prefix}loss", m.get("train_loss", "-"))
r1 = m.get(f"{prefix}rouge1", "-")
r2 = m.get(f"{prefix}rouge2", "-")
rl = m.get(f"{prefix}rougeL", "-")
bf1 = m.get(f"{prefix}bertscore_f1", "-")
glen = m.get(f"{prefix}gen_len", "-")
def fmt(v):
return f"{v:.4f}" if isinstance(v, float) else str(v)
html += f"""
<tr>
<td><strong>{split.title()}</strong></td>
<td class="metric-val">{fmt(loss)}</td>
<td class="metric-val">{fmt(r1)}</td>
<td class="metric-val">{fmt(r2)}</td>
<td class="metric-val">{fmt(rl)}</td>
<td class="metric-val">{fmt(bf1)}</td>
<td class="metric-val">{fmt(glen)}</td>
</tr>
"""
html += """
</table>
<h2>Sample Predictions</h2>
"""
if not predictions:
html += "<p>No predictions found.</p>"
else:
for i, p in enumerate(predictions):
empty_tag = " <span class='empty-warn'>(EMPTY PREDICTION)</span>" if p.get("empty_prediction") else ""
html += f"""
<div class="pred-box">
<div class="pred-source"><strong>Source:</strong> {p.get("source", "")}</div>
<div class="pred-ref">Target: {p.get("reference", "")}</div>
<div class="pred-out">Model:{empty_tag} {p.get("prediction", "")}</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
args = parse_args()
if not args.checkpoint_dir:
LOGGER.error("No checkpoint directory provided or found.")
return
checkpoint_path = Path(args.checkpoint_dir)
if not checkpoint_path.exists():
LOGGER.error(f"Checkpoint directory not found: {checkpoint_path}")
return
metrics = load_metrics(checkpoint_path)
predictions = load_predictions(checkpoint_path)
html_content = generate_html(checkpoint_path.name, metrics, predictions)
out_file = Path(args.output_file)
out_file.parent.mkdir(parents=True, exist_ok=True)
out_file.write_text(html_content, encoding="utf-8")
LOGGER.info(f"Evaluation report generated at: {out_file.absolute()}")
if __name__ == "__main__":
main()