| |
| """ |
| FSDP trainer for CodonTranslator. |
| No frameworks, no sugar. The model computes its own loss. |
| |
| Batch invariants: |
| - codon_ids [B, T] (right-padded; EOS already in-sequence) |
| - species_ids [B] (SpeciesEmbeddingStore provides fixed-size or sequence embeddings) |
| - protein_seqs: list[str] (ESM tokenization happens inside the model) |
| |
| Rules: |
| - If your loader is IterableDataset, you MUST set args.max_steps > 0. We don't guess. |
| - If you want epoch-based, use a sized dataset; we call len(dataloader). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import json |
| import math |
| import re |
| import shutil |
| import logging |
| import time |
| from dataclasses import dataclass |
| import datetime |
| import warnings |
| import importlib.util |
| import inspect |
| from typing import Any, Callable, Dict, Optional, Tuple, List |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader, IterableDataset |
|
|
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import ( |
| ShardingStrategy, |
| MixedPrecision, |
| StateDictType, |
| FullStateDictConfig, |
| FullOptimStateDictConfig, |
| ) |
| from safetensors.torch import save_file, load_file |
| import wandb |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainingArguments: |
| |
| output_dir: str = "checkpoints" |
| save_steps: int = 1000 |
| save_total_limit: int = 3 |
| save_safetensors: bool = True |
| ckpt_recent_window_steps: int = 0 |
| ckpt_recent_interval: int = 0 |
| ckpt_archive_interval: int = 0 |
|
|
| |
| num_train_epochs: int = 1 |
| max_steps: int = -1 |
| gradient_accumulation_steps: int = 1 |
| warmup_ratio: float = 0.0 |
| lr_scheduler_type: str = "cosine" |
| |
| |
| steps_per_epoch: int = 0 |
|
|
| |
| learning_rate: float = 5e-4 |
| weight_decay: float = 0.0 |
| adam_beta1: float = 0.9 |
| adam_beta2: float = 0.95 |
| max_grad_norm: float = 1.0 |
|
|
| |
| per_device_train_batch_size: int = 8 |
| per_device_eval_batch_size: int = 8 |
| dataloader_num_workers: int = 0 |
|
|
| |
| fp16: bool = False |
| bf16: bool = False |
| fsdp: Optional[str] = None |
| gradient_checkpointing: bool = False |
|
|
| |
| max_length: int = 4096 |
|
|
| |
| esm_model_name: str = "esmc_300m" |
| esm_device: str = "cuda" |
| esm_dtype: str = "bf16" |
|
|
| |
| logging_steps: int = 100 |
| eval_steps: int = 0 |
| eval_interval: int = 0 |
| override_lr_on_resume: bool = False |
| |
| |
| |
| data_cursor_path: Optional[str] = None |
|
|
|
|
| |
| |
| |
|
|
| class Trainer: |
| def __init__( |
| self, |
| model: nn.Module, |
| args: TrainingArguments, |
| data_collator: Optional[Callable] = None, |
| train_dataset: Optional[Any] = None, |
| eval_dataset: Optional[Any] = None, |
| tokenizer: Optional[Any] = None, |
| model_init: Optional[Callable[[], nn.Module]] = None, |
| compute_metrics: Optional[Callable] = None, |
| callbacks: Optional[list] = None, |
| optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[Any]] = (None, None), |
| preprocess_logits_for_metrics: Optional[Callable] = None, |
| species_store=None, |
| resume_from_checkpoint: Optional[str] = None, |
| ): |
| self.model = model |
| self.args = args |
| self.tokenizer = tokenizer |
| self.optimizer = optimizers[0] |
| self.lr_scheduler = optimizers[1] |
| self.species_store = species_store |
|
|
| self.train_dataloader: Optional[DataLoader] = None |
| self.eval_dataloader: Optional[DataLoader] = None |
|
|
| |
| self.local_rank = 0 |
| if torch.cuda.is_available(): |
| lr_env = os.environ.get("LOCAL_RANK") |
| if lr_env is not None: |
| self.local_rank = int(lr_env) |
| else: |
| r = int(os.environ.get("RANK", "0")) |
| ng = max(1, torch.cuda.device_count()) |
| self.local_rank = (r % ng) |
| self.device = torch.device(f"cuda:{self.local_rank}") |
| torch.cuda.set_device(self.device) |
| cd = torch.cuda.current_device() |
| nm = torch.cuda.get_device_name(cd) |
| logger.info( |
| f"[dist] RANK={os.environ.get('RANK')} LOCAL_RANK={os.environ.get('LOCAL_RANK')} WORLD_SIZE={os.environ.get('WORLD_SIZE')} " |
| f"cuda.count={torch.cuda.device_count()} select={self.device} current={cd} name={nm}" |
| ) |
| else: |
| self.device = torch.device("cpu") |
|
|
| |
| base = self._unwrap(self.model) |
| if self.args.gradient_checkpointing and hasattr(base, "gradient_checkpointing"): |
| base.gradient_checkpointing = True |
|
|
| |
| if self.args.fsdp: |
| self._setup_fsdp() |
| else: |
| self.model = self.model.to(self.device) |
|
|
| |
| self._use_amp = (self.device.type == "cuda") and (self.args.fp16 or self.args.bf16) |
| self._amp_dtype = torch.float16 if self.args.fp16 else (torch.bfloat16 if self.args.bf16 else None) |
| use_cuda = (self.device.type == "cuda") |
| self._scaler = torch.amp.GradScaler(device="cuda", enabled=(use_cuda and self.args.fp16)) |
|
|
| self.state = {"epoch": 0, "global_step": 0} |
|
|
| |
| self._resume_path: Optional[str] = resume_from_checkpoint |
|
|
| |
| def attach_dataloaders(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None): |
| |
| self.train_dataloader = train_loader |
| self.eval_dataloader = eval_loader |
| |
| p = getattr(self.args, "data_cursor_path", None) |
| if p and os.path.exists(p): |
| with open(p, "r") as f: |
| js = json.load(f) |
| ds = getattr(self.train_dataloader, "dataset", None) |
| if hasattr(ds, "set_resume_skip"): |
| distributed = dist.is_available() and dist.is_initialized() |
| world = dist.get_world_size() if distributed else 1 |
| rank = dist.get_rank() if distributed else 0 |
|
|
| |
| |
| total: int = 0 |
| if isinstance(js, dict): |
| try: |
| total = int(js.get("skip_samples", 0) or 0) |
| except Exception: |
| total = 0 |
| if total <= 0: |
| raw = js.get("per_rank") |
| if isinstance(raw, list) and raw: |
| try: |
| total = int(sum(int(x) for x in raw)) |
| except Exception: |
| total = 0 |
|
|
| if total > 0: |
| if distributed: |
| per = total // max(world, 1) |
| rem = total % max(world, 1) |
| n_rank = per + (1 if rank < rem else 0) |
| ds.set_resume_skip(int(n_rank)) |
| if self._is_main(): |
| logger.info( |
| "resume cursor: total=%s split across world=%s → rank=%s skip=%s", |
| total, world, rank, n_rank, |
| ) |
| else: |
| ds.set_resume_skip(int(total)) |
| if self._is_main(): |
| logger.info("resume cursor: total=%s (single-process) skip=%s", total, total) |
|
|
|
|
| |
| def _create_optimizer_and_scheduler(self): |
| if self.optimizer is None: |
| decay, no_decay = [], [] |
| for n, p in self._unwrap(self.model).named_parameters(): |
| if not p.requires_grad: |
| continue |
| if n.endswith("bias") or "norm" in n.lower() or "ln_" in n.lower(): |
| no_decay.append(p) |
| else: |
| decay.append(p) |
| |
| opt_kwargs = dict( |
| lr=self.args.learning_rate, |
| betas=(self.args.adam_beta1, self.args.adam_beta2), |
| ) |
| params = [ |
| {"params": decay, "weight_decay": self.args.weight_decay}, |
| {"params": no_decay, "weight_decay": 0.0}, |
| ] |
| sig_adamw = inspect.signature(torch.optim.AdamW) |
| if torch.cuda.is_available() and "fused" in sig_adamw.parameters: |
| opt_kwargs["fused"] = True |
| self.optimizer = torch.optim.AdamW(params, **opt_kwargs) |
| |
| if self._is_main(): |
| fused_flag = None |
| foreach_flag = None |
| if hasattr(self.optimizer, "defaults"): |
| fused_flag = self.optimizer.defaults.get("fused") |
| foreach_flag = self.optimizer.defaults.get("foreach") |
| logger.info(f"AdamW configured: fused={fused_flag} foreach={foreach_flag}") |
|
|
| |
| ds = getattr(self.train_dataloader, "dataset", None) |
| ga = max(1, self.args.gradient_accumulation_steps) |
| if isinstance(ds, IterableDataset): |
| if self.args.max_steps > 0: |
| |
| steps_per_epoch = self.args.max_steps |
| total_steps = self.args.max_steps |
| elif getattr(self.args, "steps_per_epoch", 0) and self.args.steps_per_epoch > 0: |
| |
| steps_per_epoch = max(1, int(self.args.steps_per_epoch)) |
| total_steps = max(1, self.args.num_train_epochs) * steps_per_epoch |
| else: |
| |
| self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0) |
| return |
| else: |
| |
| steps_per_epoch = max(len(self.train_dataloader) // ga, 1) |
| total_steps = self.args.max_steps if self.args.max_steps > 0 else self.args.num_train_epochs * steps_per_epoch |
|
|
| warmup = int(self.args.warmup_ratio * total_steps) |
|
|
| if self.args.lr_scheduler_type == "constant": |
| self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda step: 1.0) |
| return |
|
|
| def lrs_lambda(step: int) -> float: |
| if step < warmup: |
| return max(float(step) / max(warmup, 1), 1e-6) |
| t = (step - warmup) / max(total_steps - warmup, 1) |
| if self.args.lr_scheduler_type == "linear": |
| return max(1.0 - t, 0.0) |
| |
| return 0.5 * (1.0 + math.cos(math.pi * t)) |
|
|
| self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lrs_lambda) |
|
|
| |
| def train(self) -> Dict[str, float]: |
| assert self.train_dataloader is not None, "Call attach_dataloaders() first" |
| |
| if getattr(self, "_resume_path", None): |
| self._resume_from(self._resume_path) |
| self._resume_path = None |
|
|
| if self.optimizer is None: |
| self._create_optimizer_and_scheduler() |
|
|
| ds = self.train_dataloader.dataset |
|
|
| |
| target_total_steps: Optional[int] = None |
| if isinstance(ds, IterableDataset) and int(self.args.max_steps) < 0: |
| spe = int(getattr(self.args, "steps_per_epoch", 0) or 0) |
| if spe > 0: |
| target_total_steps = max(1, int(self.args.num_train_epochs)) * spe |
|
|
| |
| progress_total: Optional[int] = None |
| if int(self.args.max_steps) > 0: |
| progress_total = int(self.args.max_steps) |
| elif isinstance(ds, IterableDataset): |
| if target_total_steps is not None: |
| progress_total = target_total_steps |
| else: |
| ga = max(1, self.args.gradient_accumulation_steps) |
| steps_per_epoch = max(len(self.train_dataloader) // ga, 1) |
| progress_total = max(1, int(self.args.num_train_epochs)) * steps_per_epoch |
|
|
| |
| if self._is_main(): |
| if not hasattr(self, "_wandb"): |
| proj = os.environ.get("WANDB_PROJECT", "codontranslator") |
| name = os.environ.get("WANDB_NAME") |
| run_id = os.environ.get("WANDB_RUN_ID") |
| resume = os.environ.get("WANDB_RESUME") |
| wandb_dir = os.environ.get("WANDB_DIR") |
| world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1")) |
| init_kwargs = { |
| "project": proj, |
| "name": name, |
| "config": { |
| "lr": self.args.learning_rate, |
| "warmup_ratio": self.args.warmup_ratio, |
| "scheduler": self.args.lr_scheduler_type, |
| "batch_size": self.args.per_device_train_batch_size, |
| "eval_batch_size": self.args.per_device_eval_batch_size, |
| "grad_accum": self.args.gradient_accumulation_steps, |
| "effective_global_batch": self.args.per_device_train_batch_size * max(1, world_size) * max(1, self.args.gradient_accumulation_steps), |
| "epochs": self.args.num_train_epochs, |
| "steps_per_epoch": getattr(self.args, "steps_per_epoch", 0), |
| "max_steps": self.args.max_steps, |
| "weight_decay": self.args.weight_decay, |
| "world_size": world_size, |
| "output_dir": self.args.output_dir, |
| "fsdp": self.args.fsdp, |
| "bf16": self.args.bf16, |
| "fp16": self.args.fp16, |
| }, |
| } |
| if run_id: |
| init_kwargs["id"] = run_id |
| if resume: |
| init_kwargs["resume"] = resume |
| if wandb_dir: |
| init_kwargs["dir"] = wandb_dir |
| self._wandb = wandb.init(**init_kwargs) |
|
|
| self.model.train() |
| grad_accum = max(1, self.args.gradient_accumulation_steps) |
| progress = None |
| if self._is_main() and progress_total is not None and progress_total > 0: |
| progress = tqdm(total=progress_total, initial=int(self.state["global_step"]), desc="Train", dynamic_ncols=True) |
| if self.device.type == "cuda" and torch.cuda.is_available(): |
| torch.cuda.reset_peak_memory_stats(self.device) |
| world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else int(os.environ.get("WORLD_SIZE", "1")) |
| seqs_per_optimizer_step = ( |
| int(self.args.per_device_train_batch_size) * max(1, world_size) * grad_accum |
| ) |
| log_window_start = time.perf_counter() |
| log_window_optimizer_steps = 0 |
|
|
| for epoch in range(self.state["epoch"], max(1, self.args.num_train_epochs)): |
| self.state["epoch"] = epoch |
| running_loss = 0.0 |
| running_count = 0 |
|
|
| train_iter = iter(self.train_dataloader) |
| step = 0 |
| batches_this_epoch = 0 |
| optimizer_steps_this_epoch = 0 |
| |
| enforce_budget = False |
| epoch_budget = None |
| ds = self.train_dataloader.dataset |
| if isinstance(ds, IterableDataset): |
| spe = int(getattr(self.args, "steps_per_epoch", 0) or 0) |
| if spe > 0: |
| enforce_budget = True |
| epoch_budget = int(spe) |
|
|
| refill_attempts = 0 |
| max_refills = 64 |
|
|
| while True: |
| batch, has_batch, local_has_batch = self._next_batch_sync(train_iter) |
| if not has_batch: |
| |
| if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch < epoch_budget): |
| if local_has_batch and self._is_main(): |
| logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync") |
| self._barrier() |
| train_iter = iter(self.train_dataloader) |
| refill_attempts += 1 |
| if refill_attempts > max_refills: |
| if self._is_main(): |
| logger.warning( |
| "Exceeded max refills for epoch %s (steps %s/%s). Ending epoch early.", |
| epoch, optimizer_steps_this_epoch, epoch_budget, |
| ) |
| break |
| continue |
| else: |
| if local_has_batch and self._is_main(): |
| logger.warning("Rank retained extra batch while peers exhausted stream; dropping to stay in sync") |
| break |
|
|
| batch = self._prepare_batch(batch) |
| batches_this_epoch += 1 |
|
|
| codon_ids = batch["codon_ids"].to(self.device) |
| input_ids = codon_ids[:, :-1] |
| labels = codon_ids[:, :-1] |
|
|
| |
| pad_id = int(self.tokenizer.pad_token_id) if self.tokenizer is not None else 0 |
| eos_id = int(self.tokenizer.special_ids.eos) if self.tokenizer is not None else -999 |
| labels = labels.clone() |
| labels[labels == pad_id] = -100 |
| labels[labels == eos_id] = -100 |
|
|
| cond = self._build_cond(batch) |
|
|
| |
| use_cuda = (self.device.type == "cuda") |
| autocast_dtype = self._amp_dtype |
| if autocast_dtype is not None and use_cuda: |
| ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype) |
| else: |
| from contextlib import nullcontext |
| ctx = nullcontext() |
|
|
| with ctx: |
| out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) |
| loss = out["loss"] |
|
|
| if self._scaler.is_enabled(): |
| self._scaler.scale(loss / grad_accum).backward() |
| else: |
| (loss / grad_accum).backward() |
|
|
| running_loss += float(loss.detach().item()) |
| running_count += 1 |
|
|
| do_step = ((step + 1) % grad_accum == 0) |
| if do_step: |
| |
| if self.args.max_grad_norm and self.args.max_grad_norm > 0: |
| if isinstance(self.model, FSDP): |
| FSDP.clip_grad_norm_(self.model, self.args.max_grad_norm) |
| else: |
| if self._scaler.is_enabled(): |
| self._scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) |
|
|
| |
| if self._scaler.is_enabled(): |
| self._scaler.step(self.optimizer) |
| self._scaler.update() |
| else: |
| self.optimizer.step() |
| if self.lr_scheduler is not None: |
| self.lr_scheduler.step() |
| self.optimizer.zero_grad(set_to_none=True) |
| self.state["global_step"] += 1 |
| optimizer_steps_this_epoch += 1 |
| log_window_optimizer_steps += 1 |
|
|
| |
|
|
| |
| should_log = (self.state["global_step"] % max(1, self.args.logging_steps) == 0) |
| peak_alloc_gb = 0.0 |
| peak_reserved_gb = 0.0 |
| if should_log: |
| peak_alloc_gb, peak_reserved_gb = self._max_cuda_peak_gb() |
| if self._is_main() and should_log: |
| avg = running_loss / max(running_count, 1) |
| lr = float(self.optimizer.param_groups[0]["lr"]) |
| log_epoch = self._epoch_for_logging() |
| elapsed = max(time.perf_counter() - log_window_start, 1e-9) |
| step_time_s = elapsed / max(log_window_optimizer_steps, 1) |
| seq_per_s = (seqs_per_optimizer_step * max(log_window_optimizer_steps, 1)) / elapsed |
| msg = f"epoch {log_epoch} step {self.state['global_step']}: loss={avg:.4f} lr={lr:.6g}" |
| if isinstance(out, dict): |
| pl = out.get("prefix_len") |
| pc = out.get("per_cap") |
| if pl is not None and pc is not None: |
| msg += f" prefix_mean={float(pl.detach().float().mean().item()):.1f} cap_mean={float(pc.detach().float().mean().item()):.1f}" |
| msg += ( |
| f" step_time_s={step_time_s:.3f} seq_per_s={seq_per_s:.1f}" |
| f" peak_mem_alloc_gb={peak_alloc_gb:.1f} peak_mem_reserved_gb={peak_reserved_gb:.1f}" |
| ) |
| logger.info(msg) |
| if hasattr(self, "_wandb"): |
| wandb.log({ |
| "train/loss": float(avg), |
| "train/lr": float(lr), |
| "perf/step_time_s": float(step_time_s), |
| "perf/seq_per_s": float(seq_per_s), |
| "system/peak_mem_alloc_gb": float(peak_alloc_gb), |
| "system/peak_mem_reserved_gb": float(peak_reserved_gb), |
| }, step=self.state["global_step"]) |
| running_loss = 0.0 |
| running_count = 0 |
| log_window_start = time.perf_counter() |
| log_window_optimizer_steps = 0 |
|
|
| |
| if progress is not None: |
| progress.update(1) |
|
|
| |
| if target_total_steps is not None and self.state["global_step"] >= target_total_steps: |
| metrics = {"train_loss": running_loss / max(running_count, 1)} |
| self._save_checkpoint("final_model") |
| self._barrier() |
| return metrics |
|
|
| |
| should_eval = ( |
| self.eval_dataloader is not None and |
| self.args.eval_interval > 0 and |
| (self.state["global_step"] % self.args.eval_interval == 0) |
| ) |
| if should_eval: |
| eval_metrics = self.evaluate() |
| if self._is_main(): |
| el = float(eval_metrics.get("eval_loss", 0.0)) |
| ea = eval_metrics.get("eval_codon_acc", None) |
| aa = eval_metrics.get("eval_aa_acc", None) |
| if ea is not None and aa is not None: |
| logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f} aa_acc={float(aa):.3f}") |
| elif ea is not None: |
| logger.info(f"eval: loss={el:.4f} codon_acc={float(ea):.3f}") |
| elif aa is not None: |
| logger.info(f"eval: loss={el:.4f} aa_acc={float(aa):.3f}") |
| else: |
| logger.info(f"eval: loss={el:.4f}") |
| if hasattr(self, "_wandb"): |
| log_payload = {"eval/loss": el} |
| if ea is not None: |
| log_payload["eval/codon_acc"] = float(ea) |
| if aa is not None: |
| log_payload["eval/aa_acc"] = float(aa) |
| wandb.log(log_payload, step=self.state["global_step"]) |
|
|
| |
| if self.args.save_steps > 0 and (self.state["global_step"] % self.args.save_steps == 0): |
| self._save_checkpoint(f"checkpoint-{self.state['global_step']}") |
|
|
| |
| if self.args.max_steps > 0 and self.state["global_step"] >= self.args.max_steps: |
| metrics = {"train_loss": running_loss / max(running_count, 1)} |
| self._save_checkpoint("final_model") |
| self._barrier() |
| if progress is not None: |
| progress.close() |
| return metrics |
|
|
| step += 1 |
|
|
| |
| if enforce_budget and (epoch_budget is not None) and (optimizer_steps_this_epoch >= epoch_budget): |
| break |
|
|
| |
| if self._is_main(): |
| try: |
| eb = int(epoch_budget) if epoch_budget is not None else -1 |
| except Exception: |
| eb = -1 |
| logger.info( |
| "epoch %s completed: optimizer_steps=%s%s", |
| self._epoch_for_logging(), |
| optimizer_steps_this_epoch, |
| (f" / budget {eb}" if eb > 0 else ""), |
| ) |
|
|
| if dist.is_available() and dist.is_initialized(): |
| gather_device = self.device if self.device.type == "cuda" else torch.device("cpu") |
| counts_tensor = torch.tensor( |
| [batches_this_epoch, optimizer_steps_this_epoch], |
| dtype=torch.long, |
| device=gather_device, |
| ) |
| gathered = [torch.zeros_like(counts_tensor) for _ in range(dist.get_world_size())] |
| dist.all_gather(gathered, counts_tensor) |
| batch_counts = [int(t[0].item()) for t in gathered] |
| step_counts = [int(t[1].item()) for t in gathered] |
| batch_gap = max(batch_counts) - min(batch_counts) |
| step_gap = max(step_counts) - min(step_counts) |
| if self._is_main() and (batch_gap > 0 or step_gap > 0): |
| logger.warning( |
| "Epoch %s imbalance detected across ranks: batches min=%s max=%s, optimizer steps min=%s max=%s", |
| epoch, |
| min(batch_counts), |
| max(batch_counts), |
| min(step_counts), |
| max(step_counts), |
| ) |
|
|
| |
| if not isinstance(ds, IterableDataset): |
| self._save_checkpoint(f"epoch-{epoch}") |
|
|
| metrics = {"train_loss": 0.0} |
| if progress is not None: |
| progress.close() |
| self._barrier() |
| return metrics |
|
|
| |
| def evaluate(self) -> Dict[str, float]: |
| if self.eval_dataloader is None: |
| return {"eval_loss": 0.0} |
|
|
| self.model.eval() |
|
|
| loss_sum = 0.0 |
| loss_tokens = 0 |
| codon_correct = 0 |
| codon_total = 0 |
| aa_correct = 0 |
| aa_total = 0 |
|
|
| tok = self.tokenizer |
| pad_id = int(tok.pad_token_id) if tok is not None else 0 |
| eos_id = int(tok.special_ids.eos) if tok is not None and hasattr(tok, "special_ids") else -999 |
| num_special = int(tok.num_special_tokens) if tok is not None else 0 |
| codon2aa = tok.codon2aa_char_map() if tok is not None and hasattr(tok, "codon2aa_char_map") else {} |
|
|
| is_streaming = isinstance(self.eval_dataloader.dataset, IterableDataset) |
| max_batches = int(self.args.eval_steps) if (is_streaming and self.args.eval_steps > 0) else None |
|
|
| with torch.no_grad(): |
| eval_iter = iter(self.eval_dataloader) |
| b_idx = 0 |
| while True: |
| batch, has_batch, local_has_batch = self._next_batch_sync(eval_iter) |
| if not has_batch: |
| if local_has_batch and self._is_main(): |
| logger.debug("eval dataloader: discarded tail batch to stay in sync across ranks") |
| break |
|
|
| if max_batches is not None and b_idx >= max_batches: |
| break |
|
|
| batch = self._prepare_batch(batch) |
|
|
| codon_ids = batch["codon_ids"].to(self.device) |
| input_ids = codon_ids[:, :-1] |
| labels = codon_ids[:, :-1] |
|
|
| labels = labels.clone() |
| labels[labels == pad_id] = -100 |
| labels[labels == eos_id] = -100 |
|
|
| cond = self._build_cond(batch) |
|
|
| use_cuda = (self.device.type == "cuda") |
| autocast_dtype = self._amp_dtype |
| if autocast_dtype is not None and use_cuda: |
| ctx = torch.amp.autocast(device_type="cuda", dtype=autocast_dtype) |
| else: |
| from contextlib import nullcontext |
| ctx = nullcontext() |
|
|
| with ctx: |
| out = self.model(codon_ids=input_ids, cond=cond, labels=labels, return_dict=True) |
|
|
| loss = out.get("loss") |
| per_cap = out.get("per_cap") |
| logits = out.get("logits") |
|
|
| tokens_in_batch = 0 |
| if per_cap is not None: |
| tokens_in_batch = int(torch.clamp(per_cap.detach(), min=0).sum().item()) |
| loss_tokens += tokens_in_batch |
|
|
| if loss is not None and tokens_in_batch > 0: |
| loss_sum += float(loss.detach().item()) * tokens_in_batch |
|
|
| if logits is None or logits.size(1) == 0 or per_cap is None: |
| continue |
|
|
| max_cap = logits.size(1) |
| batch_size = logits.size(0) |
|
|
| labels_aligned = torch.full((batch_size, max_cap), -100, dtype=labels.dtype, device=labels.device) |
| common_cols = min(labels.size(1), max_cap) |
| if common_cols > 0: |
| labels_aligned[:, :common_cols] = labels[:, :common_cols] |
|
|
| per_cap_int = torch.clamp(per_cap.to(dtype=torch.long), min=0, max=max_cap) |
| for row in range(batch_size): |
| cap = int(per_cap_int[row].item()) |
| if cap < max_cap: |
| labels_aligned[row, cap:] = -100 |
|
|
| supervised = labels_aligned != -100 |
| if num_special > 0: |
| supervised = supervised & (labels_aligned >= num_special) |
| if not supervised.any(): |
| continue |
|
|
| preds = logits.argmax(dim=-1) |
| codon_correct += int((preds[supervised] == labels_aligned[supervised]).sum().item()) |
| codon_total += int(supervised.sum().item()) |
|
|
| if codon2aa and isinstance(batch, dict) and "protein_seqs" in batch: |
| prot_list = batch.get("protein_seqs", []) |
| for row in range(batch_size): |
| cap = int(per_cap_int[row].item()) |
| if cap <= 0: |
| continue |
| mask_row = supervised[row, :cap] |
| if not mask_row.any(): |
| continue |
| preds_row = preds[row, :cap][mask_row] |
| prot = prot_list[row] if (isinstance(prot_list, list) and row < len(prot_list)) else "" |
| if not prot: |
| continue |
| seq_len = min(len(prot), preds_row.size(0)) |
| if seq_len <= 0: |
| continue |
| pred_aa = ''.join(codon2aa.get(int(t.item()), 'X') for t in preds_row[:seq_len]) |
| truth_aa = prot[:seq_len] |
| aa_correct += sum(1 for i in range(seq_len) if pred_aa[i] == truth_aa[i]) |
| aa_total += seq_len |
|
|
| b_idx += 1 |
|
|
| totals = torch.tensor( |
| [loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total], |
| dtype=torch.float64, |
| device=self.device, |
| ) |
| if dist.is_available() and dist.is_initialized(): |
| |
| |
| |
| self._barrier() |
| dist.all_reduce(totals, op=dist.ReduceOp.SUM) |
|
|
| loss_sum, loss_tokens, codon_correct, codon_total, aa_correct, aa_total = totals.tolist() |
|
|
| self.model.train() |
|
|
| metrics: Dict[str, float] = {"eval_loss": float(loss_sum) / loss_tokens if loss_tokens > 0 else 0.0} |
| if codon_total > 0: |
| metrics["eval_codon_acc"] = float(codon_correct) / codon_total |
| if aa_total > 0: |
| metrics["eval_aa_acc"] = float(aa_correct) / aa_total |
|
|
| self._barrier() |
| return metrics |
|
|
| |
| def _setup_fsdp(self): |
| |
| device = self.device |
| if dist.is_available() and not dist.is_initialized(): |
| backend = "nccl" if device.type == "cuda" else "gloo" |
| sig = inspect.signature(dist.init_process_group) |
| if "timeout" in sig.parameters: |
| dist.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=30)) |
| else: |
| dist.init_process_group(backend=backend, init_method="env://") |
| mp = MixedPrecision( |
| param_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32), |
| reduce_dtype=(torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32), |
| buffer_dtype=torch.float32, |
| ) |
| logger.info(f"FSDP enabled: sharding={self.args.fsdp} mp_param={mp.param_dtype} mp_reduce={mp.reduce_dtype}") |
| |
| base = self._unwrap(self.model) |
| ignored = [] |
| if hasattr(base, "esm") and isinstance(base.esm, nn.Module): |
| ignored.append(base.esm) |
|
|
| self.model = FSDP( |
| self.model, |
| device_id=(self.device if device.type == "cuda" else None), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mp, |
| ignored_modules=(ignored if ignored else None), |
| sync_module_states=True, |
| ) |
|
|
| |
| if ignored: |
| ignored[0].to(device) |
|
|
| def _unwrap(self, module): |
| return getattr(module, "module", module) |
|
|
| def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
| |
| if self.species_store is not None and "species_ids" in batch: |
| sids = batch["species_ids"] |
| if torch.is_tensor(sids): |
| sids = sids.detach().cpu().tolist() |
| result = self.species_store.batch_get(sids) |
| if isinstance(result, tuple): |
| sp_tok, _ = result |
| batch["species_tok_emb"] = sp_tok.to(self.device, non_blocking=True) |
| else: |
| sp = result |
| batch["species_emb"] = sp.to(self.device, non_blocking=True) |
|
|
| |
| if "codon_ids" in batch and hasattr(batch["codon_ids"], "to"): |
| batch["codon_ids"] = batch["codon_ids"].to(self.device, non_blocking=True) |
|
|
| return batch |
|
|
| def _build_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
| cond: Dict[str, Any] = {"control_mode": "fixed"} |
| if "species_tok_emb" in batch: |
| cond["species_tok_emb_src"] = batch["species_tok_emb"] |
| cond["species_tok_emb_tgt"] = batch["species_tok_emb"] |
| elif "species_emb" in batch: |
| cond["species_emb_src"] = batch["species_emb"] |
| cond["species_emb_tgt"] = batch["species_emb"] |
| if "protein_seqs" in batch: |
| cond["protein_seqs"] = batch["protein_seqs"] |
| return cond |
|
|
| def _next_batch_sync(self, iterator): |
| """Fetch next batch and drop out early if any rank exhausts its loader.""" |
| try: |
| batch = next(iterator) |
| local_has_batch = True |
| except StopIteration: |
| batch = None |
| local_has_batch = False |
|
|
| distributed = dist.is_available() and dist.is_initialized() |
| has_batch = local_has_batch |
|
|
| if distributed: |
| flag_device = self.device if self.device.type == "cuda" else torch.device("cpu") |
| flag = torch.tensor([1 if local_has_batch else 0], device=flag_device) |
| dist.all_reduce(flag, op=dist.ReduceOp.MIN) |
| has_batch = bool(flag.item()) |
|
|
| if not has_batch: |
| return None, False, local_has_batch |
|
|
| return batch, True, local_has_batch |
|
|
| def _is_main(self) -> bool: |
| return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0 |
|
|
| def _barrier(self): |
| if dist.is_available() and dist.is_initialized(): |
| |
| if self.device.type == "cuda": |
| sig = inspect.signature(dist.barrier) |
| if "device_ids" in sig.parameters: |
| dist.barrier(device_ids=[self.local_rank]) |
| return |
| dist.barrier() |
|
|
| def _max_cuda_peak_gb(self) -> Tuple[float, float]: |
| if self.device.type != "cuda" or not torch.cuda.is_available(): |
| return 0.0, 0.0 |
| vals = torch.tensor( |
| [ |
| float(torch.cuda.max_memory_allocated(self.device)), |
| float(torch.cuda.max_memory_reserved(self.device)), |
| ], |
| dtype=torch.float64, |
| device=self.device, |
| ) |
| if dist.is_available() and dist.is_initialized(): |
| dist.all_reduce(vals, op=dist.ReduceOp.MAX) |
| scale = float(1024 ** 3) |
| return float(vals[0].item() / scale), float(vals[1].item() / scale) |
|
|
| |
|
|
| def _epoch_for_logging(self) -> int: |
| steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0) |
| if steps_per_epoch > 0: |
| est = self.state.get("global_step", 0) // steps_per_epoch |
| if self.args.num_train_epochs > 0: |
| max_epoch = max(int(self.args.num_train_epochs) - 1, 0) |
| if est > max_epoch: |
| return max_epoch |
| return int(est) |
| return int(self.state.get("epoch", 0)) |
|
|
| |
| def _save_checkpoint(self, name: str): |
| self.state["epoch"] = int(self._epoch_for_logging()) |
| |
| out_dir = os.path.join(self.args.output_dir, name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| optim_state = None |
| if isinstance(self.model, FSDP): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| with FSDP.state_dict_type( |
| self.model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(rank0_only=True, offload_to_cpu=True), |
| FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), |
| ): |
| state = self.model.state_dict() |
| |
| |
| if self.optimizer is not None: |
| optim_state = FSDP.optim_state_dict(self.model, self.optimizer) |
| else: |
| state = self._unwrap(self.model).state_dict() |
| if self.optimizer is not None: |
| optim_state = self.optimizer.state_dict() |
|
|
| |
| per_rank_positions: Optional[List[int]] = None |
| p = getattr(self.args, "data_cursor_path", None) |
| if p: |
| ds = getattr(self.train_dataloader, "dataset", None) |
| if hasattr(ds, "get_stream_position"): |
| local_pos = int(ds.get_stream_position()) |
| if dist.is_available() and dist.is_initialized(): |
| gather_device = self.device if self.device.type == "cuda" else torch.device("cpu") |
| tensor = torch.tensor([local_pos], dtype=torch.long, device=gather_device) |
| gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] |
| dist.all_gather(gathered, tensor) |
| per_rank_positions = [int(t.item()) for t in gathered] |
| else: |
| per_rank_positions = [local_pos] |
|
|
| if not self._is_main(): |
| |
| self._barrier() |
| return |
|
|
| |
| save_file(state, os.path.join(out_dir, "model.safetensors")) |
|
|
| |
| if optim_state is not None: |
| torch.save(optim_state, os.path.join(out_dir, "optimizer.pt")) |
| if self.lr_scheduler is not None: |
| torch.save(self.lr_scheduler.state_dict(), os.path.join(out_dir, "scheduler.pt")) |
|
|
| |
| base = self._unwrap(self.model) |
| |
| mlp_ratio = 4.0 |
| try: |
| if hasattr(base, "blocks") and len(getattr(base, "blocks", [])) > 0: |
| w1 = base.blocks[0].ffn.w1.weight |
| H = int(getattr(base, "hidden_size", w1.shape[1])) |
| if H > 0: |
| mlp_ratio = float(w1.shape[0]) / float(H) |
| except Exception: |
| pass |
|
|
| trainer_cfg = { |
| |
| "max_length": int(self.args.max_length), |
| "max_species_prefix": int(getattr(base, "max_species_prefix", 0)), |
| "max_protein_prefix": int(getattr(base, "max_protein_prefix", 0)), |
|
|
| |
| "hidden_size": int(getattr(base, "hidden_size", 0)), |
| "num_hidden_layers": int(getattr(base, "num_layers", 0)), |
| "num_attention_heads": int(getattr(base, "num_heads", 0)), |
| "mlp_ratio": float(mlp_ratio), |
|
|
| |
| "prepend_species": bool(getattr(base, "prepend_species", True)), |
| "prepend_protein": bool(getattr(base, "prepend_protein", False)), |
| "species_embedding_dim": int(getattr(base, "species_embedding_dim", 1024)), |
|
|
| |
| "esm_model_name": str(getattr(self.args, "esm_model_name", "")), |
| "esm_device": str(getattr(self.args, "esm_device", "cuda")), |
| "esm_dtype": str(getattr(self.args, "esm_dtype", "fp32")).lower(), |
|
|
| |
| |
| |
| "attn_impl": str(getattr(base, "attn_impl", "gqa")), |
| "num_kv_groups": int(getattr(base, "num_kv_groups", 0)), |
| } |
| with open(os.path.join(out_dir, "trainer_config.json"), "w") as f: |
| json.dump(trainer_cfg, f, indent=2) |
| with open(os.path.join(out_dir, "trainer_state.json"), "w") as f: |
| json.dump({"epoch": self.state["epoch"], "global_step": self.state["global_step"]}, f, indent=2) |
|
|
| if p and per_rank_positions is not None: |
| payload = { |
| "skip_samples": int(sum(per_rank_positions)), |
| "per_rank": per_rank_positions, |
| "world_size": len(per_rank_positions), |
| } |
| os.makedirs(os.path.dirname(os.path.abspath(p)), exist_ok=True) |
| with open(p, "w") as f: |
| json.dump(payload, f) |
|
|
| |
| try: |
| if self.tokenizer is not None and hasattr(self.tokenizer, "save_vocabulary"): |
| self.tokenizer.save_vocabulary(out_dir) |
| except Exception as e: |
| logger.warning(f"Failed to save vocabulary to {out_dir}: {e}") |
|
|
| self._prune_checkpoints(self.args.output_dir, self.args.save_total_limit) |
| logger.info(f"Saved checkpoint → {out_dir}") |
|
|
| |
| self._barrier() |
|
|
| def _resume_from(self, ckpt_dir: str): |
| st_path = os.path.join(ckpt_dir, "model.safetensors") |
| if not os.path.exists(st_path): |
| raise FileNotFoundError(f"No model.safetensors in {ckpt_dir}") |
| state = load_file(st_path) |
|
|
| if isinstance(self.model, FSDP): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| with FSDP.state_dict_type( |
| self.model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(rank0_only=False, offload_to_cpu=True), |
| ): |
| self.model.load_state_dict(state, strict=False) |
| else: |
| self._unwrap(self.model).load_state_dict(state, strict=False) |
|
|
|
|
| scheduler_restored = False |
|
|
| opt_path = os.path.join(ckpt_dir, "optimizer.pt") |
| if os.path.exists(opt_path): |
| if self.optimizer is None: |
| self._create_optimizer_and_scheduler() |
| if not self.args.override_lr_on_resume: |
| loaded = torch.load(opt_path, map_location="cpu") |
| |
| |
| if isinstance(self.model, FSDP): |
| try: |
| loaded = FSDP.optim_state_dict_to_load(self.model, self.optimizer, loaded) |
| except Exception as e: |
| msg = ( |
| "Failed to convert FSDP optimizer state dict for loading. " |
| "This checkpoint likely contains an incomplete (rank0-only sharded) optimizer.pt from an older version. " |
| "Full optimizer resume is not possible from this checkpoint.\n" |
| f"Underlying error: {e}\n" |
| "Options:\n" |
| " 1) Start a fresh run (new --output_dir), or\n" |
| " 2) Re-run with --override_lr_on_resume to skip optimizer restore (not a full resume)." |
| ) |
| if self._is_main(): |
| logger.error(msg) |
| raise RuntimeError(msg) from e |
| self.optimizer.load_state_dict(loaded) |
|
|
| sch_path = os.path.join(ckpt_dir, "scheduler.pt") |
| if os.path.exists(sch_path): |
| if self.lr_scheduler is None: |
| self._create_optimizer_and_scheduler() |
| if self.lr_scheduler is not None and not self.args.override_lr_on_resume: |
| self.lr_scheduler.load_state_dict(torch.load(sch_path, map_location="cpu")) |
| scheduler_restored = True |
|
|
| ts_path = os.path.join(ckpt_dir, "trainer_state.json") |
| if os.path.exists(ts_path): |
| with open(ts_path, "r") as f: |
| ts = json.load(f) |
| self.state["epoch"] = int(ts.get("epoch", 0)) |
| self.state["global_step"] = int(ts.get("global_step", 0)) |
|
|
| steps_per_epoch = int(getattr(self.args, "steps_per_epoch", 0) or 0) |
| if steps_per_epoch > 0: |
| inferred_epoch = self.state.get("global_step", 0) // steps_per_epoch |
| num_epochs = max(int(self.args.num_train_epochs), 1) |
| inferred_epoch = min(inferred_epoch, num_epochs - 1) |
| if inferred_epoch != self.state.get("epoch"): |
| if self._is_main(): |
| logger.info( |
| "Adjusting epoch from %s to %s based on global_step %s and steps_per_epoch %s", |
| self.state.get("epoch"), |
| inferred_epoch, |
| self.state.get("global_step"), |
| steps_per_epoch, |
| ) |
| self.state["epoch"] = int(inferred_epoch) |
|
|
| |
| |
| if self.lr_scheduler is not None and not scheduler_restored: |
| target_step = int(self.state.get("global_step", 0)) |
| if target_step > 0: |
| try: |
| |
| self.lr_scheduler.step(target_step) |
| except TypeError: |
| |
| for _ in range(target_step): |
| self.lr_scheduler.step() |
| |
| try: |
| last_lrs = self.lr_scheduler.get_last_lr() |
| except Exception: |
| last_lrs = [group.get("lr") for group in self.optimizer.param_groups] |
| if last_lrs: |
| for group, lr in zip(self.optimizer.param_groups, last_lrs): |
| group["lr"] = float(lr) |
|
|
| logger.info(f"Resumed from {ckpt_dir}") |
|
|
| def _checkpoint_step(self, path: str) -> Optional[int]: |
| m = re.fullmatch(r"checkpoint-(\d+)", os.path.basename(path)) |
| if not m: |
| return None |
| return int(m.group(1)) |
|
|
| def _prune_checkpoints(self, root: str, keep: int): |
| if not os.path.isdir(root): |
| return |
|
|
| try: |
| subdirs = [ |
| os.path.join(root, d) |
| for d in os.listdir(root) |
| if os.path.isdir(os.path.join(root, d)) |
| ] |
| except FileNotFoundError: |
| return |
|
|
| step_dirs: list[tuple[int, str]] = [] |
| for path in subdirs: |
| step = self._checkpoint_step(path) |
| if step is not None: |
| step_dirs.append((step, path)) |
|
|
| if not step_dirs: |
| return |
|
|
| step_dirs.sort(key=lambda item: item[0]) |
| latest_step = step_dirs[-1][0] |
|
|
| recent_window = max(0, int(getattr(self.args, "ckpt_recent_window_steps", 0) or 0)) |
| recent_interval = max(0, int(getattr(self.args, "ckpt_recent_interval", 0) or 0)) |
| archive_interval = max(0, int(getattr(self.args, "ckpt_archive_interval", 0) or 0)) |
|
|
| keep_paths: set[str] = set() |
| if recent_window > 0 and (recent_interval > 0 or archive_interval > 0): |
| if recent_interval <= 0: |
| recent_interval = max(1, int(getattr(self.args, "save_steps", 1) or 1)) |
|
|
| for step, path in step_dirs: |
| age = latest_step - step |
| if age <= recent_window: |
| interval = recent_interval |
| else: |
| interval = archive_interval |
| if interval > 0 and (step % interval == 0): |
| keep_paths.add(path) |
|
|
| if not keep_paths: |
| |
| if keep <= 0: |
| return |
| keep_paths = {path for _, path in step_dirs[-keep:]} |
| else: |
| |
| keep_paths.add(step_dirs[-1][1]) |
| if keep > 0: |
| kept = [(step, path) for step, path in step_dirs if path in keep_paths] |
| if len(kept) > keep: |
| trim = len(kept) - keep |
| for _, path in kept[:trim]: |
| keep_paths.discard(path) |
|
|
| removed = [] |
| for _, path in step_dirs: |
| if path in keep_paths: |
| continue |
| shutil.rmtree(path, ignore_errors=True) |
| removed.append(os.path.basename(path)) |
|
|
| if removed and self._is_main(): |
| logger.info( |
| "Pruned %s checkpoints (latest_step=%s, recent_window=%s, recent_interval=%s, archive_interval=%s)", |
| len(removed), |
| latest_step, |
| recent_window, |
| recent_interval, |
| archive_interval, |
| ) |
|
|