""" Fine-tune NVIDIA NV-CodonFM-Encodon-80M-v1 for mRNA Stability Prediction. Architecture: Custom BERT-style encoder with Rotary Position Embeddings (RoPE) Dataset: mogam-ai/CDS-BART-mRNA-stability (iCodon - mRNA half-life from multiple species) + GleghornLab/mrna_stability_other (additional stability data) Task: Regression (predict mRNA stability / half-life score) Recipe based on: - Helix-mRNA (arxiv:2502.13785): unfreeze last 2 layers, 5-30 epochs, AdamW - BEACON (arxiv:2406.10391): lr sweep 1e-5 to 5e-3, warmup 50 steps, MSE loss - CodonBERT: codon-level tokenization, CDS regression """ import os import math import json import re import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from safetensors.torch import load_file from huggingface_hub import hf_hub_download, HfApi from datasets import load_dataset, concatenate_datasets from scipy.stats import spearmanr, pearsonr import numpy as np import trackio # ============================================================ # 1. CODON TOKENIZER # ============================================================ # Build codon vocabulary: 64 sense codons + special tokens RNA_BASES = ['A', 'U', 'G', 'C'] ALL_CODONS = [] for b1 in RNA_BASES: for b2 in RNA_BASES: for b3 in RNA_BASES: ALL_CODONS.append(b1 + b2 + b3) # 64 codons # Special tokens (inferred from vocab_size=69 = 64 codons + 5 special) SPECIAL_TOKENS = { '[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[MASK]': 3, # pad_token_id=3 in config - but let's check '[SEP]': 4, } # Actually, from the config: pad_token_id=3 # Let's build: 0=PAD, 1=UNK, 2=CLS, 3=SEP, 4=MASK, then 5..68 = 64 codons # OR: 0=CLS, 1=SEP, 2=MASK, 3=PAD, 4=UNK, 5..68 = 64 codons (pad=3 matches) # The config says pad_token_id=3, so token id 3 = PAD SPECIAL_TOKENS = { '[CLS]': 0, '[SEP]': 1, '[MASK]': 2, '[PAD]': 3, '[UNK]': 4, } CODON_TO_ID = {} for i, codon in enumerate(ALL_CODONS): CODON_TO_ID[codon] = i + 5 # offset by 5 special tokens ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()} ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()}) PAD_TOKEN_ID = SPECIAL_TOKENS['[PAD]'] CLS_TOKEN_ID = SPECIAL_TOKENS['[CLS]'] SEP_TOKEN_ID = SPECIAL_TOKENS['[SEP]'] MASK_TOKEN_ID = SPECIAL_TOKENS['[MASK]'] UNK_TOKEN_ID = SPECIAL_TOKENS['[UNK]'] # Verify vocab size assert len(CODON_TO_ID) + len(SPECIAL_TOKENS) == 69, f"Expected 69, got {len(CODON_TO_ID) + len(SPECIAL_TOKENS)}" def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict: """ Tokenize an mRNA sequence into codon IDs. Sequence should be RNA (A,U,G,C) divisible by 3. Returns: input_ids and attention_mask """ # Convert T to U if DNA seq = seq.upper().replace('T', 'U') # Remove any whitespace seq = seq.strip() # Split into codons (triplets) codons = [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)] # Convert to token IDs: [CLS] + codons + [SEP] token_ids = [CLS_TOKEN_ID] for codon in codons[:max_length - 2]: # reserve space for CLS and SEP token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID)) token_ids.append(SEP_TOKEN_ID) # Create attention mask attention_mask = [1] * len(token_ids) return { 'input_ids': token_ids, 'attention_mask': attention_mask, } def pad_batch(batch, max_len, pad_id=PAD_TOKEN_ID): """Pad a batch of tokenized sequences to max_len.""" padded_ids = [] padded_masks = [] for item in batch: ids = item['input_ids'] mask = item['attention_mask'] pad_len = max_len - len(ids) padded_ids.append(ids + [pad_id] * pad_len) padded_masks.append(mask + [0] * pad_len) return { 'input_ids': torch.tensor(padded_ids, dtype=torch.long), 'attention_mask': torch.tensor(padded_masks, dtype=torch.long), } # ============================================================ # 2. MODEL ARCHITECTURE (matched to safetensors weight keys) # ============================================================ class RotaryEmbedding(nn.Module): def __init__(self, dim, theta=10000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x, seq_len): t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos(), emb.sin() def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim] sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class CodonFMAttention(nn.Module): def __init__(self, hidden_size, num_heads, rotary_theta=10000.0): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.query = nn.Linear(hidden_size, hidden_size) self.key = nn.Linear(hidden_size, hidden_size) self.value = nn.Linear(hidden_size, hidden_size) self.rotary_emb = RotaryEmbedding(self.head_dim, theta=rotary_theta) def forward(self, hidden_states, attention_mask=None): B, L, H = hidden_states.shape q = self.query(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) k = self.key(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) v = self.value(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(q, L) q, k = apply_rotary_pos_emb(q, k, cos, sin) scale = math.sqrt(self.head_dim) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale if attention_mask is not None: # attention_mask: [B, L] -> [B, 1, 1, L] attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf')) attn_weights = F.softmax(attn_weights, dim=-1) context = torch.matmul(attn_weights, v) context = context.transpose(1, 2).contiguous().view(B, L, H) return context class CodonFMTransformerLayer(nn.Module): """ Matches weight keys: - pre_attn_layer_norm, attention (Q/K/V/rotary), post_attn_dense, post_attn_layer_norm - pre_ffn_layer_norm, intermediate_dense, post_ffn_layer_norm, output_dense """ def __init__(self, hidden_size, num_heads, intermediate_size, hidden_act='gelu', layer_norm_eps=1e-12, dropout=0.1, rotary_theta=10000.0): super().__init__() # Pre-attention layer norm self.pre_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) # Attention self.attention = CodonFMAttention(hidden_size, num_heads, rotary_theta) # Post-attention projection + layer norm self.post_attn_dense = nn.Linear(hidden_size, hidden_size) self.post_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) # FFN self.pre_ffn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) self.post_ffn_layer_norm = nn.LayerNorm(intermediate_size, eps=layer_norm_eps) self.output_dense = nn.Linear(intermediate_size, hidden_size) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() if hidden_act == 'gelu' else nn.ReLU() def forward(self, hidden_states, attention_mask=None): # Pre-norm attention residual = hidden_states hidden_states = self.pre_attn_layer_norm(hidden_states) attn_output = self.attention(hidden_states, attention_mask) attn_output = self.post_attn_dense(attn_output) attn_output = self.dropout(attn_output) hidden_states = residual + attn_output hidden_states = self.post_attn_layer_norm(hidden_states) # Pre-norm FFN residual = hidden_states hidden_states = self.pre_ffn_layer_norm(hidden_states) hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.post_ffn_layer_norm(hidden_states) hidden_states = self.output_dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states return hidden_states class CodonFMEncoder(nn.Module): """CodonFM Encoder that matches the safetensors checkpoint structure.""" def __init__(self, config): super().__init__() # Embeddings self.word_embeddings = nn.Embedding( config['vocab_size'], config['hidden_size'], padding_idx=config['pad_token_id'] ) self.post_ln = nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps']) # Transformer layers self.layers = nn.ModuleList([ CodonFMTransformerLayer( hidden_size=config['hidden_size'], num_heads=config['num_attention_heads'], intermediate_size=config['intermediate_size'], hidden_act=config['hidden_act'], layer_norm_eps=config['layer_norm_eps'], dropout=config['hidden_dropout_prob'], rotary_theta=config['rotary_theta'], ) for _ in range(config['num_hidden_layers']) ]) # MLM head (cls) self.cls = nn.Sequential( nn.Linear(config['hidden_size'], config['hidden_size']), # cls.0 nn.GELU(), # cls.1 (activation, no weights) nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps']), # cls.2 nn.Linear(config['hidden_size'], config['vocab_size']), # cls.3 ) def forward(self, input_ids, attention_mask=None): # Embeddings x = self.word_embeddings(input_ids) x = self.post_ln(x) # Transformer layers for layer in self.layers: x = layer(x, attention_mask) return x # [B, L, hidden_size] class CodonFMForStabilityPrediction(nn.Module): """CodonFM encoder + regression head for mRNA stability prediction.""" def __init__(self, config): super().__init__() self.config = config self.encoder = CodonFMEncoder(config) # Regression head hidden_size = config['hidden_size'] dropout = config.get('classifier_dropout', 0.1) self.regression_head = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Dropout(dropout), nn.Linear(hidden_size, 1), ) def forward(self, input_ids, attention_mask=None, labels=None): hidden_states = self.encoder(input_ids, attention_mask) # [B, L, H] # Mean pooling over non-pad tokens if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() # [B, L, 1] pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) else: pooled = hidden_states.mean(1) logits = self.regression_head(pooled).squeeze(-1) # [B] loss = None if labels is not None: loss = F.mse_loss(logits, labels.float()) return {'loss': loss, 'logits': logits} def load_pretrained_encoder(self, checkpoint_path): """Load pretrained CodonFM weights into the encoder.""" state_dict = load_file(checkpoint_path, device='cpu') # Map checkpoint keys to our model keys # Checkpoint: model.embeddings.word_embeddings.weight -> encoder.word_embeddings.weight # Checkpoint: model.embeddings.post_ln.weight -> encoder.post_ln.weight # Checkpoint: model.layers.X.* -> encoder.layers.X.* # Checkpoint: model.cls.* -> encoder.cls.* (MLM head, will be replaced by regression head) new_state_dict = {} for key, value in state_dict.items(): # Strip 'model.' prefix if key.startswith('model.'): new_key = key[len('model.'):] else: new_key = key # Map embeddings: strip only the leading 'embeddings.' prefix if new_key.startswith('embeddings.'): new_key = new_key[len('embeddings.'):] new_state_dict['encoder.' + new_key] = value # Load with strict=False (regression_head is new, cls head won't match) missing, unexpected = self.load_state_dict(new_state_dict, strict=False) print(f"Loaded pretrained encoder weights.") print(f" Missing (new regression head params): {[k for k in missing if 'regression' in k]}") print(f" Missing (other): {[k for k in missing if 'regression' not in k]}") print(f" Unexpected (MLM head etc): {unexpected[:5]}...") return missing, unexpected # ============================================================ # 3. DATASET # ============================================================ class mRNAStabilityDataset(Dataset): """Dataset for mRNA stability regression.""" def __init__(self, sequences, labels, max_length=2046): self.sequences = sequences self.labels = labels self.max_length = max_length def __len__(self): return len(self.sequences) def __getitem__(self, idx): seq = self.sequences[idx] label = self.labels[idx] tokens = tokenize_mRNA(seq, max_length=self.max_length) return { 'input_ids': tokens['input_ids'], 'attention_mask': tokens['attention_mask'], 'label': float(label), } def collate_fn(batch): """Custom collate function to pad sequences.""" max_len = max(len(item['input_ids']) for item in batch) padded_ids = [] padded_masks = [] labels = [] for item in batch: ids = item['input_ids'] mask = item['attention_mask'] pad_len = max_len - len(ids) padded_ids.append(ids + [PAD_TOKEN_ID] * pad_len) padded_masks.append(mask + [0] * pad_len) labels.append(item['label']) return { 'input_ids': torch.tensor(padded_ids, dtype=torch.long), 'attention_mask': torch.tensor(padded_masks, dtype=torch.long), 'labels': torch.tensor(labels, dtype=torch.float32), } # ============================================================ # 4. TRAINING LOOP # ============================================================ def evaluate(model, dataloader, device): """Evaluate model on dataloader, return metrics.""" model.eval() all_preds = [] all_labels = [] total_loss = 0 n_batches = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids, attention_mask, labels) total_loss += outputs['loss'].item() n_batches += 1 all_preds.extend(outputs['logits'].cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_preds = np.array(all_preds) all_labels = np.array(all_labels) spearman_rho, _ = spearmanr(all_preds, all_labels) pearson_r, _ = pearsonr(all_preds, all_labels) mse = np.mean((all_preds - all_labels) ** 2) avg_loss = total_loss / max(n_batches, 1) return { 'loss': avg_loss, 'spearman': spearman_rho, 'pearson': pearson_r, 'mse': mse, } def train(): # ---- Config ---- HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "Imranyai/CodonFM-80M-mRNA-stability") LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-5")) NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "20")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16")) GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2")) MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "1024")) # codons (most CDS < 1024 codons) WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", "100")) WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", "0.01")) FREEZE_LAYERS = int(os.environ.get("FREEZE_LAYERS", "4")) # Freeze first 4 layers, unfreeze last 2 USE_BOTH_DATASETS = os.environ.get("USE_BOTH_DATASETS", "true").lower() == "true" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print(f"Config: lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, batch={BATCH_SIZE}, " f"grad_accum={GRAD_ACCUM}, max_len={MAX_LENGTH}, freeze_layers={FREEZE_LAYERS}") # ---- Init tracking ---- trackio.init( project="codonfm-mrna-stability", name=f"lr{LEARNING_RATE}_ep{NUM_EPOCHS}_freeze{FREEZE_LAYERS}", ) # ---- Load model config ---- config_path = hf_hub_download( repo_id="nvidia/NV-CodonFM-Encodon-80M-v1", filename="config.json" ) with open(config_path) as f: config = json.load(f) print(f"Model config: {config}") # ---- Build model ---- model = CodonFMForStabilityPrediction(config) # Load pretrained weights ckpt_path = hf_hub_download( repo_id="nvidia/NV-CodonFM-Encodon-80M-v1", filename="NV-CodonFM-Encodon-80M-v1.safetensors" ) model.load_pretrained_encoder(ckpt_path) # Freeze early layers (keep last N layers trainable) if FREEZE_LAYERS > 0: # Freeze embeddings for param in model.encoder.word_embeddings.parameters(): param.requires_grad = False for param in model.encoder.post_ln.parameters(): param.requires_grad = False # Freeze first FREEZE_LAYERS transformer layers for i in range(FREEZE_LAYERS): for param in model.encoder.layers[i].parameters(): param.requires_grad = False # Freeze MLM head (not used for regression) for param in model.encoder.cls.parameters(): param.requires_grad = False trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"Trainable params: {trainable_params:,} / {total_params:,} " f"({100*trainable_params/total_params:.1f}%)") model = model.to(device) # ---- Load datasets ---- print("\nLoading datasets...") # Primary dataset: mogam-ai/CDS-BART-mRNA-stability (iCodon-based) ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability") print(f"mogam-ai/CDS-BART-mRNA-stability: train={len(ds1['train'])}, val={len(ds1['val'])}, test={len(ds1['test'])}") if USE_BOTH_DATASETS: # Secondary dataset: GleghornLab/mrna_stability_other ds2 = load_dataset("GleghornLab/mrna_stability_other") print(f"GleghornLab/mrna_stability_other: train={len(ds2['train'])}, valid={len(ds2['valid'])}, test={len(ds2['test'])}") # Combine: use 'rna' column from ds2 (the actual RNA sequence) # ds1 has 'seq' (RNA) and 'y' (label) # ds2 has 'seqs' (protein-encoded?), 'rna' (actual RNA), 'labels' # Extract sequences and labels train_seqs = list(ds1['train']['seq']) + list(ds2['train']['rna']) train_labels = list(ds1['train']['y']) + list(ds2['train']['labels']) val_seqs = list(ds1['val']['seq']) + list(ds2['valid']['rna']) val_labels = list(ds1['val']['y']) + list(ds2['valid']['labels']) test_seqs = list(ds1['test']['seq']) + list(ds2['test']['rna']) test_labels = list(ds1['test']['y']) + list(ds2['test']['labels']) else: train_seqs = list(ds1['train']['seq']) train_labels = list(ds1['train']['y']) val_seqs = list(ds1['val']['seq']) val_labels = list(ds1['val']['y']) test_seqs = list(ds1['test']['seq']) test_labels = list(ds1['test']['y']) print(f"\nCombined dataset sizes: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}") # Filter out sequences that are too short or have issues def filter_valid(seqs, labels): valid_seqs, valid_labels = [], [] for seq, label in zip(seqs, labels): if seq is not None and len(seq) >= 9 and not np.isnan(label): # min 3 codons valid_seqs.append(seq) valid_labels.append(label) return valid_seqs, valid_labels train_seqs, train_labels = filter_valid(train_seqs, train_labels) val_seqs, val_labels = filter_valid(val_seqs, val_labels) test_seqs, test_labels = filter_valid(test_seqs, test_labels) print(f"After filtering: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}") # Dataset stats train_labels_arr = np.array(train_labels) print(f"Label stats (train): mean={train_labels_arr.mean():.3f}, std={train_labels_arr.std():.3f}, " f"min={train_labels_arr.min():.3f}, max={train_labels_arr.max():.3f}") # Create datasets train_dataset = mRNAStabilityDataset(train_seqs, train_labels, max_length=MAX_LENGTH) val_dataset = mRNAStabilityDataset(val_seqs, val_labels, max_length=MAX_LENGTH) test_dataset = mRNAStabilityDataset(test_seqs, test_labels, max_length=MAX_LENGTH) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True) # ---- Optimizer & Scheduler ---- # Differential learning rates: backbone slower, head faster backbone_params = [] head_params = [] for name, param in model.named_parameters(): if param.requires_grad: if 'regression_head' in name: head_params.append(param) else: backbone_params.append(param) optimizer = torch.optim.AdamW([ {'params': backbone_params, 'lr': LEARNING_RATE}, {'params': head_params, 'lr': LEARNING_RATE * 10}, # 10x for new head ], weight_decay=WEIGHT_DECAY) total_steps = len(train_loader) * NUM_EPOCHS // GRAD_ACCUM def get_lr_lambda(warmup_steps, total_steps): def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return lr_lambda scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=get_lr_lambda(WARMUP_STEPS, total_steps) ) # Enable mixed precision scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None use_amp = device.type == 'cuda' # ---- Training ---- print(f"\n{'='*60}") print(f"Starting training for {NUM_EPOCHS} epochs") print(f"Total steps: {total_steps}, Warmup: {WARMUP_STEPS}") print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}") print(f"{'='*60}\n") best_val_spearman = -1.0 best_epoch = -1 global_step = 0 for epoch in range(NUM_EPOCHS): model.train() epoch_loss = 0 n_batches = 0 optimizer.zero_grad() for batch_idx, batch in enumerate(train_loader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) if use_amp: with torch.amp.autocast('cuda'): outputs = model(input_ids, attention_mask, labels) loss = outputs['loss'] / GRAD_ACCUM scaler.scale(loss).backward() else: outputs = model(input_ids, attention_mask, labels) loss = outputs['loss'] / GRAD_ACCUM loss.backward() epoch_loss += outputs['loss'].item() n_batches += 1 if (batch_idx + 1) % GRAD_ACCUM == 0: if use_amp: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # Log every 50 steps if global_step % 50 == 0: avg_loss = epoch_loss / n_batches current_lr = optimizer.param_groups[0]['lr'] print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") trackio.log({ "train/loss": avg_loss, "train/lr": current_lr, "train/step": global_step, }) avg_train_loss = epoch_loss / max(n_batches, 1) # Evaluate val_metrics = evaluate(model, val_loader, device) print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}:") print(f" Train Loss: {avg_train_loss:.4f}") print(f" Val Loss: {val_metrics['loss']:.4f} | Spearman: {val_metrics['spearman']:.4f} | " f"Pearson: {val_metrics['pearson']:.4f} | MSE: {val_metrics['mse']:.4f}") trackio.log({ "train/epoch_loss": avg_train_loss, "val/loss": val_metrics['loss'], "val/spearman": val_metrics['spearman'], "val/pearson": val_metrics['pearson'], "val/mse": val_metrics['mse'], "epoch": epoch + 1, }) # Save best model if val_metrics['spearman'] > best_val_spearman: best_val_spearman = val_metrics['spearman'] best_epoch = epoch + 1 # Save locally os.makedirs("/app/best_model", exist_ok=True) torch.save(model.state_dict(), "/app/best_model/pytorch_model.bin") with open("/app/best_model/config.json", 'w') as f: json.dump({ **config, "task": "mRNA_stability_regression", "freeze_layers": FREEZE_LAYERS, "max_length": MAX_LENGTH, "best_val_spearman": best_val_spearman, "best_epoch": best_epoch, "datasets": ["mogam-ai/CDS-BART-mRNA-stability", "GleghornLab/mrna_stability_other"], }, f, indent=2) # Save tokenizer vocab with open("/app/best_model/codon_vocab.json", 'w') as f: json.dump({ "special_tokens": SPECIAL_TOKENS, "codon_to_id": CODON_TO_ID, }, f, indent=2) print(f" ★ New best model! Spearman: {best_val_spearman:.4f} (epoch {best_epoch})") # ---- Final Test Evaluation ---- print(f"\n{'='*60}") print(f"Loading best model from epoch {best_epoch}") model.load_state_dict(torch.load("/app/best_model/pytorch_model.bin", map_location=device)) test_metrics = evaluate(model, test_loader, device) print(f"\nFinal Test Results:") print(f" Loss: {test_metrics['loss']:.4f}") print(f" Spearman ρ: {test_metrics['spearman']:.4f}") print(f" Pearson r: {test_metrics['pearson']:.4f}") print(f" MSE: {test_metrics['mse']:.4f}") trackio.log({ "test/loss": test_metrics['loss'], "test/spearman": test_metrics['spearman'], "test/pearson": test_metrics['pearson'], "test/mse": test_metrics['mse'], }) # ---- Push to Hub ---- print(f"\nPushing model to Hub: {HUB_MODEL_ID}") api = HfApi() # Create repo if needed try: api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True) except Exception as e: print(f"Repo creation note: {e}") # Write model card model_card = f"""--- license: other license_name: nvidia-open-model-license tags: - biology - genomics - mRNA - stability-prediction - codon - fine-tuned base_model: nvidia/NV-CodonFM-Encodon-80M-v1 datasets: - mogam-ai/CDS-BART-mRNA-stability - GleghornLab/mrna_stability_other metrics: - spearman_correlation - pearson_correlation - mse --- # CodonFM-80M Fine-tuned for mRNA Stability Prediction ## Model Description This model is a fine-tuned version of [NVIDIA NV-CodonFM-Encodon-80M-v1](https://hf.co/nvidia/NV-CodonFM-Encodon-80M-v1) for predicting mRNA stability (half-life) from coding sequences (CDS). **Base model:** NV-CodonFM-Encodon-80M-v1 (80M parameter BERT-style Transformer with Rotary Position Embeddings) **Task:** Regression — predict mRNA stability score from codon sequence **Input:** mRNA coding sequence (codon-level tokenization, max 2046 codons) **Output:** Stability score (continuous float — higher = more stable) ## Training ### Datasets - **[mogam-ai/CDS-BART-mRNA-stability](https://hf.co/datasets/mogam-ai/CDS-BART-mRNA-stability)**: iCodon-based mRNA stability profiles from humans, mice, frogs, and fish (28,770 train / 6,207 val / 6,086 test) - **[GleghornLab/mrna_stability_other](https://hf.co/datasets/GleghornLab/mrna_stability_other)**: Additional mRNA stability data (45,749 train / 9,803 valid / 9,804 test) ### Recipe Based on [Helix-mRNA](https://arxiv.org/abs/2502.13785) and [BEACON](https://arxiv.org/abs/2406.10391): - **Strategy:** Freeze first {FREEZE_LAYERS} of 6 transformer layers, unfreeze last {6-FREEZE_LAYERS} + regression head - **Optimizer:** AdamW (backbone lr={LEARNING_RATE}, head lr={LEARNING_RATE*10}) - **Epochs:** {NUM_EPOCHS} - **Batch size:** {BATCH_SIZE} × {GRAD_ACCUM} gradient accumulation = {BATCH_SIZE*GRAD_ACCUM} effective - **Scheduler:** Cosine with {WARMUP_STEPS}-step warmup - **Mixed precision:** FP16 ### Results | Metric | Test Set | |--------|----------| | Spearman ρ | {test_metrics['spearman']:.4f} | | Pearson r | {test_metrics['pearson']:.4f} | | MSE | {test_metrics['mse']:.4f} | ### Literature Comparison | Model | Spearman ρ (mRNA Stability) | |-------|----------------------------| | CodonBERT | 0.35 | | XE | 0.50 | | Helix-mRNA | 0.52 | | HELM | 0.53 | | **This model** | **{test_metrics['spearman']:.4f}** | ## Usage ```python import torch import json from huggingface_hub import hf_hub_download # Load model (see train script for full model class definition) # ... ``` ## Citation If using this model, please cite the base CodonFM model and the iCodon dataset: ```bibtex @article{{diez2022icodon, title={{iCodon customizes gene expression based on the codon composition}}, author={{Diez, Michay and others}}, journal={{Scientific Reports}}, year={{2022}} }} ``` """ with open("/app/best_model/README.md", 'w') as f: f.write(model_card) # Upload all files api.upload_folder( folder_path="/app/best_model", repo_id=HUB_MODEL_ID, commit_message=f"Upload fine-tuned CodonFM-80M for mRNA stability (Spearman={test_metrics['spearman']:.4f})" ) # Also upload training script for reproducibility api.upload_file( path_or_fileobj="/app/train_codonfm_stability.py", path_in_repo="train_codonfm_stability.py", repo_id=HUB_MODEL_ID, commit_message="Upload training script" ) print(f"\n✅ Model pushed to: https://hf.co/{HUB_MODEL_ID}") print(f"Best validation Spearman: {best_val_spearman:.4f} (epoch {best_epoch})") print(f"Test Spearman: {test_metrics['spearman']:.4f}") print("Done!") if __name__ == "__main__": train()