File size: 4,316 Bytes
0ed74db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Run LoRA fine-tuning on a prepared Lambda Labs instance or local GPU.

This is the reproducible replacement for the ad hoc one-off script used for the
first Lambda A100 run. It assumes the repo and required data files are present
on the machine where the command is executed.

Example:
    python -u scripts/lambda_train_lora.py --fold 0 --epochs 1
    python -u scripts/lambda_train_lora.py --fold 0 --epochs 1 --target-preset oxygen
"""
from __future__ import annotations

import argparse
from pathlib import Path

import torch

from microbe_model.train.lora_model import LoraModelConfig
from microbe_model.train.lora_trainer import OXY_LABEL_TO_INT, TrainConfig, train_lora


def _parse_oxy_class_weights(raw: str | None) -> tuple[float, ...] | None:
    if raw is None:
        return None
    weights = tuple(float(part) for part in raw.split(","))
    if len(weights) != len(OXY_LABEL_TO_INT):
        classes = ", ".join(OXY_LABEL_TO_INT)
        raise argparse.ArgumentTypeError(
            f"--oxy-class-weights must provide {len(OXY_LABEL_TO_INT)} values for: {classes}"
        )
    return weights


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--fold", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--esm-model", default="facebook/esm2_t12_35M_UR50D")
    parser.add_argument("--lora-r", type=int, default=8)
    parser.add_argument("--lora-lr", type=float, default=1e-4)
    parser.add_argument("--head-lr", type=float, default=1e-3)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--grad-accum", type=int, default=8)
    parser.add_argument("--save-dir", default="artifacts/lora")
    parser.add_argument("--sequences", default="data/marker_sequences.jsonl")
    parser.add_argument("--phenotypes", default="data/bacdive_phenotypes.parquet")
    parser.add_argument("--catalog", default="data/strain_catalog.parquet")
    parser.add_argument(
        "--target-preset",
        choices=("all", "oxygen"),
        default="all",
        help="Use all task losses, or train only the oxygen loss while still reporting all metrics.",
    )
    parser.add_argument(
        "--oxy-class-weights",
        type=_parse_oxy_class_weights,
        default=None,
        help=(
            "Comma-separated oxygen class weights in order "
            "aerobe,anaerobe,facultative_anaerobe,microaerobe. Example: 1,1.5,1,1"
        ),
    )
    parser.add_argument("--no-bf16", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        device_name = torch.cuda.get_device_name(0)
    else:
        device_name = "cpu"

    weights = {"temp": 1.0, "ph": 1.0, "salt": 1.0, "oxy": 1.0}
    if args.target_preset == "oxygen":
        weights = {"temp": 0.0, "ph": 0.0, "salt": 0.0, "oxy": 1.0}

    print(
        f"[lambda-lora] fold={args.fold} epochs={args.epochs} "
        f"model={args.esm_model} preset={args.target_preset}",
        flush=True,
    )
    print(f"[lambda-lora] device={device_name}", flush=True)
    print(f"[lambda-lora] target_weights={weights}", flush=True)
    print(f"[lambda-lora] oxy_class_weights={args.oxy_class_weights}", flush=True)

    results = train_lora(
        model_cfg=LoraModelConfig(esm_model_name=args.esm_model, lora_r=args.lora_r),
        train_cfg=TrainConfig(
            fold=args.fold,
            epochs=args.epochs,
            batch_size=args.batch_size,
            grad_accum=args.grad_accum,
            lora_lr=args.lora_lr,
            head_lr=args.head_lr,
            save_dir=args.save_dir,
            bf16=not args.no_bf16,
            temp_weight=weights["temp"],
            ph_weight=weights["ph"],
            salt_weight=weights["salt"],
            oxy_weight=weights["oxy"],
            oxy_class_weights=args.oxy_class_weights,
        ),
        sequences_path=Path(args.sequences),
        phenotypes_path=Path(args.phenotypes),
        catalog_path=Path(args.catalog),
        device=device,
    )
    print(f"[lambda-lora] done best={results.get('best')}", flush=True)


if __name__ == "__main__":
    main()