| |
| """ |
| Minimal, honest training script for CodonTranslator on CSV or Parquet data. |
| |
| - Species conditioning: REQUIRED (precomputed embeddings) |
| - Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein. |
| - Global capacity is controlled by --max_length (prefix + start + codon). |
| """ |
|
|
| import os |
| import math |
| import argparse |
| import logging |
| import torch |
|
|
| from src import CodonTranslatorModel, CodonTokenizer, Trainer, TrainingArguments |
| from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore |
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger("codontranslator.train") |
|
|
| def _describe_sdp_kernels() -> None: |
| |
| flash = None; mem_eff = None; mathk = None |
| if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'): |
| tbc = torch.backends.cuda |
| if hasattr(tbc, 'flash_sdp_enabled'): |
| flash = tbc.flash_sdp_enabled() |
| if hasattr(tbc, 'mem_efficient_sdp_enabled'): |
| mem_eff = tbc.mem_efficient_sdp_enabled() |
| if hasattr(tbc, 'math_sdp_enabled'): |
| mathk = tbc.math_sdp_enabled() |
| logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}") |
|
|
| def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None: |
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| w_bytes = 2 if (bf16 or fp16) else 4 |
| opt_bytes = 8 |
| weights_gb = total * w_bytes / (1024**3) |
| opt_gb = trainable * opt_bytes / (1024**3) |
| logger.info( |
| f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)" |
| ) |
|
|
| def _speed_toggles(): |
| if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| if hasattr(torch, "set_float32_matmul_precision"): |
| torch.set_float32_matmul_precision("high") |
| if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"): |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Train CodonTranslator on CSV or Parquet data") |
| |
| p.add_argument("--train_data", type=str, default="random_sample_1000.csv", |
| help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)") |
| p.add_argument("--val_data", type=str, default=None, |
| help="Validation data: CSV file or Parquet glob/dir") |
| p.add_argument("--embeddings_dir", type=str, default="embeddings", |
| help="Dir with species embeddings (species_vocab.json, *.bin/memmap)") |
|
|
| |
| p.add_argument("--hidden", type=int, default=750, help="Model hidden size") |
| p.add_argument("--layers", type=int, default=20, help="Number of transformer layers") |
| p.add_argument("--heads", type=int, default=15, help="Number of attention heads") |
| p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'") |
| p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)") |
| p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)") |
| p.add_argument("--max_length", type=int, default=2048, |
| help="Global max length (prefix + start + codon)") |
| p.add_argument("--max_species_prefix", type=int, default=0, |
| help="Cap species prefix tokens (0 = uncapped)") |
| p.add_argument("--max_protein_prefix", type=int, default=1024, |
| help="Cap protein prefix tokens (0 = uncapped)") |
|
|
| |
|
|
| |
| p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints") |
| p.add_argument("--epochs", type=int, default=1, help="Number of training epochs") |
| p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size") |
| p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size") |
| p.add_argument("--workers", type=int, default=4, help="DataLoader workers") |
| p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps") |
| p.add_argument("--train_shuffle_buffer", type=int, default=0, |
| help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)") |
| p.add_argument("--val_shuffle_buffer", type=int, default=0, |
| help="Streaming shuffle buffer for validation (0 disables)") |
| p.add_argument("--csv_chunksize", type=int, default=200_000, |
| help="Pandas read_csv chunksize for CSV inputs") |
|
|
| |
| p.add_argument("--lr", type=float, default=1e-4, help="Learning rate") |
| p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)") |
| p.add_argument( |
| "--lr_scheduler", |
| type=str, |
| choices=["linear", "cosine", "constant"], |
| default="linear", |
| help="LR schedule applied after warmup; 'linear' decays to zero by the end of training", |
| ) |
| p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay") |
| p.add_argument("--adam_beta1", type=float, default=0.9, |
| help="Adam beta1 (momentum) coefficient") |
| p.add_argument("--adam_beta2", type=float, default=0.95, |
| help="Adam beta2 (squared-gradient) coefficient") |
| p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)") |
| p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)") |
| p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints") |
| p.add_argument("--ckpt_recent_window_steps", type=int, default=0, |
| help="If >0, keep finer-grained checkpoints within this many recent steps") |
| p.add_argument("--ckpt_recent_interval", type=int, default=0, |
| help="Retention interval inside the recent checkpoint window (0 disables custom retention)") |
| p.add_argument("--ckpt_archive_interval", type=int, default=0, |
| help="Retention interval for checkpoints older than the recent window (0 prunes them)") |
| p.add_argument("--max_steps", type=int, default=-1, |
| help="Total training steps. REQUIRED for streaming (IterableDataset)") |
| p.add_argument("--steps_per_epoch", type=int, default=0, |
| help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0") |
| p.add_argument("--max_grad_norm", type=float, default=1.0, |
| help="Clip gradients to this global L2 norm; set <=0 to disable") |
| p.add_argument("--override_lr_on_resume", action="store_true", |
| help="Do not restore LR/optimizer state on resume (keep current lr)") |
|
|
| |
| p.add_argument("--resume_from", type=str, default=None, |
| help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir") |
|
|
| |
| p.add_argument("--eval_interval", type=int, default=0, |
| help="Run evaluation every N optimizer steps on --val_data (0 disables)") |
| p.add_argument("--eval_steps", type=int, default=5000, |
| help="For streaming eval datasets: limit to this many batches (0 = full eval)") |
|
|
| |
| p.add_argument("--device", type=str, default="cuda", help="cuda or cpu") |
| p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision") |
| p.add_argument("--fp16", action="store_true", help="float16 mixed precision") |
| p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding") |
| p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| _speed_toggles() |
|
|
| if args.device == "cuda" and not torch.cuda.is_available(): |
| logger.warning("CUDA not available; switching to CPU") |
| args.device = "cpu" |
|
|
| |
| tok = CodonTokenizer() |
| |
| os.makedirs(os.path.abspath(args.output_dir), exist_ok=True) |
| tok.save_vocabulary(args.output_dir) |
| |
|
|
| |
| train_loader, val_loader, species_store = create_precomputed_dataloaders( |
| train_path=args.train_data, |
| val_path=args.val_data, |
| embeddings_dir=args.embeddings_dir, |
| tokenizer=tok, |
| batch_size=args.batch_size, |
| num_workers=args.workers, |
| species_pooling="sequence", |
| csv_chunksize=int(args.csv_chunksize), |
| train_shuffle_buffer=int(args.train_shuffle_buffer), |
| val_shuffle_buffer=int(args.val_shuffle_buffer), |
| ) |
|
|
| |
| steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0) |
| total_rows = 0 |
| paths: list[str] = [] |
| if steps_per_epoch <= 0 and int(args.max_steps) < 0: |
| def _expand_paths(maybe: str | list[str]) -> list[str]: |
| import glob as _glob |
| from pathlib import Path as _Path |
| paths: list[str] = [] |
| if isinstance(maybe, str): |
| p = _Path(maybe) |
| if p.is_dir(): |
| paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) |
| else: |
| paths = sorted(_glob.glob(str(p))) |
| else: |
| for it in maybe: |
| paths.extend(_expand_paths(it)) |
| |
| seen = set(); out = [] |
| for x in paths: |
| if x not in seen: |
| out.append(x); seen.add(x) |
| return out |
|
|
| paths = _expand_paths(args.train_data) |
| if paths: |
| try: |
| import pyarrow.parquet as pq |
| for fp in paths: |
| if fp.lower().endswith((".parquet", ".parq")): |
| pf = pq.ParquetFile(fp) |
| md = pf.metadata |
| if md is not None: |
| total_rows += int(md.num_rows) |
| except Exception: |
| |
| total_rows = 0 |
| if total_rows > 0: |
| world = int(os.environ.get("WORLD_SIZE", "1")) |
| ga = max(1, int(getattr(args, "grad_accum", 1))) |
| denom = max(1, int(args.batch_size) * max(1, world) * ga) |
| steps_per_epoch = max(1, math.ceil(total_rows / denom)) |
| logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}") |
|
|
| world = int(os.environ.get("WORLD_SIZE", "1")) |
| grad_accum = max(1, int(getattr(args, "grad_accum", 1))) |
| effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum |
| logger.info( |
| "Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s", |
| args.batch_size, |
| args.eval_batch_size, |
| world, |
| grad_accum, |
| effective_global_batch, |
| ) |
|
|
| |
| esm_dev = "cpu" |
| if args.device == "cuda" and torch.cuda.is_available(): |
| lr = int(os.environ.get("LOCAL_RANK", "0")) |
| esm_dev = f"cuda:{lr}" |
|
|
| |
| model = CodonTranslatorModel( |
| vocab_size=tok.vocab_size, |
| num_special_tokens=tok.num_special_tokens, |
| special_ids=tok.special_ids, |
| hidden_size=args.hidden, |
| num_layers=args.layers, |
| num_heads=args.heads, |
| mlp_ratio=float(args.mlp_ratio), |
| max_position_embeddings=args.max_length, |
| prepend_species=True, |
| prepend_protein=True, |
| esm_model_name="esmc_300m", |
| esm_device=esm_dev, |
| max_protein_prefix=int(args.max_protein_prefix), |
| max_species_prefix=int(args.max_species_prefix), |
| dropout=0.1, |
| species_embedding_dim=int(species_store.Ds()), |
| attn_impl=str(args.attn), |
| num_kv_groups=int(args.num_kv_groups), |
| ) |
| |
| |
| _print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16)) |
| _describe_sdp_kernels() |
|
|
| |
| targs = TrainingArguments( |
| output_dir=args.output_dir, |
| save_steps=args.save_steps, |
| save_total_limit=int(args.save_total_limit), |
| ckpt_recent_window_steps=int(args.ckpt_recent_window_steps), |
| ckpt_recent_interval=int(args.ckpt_recent_interval), |
| ckpt_archive_interval=int(args.ckpt_archive_interval), |
| num_train_epochs=args.epochs, |
| max_steps=int(args.max_steps), |
| gradient_accumulation_steps=int(args.grad_accum), |
| warmup_ratio=float(args.warmup_ratio), |
| lr_scheduler_type=str(args.lr_scheduler), |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.eval_batch_size, |
| dataloader_num_workers=args.workers, |
| learning_rate=args.lr, |
| weight_decay=args.weight_decay, |
| adam_beta1=float(args.adam_beta1), |
| adam_beta2=float(args.adam_beta2), |
| max_grad_norm=float(args.max_grad_norm), |
| logging_steps=args.logging_steps, |
| override_lr_on_resume=bool(args.override_lr_on_resume), |
| data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"), |
| fp16=bool(args.fp16), |
| bf16=bool(args.bf16), |
| fsdp=("full_shard" if args.fsdp else None), |
| gradient_checkpointing=bool(args.grad_ckpt), |
| max_length=int(args.max_length), |
| esm_model_name="esmc_300m", |
| esm_device=esm_dev, |
| esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")), |
| |
| eval_interval=int(args.eval_interval), |
| eval_steps=int(args.eval_steps), |
| steps_per_epoch=int(steps_per_epoch), |
| ) |
|
|
| |
| resume_path = None |
| if args.resume_from: |
| if args.resume_from == "auto": |
| root = os.path.abspath(args.output_dir) |
| if os.path.isdir(root): |
| try: |
| subdirs = [] |
| for d in os.listdir(root): |
| path = os.path.join(root, d) |
| if not os.path.isdir(path): |
| continue |
| if not ( |
| d == "final_model" or |
| d.startswith("checkpoint-") |
| ): |
| continue |
| if not ( |
| os.path.exists(os.path.join(path, "model.safetensors")) or |
| os.path.exists(os.path.join(path, "pytorch_model.bin")) |
| ): |
| continue |
| subdirs.append(path) |
| subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True) |
| resume_path = subdirs[0] if subdirs else None |
| except Exception: |
| resume_path = None |
| else: |
| resume_path = args.resume_from |
|
|
| trainer = Trainer( |
| model=model, |
| args=targs, |
| tokenizer=tok, |
| species_store=species_store, |
| resume_from_checkpoint=resume_path, |
| ) |
| trainer.attach_dataloaders(train_loader, val_loader) |
|
|
| logger.info("Starting training...") |
| trainer.train() |
| logger.info("Training finished.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|