| """ |
| Fine-tune NVIDIA NV-CodonFM-Encodon-80M-v1 for mRNA Stability Prediction. |
| |
| Architecture: Custom BERT-style encoder with Rotary Position Embeddings (RoPE) |
| Dataset: mogam-ai/CDS-BART-mRNA-stability (iCodon - mRNA half-life from multiple species) |
| + GleghornLab/mrna_stability_other (additional stability data) |
| Task: Regression (predict mRNA stability / half-life score) |
| |
| Recipe based on: |
| - Helix-mRNA (arxiv:2502.13785): unfreeze last 2 layers, 5-30 epochs, AdamW |
| - BEACON (arxiv:2406.10391): lr sweep 1e-5 to 5e-3, warmup 50 steps, MSE loss |
| - CodonBERT: codon-level tokenization, CDS regression |
| """ |
|
|
| import os |
| import math |
| import json |
| import re |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from safetensors.torch import load_file |
| from huggingface_hub import hf_hub_download, HfApi |
| from datasets import load_dataset, concatenate_datasets |
| from scipy.stats import spearmanr, pearsonr |
| import numpy as np |
| import trackio |
|
|
| |
| |
| |
|
|
| |
| RNA_BASES = ['A', 'U', 'G', 'C'] |
| ALL_CODONS = [] |
| for b1 in RNA_BASES: |
| for b2 in RNA_BASES: |
| for b3 in RNA_BASES: |
| ALL_CODONS.append(b1 + b2 + b3) |
| |
|
|
| |
| SPECIAL_TOKENS = { |
| '[PAD]': 0, |
| '[UNK]': 1, |
| '[CLS]': 2, |
| '[MASK]': 3, |
| '[SEP]': 4, |
| } |
|
|
| |
| |
| |
| |
| SPECIAL_TOKENS = { |
| '[CLS]': 0, |
| '[SEP]': 1, |
| '[MASK]': 2, |
| '[PAD]': 3, |
| '[UNK]': 4, |
| } |
|
|
| CODON_TO_ID = {} |
| for i, codon in enumerate(ALL_CODONS): |
| CODON_TO_ID[codon] = i + 5 |
|
|
| ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()} |
| ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()}) |
|
|
| PAD_TOKEN_ID = SPECIAL_TOKENS['[PAD]'] |
| CLS_TOKEN_ID = SPECIAL_TOKENS['[CLS]'] |
| SEP_TOKEN_ID = SPECIAL_TOKENS['[SEP]'] |
| MASK_TOKEN_ID = SPECIAL_TOKENS['[MASK]'] |
| UNK_TOKEN_ID = SPECIAL_TOKENS['[UNK]'] |
|
|
| |
| assert len(CODON_TO_ID) + len(SPECIAL_TOKENS) == 69, f"Expected 69, got {len(CODON_TO_ID) + len(SPECIAL_TOKENS)}" |
|
|
|
|
| def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict: |
| """ |
| Tokenize an mRNA sequence into codon IDs. |
| Sequence should be RNA (A,U,G,C) divisible by 3. |
| Returns: input_ids and attention_mask |
| """ |
| |
| seq = seq.upper().replace('T', 'U') |
| |
| |
| seq = seq.strip() |
| |
| |
| codons = [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)] |
| |
| |
| token_ids = [CLS_TOKEN_ID] |
| for codon in codons[:max_length - 2]: |
| token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID)) |
| token_ids.append(SEP_TOKEN_ID) |
| |
| |
| attention_mask = [1] * len(token_ids) |
| |
| return { |
| 'input_ids': token_ids, |
| 'attention_mask': attention_mask, |
| } |
|
|
|
|
| def pad_batch(batch, max_len, pad_id=PAD_TOKEN_ID): |
| """Pad a batch of tokenized sequences to max_len.""" |
| padded_ids = [] |
| padded_masks = [] |
| for item in batch: |
| ids = item['input_ids'] |
| mask = item['attention_mask'] |
| pad_len = max_len - len(ids) |
| padded_ids.append(ids + [pad_id] * pad_len) |
| padded_masks.append(mask + [0] * pad_len) |
| return { |
| 'input_ids': torch.tensor(padded_ids, dtype=torch.long), |
| 'attention_mask': torch.tensor(padded_masks, dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, theta=10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| |
| def forward(self, x, seq_len): |
| t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos(), emb.sin() |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin): |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class CodonFMAttention(nn.Module): |
| def __init__(self, hidden_size, num_heads, rotary_theta=10000.0): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| |
| self.query = nn.Linear(hidden_size, hidden_size) |
| self.key = nn.Linear(hidden_size, hidden_size) |
| self.value = nn.Linear(hidden_size, hidden_size) |
| self.rotary_emb = RotaryEmbedding(self.head_dim, theta=rotary_theta) |
| |
| def forward(self, hidden_states, attention_mask=None): |
| B, L, H = hidden_states.shape |
| |
| q = self.query(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.key(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.value(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| cos, sin = self.rotary_emb(q, L) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| |
| scale = math.sqrt(self.head_dim) |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale |
| |
| if attention_mask is not None: |
| |
| attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf')) |
| |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| context = torch.matmul(attn_weights, v) |
| context = context.transpose(1, 2).contiguous().view(B, L, H) |
| return context |
|
|
|
|
| class CodonFMTransformerLayer(nn.Module): |
| """ |
| Matches weight keys: |
| - pre_attn_layer_norm, attention (Q/K/V/rotary), post_attn_dense, post_attn_layer_norm |
| - pre_ffn_layer_norm, intermediate_dense, post_ffn_layer_norm, output_dense |
| """ |
| def __init__(self, hidden_size, num_heads, intermediate_size, |
| hidden_act='gelu', layer_norm_eps=1e-12, dropout=0.1, |
| rotary_theta=10000.0): |
| super().__init__() |
| |
| |
| self.pre_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| |
| |
| self.attention = CodonFMAttention(hidden_size, num_heads, rotary_theta) |
| |
| |
| self.post_attn_dense = nn.Linear(hidden_size, hidden_size) |
| self.post_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| |
| |
| self.pre_ffn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) |
| self.post_ffn_layer_norm = nn.LayerNorm(intermediate_size, eps=layer_norm_eps) |
| self.output_dense = nn.Linear(intermediate_size, hidden_size) |
| |
| self.dropout = nn.Dropout(dropout) |
| self.act = nn.GELU() if hidden_act == 'gelu' else nn.ReLU() |
| |
| def forward(self, hidden_states, attention_mask=None): |
| |
| residual = hidden_states |
| hidden_states = self.pre_attn_layer_norm(hidden_states) |
| attn_output = self.attention(hidden_states, attention_mask) |
| attn_output = self.post_attn_dense(attn_output) |
| attn_output = self.dropout(attn_output) |
| hidden_states = residual + attn_output |
| hidden_states = self.post_attn_layer_norm(hidden_states) |
| |
| |
| residual = hidden_states |
| hidden_states = self.pre_ffn_layer_norm(hidden_states) |
| hidden_states = self.intermediate_dense(hidden_states) |
| hidden_states = self.act(hidden_states) |
| hidden_states = self.post_ffn_layer_norm(hidden_states) |
| hidden_states = self.output_dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = residual + hidden_states |
| |
| return hidden_states |
|
|
|
|
| class CodonFMEncoder(nn.Module): |
| """CodonFM Encoder that matches the safetensors checkpoint structure.""" |
| def __init__(self, config): |
| super().__init__() |
| |
| |
| self.word_embeddings = nn.Embedding( |
| config['vocab_size'], config['hidden_size'], |
| padding_idx=config['pad_token_id'] |
| ) |
| self.post_ln = nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps']) |
| |
| |
| self.layers = nn.ModuleList([ |
| CodonFMTransformerLayer( |
| hidden_size=config['hidden_size'], |
| num_heads=config['num_attention_heads'], |
| intermediate_size=config['intermediate_size'], |
| hidden_act=config['hidden_act'], |
| layer_norm_eps=config['layer_norm_eps'], |
| dropout=config['hidden_dropout_prob'], |
| rotary_theta=config['rotary_theta'], |
| ) |
| for _ in range(config['num_hidden_layers']) |
| ]) |
| |
| |
| self.cls = nn.Sequential( |
| nn.Linear(config['hidden_size'], config['hidden_size']), |
| nn.GELU(), |
| nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps']), |
| nn.Linear(config['hidden_size'], config['vocab_size']), |
| ) |
| |
| def forward(self, input_ids, attention_mask=None): |
| |
| x = self.word_embeddings(input_ids) |
| x = self.post_ln(x) |
| |
| |
| for layer in self.layers: |
| x = layer(x, attention_mask) |
| |
| return x |
|
|
|
|
| class CodonFMForStabilityPrediction(nn.Module): |
| """CodonFM encoder + regression head for mRNA stability prediction.""" |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.encoder = CodonFMEncoder(config) |
| |
| |
| hidden_size = config['hidden_size'] |
| dropout = config.get('classifier_dropout', 0.1) |
| self.regression_head = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, hidden_size), |
| nn.Tanh(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, 1), |
| ) |
| |
| def forward(self, input_ids, attention_mask=None, labels=None): |
| hidden_states = self.encoder(input_ids, attention_mask) |
| |
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| else: |
| pooled = hidden_states.mean(1) |
| |
| logits = self.regression_head(pooled).squeeze(-1) |
| |
| loss = None |
| if labels is not None: |
| loss = F.mse_loss(logits, labels.float()) |
| |
| return {'loss': loss, 'logits': logits} |
| |
| def load_pretrained_encoder(self, checkpoint_path): |
| """Load pretrained CodonFM weights into the encoder.""" |
| state_dict = load_file(checkpoint_path, device='cpu') |
| |
| |
| |
| |
| |
| |
| |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| |
| if key.startswith('model.'): |
| new_key = key[len('model.'):] |
| else: |
| new_key = key |
| |
| |
| if new_key.startswith('embeddings.'): |
| new_key = new_key[len('embeddings.'):] |
| |
| new_state_dict['encoder.' + new_key] = value |
| |
| |
| missing, unexpected = self.load_state_dict(new_state_dict, strict=False) |
| print(f"Loaded pretrained encoder weights.") |
| print(f" Missing (new regression head params): {[k for k in missing if 'regression' in k]}") |
| print(f" Missing (other): {[k for k in missing if 'regression' not in k]}") |
| print(f" Unexpected (MLM head etc): {unexpected[:5]}...") |
| return missing, unexpected |
|
|
|
|
| |
| |
| |
|
|
| class mRNAStabilityDataset(Dataset): |
| """Dataset for mRNA stability regression.""" |
| def __init__(self, sequences, labels, max_length=2046): |
| self.sequences = sequences |
| self.labels = labels |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.sequences) |
| |
| def __getitem__(self, idx): |
| seq = self.sequences[idx] |
| label = self.labels[idx] |
| tokens = tokenize_mRNA(seq, max_length=self.max_length) |
| return { |
| 'input_ids': tokens['input_ids'], |
| 'attention_mask': tokens['attention_mask'], |
| 'label': float(label), |
| } |
|
|
|
|
| def collate_fn(batch): |
| """Custom collate function to pad sequences.""" |
| max_len = max(len(item['input_ids']) for item in batch) |
| |
| padded_ids = [] |
| padded_masks = [] |
| labels = [] |
| |
| for item in batch: |
| ids = item['input_ids'] |
| mask = item['attention_mask'] |
| pad_len = max_len - len(ids) |
| padded_ids.append(ids + [PAD_TOKEN_ID] * pad_len) |
| padded_masks.append(mask + [0] * pad_len) |
| labels.append(item['label']) |
| |
| return { |
| 'input_ids': torch.tensor(padded_ids, dtype=torch.long), |
| 'attention_mask': torch.tensor(padded_masks, dtype=torch.long), |
| 'labels': torch.tensor(labels, dtype=torch.float32), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def evaluate(model, dataloader, device): |
| """Evaluate model on dataloader, return metrics.""" |
| model.eval() |
| all_preds = [] |
| all_labels = [] |
| total_loss = 0 |
| n_batches = 0 |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| |
| outputs = model(input_ids, attention_mask, labels) |
| total_loss += outputs['loss'].item() |
| n_batches += 1 |
| |
| all_preds.extend(outputs['logits'].cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| |
| all_preds = np.array(all_preds) |
| all_labels = np.array(all_labels) |
| |
| spearman_rho, _ = spearmanr(all_preds, all_labels) |
| pearson_r, _ = pearsonr(all_preds, all_labels) |
| mse = np.mean((all_preds - all_labels) ** 2) |
| avg_loss = total_loss / max(n_batches, 1) |
| |
| return { |
| 'loss': avg_loss, |
| 'spearman': spearman_rho, |
| 'pearson': pearson_r, |
| 'mse': mse, |
| } |
|
|
|
|
| def train(): |
| |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "Imranyai/CodonFM-80M-mRNA-stability") |
| LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-5")) |
| NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "20")) |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16")) |
| GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2")) |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "1024")) |
| WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", "100")) |
| WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", "0.01")) |
| FREEZE_LAYERS = int(os.environ.get("FREEZE_LAYERS", "4")) |
| USE_BOTH_DATASETS = os.environ.get("USE_BOTH_DATASETS", "true").lower() == "true" |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| print(f"Config: lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, batch={BATCH_SIZE}, " |
| f"grad_accum={GRAD_ACCUM}, max_len={MAX_LENGTH}, freeze_layers={FREEZE_LAYERS}") |
| |
| |
| trackio.init( |
| project="codonfm-mrna-stability", |
| name=f"lr{LEARNING_RATE}_ep{NUM_EPOCHS}_freeze{FREEZE_LAYERS}", |
| ) |
| |
| |
| config_path = hf_hub_download( |
| repo_id="nvidia/NV-CodonFM-Encodon-80M-v1", |
| filename="config.json" |
| ) |
| with open(config_path) as f: |
| config = json.load(f) |
| print(f"Model config: {config}") |
| |
| |
| model = CodonFMForStabilityPrediction(config) |
| |
| |
| ckpt_path = hf_hub_download( |
| repo_id="nvidia/NV-CodonFM-Encodon-80M-v1", |
| filename="NV-CodonFM-Encodon-80M-v1.safetensors" |
| ) |
| model.load_pretrained_encoder(ckpt_path) |
| |
| |
| if FREEZE_LAYERS > 0: |
| |
| for param in model.encoder.word_embeddings.parameters(): |
| param.requires_grad = False |
| for param in model.encoder.post_ln.parameters(): |
| param.requires_grad = False |
| |
| |
| for i in range(FREEZE_LAYERS): |
| for param in model.encoder.layers[i].parameters(): |
| param.requires_grad = False |
| |
| |
| for param in model.encoder.cls.parameters(): |
| param.requires_grad = False |
| |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Trainable params: {trainable_params:,} / {total_params:,} " |
| f"({100*trainable_params/total_params:.1f}%)") |
| |
| model = model.to(device) |
| |
| |
| print("\nLoading datasets...") |
| |
| |
| ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability") |
| print(f"mogam-ai/CDS-BART-mRNA-stability: train={len(ds1['train'])}, val={len(ds1['val'])}, test={len(ds1['test'])}") |
| |
| if USE_BOTH_DATASETS: |
| |
| ds2 = load_dataset("GleghornLab/mrna_stability_other") |
| print(f"GleghornLab/mrna_stability_other: train={len(ds2['train'])}, valid={len(ds2['valid'])}, test={len(ds2['test'])}") |
| |
| |
| |
| |
| |
| |
| train_seqs = list(ds1['train']['seq']) + list(ds2['train']['rna']) |
| train_labels = list(ds1['train']['y']) + list(ds2['train']['labels']) |
| val_seqs = list(ds1['val']['seq']) + list(ds2['valid']['rna']) |
| val_labels = list(ds1['val']['y']) + list(ds2['valid']['labels']) |
| test_seqs = list(ds1['test']['seq']) + list(ds2['test']['rna']) |
| test_labels = list(ds1['test']['y']) + list(ds2['test']['labels']) |
| else: |
| train_seqs = list(ds1['train']['seq']) |
| train_labels = list(ds1['train']['y']) |
| val_seqs = list(ds1['val']['seq']) |
| val_labels = list(ds1['val']['y']) |
| test_seqs = list(ds1['test']['seq']) |
| test_labels = list(ds1['test']['y']) |
| |
| print(f"\nCombined dataset sizes: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}") |
| |
| |
| def filter_valid(seqs, labels): |
| valid_seqs, valid_labels = [], [] |
| for seq, label in zip(seqs, labels): |
| if seq is not None and len(seq) >= 9 and not np.isnan(label): |
| valid_seqs.append(seq) |
| valid_labels.append(label) |
| return valid_seqs, valid_labels |
| |
| train_seqs, train_labels = filter_valid(train_seqs, train_labels) |
| val_seqs, val_labels = filter_valid(val_seqs, val_labels) |
| test_seqs, test_labels = filter_valid(test_seqs, test_labels) |
| print(f"After filtering: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}") |
| |
| |
| train_labels_arr = np.array(train_labels) |
| print(f"Label stats (train): mean={train_labels_arr.mean():.3f}, std={train_labels_arr.std():.3f}, " |
| f"min={train_labels_arr.min():.3f}, max={train_labels_arr.max():.3f}") |
| |
| |
| train_dataset = mRNAStabilityDataset(train_seqs, train_labels, max_length=MAX_LENGTH) |
| val_dataset = mRNAStabilityDataset(val_seqs, val_labels, max_length=MAX_LENGTH) |
| test_dataset = mRNAStabilityDataset(test_seqs, test_labels, max_length=MAX_LENGTH) |
| |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, |
| collate_fn=collate_fn, num_workers=2, pin_memory=True) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, |
| collate_fn=collate_fn, num_workers=2, pin_memory=True) |
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, |
| collate_fn=collate_fn, num_workers=2, pin_memory=True) |
| |
| |
| |
| backbone_params = [] |
| head_params = [] |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| if 'regression_head' in name: |
| head_params.append(param) |
| else: |
| backbone_params.append(param) |
| |
| optimizer = torch.optim.AdamW([ |
| {'params': backbone_params, 'lr': LEARNING_RATE}, |
| {'params': head_params, 'lr': LEARNING_RATE * 10}, |
| ], weight_decay=WEIGHT_DECAY) |
| |
| total_steps = len(train_loader) * NUM_EPOCHS // GRAD_ACCUM |
| |
| def get_lr_lambda(warmup_steps, total_steps): |
| def lr_lambda(current_step): |
| if current_step < warmup_steps: |
| return float(current_step) / float(max(1, warmup_steps)) |
| progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| return lr_lambda |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, lr_lambda=get_lr_lambda(WARMUP_STEPS, total_steps) |
| ) |
| |
| |
| scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None |
| use_amp = device.type == 'cuda' |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Starting training for {NUM_EPOCHS} epochs") |
| print(f"Total steps: {total_steps}, Warmup: {WARMUP_STEPS}") |
| print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}") |
| print(f"{'='*60}\n") |
| |
| best_val_spearman = -1.0 |
| best_epoch = -1 |
| global_step = 0 |
| |
| for epoch in range(NUM_EPOCHS): |
| model.train() |
| epoch_loss = 0 |
| n_batches = 0 |
| optimizer.zero_grad() |
| |
| for batch_idx, batch in enumerate(train_loader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| |
| if use_amp: |
| with torch.amp.autocast('cuda'): |
| outputs = model(input_ids, attention_mask, labels) |
| loss = outputs['loss'] / GRAD_ACCUM |
| scaler.scale(loss).backward() |
| else: |
| outputs = model(input_ids, attention_mask, labels) |
| loss = outputs['loss'] / GRAD_ACCUM |
| loss.backward() |
| |
| epoch_loss += outputs['loss'].item() |
| n_batches += 1 |
| |
| if (batch_idx + 1) % GRAD_ACCUM == 0: |
| if use_amp: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
| |
| |
| if global_step % 50 == 0: |
| avg_loss = epoch_loss / n_batches |
| current_lr = optimizer.param_groups[0]['lr'] |
| print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
| trackio.log({ |
| "train/loss": avg_loss, |
| "train/lr": current_lr, |
| "train/step": global_step, |
| }) |
| |
| avg_train_loss = epoch_loss / max(n_batches, 1) |
| |
| |
| val_metrics = evaluate(model, val_loader, device) |
| |
| print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}:") |
| print(f" Train Loss: {avg_train_loss:.4f}") |
| print(f" Val Loss: {val_metrics['loss']:.4f} | Spearman: {val_metrics['spearman']:.4f} | " |
| f"Pearson: {val_metrics['pearson']:.4f} | MSE: {val_metrics['mse']:.4f}") |
| |
| trackio.log({ |
| "train/epoch_loss": avg_train_loss, |
| "val/loss": val_metrics['loss'], |
| "val/spearman": val_metrics['spearman'], |
| "val/pearson": val_metrics['pearson'], |
| "val/mse": val_metrics['mse'], |
| "epoch": epoch + 1, |
| }) |
| |
| |
| if val_metrics['spearman'] > best_val_spearman: |
| best_val_spearman = val_metrics['spearman'] |
| best_epoch = epoch + 1 |
| |
| |
| os.makedirs("/app/best_model", exist_ok=True) |
| torch.save(model.state_dict(), "/app/best_model/pytorch_model.bin") |
| with open("/app/best_model/config.json", 'w') as f: |
| json.dump({ |
| **config, |
| "task": "mRNA_stability_regression", |
| "freeze_layers": FREEZE_LAYERS, |
| "max_length": MAX_LENGTH, |
| "best_val_spearman": best_val_spearman, |
| "best_epoch": best_epoch, |
| "datasets": ["mogam-ai/CDS-BART-mRNA-stability", "GleghornLab/mrna_stability_other"], |
| }, f, indent=2) |
| |
| |
| with open("/app/best_model/codon_vocab.json", 'w') as f: |
| json.dump({ |
| "special_tokens": SPECIAL_TOKENS, |
| "codon_to_id": CODON_TO_ID, |
| }, f, indent=2) |
| |
| print(f" ★ New best model! Spearman: {best_val_spearman:.4f} (epoch {best_epoch})") |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Loading best model from epoch {best_epoch}") |
| model.load_state_dict(torch.load("/app/best_model/pytorch_model.bin", map_location=device)) |
| |
| test_metrics = evaluate(model, test_loader, device) |
| print(f"\nFinal Test Results:") |
| print(f" Loss: {test_metrics['loss']:.4f}") |
| print(f" Spearman ρ: {test_metrics['spearman']:.4f}") |
| print(f" Pearson r: {test_metrics['pearson']:.4f}") |
| print(f" MSE: {test_metrics['mse']:.4f}") |
| |
| trackio.log({ |
| "test/loss": test_metrics['loss'], |
| "test/spearman": test_metrics['spearman'], |
| "test/pearson": test_metrics['pearson'], |
| "test/mse": test_metrics['mse'], |
| }) |
| |
| |
| print(f"\nPushing model to Hub: {HUB_MODEL_ID}") |
| api = HfApi() |
| |
| |
| try: |
| api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True) |
| except Exception as e: |
| print(f"Repo creation note: {e}") |
| |
| |
| model_card = f"""--- |
| license: other |
| license_name: nvidia-open-model-license |
| tags: |
| - biology |
| - genomics |
| - mRNA |
| - stability-prediction |
| - codon |
| - fine-tuned |
| base_model: nvidia/NV-CodonFM-Encodon-80M-v1 |
| datasets: |
| - mogam-ai/CDS-BART-mRNA-stability |
| - GleghornLab/mrna_stability_other |
| metrics: |
| - spearman_correlation |
| - pearson_correlation |
| - mse |
| --- |
| |
| # CodonFM-80M Fine-tuned for mRNA Stability Prediction |
| |
| ## Model Description |
| |
| This model is a fine-tuned version of [NVIDIA NV-CodonFM-Encodon-80M-v1](https://hf.co/nvidia/NV-CodonFM-Encodon-80M-v1) |
| for predicting mRNA stability (half-life) from coding sequences (CDS). |
| |
| **Base model:** NV-CodonFM-Encodon-80M-v1 (80M parameter BERT-style Transformer with Rotary Position Embeddings) |
| **Task:** Regression — predict mRNA stability score from codon sequence |
| **Input:** mRNA coding sequence (codon-level tokenization, max 2046 codons) |
| **Output:** Stability score (continuous float — higher = more stable) |
| |
| ## Training |
| |
| ### Datasets |
| - **[mogam-ai/CDS-BART-mRNA-stability](https://hf.co/datasets/mogam-ai/CDS-BART-mRNA-stability)**: |
| iCodon-based mRNA stability profiles from humans, mice, frogs, and fish (28,770 train / 6,207 val / 6,086 test) |
| - **[GleghornLab/mrna_stability_other](https://hf.co/datasets/GleghornLab/mrna_stability_other)**: |
| Additional mRNA stability data (45,749 train / 9,803 valid / 9,804 test) |
| |
| ### Recipe |
| Based on [Helix-mRNA](https://arxiv.org/abs/2502.13785) and [BEACON](https://arxiv.org/abs/2406.10391): |
| - **Strategy:** Freeze first {FREEZE_LAYERS} of 6 transformer layers, unfreeze last {6-FREEZE_LAYERS} + regression head |
| - **Optimizer:** AdamW (backbone lr={LEARNING_RATE}, head lr={LEARNING_RATE*10}) |
| - **Epochs:** {NUM_EPOCHS} |
| - **Batch size:** {BATCH_SIZE} × {GRAD_ACCUM} gradient accumulation = {BATCH_SIZE*GRAD_ACCUM} effective |
| - **Scheduler:** Cosine with {WARMUP_STEPS}-step warmup |
| - **Mixed precision:** FP16 |
| |
| ### Results |
| |
| | Metric | Test Set | |
| |--------|----------| |
| | Spearman ρ | {test_metrics['spearman']:.4f} | |
| | Pearson r | {test_metrics['pearson']:.4f} | |
| | MSE | {test_metrics['mse']:.4f} | |
| |
| ### Literature Comparison |
| | Model | Spearman ρ (mRNA Stability) | |
| |-------|----------------------------| |
| | CodonBERT | 0.35 | |
| | XE | 0.50 | |
| | Helix-mRNA | 0.52 | |
| | HELM | 0.53 | |
| | **This model** | **{test_metrics['spearman']:.4f}** | |
| |
| ## Usage |
| |
| ```python |
| import torch |
| import json |
| from huggingface_hub import hf_hub_download |
| |
| # Load model (see train script for full model class definition) |
| # ... |
| ``` |
| |
| ## Citation |
| |
| If using this model, please cite the base CodonFM model and the iCodon dataset: |
| |
| ```bibtex |
| @article{{diez2022icodon, |
| title={{iCodon customizes gene expression based on the codon composition}}, |
| author={{Diez, Michay and others}}, |
| journal={{Scientific Reports}}, |
| year={{2022}} |
| }} |
| ``` |
| """ |
| |
| with open("/app/best_model/README.md", 'w') as f: |
| f.write(model_card) |
| |
| |
| api.upload_folder( |
| folder_path="/app/best_model", |
| repo_id=HUB_MODEL_ID, |
| commit_message=f"Upload fine-tuned CodonFM-80M for mRNA stability (Spearman={test_metrics['spearman']:.4f})" |
| ) |
| |
| |
| api.upload_file( |
| path_or_fileobj="/app/train_codonfm_stability.py", |
| path_in_repo="train_codonfm_stability.py", |
| repo_id=HUB_MODEL_ID, |
| commit_message="Upload training script" |
| ) |
| |
| print(f"\n✅ Model pushed to: https://hf.co/{HUB_MODEL_ID}") |
| print(f"Best validation Spearman: {best_val_spearman:.4f} (epoch {best_epoch})") |
| print(f"Test Spearman: {test_metrics['spearman']:.4f}") |
| print("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|