"""Evaluate a LoRA checkpoint's oxygen classification behavior in detail. Default inputs match the first all-task Lambda run: python scripts/38_eval_lora_checkpoint.py Writes: artifacts/lora/fold0_oxygen_diagnostics.json artifacts/lora/fold0_oxygen_diagnostics.md The CLI path needs the LoRA/embedding dependencies (`torch`, `transformers`, `peft`). The metric helpers are intentionally dependency-light so they can be unit-tested without loading ESM-2. """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Any import numpy as np DEFAULT_CLASSES = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"] def _round_float(value: float, ndigits: int = 6) -> float: return round(float(value), ndigits) def compute_oxygen_diagnostics( probabilities: np.ndarray, labels: np.ndarray, rows: list[dict[str, Any]], classes: list[str], *, top_n_errors: int = 25, checkpoint: str | None = None, ) -> dict[str, Any]: """Compute confusion matrix, per-class scores, and confident mistakes.""" if probabilities.ndim != 2: raise ValueError("probabilities must be a 2D array") if probabilities.shape[0] != labels.shape[0] or labels.shape[0] != len(rows): raise ValueError("probabilities, labels, and rows must have matching lengths") if probabilities.shape[1] != len(classes): raise ValueError("probabilities width must match number of classes") preds = probabilities.argmax(axis=1) n_classes = len(classes) confusion = np.zeros((n_classes, n_classes), dtype=int) for true_idx, pred_idx in zip(labels.astype(int), preds.astype(int), strict=True): confusion[true_idx, pred_idx] += 1 per_class: dict[str, dict[str, float | int]] = {} f1_values: list[float] = [] supported_f1_values: list[float] = [] for idx, name in enumerate(classes): tp = int(confusion[idx, idx]) support = int(confusion[idx, :].sum()) predicted = int(confusion[:, idx].sum()) precision = tp / predicted if predicted else 0.0 recall = tp / support if support else 0.0 f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0 f1_values.append(f1) if support: supported_f1_values.append(f1) per_class[name] = { "precision": _round_float(precision), "recall": _round_float(recall), "f1": _round_float(f1), "support": support, "predicted": predicted, } wrong_predictions: list[dict[str, Any]] = [] for i, (true_idx, pred_idx) in enumerate(zip(labels.astype(int), preds.astype(int), strict=True)): if true_idx == pred_idx: continue pred_prob = float(probabilities[i, pred_idx]) true_prob = float(probabilities[i, true_idx]) row = rows[i] wrong_predictions.append({ "bacdive_id": row.get("bacdive_id"), "genome_accession": row.get("genome_accession"), "group": row.get("group"), "true": classes[true_idx], "pred": classes[pred_idx], "confidence": _round_float(pred_prob), "true_probability": _round_float(true_prob), "margin": _round_float(pred_prob - true_prob), }) wrong_predictions.sort(key=lambda item: (item["confidence"], item["margin"]), reverse=True) n = int(labels.shape[0]) accuracy = float((preds == labels).mean()) if n else 0.0 out: dict[str, Any] = { "checkpoint": checkpoint, "n": n, "classes": classes, "accuracy": _round_float(accuracy), "macro_f1": _round_float( float(np.mean(supported_f1_values)) if supported_f1_values else 0.0 ), "macro_f1_all_classes": _round_float(float(np.mean(f1_values)) if f1_values else 0.0), "confusion_matrix": confusion.tolist(), "per_class": per_class, "wrong_predictions": wrong_predictions[:top_n_errors], } return out def render_markdown(diagnostics: dict[str, Any]) -> str: """Render diagnostics as a compact Markdown report.""" classes = diagnostics["classes"] lines = [ "# LoRA Oxygen Diagnostics", "", f"Checkpoint: `{diagnostics.get('checkpoint')}`", "", f"- Labeled validation rows: `{diagnostics['n']}`", f"- Accuracy: `{diagnostics['accuracy']:.4f}`", f"- Macro F1 (supported classes): `{diagnostics['macro_f1']:.4f}`", f"- Macro F1 (all configured classes): `{diagnostics['macro_f1_all_classes']:.4f}`", "", "## Per-Class Metrics", "", "| Class | Precision | Recall | F1 | Support | Predicted |", "|---|---:|---:|---:|---:|---:|", ] for cls in classes: m = diagnostics["per_class"][cls] lines.append( f"| {cls} | {m['precision']:.4f} | {m['recall']:.4f} | " f"{m['f1']:.4f} | {m['support']} | {m['predicted']} |" ) lines.extend([ "", "## Confusion Matrix", "", "| True \\ Pred | " + " | ".join(classes) + " |", "|---" + "|---:" * len(classes) + "|", ]) for cls, row in zip(classes, diagnostics["confusion_matrix"], strict=True): lines.append("| " + cls + " | " + " | ".join(str(int(v)) for v in row) + " |") lines.extend([ "", "## High-Confidence Wrong Predictions", "", "| BacDive ID | Genome | Group | True | Pred | Confidence | True Prob. | Margin |", "|---:|---|---|---|---|---:|---:|---:|", ]) wrong = diagnostics.get("wrong_predictions", []) if not wrong: lines.append("| - | - | - | - | - | - | - | - |") for item in wrong: lines.append( f"| {item.get('bacdive_id')} | {item.get('genome_accession')} | " f"{item.get('group') or ''} | {item['true']} | {item['pred']} | " f"{item['confidence']:.4f} | {item['true_probability']:.4f} | " f"{item['margin']:.4f} |" ) return "\n".join(lines) + "\n" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--checkpoint", default="artifacts/lora/fold0_best.pt") parser.add_argument("--fold", type=int, default=None) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--sequences", default="data/marker_sequences.jsonl") parser.add_argument("--phenotypes", default="data/bacdive_phenotypes.parquet") parser.add_argument("--catalog", default="data/strain_catalog.parquet") parser.add_argument("--top-n-errors", type=int, default=25) parser.add_argument("--output-json", default="artifacts/lora/fold0_oxygen_diagnostics.json") parser.add_argument("--output-md", default="artifacts/lora/fold0_oxygen_diagnostics.md") parser.add_argument("--device", default=None, help="Defaults to cuda when available, else cpu.") return parser.parse_args() def _evaluate_checkpoint(args: argparse.Namespace) -> dict[str, Any]: import torch from microbe_model.train.lora_model import LoraModelConfig, OXYGEN_CLASSES, PhenoLoRAModel from microbe_model.train.lora_trainer import _build_dataset, _collate, _group_kfold_split device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) checkpoint_path = Path(args.checkpoint) checkpoint = torch.load(checkpoint_path, map_location=device) model_cfg = LoraModelConfig(**checkpoint["model_cfg"]) fold = args.fold if args.fold is not None else int(checkpoint.get("train_cfg", {}).get("fold", 0)) rows = _build_dataset(Path(args.sequences), Path(args.phenotypes), Path(args.catalog)) _, val_rows = _group_kfold_split(rows, n_splits=5, fold=fold) val_rows = [row for row in val_rows if row["label_mask"].get("oxy")] model = PhenoLoRAModel(model_cfg).to(device) model.load_state_dict(checkpoint["state_dict"], strict=False) model.eval() probs_out: list[np.ndarray] = [] labels_out: list[int] = [] rows_out: list[dict[str, Any]] = [] with torch.no_grad(): for start in range(0, len(val_rows), args.batch_size): chunk = val_rows[start : start + args.batch_size] batch = _collate(chunk) preds = model(batch["genomes"], device=device) probs = torch.softmax(preds["oxy"], dim=-1).detach().cpu().float().numpy() labels = batch["labels"]["oxy"].cpu().numpy().astype(int) masks = batch["label_mask"]["oxy"].cpu().numpy().astype(bool) if masks.any(): probs_out.append(probs[masks]) labels_out.extend(labels[masks].tolist()) rows_out.extend([row for row, keep in zip(chunk, masks, strict=True) if keep]) if probs_out: probabilities = np.concatenate(probs_out, axis=0) else: probabilities = np.zeros((0, len(OXYGEN_CLASSES)), dtype=float) labels = np.array(labels_out, dtype=int) return compute_oxygen_diagnostics( probabilities, labels, rows_out, list(OXYGEN_CLASSES), top_n_errors=args.top_n_errors, checkpoint=str(checkpoint_path), ) def main() -> None: args = parse_args() diagnostics = _evaluate_checkpoint(args) out_json = Path(args.output_json) out_md = Path(args.output_md) out_json.parent.mkdir(parents=True, exist_ok=True) out_md.parent.mkdir(parents=True, exist_ok=True) out_json.write_text(json.dumps(diagnostics, indent=2) + "\n") out_md.write_text(render_markdown(diagnostics)) print(f"Wrote {out_json}") print(f"Wrote {out_md}") print(f"oxygen macro_f1={diagnostics['macro_f1']:.4f} accuracy={diagnostics['accuracy']:.4f}") if __name__ == "__main__": main()