Text Classification
biology
genomics
mRNA
stability-prediction
codon
fine-tuned
regression
Imranyai's picture
Add data_setup.py β€” dataset setup, preprocessing & documentation
ea558c0 verified
"""
data_setup.py β€” Download, preprocess, audit, and prepare all datasets for
CodonFM-80M mRNA stability fine-tuning and benchmarking.
Usage:
# Download and preprocess everything
python data_setup.py --all
# Download only training datasets
python data_setup.py --training
# Download only benchmark datasets
python data_setup.py --benchmark
# Audit datasets (inspect stats, find issues)
python data_setup.py --audit
# Export preprocessed training data to local files
python data_setup.py --training --export ./processed_data
# Show codon vocabulary and tokenizer details
python data_setup.py --vocab
"""
import argparse
import json
import os
import sys
import csv
import urllib.request
from collections import Counter
from pathlib import Path
import numpy as np
try:
import pandas as pd
except ImportError:
pd = None
try:
from datasets import load_dataset
except ImportError:
load_dataset = None
# ============================================================
# 1. CODON VOCABULARY & TOKENIZER
# ============================================================
RNA_BASES = ['A', 'U', 'G', 'C']
ALL_CODONS = [b1 + b2 + b3 for b1 in RNA_BASES for b2 in RNA_BASES for b3 in RNA_BASES]
# Biological codon table (RNA) β†’ Amino Acid
CODON_TABLE = {
'UUU': 'Phe', 'UUC': 'Phe', 'UUA': 'Leu', 'UUG': 'Leu',
'CUU': 'Leu', 'CUC': 'Leu', 'CUA': 'Leu', 'CUG': 'Leu',
'AUU': 'Ile', 'AUC': 'Ile', 'AUA': 'Ile', 'AUG': 'Met/Start',
'GUU': 'Val', 'GUC': 'Val', 'GUA': 'Val', 'GUG': 'Val',
'UCU': 'Ser', 'UCC': 'Ser', 'UCA': 'Ser', 'UCG': 'Ser',
'CCU': 'Pro', 'CCC': 'Pro', 'CCA': 'Pro', 'CCG': 'Pro',
'ACU': 'Thr', 'ACC': 'Thr', 'ACA': 'Thr', 'ACG': 'Thr',
'GCU': 'Ala', 'GCC': 'Ala', 'GCA': 'Ala', 'GCG': 'Ala',
'UAU': 'Tyr', 'UAC': 'Tyr', 'UAA': 'Stop', 'UAG': 'Stop',
'CAU': 'His', 'CAC': 'His', 'CAA': 'Gln', 'CAG': 'Gln',
'AAU': 'Asn', 'AAC': 'Asn', 'AAA': 'Lys', 'AAG': 'Lys',
'GAU': 'Asp', 'GAC': 'Asp', 'GAA': 'Glu', 'GAG': 'Glu',
'UGU': 'Cys', 'UGC': 'Cys', 'UGA': 'Stop', 'UGG': 'Trp',
'CGU': 'Arg', 'CGC': 'Arg', 'CGA': 'Arg', 'CGG': 'Arg',
'AGU': 'Ser', 'AGC': 'Ser', 'AGA': 'Arg', 'AGG': 'Arg',
'GGU': 'Gly', 'GGC': 'Gly', 'GGA': 'Gly', 'GGG': 'Gly',
}
# Token vocabulary (matches CodonFM config: vocab_size=69, pad_token_id=3)
SPECIAL_TOKENS = {'[CLS]': 0, '[SEP]': 1, '[MASK]': 2, '[PAD]': 3, '[UNK]': 4}
CODON_TO_ID = {codon: i + 5 for i, codon in enumerate(ALL_CODONS)}
ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()}
ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()})
VOCAB_SIZE = len(SPECIAL_TOKENS) + len(ALL_CODONS) # 5 + 64 = 69
assert VOCAB_SIZE == 69
PAD_TOKEN_ID = 3
CLS_TOKEN_ID = 0
SEP_TOKEN_ID = 1
MASK_TOKEN_ID = 2
UNK_TOKEN_ID = 4
def seq_to_codons(seq: str) -> list:
"""Split an mRNA/DNA sequence into codon triplets."""
seq = seq.upper().replace('T', 'U').strip()
return [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)]
def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict:
"""Tokenize an mRNA/DNA sequence into CodonFM token IDs."""
codons = seq_to_codons(seq)
token_ids = [CLS_TOKEN_ID]
for codon in codons[:max_length - 2]:
token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID))
token_ids.append(SEP_TOKEN_ID)
attention_mask = [1] * len(token_ids)
return {'input_ids': token_ids, 'attention_mask': attention_mask}
def validate_sequence(seq: str) -> dict:
"""Validate an mRNA/DNA sequence for CodonFM compatibility."""
seq_clean = seq.upper().replace('T', 'U').strip()
issues = []
if len(seq_clean) == 0:
issues.append("Empty sequence")
if len(seq_clean) % 3 != 0:
issues.append(f"Length {len(seq_clean)} not divisible by 3 (truncated to {len(seq_clean) - len(seq_clean) % 3})")
invalid_chars = set(seq_clean) - {'A', 'U', 'G', 'C'}
if invalid_chars:
issues.append(f"Invalid characters: {invalid_chars}")
codons = seq_to_codons(seq_clean)
n_codons = len(codons)
# Check for start codon
starts_with_aug = codons[0] == 'AUG' if codons else False
# Check for stop codons
stop_codons = {'UAA', 'UAG', 'UGA'}
internal_stops = [i for i, c in enumerate(codons[:-1]) if c in stop_codons]
ends_with_stop = codons[-1] in stop_codons if codons else False
if internal_stops:
issues.append(f"Internal stop codons at positions: {internal_stops}")
# Unknown codons
unk_codons = [c for c in codons if c not in CODON_TO_ID]
if unk_codons:
issues.append(f"Unknown codons: {set(unk_codons)}")
return {
'valid': len(issues) == 0,
'issues': issues,
'length_nt': len(seq_clean),
'length_codons': n_codons,
'starts_with_AUG': starts_with_aug,
'ends_with_stop': ends_with_stop,
'n_internal_stops': len(internal_stops),
}
# ============================================================
# 2. TRAINING DATASETS
# ============================================================
TRAINING_DATASETS = {
'mogam-ai/CDS-BART-mRNA-stability': {
'description': 'iCodon vertebrate mRNA stability profiles (human, mouse, frog, fish)',
'source_paper': 'Diez et al. 2022, Scientific Reports "iCodon customizes gene expression based on the codon composition"',
'seq_col': 'seq',
'label_col': 'y',
'splits': {'train': 'train', 'val': 'val', 'test': 'test'},
'label_meaning': 'mRNA half-life z-score (higher = more stable, meanβ‰ˆ0, stdβ‰ˆ1)',
'species': ['Human', 'Mouse', 'Xenopus (frog)', 'Zebrafish'],
'notes': 'RNA sequences (A,U,G,C). All divisible by 3. Subset of GleghornLab dataset.',
},
'GleghornLab/mrna_stability_other': {
'description': 'Extended multi-species mRNA stability data (superset of mogam-ai dataset)',
'source_paper': 'Li et al. 2024, Genome Research "CodonBERT large language model for mRNA vaccines"',
'seq_col': 'rna',
'label_col': 'labels',
'splits': {'train': 'train', 'val': 'valid', 'test': 'test'},
'label_meaning': 'mRNA half-life z-score (higher = more stable)',
'species': ['Multiple vertebrate species'],
'notes': 'Has extra "seqs" column (protein-encoded, not used). Contains 1 outlier sequence of 3 nt. Superset of mogam-ai.',
'extra_col': 'seqs',
},
}
def download_training_data(export_dir=None):
"""Download and inspect training datasets from HuggingFace Hub."""
if load_dataset is None:
print("ERROR: `datasets` library required. Run: pip install datasets")
return None
all_data = {}
for repo_id, info in TRAINING_DATASETS.items():
print(f"\n{'='*60}")
print(f"πŸ“¦ {repo_id}")
print(f" {info['description']}")
print(f"={'='*60}")
ds = load_dataset(repo_id)
for split_name, hf_split in info['splits'].items():
split_data = ds[hf_split]
seqs = split_data[info['seq_col']]
labels = split_data[info['label_col']]
print(f"\n [{split_name}] {len(seqs)} samples")
print(f" Seq lengths (nt): min={min(len(s) for s in seqs)}, "
f"mean={np.mean([len(s) for s in seqs]):.0f}, "
f"max={max(len(s) for s in seqs)}")
print(f" Seq lengths (cod): min={min(len(s)//3 for s in seqs)}, "
f"mean={np.mean([len(s)//3 for s in seqs]):.0f}, "
f"max={max(len(s)//3 for s in seqs)}")
labels_arr = np.array(labels)
print(f" Labels: min={labels_arr.min():.3f}, mean={labels_arr.mean():.3f}, "
f"std={labels_arr.std():.3f}, max={labels_arr.max():.3f}")
all_data[repo_id] = ds
if export_dir:
export_training_data(all_data, export_dir)
return all_data
def preprocess_training_data(use_both_datasets=True, min_codons=3, max_codons=2046,
remove_duplicates=True, deduplicate_across_splits=True):
"""
Preprocess training data: clean, filter, deduplicate, and combine.
Steps:
1. Load both HF datasets
2. Use GleghornLab as primary (superset) OR combine both
3. Filter: remove sequences < min_codons or > max_codons codons
4. Filter: remove sequences with NaN labels
5. Filter: remove sequences with invalid characters
6. Deduplicate: remove exact sequence duplicates within each split
7. Deduplicate: ensure no train sequences appear in val/test (data leakage check)
8. Return clean {train, val, test} dictionaries
Returns:
dict with 'train', 'val', 'test' keys, each containing 'sequences' and 'labels' lists
"""
if load_dataset is None:
raise ImportError("datasets library required: pip install datasets")
print("Loading datasets...")
if use_both_datasets:
# Use GleghornLab (superset) β€” it contains ALL of mogam-ai plus extra data
ds = load_dataset("GleghornLab/mrna_stability_other")
raw_data = {
'train': {'seqs': ds['train']['rna'], 'labels': ds['train']['labels']},
'val': {'seqs': ds['valid']['rna'], 'labels': ds['valid']['labels']},
'test': {'seqs': ds['test']['rna'], 'labels': ds['test']['labels']},
}
print(f" Using GleghornLab/mrna_stability_other (superset)")
else:
# Use mogam-ai only (smaller, cleaner)
ds = load_dataset("mogam-ai/CDS-BART-mRNA-stability")
raw_data = {
'train': {'seqs': ds['train']['seq'], 'labels': ds['train']['y']},
'val': {'seqs': ds['val']['seq'], 'labels': ds['val']['y']},
'test': {'seqs': ds['test']['seq'], 'labels': ds['test']['y']},
}
print(f" Using mogam-ai/CDS-BART-mRNA-stability only")
clean_data = {}
total_removed = {'short': 0, 'long': 0, 'nan': 0, 'invalid': 0, 'duplicate': 0}
for split in ['train', 'val', 'test']:
seqs = raw_data[split]['seqs']
labels = raw_data[split]['labels']
orig_count = len(seqs)
clean_seqs = []
clean_labels = []
seen = set()
for seq, label in zip(seqs, labels):
# Skip None/empty
if seq is None or len(seq) == 0:
total_removed['invalid'] += 1
continue
# Normalize: uppercase, T→U
seq = seq.upper().replace('T', 'U').strip()
# Check NaN label
if np.isnan(label):
total_removed['nan'] += 1
continue
# Check invalid characters
if set(seq) - {'A', 'U', 'G', 'C'}:
total_removed['invalid'] += 1
continue
# Check length
n_codons = len(seq) // 3
if n_codons < min_codons:
total_removed['short'] += 1
continue
if n_codons > max_codons:
total_removed['long'] += 1
continue
# Deduplicate within split
if remove_duplicates:
if seq in seen:
total_removed['duplicate'] += 1
continue
seen.add(seq)
clean_seqs.append(seq)
clean_labels.append(float(label))
clean_data[split] = {'sequences': clean_seqs, 'labels': clean_labels}
print(f" [{split}] {orig_count} β†’ {len(clean_seqs)} samples "
f"(removed {orig_count - len(clean_seqs)})")
# Cross-split deduplication: check for train/test leakage
if deduplicate_across_splits:
test_seqs = set(clean_data['test']['sequences'])
val_seqs = set(clean_data['val']['sequences'])
leakage_test = sum(1 for s in clean_data['train']['sequences'] if s in test_seqs)
leakage_val = sum(1 for s in clean_data['train']['sequences'] if s in val_seqs)
val_test_overlap = len(val_seqs & test_seqs)
print(f"\n Data leakage check:")
print(f" Train→Test overlap: {leakage_test} sequences")
print(f" Train→Val overlap: {leakage_val} sequences")
print(f" Val→Test overlap: {val_test_overlap} sequences")
if leakage_test > 0 or leakage_val > 0:
print(f" ⚠️ WARNING: Data leakage detected! Removing leaked sequences from train...")
eval_seqs = test_seqs | val_seqs
filtered_train = [(s, l) for s, l in
zip(clean_data['train']['sequences'], clean_data['train']['labels'])
if s not in eval_seqs]
clean_data['train']['sequences'] = [x[0] for x in filtered_train]
clean_data['train']['labels'] = [x[1] for x in filtered_train]
print(f" Train after dedup: {len(clean_data['train']['sequences'])} samples")
print(f"\n Removal summary: {total_removed}")
print(f" Final sizes: train={len(clean_data['train']['sequences'])}, "
f"val={len(clean_data['val']['sequences'])}, "
f"test={len(clean_data['test']['sequences'])}")
return clean_data
def export_training_data(data, export_dir):
"""Export preprocessed data to CSV files."""
os.makedirs(export_dir, exist_ok=True)
if isinstance(data, dict) and 'train' in data and 'sequences' in data.get('train', {}):
# Already preprocessed format
for split in ['train', 'val', 'test']:
if split not in data:
continue
filepath = os.path.join(export_dir, f'{split}.csv')
with open(filepath, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['sequence', 'stability_score'])
for seq, label in zip(data[split]['sequences'], data[split]['labels']):
writer.writerow([seq, label])
print(f" Exported {split}: {len(data[split]['sequences'])} rows β†’ {filepath}")
else:
print(" Export requires preprocessed data. Run preprocess_training_data() first.")
# ============================================================
# 3. BENCHMARK DATASETS
# ============================================================
CODONBERT_BASE_URL = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/master/benchmarks/CodonBERT/data/fine-tune"
BENCHMARK_DATASETS = {
'stability': {
'url': f"{CODONBERT_BASE_URL}/mRNA_Stability.csv",
'filename': 'mRNA_Stability.csv',
'description': 'mRNA Stability (iCodon vertebrate mRNA half-life)',
'source': 'Diez et al. 2022, Scientific Reports',
'samples': 65356,
'seq_length': '3-3066 nt (1-1022 codons)',
'label': 'Half-life z-score (continuous, meanβ‰ˆ0, stdβ‰ˆ1)',
'metric': 'Spearman ρ',
'columns': 'sequence, value, dataset, split',
'species': 'Multi-vertebrate (human, mouse, frog, fish)',
},
'mrfp': {
'url': f"{CODONBERT_BASE_URL}/mRFP_Expression.csv",
'filename': 'mRFP_Expression.csv',
'description': 'mRFP Protein Expression in E. coli',
'source': 'Li et al. 2024, Genome Research (CodonBERT)',
'samples': 1459,
'seq_length': '678 nt (226 codons, fixed)',
'label': 'Fluorescence intensity (log scale, range 7.4-11.4)',
'metric': 'Spearman ρ',
'columns': 'sequence, value, dataset, split',
'species': 'E. coli (synthetic mRFP variants)',
},
'vaccine': {
'url': f"{CODONBERT_BASE_URL}/CoV_Vaccine_Degradation.csv",
'filename': 'CoV_Vaccine_Degradation.csv',
'description': 'SARS-CoV-2 mRNA Vaccine Degradation',
'source': 'CodonBERT benchmark (derived from Stanford OpenVaccine)',
'samples': 2400,
'seq_length': '81 nt (27 codons, fixed)',
'label': 'Degradation score (z-normalized, range -7.2 to 6.5)',
'metric': 'Spearman ρ',
'columns': 'sequence, value, dataset, split',
'species': 'Synthetic SARS-CoV-2 mRNA vaccine fragments',
},
'riboswitch': {
'url': f"{CODONBERT_BASE_URL}/Tc-Riboswitches.csv",
'filename': 'Tc-Riboswitches.csv',
'description': 'Tetracycline Riboswitch Activity',
'source': 'Groher et al. 2018 (via CodonBERT)',
'samples': 355,
'seq_length': '66-75 nt (22-25 codons)',
'label': 'Switching factor (continuous, range -0.3 to 3.1)',
'metric': 'Spearman ρ',
'columns': 'sequence, value, dataset, split',
'species': 'Synthetic tetracycline riboswitches',
},
'mlos': {
'url': f"{CODONBERT_BASE_URL}/MLOS.csv",
'filename': 'MLOS.csv',
'description': 'MLOS Flu Vaccine Antigen Expression',
'source': 'Ren et al. 2024 (HELM/MLOS)',
'samples': 167,
'seq_length': '~1700 nt (~567 codons)',
'label': 'Expression level (continuous, range 0.3-2.2)',
'metric': 'Spearman ρ',
'columns': 'cds, value (no split column β€” uses random 70/15/15)',
'species': 'Influenza haemagglutinin CDS variants',
'notes': 'No pre-defined splits. Column name is "cds" not "sequence".',
},
}
def download_benchmark_data(data_dir='./benchmark_data'):
"""Download all benchmark datasets."""
os.makedirs(data_dir, exist_ok=True)
for task_name, info in BENCHMARK_DATASETS.items():
filepath = os.path.join(data_dir, info['filename'])
if os.path.exists(filepath):
size = os.path.getsize(filepath)
print(f" βœ“ {info['filename']} already exists ({size/1024:.1f} KB)")
else:
print(f" ↓ Downloading {info['filename']}...")
try:
urllib.request.urlretrieve(info['url'], filepath)
size = os.path.getsize(filepath)
print(f" βœ“ Downloaded {info['filename']} ({size/1024:.1f} KB)")
except Exception as e:
print(f" βœ— Failed to download {info['filename']}: {e}")
return data_dir
# ============================================================
# 4. AUDIT
# ============================================================
def audit_dataset(sequences, labels, name="dataset"):
"""Run a comprehensive audit on a list of sequences and labels."""
print(f"\n{'='*60}")
print(f"AUDIT: {name} ({len(sequences)} sequences)")
print(f"{'='*60}")
if len(sequences) == 0:
print(" (empty)")
return
# ---- Sequence stats ----
lengths_nt = [len(s) for s in sequences]
lengths_codon = [len(s) // 3 for s in sequences]
print(f"\n πŸ“ Sequence Lengths:")
print(f" Nucleotides: min={min(lengths_nt)}, mean={np.mean(lengths_nt):.0f}, "
f"median={np.median(lengths_nt):.0f}, max={max(lengths_nt)}")
print(f" Codons: min={min(lengths_codon)}, mean={np.mean(lengths_codon):.0f}, "
f"median={np.median(lengths_codon):.0f}, max={max(lengths_codon)}")
# Length distribution buckets
buckets = [0, 100, 300, 500, 1000, 2000, 3000, 10000]
hist = np.histogram(lengths_codon, bins=buckets)[0]
print(f" Codon length distribution:")
for i, count in enumerate(hist):
pct = 100 * count / len(sequences)
bar = 'β–ˆ' * int(pct / 2)
print(f" {buckets[i]:>5}-{buckets[i+1]:>5} codons: {count:>6} ({pct:>5.1f}%) {bar}")
# Sequences > 2046 codons (CodonFM max)
over_limit = sum(1 for c in lengths_codon if c > 2046)
if over_limit > 0:
print(f" ⚠️ {over_limit} sequences exceed CodonFM max (2046 codons) β€” will be truncated")
# ---- Nucleotide composition ----
all_chars = Counter()
for s in sequences:
all_chars.update(s.upper())
total_bases = sum(all_chars.values())
print(f"\n 🧬 Nucleotide Composition:")
for base in ['A', 'U', 'G', 'C']:
count = all_chars.get(base, 0)
pct = 100 * count / total_bases
print(f" {base}: {count:>12,} ({pct:.1f}%)")
unexpected = {k: v for k, v in all_chars.items() if k not in 'AUGC'}
if unexpected:
print(f" ⚠️ Unexpected characters: {unexpected}")
# Not divisible by 3
not_div3 = sum(1 for s in sequences if len(s) % 3 != 0)
if not_div3 > 0:
print(f" ⚠️ {not_div3} sequences not divisible by 3")
# ---- Codon usage ----
codon_counts = Counter()
for s in sequences[:5000]: # sample for speed
codons = seq_to_codons(s)
codon_counts.update(codons)
print(f"\n πŸ”€ Codon Usage (top 10 / bottom 10 from {min(5000, len(sequences))} seqs):")
sorted_codons = codon_counts.most_common()
for codon, count in sorted_codons[:10]:
aa = CODON_TABLE.get(codon, '?')
print(f" {codon} ({aa:>9s}): {count:>8,}")
print(f" ...")
for codon, count in sorted_codons[-10:]:
aa = CODON_TABLE.get(codon, '?')
print(f" {codon} ({aa:>9s}): {count:>8,}")
# Start/stop codon analysis
starts_with_aug = sum(1 for s in sequences if s[:3].upper().replace('T', 'U') == 'AUG')
stop_codons = {'UAA', 'UAG', 'UGA'}
ends_with_stop = sum(1 for s in sequences
if seq_to_codons(s)[-1] in stop_codons) if sequences else 0
print(f"\n 🚦 Start/Stop Codons:")
print(f" Starts with AUG: {starts_with_aug}/{len(sequences)} ({100*starts_with_aug/len(sequences):.1f}%)")
print(f" Ends with stop: {ends_with_stop}/{len(sequences)} ({100*ends_with_stop/len(sequences):.1f}%)")
# ---- Label stats ----
labels_arr = np.array(labels, dtype=float)
nan_count = np.isnan(labels_arr).sum()
labels_clean = labels_arr[~np.isnan(labels_arr)]
print(f"\n πŸ“Š Label Distribution:")
print(f" Count: {len(labels_arr)}, NaN: {nan_count}")
if len(labels_clean) > 0:
print(f" Min: {labels_clean.min():.4f}")
print(f" Q1: {np.percentile(labels_clean, 25):.4f}")
print(f" Median: {np.median(labels_clean):.4f}")
print(f" Q3: {np.percentile(labels_clean, 75):.4f}")
print(f" Max: {labels_clean.max():.4f}")
print(f" Mean: {labels_clean.mean():.4f}")
print(f" Std: {labels_clean.std():.4f}")
print(f" Skew: {float(((labels_clean - labels_clean.mean()) ** 3).mean() / labels_clean.std() ** 3):.4f}")
# ---- Duplicates ----
unique_seqs = len(set(sequences))
dup_count = len(sequences) - unique_seqs
print(f"\n πŸ” Duplicates:")
print(f" Unique sequences: {unique_seqs}")
print(f" Duplicate sequences: {dup_count}")
# ---- Outliers ----
if len(labels_clean) > 0:
q1, q3 = np.percentile(labels_clean, [25, 75])
iqr = q3 - q1
lower = q1 - 3 * iqr
upper = q3 + 3 * iqr
outliers = np.sum((labels_clean < lower) | (labels_clean > upper))
print(f"\n ⚑ Outliers (>3 IQR):")
print(f" Label outliers: {outliers}/{len(labels_clean)}")
very_short = sum(1 for c in lengths_codon if c < 10)
very_long = sum(1 for c in lengths_codon if c > 1000)
print(f" Very short (<10 codons): {very_short}")
print(f" Very long (>1000 codons): {very_long}")
def run_full_audit():
"""Run audit on all training and benchmark datasets."""
print("=" * 70)
print("FULL DATASET AUDIT")
print("=" * 70)
# Training datasets
print("\n\nπŸ“š TRAINING DATASETS")
print("=" * 70)
if load_dataset is not None:
ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability")
for split in ['train', 'val', 'test']:
audit_dataset(
ds1[split]['seq'], ds1[split]['y'],
f"mogam-ai/CDS-BART-mRNA-stability [{split}]"
)
ds2 = load_dataset("GleghornLab/mrna_stability_other")
for split, hf_split in [('train', 'train'), ('val', 'valid'), ('test', 'test')]:
audit_dataset(
ds2[hf_split]['rna'], ds2[hf_split]['labels'],
f"GleghornLab/mrna_stability_other [{split}]"
)
# Cross-dataset analysis
print("\n\nπŸ“Š CROSS-DATASET ANALYSIS")
print("=" * 60)
ds1_all = set(ds1['train']['seq']) | set(ds1['val']['seq']) | set(ds1['test']['seq'])
ds2_all = set(ds2['train']['rna']) | set(ds2['valid']['rna']) | set(ds2['test']['rna'])
print(f" mogam-ai total unique: {len(ds1_all)}")
print(f" GleghornLab total unique: {len(ds2_all)}")
print(f" Overlap: {len(ds1_all & ds2_all)}")
print(f" mogam-ai βŠ‚ GleghornLab: {ds1_all.issubset(ds2_all)}")
print(f" GleghornLab-only: {len(ds2_all - ds1_all)}")
else:
print(" Skipping (datasets library not installed)")
# Benchmark datasets
print("\n\nπŸ“Š BENCHMARK DATASETS")
print("=" * 70)
if pd is not None:
data_dir = download_benchmark_data()
for task_name, info in BENCHMARK_DATASETS.items():
filepath = os.path.join(data_dir, info['filename'])
if not os.path.exists(filepath):
continue
df = pd.read_csv(filepath)
df.columns = [c.lower().strip() for c in df.columns]
seq_col = 'sequence' if 'sequence' in df.columns else 'cds'
val_col = 'value'
if seq_col in df.columns and val_col in df.columns:
audit_dataset(
df[seq_col].tolist(), df[val_col].tolist(),
f"Benchmark: {task_name} ({info['filename']})"
)
else:
print(" Skipping (pandas not installed)")
# ============================================================
# 5. VOCAB DISPLAY
# ============================================================
def show_vocab():
"""Display the full codon vocabulary with amino acid mapping."""
print("=" * 70)
print("CodonFM Tokenizer Vocabulary (vocab_size=69)")
print("=" * 70)
print("\n SPECIAL TOKENS:")
print(f" {'ID':>4} {'Token':<10} {'Description'}")
print(f" {'─'*4} {'─'*10} {'─'*30}")
descriptions = {
'[CLS]': 'Classification token (prepended)',
'[SEP]': 'Separator token (appended)',
'[MASK]': 'Mask token (for MLM pretraining)',
'[PAD]': 'Padding token (pad_token_id=3)',
'[UNK]': 'Unknown token (for invalid codons)',
}
for token, tid in sorted(SPECIAL_TOKENS.items(), key=lambda x: x[1]):
print(f" {tid:>4} {token:<10} {descriptions.get(token, '')}")
print(f"\n CODON TOKENS (64 codons β†’ 20 amino acids + 3 stop):")
print(f" {'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid':<12} "
f"{'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid'}")
print(f" {'─'*4} {'─'*6} {'─'*12} {'─'*4} {'─'*6} {'─'*12} "
f"{'─'*4} {'─'*6} {'─'*12} {'─'*4} {'─'*6} {'─'*12}")
items = sorted(CODON_TO_ID.items(), key=lambda x: x[1])
for i in range(0, len(items), 4):
row_parts = []
for j in range(4):
if i + j < len(items):
codon, tid = items[i + j]
aa = CODON_TABLE.get(codon, '?')
row_parts.append(f" {tid:>4} {codon:<6} {aa:<12}")
else:
row_parts.append(f" {'':>4} {'':6} {'':12}")
print("".join(row_parts))
print(f"\n TOKENIZATION EXAMPLE:")
example = "AUGGCAGCCGAGACUCGG"
codons = seq_to_codons(example)
tokens = tokenize_mRNA(example)
print(f" Input: {example}")
print(f" Codons: {' '.join(codons)}")
print(f" Token IDs: {tokens['input_ids']}")
decoded = [ID_TO_CODON.get(t, '?') for t in tokens['input_ids']]
print(f" Decoded: {' '.join(decoded)}")
print(f"\n CONFIG (matches nvidia/NV-CodonFM-Encodon-80M-v1/config.json):")
print(f" vocab_size: 69")
print(f" pad_token_id: 3 ([PAD])")
print(f" max_position_embeddings: 2046 codons (~6138 nt)")
print(f" position_embedding_type: rotary (RoPE, ΞΈ=10000)")
# ============================================================
# CLI
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="Dataset setup for CodonFM-80M mRNA stability fine-tuning",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python data_setup.py --all # Download & audit everything
python data_setup.py --training # Download training datasets
python data_setup.py --training --export ./data # Download & export to CSV
python data_setup.py --benchmark # Download benchmark datasets
python data_setup.py --audit # Full audit of all datasets
python data_setup.py --vocab # Show codon vocabulary
python data_setup.py --preprocess # Preprocess & deduplicate
"""
)
parser.add_argument('--all', action='store_true', help='Download and audit everything')
parser.add_argument('--training', action='store_true', help='Download training datasets from HF Hub')
parser.add_argument('--benchmark', action='store_true', help='Download benchmark datasets from GitHub')
parser.add_argument('--audit', action='store_true', help='Run full dataset audit')
parser.add_argument('--preprocess', action='store_true', help='Preprocess training data (clean, deduplicate)')
parser.add_argument('--vocab', action='store_true', help='Show codon vocabulary and tokenizer')
parser.add_argument('--export', type=str, default=None, help='Export directory for preprocessed CSVs')
parser.add_argument('--benchmark_dir', type=str, default='./benchmark_data',
help='Directory for benchmark data')
parser.add_argument('--use_both', action='store_true', default=True,
help='Use both training datasets (default: True)')
parser.add_argument('--mogam_only', action='store_true',
help='Use only mogam-ai dataset (smaller, cleaner)')
args = parser.parse_args()
# Default: show help
if not any([args.all, args.training, args.benchmark, args.audit, args.preprocess, args.vocab]):
parser.print_help()
return
if args.vocab or args.all:
show_vocab()
if args.training or args.all:
print("\n\nπŸ“¦ DOWNLOADING TRAINING DATASETS")
print("=" * 60)
download_training_data(export_dir=args.export)
if args.benchmark or args.all:
print("\n\nπŸ“¦ DOWNLOADING BENCHMARK DATASETS")
print("=" * 60)
download_benchmark_data(args.benchmark_dir)
if args.preprocess or args.all:
print("\n\nπŸ”§ PREPROCESSING TRAINING DATA")
print("=" * 60)
use_both = not args.mogam_only
clean_data = preprocess_training_data(use_both_datasets=use_both)
if args.export:
export_training_data(clean_data, args.export)
if args.audit or args.all:
print("\n\nπŸ” RUNNING FULL AUDIT")
run_full_audit()
if __name__ == '__main__':
main()