microbe-model / scripts /lambda_train_lora.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Run LoRA fine-tuning on a prepared Lambda Labs instance or local GPU.
This is the reproducible replacement for the ad hoc one-off script used for the
first Lambda A100 run. It assumes the repo and required data files are present
on the machine where the command is executed.
Example:
python -u scripts/lambda_train_lora.py --fold 0 --epochs 1
python -u scripts/lambda_train_lora.py --fold 0 --epochs 1 --target-preset oxygen
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from microbe_model.train.lora_model import LoraModelConfig
from microbe_model.train.lora_trainer import OXY_LABEL_TO_INT, TrainConfig, train_lora
def _parse_oxy_class_weights(raw: str | None) -> tuple[float, ...] | None:
if raw is None:
return None
weights = tuple(float(part) for part in raw.split(","))
if len(weights) != len(OXY_LABEL_TO_INT):
classes = ", ".join(OXY_LABEL_TO_INT)
raise argparse.ArgumentTypeError(
f"--oxy-class-weights must provide {len(OXY_LABEL_TO_INT)} values for: {classes}"
)
return weights
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--esm-model", default="facebook/esm2_t12_35M_UR50D")
parser.add_argument("--lora-r", type=int, default=8)
parser.add_argument("--lora-lr", type=float, default=1e-4)
parser.add_argument("--head-lr", type=float, default=1e-3)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--grad-accum", type=int, default=8)
parser.add_argument("--save-dir", default="artifacts/lora")
parser.add_argument("--sequences", default="data/marker_sequences.jsonl")
parser.add_argument("--phenotypes", default="data/bacdive_phenotypes.parquet")
parser.add_argument("--catalog", default="data/strain_catalog.parquet")
parser.add_argument(
"--target-preset",
choices=("all", "oxygen"),
default="all",
help="Use all task losses, or train only the oxygen loss while still reporting all metrics.",
)
parser.add_argument(
"--oxy-class-weights",
type=_parse_oxy_class_weights,
default=None,
help=(
"Comma-separated oxygen class weights in order "
"aerobe,anaerobe,facultative_anaerobe,microaerobe. Example: 1,1.5,1,1"
),
)
parser.add_argument("--no-bf16", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
device_name = torch.cuda.get_device_name(0)
else:
device_name = "cpu"
weights = {"temp": 1.0, "ph": 1.0, "salt": 1.0, "oxy": 1.0}
if args.target_preset == "oxygen":
weights = {"temp": 0.0, "ph": 0.0, "salt": 0.0, "oxy": 1.0}
print(
f"[lambda-lora] fold={args.fold} epochs={args.epochs} "
f"model={args.esm_model} preset={args.target_preset}",
flush=True,
)
print(f"[lambda-lora] device={device_name}", flush=True)
print(f"[lambda-lora] target_weights={weights}", flush=True)
print(f"[lambda-lora] oxy_class_weights={args.oxy_class_weights}", flush=True)
results = train_lora(
model_cfg=LoraModelConfig(esm_model_name=args.esm_model, lora_r=args.lora_r),
train_cfg=TrainConfig(
fold=args.fold,
epochs=args.epochs,
batch_size=args.batch_size,
grad_accum=args.grad_accum,
lora_lr=args.lora_lr,
head_lr=args.head_lr,
save_dir=args.save_dir,
bf16=not args.no_bf16,
temp_weight=weights["temp"],
ph_weight=weights["ph"],
salt_weight=weights["salt"],
oxy_weight=weights["oxy"],
oxy_class_weights=args.oxy_class_weights,
),
sequences_path=Path(args.sequences),
phenotypes_path=Path(args.phenotypes),
catalog_path=Path(args.catalog),
device=device,
)
print(f"[lambda-lora] done best={results.get('best')}", flush=True)
if __name__ == "__main__":
main()