#!/usr/bin/env python """ Minimal, honest training script for CodonTranslator on CSV or Parquet data. - Species conditioning: REQUIRED (precomputed embeddings) - Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein. - Global capacity is controlled by --max_length (prefix + start + codon). """ import os import math import argparse import logging import torch from src import CodonTranslatorModel, CodonTokenizer, Trainer, TrainingArguments from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger("codontranslator.train") def _describe_sdp_kernels() -> None: # Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch flash = None; mem_eff = None; mathk = None if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'): tbc = torch.backends.cuda if hasattr(tbc, 'flash_sdp_enabled'): flash = tbc.flash_sdp_enabled() if hasattr(tbc, 'mem_efficient_sdp_enabled'): mem_eff = tbc.mem_efficient_sdp_enabled() if hasattr(tbc, 'math_sdp_enabled'): mathk = tbc.math_sdp_enabled() logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}") def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None: total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) w_bytes = 2 if (bf16 or fp16) else 4 opt_bytes = 8 # Adam moments in FP32 weights_gb = total * w_bytes / (1024**3) opt_gb = trainable * opt_bytes / (1024**3) logger.info( f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)" ) def _speed_toggles(): if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): torch.backends.cuda.matmul.allow_tf32 = True if hasattr(torch, "set_float32_matmul_precision"): torch.set_float32_matmul_precision("high") if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"): torch.backends.cudnn.benchmark = True def parse_args(): p = argparse.ArgumentParser(description="Train CodonTranslator on CSV or Parquet data") # Data (CSV path or Parquet glob/dir) p.add_argument("--train_data", type=str, default="random_sample_1000.csv", help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)") p.add_argument("--val_data", type=str, default=None, help="Validation data: CSV file or Parquet glob/dir") p.add_argument("--embeddings_dir", type=str, default="embeddings", help="Dir with species embeddings (species_vocab.json, *.bin/memmap)") # Model / capacity p.add_argument("--hidden", type=int, default=750, help="Model hidden size") p.add_argument("--layers", type=int, default=20, help="Number of transformer layers") p.add_argument("--heads", type=int, default=15, help="Number of attention heads") p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'") p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)") p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)") p.add_argument("--max_length", type=int, default=2048, help="Global max length (prefix + start + codon)") p.add_argument("--max_species_prefix", type=int, default=0, help="Cap species prefix tokens (0 = uncapped)") p.add_argument("--max_protein_prefix", type=int, default=1024, help="Cap protein prefix tokens (0 = uncapped)") # Protein conditioning: always enabled (ESM-C) # Training p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints") p.add_argument("--epochs", type=int, default=1, help="Number of training epochs") p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size") p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size") p.add_argument("--workers", type=int, default=4, help="DataLoader workers") p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps") p.add_argument("--train_shuffle_buffer", type=int, default=0, help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)") p.add_argument("--val_shuffle_buffer", type=int, default=0, help="Streaming shuffle buffer for validation (0 disables)") p.add_argument("--csv_chunksize", type=int, default=200_000, help="Pandas read_csv chunksize for CSV inputs") # Optim / schedule p.add_argument("--lr", type=float, default=1e-4, help="Learning rate") p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)") p.add_argument( "--lr_scheduler", type=str, choices=["linear", "cosine", "constant"], default="linear", help="LR schedule applied after warmup; 'linear' decays to zero by the end of training", ) p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay") p.add_argument("--adam_beta1", type=float, default=0.9, help="Adam beta1 (momentum) coefficient") p.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 (squared-gradient) coefficient") p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)") p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)") p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints") p.add_argument("--ckpt_recent_window_steps", type=int, default=0, help="If >0, keep finer-grained checkpoints within this many recent steps") p.add_argument("--ckpt_recent_interval", type=int, default=0, help="Retention interval inside the recent checkpoint window (0 disables custom retention)") p.add_argument("--ckpt_archive_interval", type=int, default=0, help="Retention interval for checkpoints older than the recent window (0 prunes them)") p.add_argument("--max_steps", type=int, default=-1, help="Total training steps. REQUIRED for streaming (IterableDataset)") p.add_argument("--steps_per_epoch", type=int, default=0, help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0") p.add_argument("--max_grad_norm", type=float, default=1.0, help="Clip gradients to this global L2 norm; set <=0 to disable") p.add_argument("--override_lr_on_resume", action="store_true", help="Do not restore LR/optimizer state on resume (keep current lr)") # Resume p.add_argument("--resume_from", type=str, default=None, help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir") # Evaluation scheduling p.add_argument("--eval_interval", type=int, default=0, help="Run evaluation every N optimizer steps on --val_data (0 disables)") p.add_argument("--eval_steps", type=int, default=5000, help="For streaming eval datasets: limit to this many batches (0 = full eval)") # Hardware / precision p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision") p.add_argument("--fp16", action="store_true", help="float16 mixed precision") p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding") p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing") return p.parse_args() def main(): args = parse_args() _speed_toggles() if args.device == "cuda" and not torch.cuda.is_available(): logger.warning("CUDA not available; switching to CPU") args.device = "cpu" # Tokenizer tok = CodonTokenizer() # Ensure output dir exists and persist vocab.json (used by sampler) os.makedirs(os.path.abspath(args.output_dir), exist_ok=True) tok.save_vocabulary(args.output_dir) # Data first — we need Ds for species embeddings train_loader, val_loader, species_store = create_precomputed_dataloaders( train_path=args.train_data, val_path=args.val_data, embeddings_dir=args.embeddings_dir, tokenizer=tok, batch_size=args.batch_size, num_workers=args.workers, species_pooling="sequence", # prefer variable-length token sequence if available csv_chunksize=int(args.csv_chunksize), train_shuffle_buffer=int(args.train_shuffle_buffer), val_shuffle_buffer=int(args.val_shuffle_buffer), ) # Estimate steps_per_epoch for streaming schedule shaping if not provided steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0) total_rows = 0 paths: list[str] = [] if steps_per_epoch <= 0 and int(args.max_steps) < 0: def _expand_paths(maybe: str | list[str]) -> list[str]: import glob as _glob from pathlib import Path as _Path paths: list[str] = [] if isinstance(maybe, str): p = _Path(maybe) if p.is_dir(): paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) else: paths = sorted(_glob.glob(str(p))) else: for it in maybe: paths.extend(_expand_paths(it)) # de-dup seen = set(); out = [] for x in paths: if x not in seen: out.append(x); seen.add(x) return out paths = _expand_paths(args.train_data) if paths: try: import pyarrow.parquet as pq for fp in paths: if fp.lower().endswith((".parquet", ".parq")): pf = pq.ParquetFile(fp) md = pf.metadata if md is not None: total_rows += int(md.num_rows) except Exception: # Fallback: keep steps_per_epoch at 0 if pyarrow not available total_rows = 0 if total_rows > 0: world = int(os.environ.get("WORLD_SIZE", "1")) ga = max(1, int(getattr(args, "grad_accum", 1))) denom = max(1, int(args.batch_size) * max(1, world) * ga) steps_per_epoch = max(1, math.ceil(total_rows / denom)) logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}") world = int(os.environ.get("WORLD_SIZE", "1")) grad_accum = max(1, int(getattr(args, "grad_accum", 1))) effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum logger.info( "Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s", args.batch_size, args.eval_batch_size, world, grad_accum, effective_global_batch, ) # Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks) esm_dev = "cpu" if args.device == "cuda" and torch.cuda.is_available(): lr = int(os.environ.get("LOCAL_RANK", "0")) esm_dev = f"cuda:{lr}" # Model — species is always on; protein defaults to ON (can be disabled with --no_protein) model = CodonTranslatorModel( vocab_size=tok.vocab_size, num_special_tokens=tok.num_special_tokens, special_ids=tok.special_ids, hidden_size=args.hidden, num_layers=args.layers, num_heads=args.heads, mlp_ratio=float(args.mlp_ratio), max_position_embeddings=args.max_length, prepend_species=True, prepend_protein=True, esm_model_name="esmc_300m", esm_device=esm_dev, max_protein_prefix=int(args.max_protein_prefix), max_species_prefix=int(args.max_species_prefix), dropout=0.1, species_embedding_dim=int(species_store.Ds()), attn_impl=str(args.attn), num_kv_groups=int(args.num_kv_groups), ) # Report model size and SDPA (Flash) kernel configuration _print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16)) _describe_sdp_kernels() # Trainer args targs = TrainingArguments( output_dir=args.output_dir, save_steps=args.save_steps, save_total_limit=int(args.save_total_limit), ckpt_recent_window_steps=int(args.ckpt_recent_window_steps), ckpt_recent_interval=int(args.ckpt_recent_interval), ckpt_archive_interval=int(args.ckpt_archive_interval), num_train_epochs=args.epochs, max_steps=int(args.max_steps), gradient_accumulation_steps=int(args.grad_accum), warmup_ratio=float(args.warmup_ratio), lr_scheduler_type=str(args.lr_scheduler), per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.eval_batch_size, dataloader_num_workers=args.workers, learning_rate=args.lr, weight_decay=args.weight_decay, adam_beta1=float(args.adam_beta1), adam_beta2=float(args.adam_beta2), max_grad_norm=float(args.max_grad_norm), logging_steps=args.logging_steps, override_lr_on_resume=bool(args.override_lr_on_resume), data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"), fp16=bool(args.fp16), bf16=bool(args.bf16), fsdp=("full_shard" if args.fsdp else None), gradient_checkpointing=bool(args.grad_ckpt), max_length=int(args.max_length), esm_model_name="esmc_300m", esm_device=esm_dev, esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")), # sampling eval eval_interval=int(args.eval_interval), eval_steps=int(args.eval_steps), steps_per_epoch=int(steps_per_epoch), ) # Resolve auto-resume if requested resume_path = None if args.resume_from: if args.resume_from == "auto": root = os.path.abspath(args.output_dir) if os.path.isdir(root): try: subdirs = [] for d in os.listdir(root): path = os.path.join(root, d) if not os.path.isdir(path): continue if not ( d == "final_model" or d.startswith("checkpoint-") ): continue if not ( os.path.exists(os.path.join(path, "model.safetensors")) or os.path.exists(os.path.join(path, "pytorch_model.bin")) ): continue subdirs.append(path) subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True) resume_path = subdirs[0] if subdirs else None except Exception: resume_path = None else: resume_path = args.resume_from trainer = Trainer( model=model, args=targs, tokenizer=tok, species_store=species_store, resume_from_checkpoint=resume_path, ) trainer.attach_dataloaders(train_loader, val_loader) logger.info("Starting training...") trainer.train() logger.info("Training finished.") if __name__ == "__main__": main()