microbe-model / scripts /38_eval_lora_checkpoint.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Evaluate a LoRA checkpoint's oxygen classification behavior in detail.
Default inputs match the first all-task Lambda run:
python scripts/38_eval_lora_checkpoint.py
Writes:
artifacts/lora/fold0_oxygen_diagnostics.json
artifacts/lora/fold0_oxygen_diagnostics.md
The CLI path needs the LoRA/embedding dependencies (`torch`, `transformers`,
`peft`). The metric helpers are intentionally dependency-light so they can be
unit-tested without loading ESM-2.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import numpy as np
DEFAULT_CLASSES = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"]
def _round_float(value: float, ndigits: int = 6) -> float:
return round(float(value), ndigits)
def compute_oxygen_diagnostics(
probabilities: np.ndarray,
labels: np.ndarray,
rows: list[dict[str, Any]],
classes: list[str],
*,
top_n_errors: int = 25,
checkpoint: str | None = None,
) -> dict[str, Any]:
"""Compute confusion matrix, per-class scores, and confident mistakes."""
if probabilities.ndim != 2:
raise ValueError("probabilities must be a 2D array")
if probabilities.shape[0] != labels.shape[0] or labels.shape[0] != len(rows):
raise ValueError("probabilities, labels, and rows must have matching lengths")
if probabilities.shape[1] != len(classes):
raise ValueError("probabilities width must match number of classes")
preds = probabilities.argmax(axis=1)
n_classes = len(classes)
confusion = np.zeros((n_classes, n_classes), dtype=int)
for true_idx, pred_idx in zip(labels.astype(int), preds.astype(int), strict=True):
confusion[true_idx, pred_idx] += 1
per_class: dict[str, dict[str, float | int]] = {}
f1_values: list[float] = []
supported_f1_values: list[float] = []
for idx, name in enumerate(classes):
tp = int(confusion[idx, idx])
support = int(confusion[idx, :].sum())
predicted = int(confusion[:, idx].sum())
precision = tp / predicted if predicted else 0.0
recall = tp / support if support else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
f1_values.append(f1)
if support:
supported_f1_values.append(f1)
per_class[name] = {
"precision": _round_float(precision),
"recall": _round_float(recall),
"f1": _round_float(f1),
"support": support,
"predicted": predicted,
}
wrong_predictions: list[dict[str, Any]] = []
for i, (true_idx, pred_idx) in enumerate(zip(labels.astype(int), preds.astype(int), strict=True)):
if true_idx == pred_idx:
continue
pred_prob = float(probabilities[i, pred_idx])
true_prob = float(probabilities[i, true_idx])
row = rows[i]
wrong_predictions.append({
"bacdive_id": row.get("bacdive_id"),
"genome_accession": row.get("genome_accession"),
"group": row.get("group"),
"true": classes[true_idx],
"pred": classes[pred_idx],
"confidence": _round_float(pred_prob),
"true_probability": _round_float(true_prob),
"margin": _round_float(pred_prob - true_prob),
})
wrong_predictions.sort(key=lambda item: (item["confidence"], item["margin"]), reverse=True)
n = int(labels.shape[0])
accuracy = float((preds == labels).mean()) if n else 0.0
out: dict[str, Any] = {
"checkpoint": checkpoint,
"n": n,
"classes": classes,
"accuracy": _round_float(accuracy),
"macro_f1": _round_float(
float(np.mean(supported_f1_values)) if supported_f1_values else 0.0
),
"macro_f1_all_classes": _round_float(float(np.mean(f1_values)) if f1_values else 0.0),
"confusion_matrix": confusion.tolist(),
"per_class": per_class,
"wrong_predictions": wrong_predictions[:top_n_errors],
}
return out
def render_markdown(diagnostics: dict[str, Any]) -> str:
"""Render diagnostics as a compact Markdown report."""
classes = diagnostics["classes"]
lines = [
"# LoRA Oxygen Diagnostics",
"",
f"Checkpoint: `{diagnostics.get('checkpoint')}`",
"",
f"- Labeled validation rows: `{diagnostics['n']}`",
f"- Accuracy: `{diagnostics['accuracy']:.4f}`",
f"- Macro F1 (supported classes): `{diagnostics['macro_f1']:.4f}`",
f"- Macro F1 (all configured classes): `{diagnostics['macro_f1_all_classes']:.4f}`",
"",
"## Per-Class Metrics",
"",
"| Class | Precision | Recall | F1 | Support | Predicted |",
"|---|---:|---:|---:|---:|---:|",
]
for cls in classes:
m = diagnostics["per_class"][cls]
lines.append(
f"| {cls} | {m['precision']:.4f} | {m['recall']:.4f} | "
f"{m['f1']:.4f} | {m['support']} | {m['predicted']} |"
)
lines.extend([
"",
"## Confusion Matrix",
"",
"| True \\ Pred | " + " | ".join(classes) + " |",
"|---" + "|---:" * len(classes) + "|",
])
for cls, row in zip(classes, diagnostics["confusion_matrix"], strict=True):
lines.append("| " + cls + " | " + " | ".join(str(int(v)) for v in row) + " |")
lines.extend([
"",
"## High-Confidence Wrong Predictions",
"",
"| BacDive ID | Genome | Group | True | Pred | Confidence | True Prob. | Margin |",
"|---:|---|---|---|---|---:|---:|---:|",
])
wrong = diagnostics.get("wrong_predictions", [])
if not wrong:
lines.append("| - | - | - | - | - | - | - | - |")
for item in wrong:
lines.append(
f"| {item.get('bacdive_id')} | {item.get('genome_accession')} | "
f"{item.get('group') or ''} | {item['true']} | {item['pred']} | "
f"{item['confidence']:.4f} | {item['true_probability']:.4f} | "
f"{item['margin']:.4f} |"
)
return "\n".join(lines) + "\n"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--checkpoint", default="artifacts/lora/fold0_best.pt")
parser.add_argument("--fold", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--sequences", default="data/marker_sequences.jsonl")
parser.add_argument("--phenotypes", default="data/bacdive_phenotypes.parquet")
parser.add_argument("--catalog", default="data/strain_catalog.parquet")
parser.add_argument("--top-n-errors", type=int, default=25)
parser.add_argument("--output-json", default="artifacts/lora/fold0_oxygen_diagnostics.json")
parser.add_argument("--output-md", default="artifacts/lora/fold0_oxygen_diagnostics.md")
parser.add_argument("--device", default=None, help="Defaults to cuda when available, else cpu.")
return parser.parse_args()
def _evaluate_checkpoint(args: argparse.Namespace) -> dict[str, Any]:
import torch
from microbe_model.train.lora_model import LoraModelConfig, OXYGEN_CLASSES, PhenoLoRAModel
from microbe_model.train.lora_trainer import _build_dataset, _collate, _group_kfold_split
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
checkpoint_path = Path(args.checkpoint)
checkpoint = torch.load(checkpoint_path, map_location=device)
model_cfg = LoraModelConfig(**checkpoint["model_cfg"])
fold = args.fold if args.fold is not None else int(checkpoint.get("train_cfg", {}).get("fold", 0))
rows = _build_dataset(Path(args.sequences), Path(args.phenotypes), Path(args.catalog))
_, val_rows = _group_kfold_split(rows, n_splits=5, fold=fold)
val_rows = [row for row in val_rows if row["label_mask"].get("oxy")]
model = PhenoLoRAModel(model_cfg).to(device)
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.eval()
probs_out: list[np.ndarray] = []
labels_out: list[int] = []
rows_out: list[dict[str, Any]] = []
with torch.no_grad():
for start in range(0, len(val_rows), args.batch_size):
chunk = val_rows[start : start + args.batch_size]
batch = _collate(chunk)
preds = model(batch["genomes"], device=device)
probs = torch.softmax(preds["oxy"], dim=-1).detach().cpu().float().numpy()
labels = batch["labels"]["oxy"].cpu().numpy().astype(int)
masks = batch["label_mask"]["oxy"].cpu().numpy().astype(bool)
if masks.any():
probs_out.append(probs[masks])
labels_out.extend(labels[masks].tolist())
rows_out.extend([row for row, keep in zip(chunk, masks, strict=True) if keep])
if probs_out:
probabilities = np.concatenate(probs_out, axis=0)
else:
probabilities = np.zeros((0, len(OXYGEN_CLASSES)), dtype=float)
labels = np.array(labels_out, dtype=int)
return compute_oxygen_diagnostics(
probabilities,
labels,
rows_out,
list(OXYGEN_CLASSES),
top_n_errors=args.top_n_errors,
checkpoint=str(checkpoint_path),
)
def main() -> None:
args = parse_args()
diagnostics = _evaluate_checkpoint(args)
out_json = Path(args.output_json)
out_md = Path(args.output_md)
out_json.parent.mkdir(parents=True, exist_ok=True)
out_md.parent.mkdir(parents=True, exist_ok=True)
out_json.write_text(json.dumps(diagnostics, indent=2) + "\n")
out_md.write_text(render_markdown(diagnostics))
print(f"Wrote {out_json}")
print(f"Wrote {out_md}")
print(f"oxygen macro_f1={diagnostics['macro_f1']:.4f} accuracy={diagnostics['accuracy']:.4f}")
if __name__ == "__main__":
main()