Spaces:
Running
Running
| """Training loop for PhenoLoRAModel — multi-task, masked, group-K-fold compatible.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from sklearn.metrics import f1_score | |
| from torch import nn, optim | |
| from microbe_model.train.lora_model import ( | |
| CATEGORIES, | |
| OXYGEN_CLASSES, | |
| LoraModelConfig, | |
| PhenoLoRAModel, | |
| masked_multitask_loss, | |
| ) | |
| OXY_LABEL_TO_INT = {c: i for i, c in enumerate(OXYGEN_CLASSES)} | |
| class TrainConfig: | |
| fold: int = 0 | |
| epochs: int = 3 | |
| batch_size: int = 2 | |
| grad_accum: int = 8 | |
| lora_lr: float = 1e-4 | |
| head_lr: float = 1e-3 | |
| weight_decay: float = 0.01 | |
| warmup_frac: float = 0.05 | |
| bf16: bool = True | |
| max_proteins_per_category: int = 16 | |
| save_dir: str = "artifacts/lora" | |
| grad_clip: float = 1.0 | |
| temp_weight: float = 1.0 | |
| ph_weight: float = 1.0 | |
| salt_weight: float = 1.0 | |
| oxy_weight: float = 1.0 | |
| oxy_class_weights: tuple[float, ...] | None = None | |
| def _build_dataset( | |
| sequences_path: Path, | |
| phenotypes_path: Path, | |
| catalog_path: Path, # kept for symmetry; only used if pheno lacks family/genus | |
| ) -> list[dict]: | |
| """Join marker sequences with phenotype labels + family groups → list of records.""" | |
| pheno = pd.read_parquet(phenotypes_path) | |
| if "family" not in pheno.columns or "genus" not in pheno.columns: | |
| catalog = pd.read_parquet(catalog_path) | |
| keep = [c for c in ("family", "genus", "species") if c not in pheno.columns] | |
| pheno = pheno.merge( | |
| catalog[["bacdive_id", *keep]].drop_duplicates("bacdive_id"), | |
| on="bacdive_id", | |
| how="left", | |
| ) | |
| rows: list[dict] = [] | |
| with open(sequences_path) as fh: | |
| for line in fh: | |
| try: | |
| r = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| bacdive_id = int(r["bacdive_id"]) | |
| sub = pheno[pheno["bacdive_id"] == bacdive_id] | |
| if sub.empty: | |
| continue | |
| p_row = sub.iloc[0] | |
| def _val(col: str): | |
| v = p_row.get(col) | |
| if pd.isna(v): | |
| return None, 0 | |
| return v, 1 | |
| temp_v, temp_m = _val("optimal_temperature_c") | |
| ph_v, ph_m = _val("optimal_ph") | |
| salt_v, salt_m = _val("salt_tolerance_pct") | |
| oxy_raw, oxy_m = _val("oxygen_requirement") | |
| if oxy_m and oxy_raw not in OXY_LABEL_TO_INT: | |
| oxy_m = 0 | |
| oxy_raw = None | |
| rows.append({ | |
| "bacdive_id": bacdive_id, | |
| "genome_accession": r["genome_accession"], | |
| "by_category": r["by_category"], | |
| "group": ( | |
| p_row.get("family") | |
| or p_row.get("genus") | |
| or (p_row.get("species") or "__unk__").split()[0] | |
| ), | |
| "labels": { | |
| "temp": float(temp_v) if temp_m else 0.0, | |
| "ph": float(ph_v) if ph_m else 0.0, | |
| "salt": float(salt_v) if salt_m else 0.0, | |
| "oxy": OXY_LABEL_TO_INT[oxy_raw] if oxy_m else 0, | |
| }, | |
| "label_mask": { | |
| "temp": temp_m, "ph": ph_m, "salt": salt_m, "oxy": oxy_m, | |
| }, | |
| }) | |
| return rows | |
| def _group_kfold_split(rows: list[dict], n_splits: int, fold: int): | |
| from sklearn.model_selection import GroupKFold | |
| groups = [r["group"] for r in rows] | |
| indices = np.arange(len(rows)) | |
| gkf = GroupKFold(n_splits=n_splits) | |
| splits = list(gkf.split(indices, groups=groups)) | |
| train_idx, val_idx = splits[fold] | |
| train = [rows[i] for i in train_idx] | |
| val = [rows[i] for i in val_idx] | |
| return train, val | |
| def _collate(batch: list[dict]) -> dict: | |
| genomes = [r["by_category"] for r in batch] | |
| labels = { | |
| k: torch.tensor([r["labels"][k] for r in batch], dtype=torch.float32) | |
| for k in ("temp", "ph", "salt") | |
| } | |
| labels["oxy"] = torch.tensor([r["labels"]["oxy"] for r in batch], dtype=torch.long) | |
| label_mask = { | |
| k: torch.tensor([r["label_mask"][k] for r in batch], dtype=torch.float32) | |
| for k in ("temp", "ph", "salt", "oxy") | |
| } | |
| return {"genomes": genomes, "labels": labels, "label_mask": label_mask} | |
| def _iter_batches(rows: list[dict], batch_size: int, shuffle: bool): | |
| indices = list(range(len(rows))) | |
| if shuffle: | |
| import random | |
| random.shuffle(indices) | |
| for i in range(0, len(indices), batch_size): | |
| chunk = [rows[j] for j in indices[i : i + batch_size]] | |
| yield _collate(chunk) | |
| def run_validation(model: PhenoLoRAModel, val_rows: list[dict], device: torch.device, batch_size: int) -> dict: | |
| """Compute validation metrics in inference mode (no grad).""" | |
| model.eval() | |
| pred_lists: dict[str, list] = {k: [] for k in ("temp", "ph", "salt", "oxy")} | |
| label_lists: dict[str, list] = {k: [] for k in ("temp", "ph", "salt", "oxy")} | |
| mask_lists: dict[str, list] = {k: [] for k in ("temp", "ph", "salt", "oxy")} | |
| for batch in _iter_batches(val_rows, batch_size, shuffle=False): | |
| preds = model(batch["genomes"], device=device) | |
| for k in ("temp", "ph", "salt"): | |
| pred_lists[k].append(preds[k].cpu().float().numpy()) | |
| label_lists[k].append(batch["labels"][k].cpu().numpy()) | |
| mask_lists[k].append(batch["label_mask"][k].cpu().numpy()) | |
| pred_lists["oxy"].append(preds["oxy"].argmax(dim=-1).cpu().numpy()) | |
| label_lists["oxy"].append(batch["labels"]["oxy"].cpu().numpy()) | |
| mask_lists["oxy"].append(batch["label_mask"]["oxy"].cpu().numpy()) | |
| out: dict = {} | |
| for k in ("temp", "ph", "salt"): | |
| preds_arr = np.concatenate(pred_lists[k]) | |
| labels_arr = np.concatenate(label_lists[k]) | |
| masks_arr = np.concatenate(mask_lists[k]).astype(bool) | |
| if masks_arr.sum() == 0: | |
| out[k] = {"mae": None, "n": 0} | |
| continue | |
| mae = float(np.mean(np.abs(preds_arr[masks_arr] - labels_arr[masks_arr]))) | |
| out[k] = {"mae": mae, "n": int(masks_arr.sum())} | |
| preds_oxy = np.concatenate(pred_lists["oxy"]) | |
| labels_oxy = np.concatenate(label_lists["oxy"]) | |
| masks_oxy = np.concatenate(mask_lists["oxy"]).astype(bool) | |
| if masks_oxy.sum() == 0: | |
| out["oxy"] = {"f1_macro": None, "n": 0} | |
| else: | |
| f1 = float(f1_score(labels_oxy[masks_oxy], preds_oxy[masks_oxy], average="macro")) | |
| out["oxy"] = {"f1_macro": f1, "n": int(masks_oxy.sum())} | |
| return out | |
| def train_lora( | |
| *, | |
| model_cfg: LoraModelConfig, | |
| train_cfg: TrainConfig, | |
| sequences_path: Path, | |
| phenotypes_path: Path, | |
| catalog_path: Path, | |
| device: torch.device | None = None, | |
| ) -> dict: | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[lora] device = {device}", flush=True) | |
| rows = _build_dataset(sequences_path, phenotypes_path, catalog_path) | |
| print(f"[lora] loaded {len(rows):,} records with sequences + labels", flush=True) | |
| train_rows, val_rows = _group_kfold_split(rows, n_splits=5, fold=train_cfg.fold) | |
| print(f"[lora] fold {train_cfg.fold}: {len(train_rows):,} train / {len(val_rows):,} val", | |
| flush=True) | |
| model = PhenoLoRAModel(model_cfg).to(device) | |
| trainable, total = model.trainable_param_count() | |
| print(f"[lora] trainable params: {trainable:,} / total: {total:,} " | |
| f"({100 * trainable / total:.2f}%)", flush=True) | |
| lora_params: list[nn.Parameter] = [] | |
| head_params: list[nn.Parameter] = [] | |
| for name, p in model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if name.startswith("heads."): | |
| head_params.append(p) | |
| else: | |
| lora_params.append(p) | |
| optimizer = optim.AdamW( | |
| [ | |
| {"params": lora_params, "lr": train_cfg.lora_lr}, | |
| {"params": head_params, "lr": train_cfg.head_lr}, | |
| ], | |
| weight_decay=train_cfg.weight_decay, | |
| ) | |
| n_train_batches = math.ceil(len(train_rows) / train_cfg.batch_size) | |
| total_steps = max(1, n_train_batches * train_cfg.epochs // max(train_cfg.grad_accum, 1)) | |
| warmup_steps = max(1, int(total_steps * train_cfg.warmup_frac)) | |
| scheduler = optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| lr_lambda=lambda step: ( | |
| step / max(warmup_steps, 1) | |
| if step < warmup_steps | |
| else 0.5 * (1.0 + math.cos(math.pi * (step - warmup_steps) / max(total_steps - warmup_steps, 1))) | |
| ), | |
| ) | |
| save_dir = Path(train_cfg.save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| autocast_dtype = torch.bfloat16 if train_cfg.bf16 else torch.float32 | |
| history: list[dict] = [] | |
| best = {"epoch": -1, "val": None, "score": float("inf")} | |
| global_step = 0 | |
| for epoch in range(train_cfg.epochs): | |
| model.train() | |
| t0 = time.time() | |
| running_loss = 0.0 | |
| running_n = 0 | |
| for batch_idx, batch in enumerate(_iter_batches(train_rows, train_cfg.batch_size, shuffle=True)): | |
| with torch.autocast(device_type=device.type, dtype=autocast_dtype, enabled=(device.type == "cuda")): | |
| preds = model(batch["genomes"], device=device) | |
| loss, per_target = masked_multitask_loss( | |
| preds, | |
| {k: v.to(device) for k, v in batch["labels"].items()}, | |
| {k: v.to(device) for k, v in batch["label_mask"].items()}, | |
| target_weights={ | |
| "temp": train_cfg.temp_weight, | |
| "ph": train_cfg.ph_weight, | |
| "salt": train_cfg.salt_weight, | |
| "oxy": train_cfg.oxy_weight, | |
| }, | |
| oxy_class_weights=train_cfg.oxy_class_weights, | |
| ) | |
| loss = loss / max(train_cfg.grad_accum, 1) | |
| loss.backward() | |
| running_loss += float(loss.detach().cpu()) * max(train_cfg.grad_accum, 1) | |
| running_n += 1 | |
| if (batch_idx + 1) % train_cfg.grad_accum == 0: | |
| nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], | |
| max_norm=train_cfg.grad_clip, | |
| ) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step += 1 | |
| if global_step % 50 == 0: | |
| print(f" ep {epoch+1} step {global_step}: " | |
| f"loss={running_loss/max(running_n,1):.4f} " | |
| f"lr_lora={scheduler.get_last_lr()[0]:.2e}", | |
| flush=True) | |
| val_metrics = run_validation(model, val_rows, device, train_cfg.batch_size) | |
| elapsed = time.time() - t0 | |
| record = { | |
| "epoch": epoch + 1, | |
| "train_loss": running_loss / max(running_n, 1), | |
| "val": val_metrics, | |
| "elapsed_s": elapsed, | |
| } | |
| history.append(record) | |
| print(f"[lora] epoch {epoch+1} done in {elapsed:.0f}s val={val_metrics}", flush=True) | |
| score = sum( | |
| (val_metrics[k]["mae"] or 0.0) | |
| for k in ("temp", "ph", "salt") | |
| if val_metrics[k]["mae"] is not None | |
| ) - (val_metrics["oxy"]["f1_macro"] or 0.0) | |
| if score < best["score"]: | |
| best = {"epoch": epoch + 1, "val": val_metrics, "score": score} | |
| torch.save( | |
| { | |
| "epoch": epoch + 1, | |
| "model_cfg": asdict(model_cfg), | |
| "train_cfg": asdict(train_cfg), | |
| "state_dict": {k: v for k, v in model.state_dict().items() if "lora" in k.lower() or k.startswith("heads.")}, | |
| }, | |
| save_dir / f"fold{train_cfg.fold}_best.pt", | |
| ) | |
| results = { | |
| "model_cfg": asdict(model_cfg), | |
| "train_cfg": asdict(train_cfg), | |
| "history": history, | |
| "best": best, | |
| } | |
| out_json = save_dir / f"fold{train_cfg.fold}_results.json" | |
| with open(out_json, "w") as fh: | |
| json.dump(results, fh, indent=2) | |
| print(f"[lora] wrote {out_json}", flush=True) | |
| return results | |