File size: 16,302 Bytes
2d8da02 239fef8 2d8da02 b2f9369 2d8da02 53a2f0d 2d8da02 53a2f0d 2d8da02 b2f9369 2d8da02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | #!/usr/bin/env python
"""
Minimal, honest training script for CodonTranslator on CSV or Parquet data.
- Species conditioning: REQUIRED (precomputed embeddings)
- Protein conditioning (ESM-C): ENABLED BY DEFAULT. Disable with --no_protein.
- Global capacity is controlled by --max_length (prefix + start + codon).
"""
import os
import math
import argparse
import logging
import torch
from src import CodonTranslatorModel, CodonTokenizer, Trainer, TrainingArguments
from src.dataset import create_precomputed_dataloaders, SpeciesEmbeddingStore
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("codontranslator.train")
def _describe_sdp_kernels() -> None:
# Log the enabled SDPA backends (Flash/MemEff/Math) without raising on older PyTorch
flash = None; mem_eff = None; mathk = None
if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda'):
tbc = torch.backends.cuda
if hasattr(tbc, 'flash_sdp_enabled'):
flash = tbc.flash_sdp_enabled()
if hasattr(tbc, 'mem_efficient_sdp_enabled'):
mem_eff = tbc.mem_efficient_sdp_enabled()
if hasattr(tbc, 'math_sdp_enabled'):
mathk = tbc.math_sdp_enabled()
logger.info(f"SDP kernels: flash={flash} mem_efficient={mem_eff} math={mathk}")
def _print_model_size(model: torch.nn.Module, bf16: bool, fp16: bool) -> None:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
w_bytes = 2 if (bf16 or fp16) else 4
opt_bytes = 8 # Adam moments in FP32
weights_gb = total * w_bytes / (1024**3)
opt_gb = trainable * opt_bytes / (1024**3)
logger.info(
f"Model params: total={total:,} trainable={trainable:,} (~{weights_gb:.2f} GB weights, ~{opt_gb:.2f} GB optimizer)"
)
def _speed_toggles():
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
if hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "benchmark"):
torch.backends.cudnn.benchmark = True
def parse_args():
p = argparse.ArgumentParser(description="Train CodonTranslator on CSV or Parquet data")
# Data (CSV path or Parquet glob/dir)
p.add_argument("--train_data", type=str, default="random_sample_1000.csv",
help="Training data: CSV file or Parquet glob/dir (e.g., ./data/train_shards/*.parquet)")
p.add_argument("--val_data", type=str, default=None,
help="Validation data: CSV file or Parquet glob/dir")
p.add_argument("--embeddings_dir", type=str, default="embeddings",
help="Dir with species embeddings (species_vocab.json, *.bin/memmap)")
# Model / capacity
p.add_argument("--hidden", type=int, default=750, help="Model hidden size")
p.add_argument("--layers", type=int, default=20, help="Number of transformer layers")
p.add_argument("--heads", type=int, default=15, help="Number of attention heads")
p.add_argument("--attn", type=str, choices=["mha", "gqa"], default="gqa", help="Attention implementation: 'mha' or 'gqa'")
p.add_argument("--num_kv_groups", type=int, default=5, help="GQA: number of KV groups (0 = default/no grouping)")
p.add_argument("--mlp_ratio", type=float, default=3.2, help="FFN expansion ratio (mlp hidden = ratio * hidden)")
p.add_argument("--max_length", type=int, default=2048,
help="Global max length (prefix + start + codon)")
p.add_argument("--max_species_prefix", type=int, default=0,
help="Cap species prefix tokens (0 = uncapped)")
p.add_argument("--max_protein_prefix", type=int, default=1024,
help="Cap protein prefix tokens (0 = uncapped)")
# Protein conditioning: always enabled (ESM-C)
# Training
p.add_argument("--output_dir", type=str, default="checkpoints", help="Where to save checkpoints")
p.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
p.add_argument("--batch_size", type=int, default=20, help="Per-device train batch size")
p.add_argument("--eval_batch_size", type=int, default=32, help="Per-device eval batch size")
p.add_argument("--workers", type=int, default=4, help="DataLoader workers")
p.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
p.add_argument("--train_shuffle_buffer", type=int, default=0,
help="Streaming shuffle buffer for training (set 0 when data is pre-shuffled)")
p.add_argument("--val_shuffle_buffer", type=int, default=0,
help="Streaming shuffle buffer for validation (0 disables)")
p.add_argument("--csv_chunksize", type=int, default=200_000,
help="Pandas read_csv chunksize for CSV inputs")
# Optim / schedule
p.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio for LR schedule (0.0-1.0)")
p.add_argument(
"--lr_scheduler",
type=str,
choices=["linear", "cosine", "constant"],
default="linear",
help="LR schedule applied after warmup; 'linear' decays to zero by the end of training",
)
p.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")
p.add_argument("--adam_beta1", type=float, default=0.9,
help="Adam beta1 (momentum) coefficient")
p.add_argument("--adam_beta2", type=float, default=0.95,
help="Adam beta2 (squared-gradient) coefficient")
p.add_argument("--logging_steps", type=int, default=20, help="Logging interval (steps)")
p.add_argument("--save_steps", type=int, default=10, help="Save every N steps (0 disables step-saving)")
p.add_argument("--save_total_limit", type=int, default=10, help="Keep at most N recent checkpoints")
p.add_argument("--ckpt_recent_window_steps", type=int, default=0,
help="If >0, keep finer-grained checkpoints within this many recent steps")
p.add_argument("--ckpt_recent_interval", type=int, default=0,
help="Retention interval inside the recent checkpoint window (0 disables custom retention)")
p.add_argument("--ckpt_archive_interval", type=int, default=0,
help="Retention interval for checkpoints older than the recent window (0 prunes them)")
p.add_argument("--max_steps", type=int, default=-1,
help="Total training steps. REQUIRED for streaming (IterableDataset)")
p.add_argument("--steps_per_epoch", type=int, default=0,
help="For streaming datasets: shape LR schedule as epochs*steps_per_epoch when max_steps<0")
p.add_argument("--max_grad_norm", type=float, default=1.0,
help="Clip gradients to this global L2 norm; set <=0 to disable")
p.add_argument("--override_lr_on_resume", action="store_true",
help="Do not restore LR/optimizer state on resume (keep current lr)")
# Resume
p.add_argument("--resume_from", type=str, default=None,
help="Path to checkpoint dir to resume from; pass 'auto' to pick latest in output_dir")
# Evaluation scheduling
p.add_argument("--eval_interval", type=int, default=0,
help="Run evaluation every N optimizer steps on --val_data (0 disables)")
p.add_argument("--eval_steps", type=int, default=5000,
help="For streaming eval datasets: limit to this many batches (0 = full eval)")
# Hardware / precision
p.add_argument("--device", type=str, default="cuda", help="cuda or cpu")
p.add_argument("--bf16", action="store_true", help="bfloat16 mixed precision")
p.add_argument("--fp16", action="store_true", help="float16 mixed precision")
p.add_argument("--fsdp", action="store_true", help="Enable FSDP full sharding")
p.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing")
return p.parse_args()
def main():
args = parse_args()
_speed_toggles()
if args.device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA not available; switching to CPU")
args.device = "cpu"
# Tokenizer
tok = CodonTokenizer()
# Ensure output dir exists and persist vocab.json (used by sampler)
os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)
tok.save_vocabulary(args.output_dir)
# Data first — we need Ds for species embeddings
train_loader, val_loader, species_store = create_precomputed_dataloaders(
train_path=args.train_data,
val_path=args.val_data,
embeddings_dir=args.embeddings_dir,
tokenizer=tok,
batch_size=args.batch_size,
num_workers=args.workers,
species_pooling="sequence", # prefer variable-length token sequence if available
csv_chunksize=int(args.csv_chunksize),
train_shuffle_buffer=int(args.train_shuffle_buffer),
val_shuffle_buffer=int(args.val_shuffle_buffer),
)
# Estimate steps_per_epoch for streaming schedule shaping if not provided
steps_per_epoch = int(getattr(args, "steps_per_epoch", 0) or 0)
total_rows = 0
paths: list[str] = []
if steps_per_epoch <= 0 and int(args.max_steps) < 0:
def _expand_paths(maybe: str | list[str]) -> list[str]:
import glob as _glob
from pathlib import Path as _Path
paths: list[str] = []
if isinstance(maybe, str):
p = _Path(maybe)
if p.is_dir():
paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
else:
paths = sorted(_glob.glob(str(p)))
else:
for it in maybe:
paths.extend(_expand_paths(it))
# de-dup
seen = set(); out = []
for x in paths:
if x not in seen:
out.append(x); seen.add(x)
return out
paths = _expand_paths(args.train_data)
if paths:
try:
import pyarrow.parquet as pq
for fp in paths:
if fp.lower().endswith((".parquet", ".parq")):
pf = pq.ParquetFile(fp)
md = pf.metadata
if md is not None:
total_rows += int(md.num_rows)
except Exception:
# Fallback: keep steps_per_epoch at 0 if pyarrow not available
total_rows = 0
if total_rows > 0:
world = int(os.environ.get("WORLD_SIZE", "1"))
ga = max(1, int(getattr(args, "grad_accum", 1)))
denom = max(1, int(args.batch_size) * max(1, world) * ga)
steps_per_epoch = max(1, math.ceil(total_rows / denom))
logger.info(f"Estimated steps_per_epoch={steps_per_epoch} from {len(paths)} parquet files, total_rows={total_rows}")
world = int(os.environ.get("WORLD_SIZE", "1"))
grad_accum = max(1, int(getattr(args, "grad_accum", 1)))
effective_global_batch = int(args.batch_size) * max(1, world) * grad_accum
logger.info(
"Batch config: per_device_train_batch=%s per_device_eval_batch=%s world_size=%s grad_accum=%s effective_global_batch=%s",
args.batch_size,
args.eval_batch_size,
world,
grad_accum,
effective_global_batch,
)
# Resolve per-process CUDA device for ESM (avoid defaulting to cuda:0 on all ranks)
esm_dev = "cpu"
if args.device == "cuda" and torch.cuda.is_available():
lr = int(os.environ.get("LOCAL_RANK", "0"))
esm_dev = f"cuda:{lr}"
# Model — species is always on; protein defaults to ON (can be disabled with --no_protein)
model = CodonTranslatorModel(
vocab_size=tok.vocab_size,
num_special_tokens=tok.num_special_tokens,
special_ids=tok.special_ids,
hidden_size=args.hidden,
num_layers=args.layers,
num_heads=args.heads,
mlp_ratio=float(args.mlp_ratio),
max_position_embeddings=args.max_length,
prepend_species=True,
prepend_protein=True,
esm_model_name="esmc_300m",
esm_device=esm_dev,
max_protein_prefix=int(args.max_protein_prefix),
max_species_prefix=int(args.max_species_prefix),
dropout=0.1,
species_embedding_dim=int(species_store.Ds()),
attn_impl=str(args.attn),
num_kv_groups=int(args.num_kv_groups),
)
# Report model size and SDPA (Flash) kernel configuration
_print_model_size(model, bf16=bool(args.bf16), fp16=bool(args.fp16))
_describe_sdp_kernels()
# Trainer args
targs = TrainingArguments(
output_dir=args.output_dir,
save_steps=args.save_steps,
save_total_limit=int(args.save_total_limit),
ckpt_recent_window_steps=int(args.ckpt_recent_window_steps),
ckpt_recent_interval=int(args.ckpt_recent_interval),
ckpt_archive_interval=int(args.ckpt_archive_interval),
num_train_epochs=args.epochs,
max_steps=int(args.max_steps),
gradient_accumulation_steps=int(args.grad_accum),
warmup_ratio=float(args.warmup_ratio),
lr_scheduler_type=str(args.lr_scheduler),
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.eval_batch_size,
dataloader_num_workers=args.workers,
learning_rate=args.lr,
weight_decay=args.weight_decay,
adam_beta1=float(args.adam_beta1),
adam_beta2=float(args.adam_beta2),
max_grad_norm=float(args.max_grad_norm),
logging_steps=args.logging_steps,
override_lr_on_resume=bool(args.override_lr_on_resume),
data_cursor_path=os.path.join(os.path.abspath(args.output_dir), "data_cursor.json"),
fp16=bool(args.fp16),
bf16=bool(args.bf16),
fsdp=("full_shard" if args.fsdp else None),
gradient_checkpointing=bool(args.grad_ckpt),
max_length=int(args.max_length),
esm_model_name="esmc_300m",
esm_device=esm_dev,
esm_dtype=("bf16" if args.bf16 else ("fp16" if args.fp16 else "fp32")),
# sampling eval
eval_interval=int(args.eval_interval),
eval_steps=int(args.eval_steps),
steps_per_epoch=int(steps_per_epoch),
)
# Resolve auto-resume if requested
resume_path = None
if args.resume_from:
if args.resume_from == "auto":
root = os.path.abspath(args.output_dir)
if os.path.isdir(root):
try:
subdirs = []
for d in os.listdir(root):
path = os.path.join(root, d)
if not os.path.isdir(path):
continue
if not (
d == "final_model" or
d.startswith("checkpoint-")
):
continue
if not (
os.path.exists(os.path.join(path, "model.safetensors")) or
os.path.exists(os.path.join(path, "pytorch_model.bin"))
):
continue
subdirs.append(path)
subdirs.sort(key=lambda d: os.path.getmtime(d), reverse=True)
resume_path = subdirs[0] if subdirs else None
except Exception:
resume_path = None
else:
resume_path = args.resume_from
trainer = Trainer(
model=model,
args=targs,
tokenizer=tok,
species_store=species_store,
resume_from_checkpoint=resume_path,
)
trainer.attach_dataloaders(train_loader, val_loader)
logger.info("Starting training...")
trainer.train()
logger.info("Training finished.")
if __name__ == "__main__":
main()
|