CodonTranslator / train.py
alegendaryfish's picture
Align public training codebase with local training setup
239fef8 verified
#!/usr/bin/env python
"""
Minimal, honest training script for CodonTranslator on CSV or Parquet data.
- Species conditioning: REQUIRED (precomputed embeddings)
- Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein.
- Global capacity is controlled by --max_length (prefix + start + codon).
"""
import os
import math
import argparse
import logging
import torch
from src import CodonTranslatorModel, CodonTokenizer, Trainer, TrainingArguments
from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("codontranslator.train")
def _describe_sdp_kernels() -> None:
# Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch
flash = None; mem_eff = None; mathk = None
if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'):
tbc = torch.backends.cuda
if hasattr(tbc, 'flash_sdp_enabled'):
flash = tbc.flash_sdp_enabled()
if hasattr(tbc, 'mem_efficient_sdp_enabled'):
mem_eff = tbc.mem_efficient_sdp_enabled()
if hasattr(tbc, 'math_sdp_enabled'):
mathk = tbc.math_sdp_enabled()
logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}")
def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
w_bytes = 2 if (bf16 or fp16) else 4
opt_bytes = 8 # Adam moments in FP32
weights_gb = total * w_bytes / (1024**3)
opt_gb = trainable * opt_bytes / (1024**3)
logger.info(
f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)"
)
def _speed_toggles():
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"):
torch.backends.cudnn.benchmark = True
def parse_args():
p = argparse.ArgumentParser(description="Train CodonTranslator on CSV or Parquet data")
# Data (CSV path or Parquet glob/dir)
p.add_argument("--train_data", type=str, default="random_sample_1000.csv",
help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)")
p.add_argument("--val_data", type=str, default=None,
help="Validation data: CSV file or Parquet glob/dir")
p.add_argument("--embeddings_dir", type=str, default="embeddings",
help="Dir with species embeddings (species_vocab.json, *.bin/memmap)")
# Model / capacity
p.add_argument("--hidden", type=int, default=750, help="Model hidden size")
p.add_argument("--layers", type=int, default=20, help="Number of transformer layers")
p.add_argument("--heads", type=int, default=15, help="Number of attention heads")
p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'")
p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)")
p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)")
p.add_argument("--max_length", type=int, default=2048,
help="Global max length (prefix + start + codon)")
p.add_argument("--max_species_prefix", type=int, default=0,
help="Cap species prefix tokens (0 = uncapped)")
p.add_argument("--max_protein_prefix", type=int, default=1024,
help="Cap protein prefix tokens (0 = uncapped)")
# Protein conditioning: always enabled (ESM-C)
# Training
p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints")
p.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size")
p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size")
p.add_argument("--workers", type=int, default=4, help="DataLoader workers")
p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
p.add_argument("--train_shuffle_buffer", type=int, default=0,
help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)")
p.add_argument("--val_shuffle_buffer", type=int, default=0,
help="Streaming shuffle buffer for validation (0 disables)")
p.add_argument("--csv_chunksize", type=int, default=200_000,
help="Pandas read_csv chunksize for CSV inputs")
# Optim / schedule
p.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)")
p.add_argument(
"--lr_scheduler",
type=str,
choices=["linear", "cosine", "constant"],
default="linear",
help="LR schedule applied after warmup; 'linear' decays to zero by the end of training",
)
p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")
p.add_argument("--adam_beta1", type=float, default=0.9,
help="Adam beta1 (momentum) coefficient")
p.add_argument("--adam_beta2", type=float, default=0.95,
help="Adam beta2 (squared-gradient) coefficient")
p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)")
p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)")
p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints")
p.add_argument("--ckpt_recent_window_steps", type=int, default=0,
help="If >0, keep finer-grained checkpoints within this many recent steps")
p.add_argument("--ckpt_recent_interval", type=int, default=0,
help="Retention interval inside the recent checkpoint window (0 disables custom retention)")
p.add_argument("--ckpt_archive_interval", type=int, default=0,
help="Retention interval for checkpoints older than the recent window (0 prunes them)")
p.add_argument("--max_steps", type=int, default=-1,
help="Total training steps. REQUIRED for streaming (IterableDataset)")
p.add_argument("--steps_per_epoch", type=int, default=0,
help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0")
p.add_argument("--max_grad_norm", type=float, default=1.0,
help="Clip gradients to this global L2 norm; set <=0 to disable")
p.add_argument("--override_lr_on_resume", action="store_true",
help="Do not restore LR/optimizer state on resume (keep current lr)")
# Resume
p.add_argument("--resume_from", type=str, default=None,
help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir")
# Evaluation scheduling
p.add_argument("--eval_interval", type=int, default=0,
help="Run evaluation every N optimizer steps on --val_data (0 disables)")
p.add_argument("--eval_steps", type=int, default=5000,
help="For streaming eval datasets: limit to this many batches (0 = full eval)")
# Hardware / precision
p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision")
p.add_argument("--fp16", action="store_true", help="float16 mixed precision")
p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding")
p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing")
return p.parse_args()
def main():
args = parse_args()
_speed_toggles()
if args.device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA not available; switching to CPU")
args.device = "cpu"
# Tokenizer
tok = CodonTokenizer()
# Ensure output dir exists and persist vocab.json (used by sampler)
os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)
tok.save_vocabulary(args.output_dir)
# Data first — we need Ds for species embeddings
train_loader, val_loader, species_store = create_precomputed_dataloaders(
train_path=args.train_data,
val_path=args.val_data,
embeddings_dir=args.embeddings_dir,
tokenizer=tok,
batch_size=args.batch_size,
num_workers=args.workers,
species_pooling="sequence", # prefer variable-length token sequence if available
csv_chunksize=int(args.csv_chunksize),
train_shuffle_buffer=int(args.train_shuffle_buffer),
val_shuffle_buffer=int(args.val_shuffle_buffer),
)
# Estimate steps_per_epoch for streaming schedule shaping if not provided
steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0)
total_rows = 0
paths: list[str] = []
if steps_per_epoch <= 0 and int(args.max_steps) < 0:
def _expand_paths(maybe: str | list[str]) -> list[str]:
import glob as _glob
from pathlib import Path as _Path
paths: list[str] = []
if isinstance(maybe, str):
p = _Path(maybe)
if p.is_dir():
paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
else:
paths = sorted(_glob.glob(str(p)))
else:
for it in maybe:
paths.extend(_expand_paths(it))
# de-dup
seen = set(); out = []
for x in paths:
if x not in seen:
out.append(x); seen.add(x)
return out
paths = _expand_paths(args.train_data)
if paths:
try:
import pyarrow.parquet as pq
for fp in paths:
if fp.lower().endswith((".parquet", ".parq")):
pf = pq.ParquetFile(fp)
md = pf.metadata
if md is not None:
total_rows += int(md.num_rows)
except Exception:
# Fallback: keep steps_per_epoch at 0 if pyarrow not available
total_rows = 0
if total_rows > 0:
world = int(os.environ.get("WORLD_SIZE", "1"))
ga = max(1, int(getattr(args, "grad_accum", 1)))
denom = max(1, int(args.batch_size) * max(1, world) * ga)
steps_per_epoch = max(1, math.ceil(total_rows / denom))
logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}")
world = int(os.environ.get("WORLD_SIZE", "1"))
grad_accum = max(1, int(getattr(args, "grad_accum", 1)))
effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum
logger.info(
"Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s",
args.batch_size,
args.eval_batch_size,
world,
grad_accum,
effective_global_batch,
)
# Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks)
esm_dev = "cpu"
if args.device == "cuda" and torch.cuda.is_available():
lr = int(os.environ.get("LOCAL_RANK", "0"))
esm_dev = f"cuda:{lr}"
# Model — species is always on; protein defaults to ON (can be disabled with --no_protein)
model = CodonTranslatorModel(
vocab_size=tok.vocab_size,
num_special_tokens=tok.num_special_tokens,
special_ids=tok.special_ids,
hidden_size=args.hidden,
num_layers=args.layers,
num_heads=args.heads,
mlp_ratio=float(args.mlp_ratio),
max_position_embeddings=args.max_length,
prepend_species=True,
prepend_protein=True,
esm_model_name="esmc_300m",
esm_device=esm_dev,
max_protein_prefix=int(args.max_protein_prefix),
max_species_prefix=int(args.max_species_prefix),
dropout=0.1,
species_embedding_dim=int(species_store.Ds()),
attn_impl=str(args.attn),
num_kv_groups=int(args.num_kv_groups),
)
# Report model size and SDPA (Flash) kernel configuration
_print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16))
_describe_sdp_kernels()
# Trainer args
targs = TrainingArguments(
output_dir=args.output_dir,
save_steps=args.save_steps,
save_total_limit=int(args.save_total_limit),
ckpt_recent_window_steps=int(args.ckpt_recent_window_steps),
ckpt_recent_interval=int(args.ckpt_recent_interval),
ckpt_archive_interval=int(args.ckpt_archive_interval),
num_train_epochs=args.epochs,
max_steps=int(args.max_steps),
gradient_accumulation_steps=int(args.grad_accum),
warmup_ratio=float(args.warmup_ratio),
lr_scheduler_type=str(args.lr_scheduler),
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.eval_batch_size,
dataloader_num_workers=args.workers,
learning_rate=args.lr,
weight_decay=args.weight_decay,
adam_beta1=float(args.adam_beta1),
adam_beta2=float(args.adam_beta2),
max_grad_norm=float(args.max_grad_norm),
logging_steps=args.logging_steps,
override_lr_on_resume=bool(args.override_lr_on_resume),
data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"),
fp16=bool(args.fp16),
bf16=bool(args.bf16),
fsdp=("full_shard" if args.fsdp else None),
gradient_checkpointing=bool(args.grad_ckpt),
max_length=int(args.max_length),
esm_model_name="esmc_300m",
esm_device=esm_dev,
esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")),
# sampling eval
eval_interval=int(args.eval_interval),
eval_steps=int(args.eval_steps),
steps_per_epoch=int(steps_per_epoch),
)
# Resolve auto-resume if requested
resume_path = None
if args.resume_from:
if args.resume_from == "auto":
root = os.path.abspath(args.output_dir)
if os.path.isdir(root):
try:
subdirs = []
for d in os.listdir(root):
path = os.path.join(root, d)
if not os.path.isdir(path):
continue
if not (
d == "final_model" or
d.startswith("checkpoint-")
):
continue
if not (
os.path.exists(os.path.join(path, "model.safetensors")) or
os.path.exists(os.path.join(path, "pytorch_model.bin"))
):
continue
subdirs.append(path)
subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True)
resume_path = subdirs[0] if subdirs else None
except Exception:
resume_path = None
else:
resume_path = args.resume_from
trainer = Trainer(
model=model,
args=targs,
tokenizer=tok,
species_store=species_store,
resume_from_checkpoint=resume_path,
)
trainer.attach_dataloaders(train_loader, val_loader)
logger.info("Starting training...")
trainer.train()
logger.info("Training finished.")
if __name__ == "__main__":
main()