Text Classification
biology
genomics
mRNA
stability-prediction
codon
fine-tuned
regression
CodonFM-80M-mRNA-stability / train_codonfm_stability.py
Imranyai's picture
Add train_codonfm_stability.py
41780cc verified
"""
Fine-tune NVIDIA NV-CodonFM-Encodon-80M-v1 for mRNA Stability Prediction.
Architecture: Custom BERT-style encoder with Rotary Position Embeddings (RoPE)
Dataset: mogam-ai/CDS-BART-mRNA-stability (iCodon - mRNA half-life from multiple species)
+ GleghornLab/mrna_stability_other (additional stability data)
Task: Regression (predict mRNA stability / half-life score)
Recipe based on:
- Helix-mRNA (arxiv:2502.13785): unfreeze last 2 layers, 5-30 epochs, AdamW
- BEACON (arxiv:2406.10391): lr sweep 1e-5 to 5e-3, warmup 50 steps, MSE loss
- CodonBERT: codon-level tokenization, CDS regression
"""
import os
import math
import json
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, HfApi
from datasets import load_dataset, concatenate_datasets
from scipy.stats import spearmanr, pearsonr
import numpy as np
import trackio
# ============================================================
# 1. CODON TOKENIZER
# ============================================================
# Build codon vocabulary: 64 sense codons + special tokens
RNA_BASES = ['A', 'U', 'G', 'C']
ALL_CODONS = []
for b1 in RNA_BASES:
for b2 in RNA_BASES:
for b3 in RNA_BASES:
ALL_CODONS.append(b1 + b2 + b3)
# 64 codons
# Special tokens (inferred from vocab_size=69 = 64 codons + 5 special)
SPECIAL_TOKENS = {
'[PAD]': 0,
'[UNK]': 1,
'[CLS]': 2,
'[MASK]': 3, # pad_token_id=3 in config - but let's check
'[SEP]': 4,
}
# Actually, from the config: pad_token_id=3
# Let's build: 0=PAD, 1=UNK, 2=CLS, 3=SEP, 4=MASK, then 5..68 = 64 codons
# OR: 0=CLS, 1=SEP, 2=MASK, 3=PAD, 4=UNK, 5..68 = 64 codons (pad=3 matches)
# The config says pad_token_id=3, so token id 3 = PAD
SPECIAL_TOKENS = {
'[CLS]': 0,
'[SEP]': 1,
'[MASK]': 2,
'[PAD]': 3,
'[UNK]': 4,
}
CODON_TO_ID = {}
for i, codon in enumerate(ALL_CODONS):
CODON_TO_ID[codon] = i + 5 # offset by 5 special tokens
ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()}
ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()})
PAD_TOKEN_ID = SPECIAL_TOKENS['[PAD]']
CLS_TOKEN_ID = SPECIAL_TOKENS['[CLS]']
SEP_TOKEN_ID = SPECIAL_TOKENS['[SEP]']
MASK_TOKEN_ID = SPECIAL_TOKENS['[MASK]']
UNK_TOKEN_ID = SPECIAL_TOKENS['[UNK]']
# Verify vocab size
assert len(CODON_TO_ID) + len(SPECIAL_TOKENS) == 69, f"Expected 69, got {len(CODON_TO_ID) + len(SPECIAL_TOKENS)}"
def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict:
"""
Tokenize an mRNA sequence into codon IDs.
Sequence should be RNA (A,U,G,C) divisible by 3.
Returns: input_ids and attention_mask
"""
# Convert T to U if DNA
seq = seq.upper().replace('T', 'U')
# Remove any whitespace
seq = seq.strip()
# Split into codons (triplets)
codons = [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)]
# Convert to token IDs: [CLS] + codons + [SEP]
token_ids = [CLS_TOKEN_ID]
for codon in codons[:max_length - 2]: # reserve space for CLS and SEP
token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID))
token_ids.append(SEP_TOKEN_ID)
# Create attention mask
attention_mask = [1] * len(token_ids)
return {
'input_ids': token_ids,
'attention_mask': attention_mask,
}
def pad_batch(batch, max_len, pad_id=PAD_TOKEN_ID):
"""Pad a batch of tokenized sequences to max_len."""
padded_ids = []
padded_masks = []
for item in batch:
ids = item['input_ids']
mask = item['attention_mask']
pad_len = max_len - len(ids)
padded_ids.append(ids + [pad_id] * pad_len)
padded_masks.append(mask + [0] * pad_len)
return {
'input_ids': torch.tensor(padded_ids, dtype=torch.long),
'attention_mask': torch.tensor(padded_masks, dtype=torch.long),
}
# ============================================================
# 2. MODEL ARCHITECTURE (matched to safetensors weight keys)
# ============================================================
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len):
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CodonFMAttention(nn.Module):
def __init__(self, hidden_size, num_heads, rotary_theta=10000.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.rotary_emb = RotaryEmbedding(self.head_dim, theta=rotary_theta)
def forward(self, hidden_states, attention_mask=None):
B, L, H = hidden_states.shape
q = self.query(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(q, L)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
scale = math.sqrt(self.head_dim)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale
if attention_mask is not None:
# attention_mask: [B, L] -> [B, 1, 1, L]
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(B, L, H)
return context
class CodonFMTransformerLayer(nn.Module):
"""
Matches weight keys:
- pre_attn_layer_norm, attention (Q/K/V/rotary), post_attn_dense, post_attn_layer_norm
- pre_ffn_layer_norm, intermediate_dense, post_ffn_layer_norm, output_dense
"""
def __init__(self, hidden_size, num_heads, intermediate_size,
hidden_act='gelu', layer_norm_eps=1e-12, dropout=0.1,
rotary_theta=10000.0):
super().__init__()
# Pre-attention layer norm
self.pre_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
# Attention
self.attention = CodonFMAttention(hidden_size, num_heads, rotary_theta)
# Post-attention projection + layer norm
self.post_attn_dense = nn.Linear(hidden_size, hidden_size)
self.post_attn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
# FFN
self.pre_ffn_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.intermediate_dense = nn.Linear(hidden_size, intermediate_size)
self.post_ffn_layer_norm = nn.LayerNorm(intermediate_size, eps=layer_norm_eps)
self.output_dense = nn.Linear(intermediate_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU() if hidden_act == 'gelu' else nn.ReLU()
def forward(self, hidden_states, attention_mask=None):
# Pre-norm attention
residual = hidden_states
hidden_states = self.pre_attn_layer_norm(hidden_states)
attn_output = self.attention(hidden_states, attention_mask)
attn_output = self.post_attn_dense(attn_output)
attn_output = self.dropout(attn_output)
hidden_states = residual + attn_output
hidden_states = self.post_attn_layer_norm(hidden_states)
# Pre-norm FFN
residual = hidden_states
hidden_states = self.pre_ffn_layer_norm(hidden_states)
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.post_ffn_layer_norm(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CodonFMEncoder(nn.Module):
"""CodonFM Encoder that matches the safetensors checkpoint structure."""
def __init__(self, config):
super().__init__()
# Embeddings
self.word_embeddings = nn.Embedding(
config['vocab_size'], config['hidden_size'],
padding_idx=config['pad_token_id']
)
self.post_ln = nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps'])
# Transformer layers
self.layers = nn.ModuleList([
CodonFMTransformerLayer(
hidden_size=config['hidden_size'],
num_heads=config['num_attention_heads'],
intermediate_size=config['intermediate_size'],
hidden_act=config['hidden_act'],
layer_norm_eps=config['layer_norm_eps'],
dropout=config['hidden_dropout_prob'],
rotary_theta=config['rotary_theta'],
)
for _ in range(config['num_hidden_layers'])
])
# MLM head (cls)
self.cls = nn.Sequential(
nn.Linear(config['hidden_size'], config['hidden_size']), # cls.0
nn.GELU(), # cls.1 (activation, no weights)
nn.LayerNorm(config['hidden_size'], eps=config['layer_norm_eps']), # cls.2
nn.Linear(config['hidden_size'], config['vocab_size']), # cls.3
)
def forward(self, input_ids, attention_mask=None):
# Embeddings
x = self.word_embeddings(input_ids)
x = self.post_ln(x)
# Transformer layers
for layer in self.layers:
x = layer(x, attention_mask)
return x # [B, L, hidden_size]
class CodonFMForStabilityPrediction(nn.Module):
"""CodonFM encoder + regression head for mRNA stability prediction."""
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = CodonFMEncoder(config)
# Regression head
hidden_size = config['hidden_size']
dropout = config.get('classifier_dropout', 0.1)
self.regression_head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(hidden_size, 1),
)
def forward(self, input_ids, attention_mask=None, labels=None):
hidden_states = self.encoder(input_ids, attention_mask) # [B, L, H]
# Mean pooling over non-pad tokens
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # [B, L, 1]
pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
else:
pooled = hidden_states.mean(1)
logits = self.regression_head(pooled).squeeze(-1) # [B]
loss = None
if labels is not None:
loss = F.mse_loss(logits, labels.float())
return {'loss': loss, 'logits': logits}
def load_pretrained_encoder(self, checkpoint_path):
"""Load pretrained CodonFM weights into the encoder."""
state_dict = load_file(checkpoint_path, device='cpu')
# Map checkpoint keys to our model keys
# Checkpoint: model.embeddings.word_embeddings.weight -> encoder.word_embeddings.weight
# Checkpoint: model.embeddings.post_ln.weight -> encoder.post_ln.weight
# Checkpoint: model.layers.X.* -> encoder.layers.X.*
# Checkpoint: model.cls.* -> encoder.cls.* (MLM head, will be replaced by regression head)
new_state_dict = {}
for key, value in state_dict.items():
# Strip 'model.' prefix
if key.startswith('model.'):
new_key = key[len('model.'):]
else:
new_key = key
# Map embeddings: strip only the leading 'embeddings.' prefix
if new_key.startswith('embeddings.'):
new_key = new_key[len('embeddings.'):]
new_state_dict['encoder.' + new_key] = value
# Load with strict=False (regression_head is new, cls head won't match)
missing, unexpected = self.load_state_dict(new_state_dict, strict=False)
print(f"Loaded pretrained encoder weights.")
print(f" Missing (new regression head params): {[k for k in missing if 'regression' in k]}")
print(f" Missing (other): {[k for k in missing if 'regression' not in k]}")
print(f" Unexpected (MLM head etc): {unexpected[:5]}...")
return missing, unexpected
# ============================================================
# 3. DATASET
# ============================================================
class mRNAStabilityDataset(Dataset):
"""Dataset for mRNA stability regression."""
def __init__(self, sequences, labels, max_length=2046):
self.sequences = sequences
self.labels = labels
self.max_length = max_length
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
seq = self.sequences[idx]
label = self.labels[idx]
tokens = tokenize_mRNA(seq, max_length=self.max_length)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask'],
'label': float(label),
}
def collate_fn(batch):
"""Custom collate function to pad sequences."""
max_len = max(len(item['input_ids']) for item in batch)
padded_ids = []
padded_masks = []
labels = []
for item in batch:
ids = item['input_ids']
mask = item['attention_mask']
pad_len = max_len - len(ids)
padded_ids.append(ids + [PAD_TOKEN_ID] * pad_len)
padded_masks.append(mask + [0] * pad_len)
labels.append(item['label'])
return {
'input_ids': torch.tensor(padded_ids, dtype=torch.long),
'attention_mask': torch.tensor(padded_masks, dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.float32),
}
# ============================================================
# 4. TRAINING LOOP
# ============================================================
def evaluate(model, dataloader, device):
"""Evaluate model on dataloader, return metrics."""
model.eval()
all_preds = []
all_labels = []
total_loss = 0
n_batches = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask, labels)
total_loss += outputs['loss'].item()
n_batches += 1
all_preds.extend(outputs['logits'].cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
spearman_rho, _ = spearmanr(all_preds, all_labels)
pearson_r, _ = pearsonr(all_preds, all_labels)
mse = np.mean((all_preds - all_labels) ** 2)
avg_loss = total_loss / max(n_batches, 1)
return {
'loss': avg_loss,
'spearman': spearman_rho,
'pearson': pearson_r,
'mse': mse,
}
def train():
# ---- Config ----
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "Imranyai/CodonFM-80M-mRNA-stability")
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-5"))
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "20"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2"))
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "1024")) # codons (most CDS < 1024 codons)
WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", "100"))
WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", "0.01"))
FREEZE_LAYERS = int(os.environ.get("FREEZE_LAYERS", "4")) # Freeze first 4 layers, unfreeze last 2
USE_BOTH_DATASETS = os.environ.get("USE_BOTH_DATASETS", "true").lower() == "true"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Config: lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, batch={BATCH_SIZE}, "
f"grad_accum={GRAD_ACCUM}, max_len={MAX_LENGTH}, freeze_layers={FREEZE_LAYERS}")
# ---- Init tracking ----
trackio.init(
project="codonfm-mrna-stability",
name=f"lr{LEARNING_RATE}_ep{NUM_EPOCHS}_freeze{FREEZE_LAYERS}",
)
# ---- Load model config ----
config_path = hf_hub_download(
repo_id="nvidia/NV-CodonFM-Encodon-80M-v1",
filename="config.json"
)
with open(config_path) as f:
config = json.load(f)
print(f"Model config: {config}")
# ---- Build model ----
model = CodonFMForStabilityPrediction(config)
# Load pretrained weights
ckpt_path = hf_hub_download(
repo_id="nvidia/NV-CodonFM-Encodon-80M-v1",
filename="NV-CodonFM-Encodon-80M-v1.safetensors"
)
model.load_pretrained_encoder(ckpt_path)
# Freeze early layers (keep last N layers trainable)
if FREEZE_LAYERS > 0:
# Freeze embeddings
for param in model.encoder.word_embeddings.parameters():
param.requires_grad = False
for param in model.encoder.post_ln.parameters():
param.requires_grad = False
# Freeze first FREEZE_LAYERS transformer layers
for i in range(FREEZE_LAYERS):
for param in model.encoder.layers[i].parameters():
param.requires_grad = False
# Freeze MLM head (not used for regression)
for param in model.encoder.cls.parameters():
param.requires_grad = False
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params:,} / {total_params:,} "
f"({100*trainable_params/total_params:.1f}%)")
model = model.to(device)
# ---- Load datasets ----
print("\nLoading datasets...")
# Primary dataset: mogam-ai/CDS-BART-mRNA-stability (iCodon-based)
ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability")
print(f"mogam-ai/CDS-BART-mRNA-stability: train={len(ds1['train'])}, val={len(ds1['val'])}, test={len(ds1['test'])}")
if USE_BOTH_DATASETS:
# Secondary dataset: GleghornLab/mrna_stability_other
ds2 = load_dataset("GleghornLab/mrna_stability_other")
print(f"GleghornLab/mrna_stability_other: train={len(ds2['train'])}, valid={len(ds2['valid'])}, test={len(ds2['test'])}")
# Combine: use 'rna' column from ds2 (the actual RNA sequence)
# ds1 has 'seq' (RNA) and 'y' (label)
# ds2 has 'seqs' (protein-encoded?), 'rna' (actual RNA), 'labels'
# Extract sequences and labels
train_seqs = list(ds1['train']['seq']) + list(ds2['train']['rna'])
train_labels = list(ds1['train']['y']) + list(ds2['train']['labels'])
val_seqs = list(ds1['val']['seq']) + list(ds2['valid']['rna'])
val_labels = list(ds1['val']['y']) + list(ds2['valid']['labels'])
test_seqs = list(ds1['test']['seq']) + list(ds2['test']['rna'])
test_labels = list(ds1['test']['y']) + list(ds2['test']['labels'])
else:
train_seqs = list(ds1['train']['seq'])
train_labels = list(ds1['train']['y'])
val_seqs = list(ds1['val']['seq'])
val_labels = list(ds1['val']['y'])
test_seqs = list(ds1['test']['seq'])
test_labels = list(ds1['test']['y'])
print(f"\nCombined dataset sizes: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}")
# Filter out sequences that are too short or have issues
def filter_valid(seqs, labels):
valid_seqs, valid_labels = [], []
for seq, label in zip(seqs, labels):
if seq is not None and len(seq) >= 9 and not np.isnan(label): # min 3 codons
valid_seqs.append(seq)
valid_labels.append(label)
return valid_seqs, valid_labels
train_seqs, train_labels = filter_valid(train_seqs, train_labels)
val_seqs, val_labels = filter_valid(val_seqs, val_labels)
test_seqs, test_labels = filter_valid(test_seqs, test_labels)
print(f"After filtering: train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}")
# Dataset stats
train_labels_arr = np.array(train_labels)
print(f"Label stats (train): mean={train_labels_arr.mean():.3f}, std={train_labels_arr.std():.3f}, "
f"min={train_labels_arr.min():.3f}, max={train_labels_arr.max():.3f}")
# Create datasets
train_dataset = mRNAStabilityDataset(train_seqs, train_labels, max_length=MAX_LENGTH)
val_dataset = mRNAStabilityDataset(val_seqs, val_labels, max_length=MAX_LENGTH)
test_dataset = mRNAStabilityDataset(test_seqs, test_labels, max_length=MAX_LENGTH)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False,
collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False,
collate_fn=collate_fn, num_workers=2, pin_memory=True)
# ---- Optimizer & Scheduler ----
# Differential learning rates: backbone slower, head faster
backbone_params = []
head_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'regression_head' in name:
head_params.append(param)
else:
backbone_params.append(param)
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': LEARNING_RATE},
{'params': head_params, 'lr': LEARNING_RATE * 10}, # 10x for new head
], weight_decay=WEIGHT_DECAY)
total_steps = len(train_loader) * NUM_EPOCHS // GRAD_ACCUM
def get_lr_lambda(warmup_steps, total_steps):
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return lr_lambda
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=get_lr_lambda(WARMUP_STEPS, total_steps)
)
# Enable mixed precision
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
use_amp = device.type == 'cuda'
# ---- Training ----
print(f"\n{'='*60}")
print(f"Starting training for {NUM_EPOCHS} epochs")
print(f"Total steps: {total_steps}, Warmup: {WARMUP_STEPS}")
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
print(f"{'='*60}\n")
best_val_spearman = -1.0
best_epoch = -1
global_step = 0
for epoch in range(NUM_EPOCHS):
model.train()
epoch_loss = 0
n_batches = 0
optimizer.zero_grad()
for batch_idx, batch in enumerate(train_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
if use_amp:
with torch.amp.autocast('cuda'):
outputs = model(input_ids, attention_mask, labels)
loss = outputs['loss'] / GRAD_ACCUM
scaler.scale(loss).backward()
else:
outputs = model(input_ids, attention_mask, labels)
loss = outputs['loss'] / GRAD_ACCUM
loss.backward()
epoch_loss += outputs['loss'].item()
n_batches += 1
if (batch_idx + 1) % GRAD_ACCUM == 0:
if use_amp:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# Log every 50 steps
if global_step % 50 == 0:
avg_loss = epoch_loss / n_batches
current_lr = optimizer.param_groups[0]['lr']
print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
trackio.log({
"train/loss": avg_loss,
"train/lr": current_lr,
"train/step": global_step,
})
avg_train_loss = epoch_loss / max(n_batches, 1)
# Evaluate
val_metrics = evaluate(model, val_loader, device)
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}:")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {val_metrics['loss']:.4f} | Spearman: {val_metrics['spearman']:.4f} | "
f"Pearson: {val_metrics['pearson']:.4f} | MSE: {val_metrics['mse']:.4f}")
trackio.log({
"train/epoch_loss": avg_train_loss,
"val/loss": val_metrics['loss'],
"val/spearman": val_metrics['spearman'],
"val/pearson": val_metrics['pearson'],
"val/mse": val_metrics['mse'],
"epoch": epoch + 1,
})
# Save best model
if val_metrics['spearman'] > best_val_spearman:
best_val_spearman = val_metrics['spearman']
best_epoch = epoch + 1
# Save locally
os.makedirs("/app/best_model", exist_ok=True)
torch.save(model.state_dict(), "/app/best_model/pytorch_model.bin")
with open("/app/best_model/config.json", 'w') as f:
json.dump({
**config,
"task": "mRNA_stability_regression",
"freeze_layers": FREEZE_LAYERS,
"max_length": MAX_LENGTH,
"best_val_spearman": best_val_spearman,
"best_epoch": best_epoch,
"datasets": ["mogam-ai/CDS-BART-mRNA-stability", "GleghornLab/mrna_stability_other"],
}, f, indent=2)
# Save tokenizer vocab
with open("/app/best_model/codon_vocab.json", 'w') as f:
json.dump({
"special_tokens": SPECIAL_TOKENS,
"codon_to_id": CODON_TO_ID,
}, f, indent=2)
print(f" ★ New best model! Spearman: {best_val_spearman:.4f} (epoch {best_epoch})")
# ---- Final Test Evaluation ----
print(f"\n{'='*60}")
print(f"Loading best model from epoch {best_epoch}")
model.load_state_dict(torch.load("/app/best_model/pytorch_model.bin", map_location=device))
test_metrics = evaluate(model, test_loader, device)
print(f"\nFinal Test Results:")
print(f" Loss: {test_metrics['loss']:.4f}")
print(f" Spearman ρ: {test_metrics['spearman']:.4f}")
print(f" Pearson r: {test_metrics['pearson']:.4f}")
print(f" MSE: {test_metrics['mse']:.4f}")
trackio.log({
"test/loss": test_metrics['loss'],
"test/spearman": test_metrics['spearman'],
"test/pearson": test_metrics['pearson'],
"test/mse": test_metrics['mse'],
})
# ---- Push to Hub ----
print(f"\nPushing model to Hub: {HUB_MODEL_ID}")
api = HfApi()
# Create repo if needed
try:
api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True)
except Exception as e:
print(f"Repo creation note: {e}")
# Write model card
model_card = f"""---
license: other
license_name: nvidia-open-model-license
tags:
- biology
- genomics
- mRNA
- stability-prediction
- codon
- fine-tuned
base_model: nvidia/NV-CodonFM-Encodon-80M-v1
datasets:
- mogam-ai/CDS-BART-mRNA-stability
- GleghornLab/mrna_stability_other
metrics:
- spearman_correlation
- pearson_correlation
- mse
---
# CodonFM-80M Fine-tuned for mRNA Stability Prediction
## Model Description
This model is a fine-tuned version of [NVIDIA NV-CodonFM-Encodon-80M-v1](https://hf.co/nvidia/NV-CodonFM-Encodon-80M-v1)
for predicting mRNA stability (half-life) from coding sequences (CDS).
**Base model:** NV-CodonFM-Encodon-80M-v1 (80M parameter BERT-style Transformer with Rotary Position Embeddings)
**Task:** Regression — predict mRNA stability score from codon sequence
**Input:** mRNA coding sequence (codon-level tokenization, max 2046 codons)
**Output:** Stability score (continuous float — higher = more stable)
## Training
### Datasets
- **[mogam-ai/CDS-BART-mRNA-stability](https://hf.co/datasets/mogam-ai/CDS-BART-mRNA-stability)**:
iCodon-based mRNA stability profiles from humans, mice, frogs, and fish (28,770 train / 6,207 val / 6,086 test)
- **[GleghornLab/mrna_stability_other](https://hf.co/datasets/GleghornLab/mrna_stability_other)**:
Additional mRNA stability data (45,749 train / 9,803 valid / 9,804 test)
### Recipe
Based on [Helix-mRNA](https://arxiv.org/abs/2502.13785) and [BEACON](https://arxiv.org/abs/2406.10391):
- **Strategy:** Freeze first {FREEZE_LAYERS} of 6 transformer layers, unfreeze last {6-FREEZE_LAYERS} + regression head
- **Optimizer:** AdamW (backbone lr={LEARNING_RATE}, head lr={LEARNING_RATE*10})
- **Epochs:** {NUM_EPOCHS}
- **Batch size:** {BATCH_SIZE} × {GRAD_ACCUM} gradient accumulation = {BATCH_SIZE*GRAD_ACCUM} effective
- **Scheduler:** Cosine with {WARMUP_STEPS}-step warmup
- **Mixed precision:** FP16
### Results
| Metric | Test Set |
|--------|----------|
| Spearman ρ | {test_metrics['spearman']:.4f} |
| Pearson r | {test_metrics['pearson']:.4f} |
| MSE | {test_metrics['mse']:.4f} |
### Literature Comparison
| Model | Spearman ρ (mRNA Stability) |
|-------|----------------------------|
| CodonBERT | 0.35 |
| XE | 0.50 |
| Helix-mRNA | 0.52 |
| HELM | 0.53 |
| **This model** | **{test_metrics['spearman']:.4f}** |
## Usage
```python
import torch
import json
from huggingface_hub import hf_hub_download
# Load model (see train script for full model class definition)
# ...
```
## Citation
If using this model, please cite the base CodonFM model and the iCodon dataset:
```bibtex
@article{{diez2022icodon,
title={{iCodon customizes gene expression based on the codon composition}},
author={{Diez, Michay and others}},
journal={{Scientific Reports}},
year={{2022}}
}}
```
"""
with open("/app/best_model/README.md", 'w') as f:
f.write(model_card)
# Upload all files
api.upload_folder(
folder_path="/app/best_model",
repo_id=HUB_MODEL_ID,
commit_message=f"Upload fine-tuned CodonFM-80M for mRNA stability (Spearman={test_metrics['spearman']:.4f})"
)
# Also upload training script for reproducibility
api.upload_file(
path_or_fileobj="/app/train_codonfm_stability.py",
path_in_repo="train_codonfm_stability.py",
repo_id=HUB_MODEL_ID,
commit_message="Upload training script"
)
print(f"\n✅ Model pushed to: https://hf.co/{HUB_MODEL_ID}")
print(f"Best validation Spearman: {best_val_spearman:.4f} (epoch {best_epoch})")
print(f"Test Spearman: {test_metrics['spearman']:.4f}")
print("Done!")
if __name__ == "__main__":
train()