""" data_setup.py β€” Download, preprocess, audit, and prepare all datasets for CodonFM-80M mRNA stability fine-tuning and benchmarking. Usage: # Download and preprocess everything python data_setup.py --all # Download only training datasets python data_setup.py --training # Download only benchmark datasets python data_setup.py --benchmark # Audit datasets (inspect stats, find issues) python data_setup.py --audit # Export preprocessed training data to local files python data_setup.py --training --export ./processed_data # Show codon vocabulary and tokenizer details python data_setup.py --vocab """ import argparse import json import os import sys import csv import urllib.request from collections import Counter from pathlib import Path import numpy as np try: import pandas as pd except ImportError: pd = None try: from datasets import load_dataset except ImportError: load_dataset = None # ============================================================ # 1. CODON VOCABULARY & TOKENIZER # ============================================================ RNA_BASES = ['A', 'U', 'G', 'C'] ALL_CODONS = [b1 + b2 + b3 for b1 in RNA_BASES for b2 in RNA_BASES for b3 in RNA_BASES] # Biological codon table (RNA) β†’ Amino Acid CODON_TABLE = { 'UUU': 'Phe', 'UUC': 'Phe', 'UUA': 'Leu', 'UUG': 'Leu', 'CUU': 'Leu', 'CUC': 'Leu', 'CUA': 'Leu', 'CUG': 'Leu', 'AUU': 'Ile', 'AUC': 'Ile', 'AUA': 'Ile', 'AUG': 'Met/Start', 'GUU': 'Val', 'GUC': 'Val', 'GUA': 'Val', 'GUG': 'Val', 'UCU': 'Ser', 'UCC': 'Ser', 'UCA': 'Ser', 'UCG': 'Ser', 'CCU': 'Pro', 'CCC': 'Pro', 'CCA': 'Pro', 'CCG': 'Pro', 'ACU': 'Thr', 'ACC': 'Thr', 'ACA': 'Thr', 'ACG': 'Thr', 'GCU': 'Ala', 'GCC': 'Ala', 'GCA': 'Ala', 'GCG': 'Ala', 'UAU': 'Tyr', 'UAC': 'Tyr', 'UAA': 'Stop', 'UAG': 'Stop', 'CAU': 'His', 'CAC': 'His', 'CAA': 'Gln', 'CAG': 'Gln', 'AAU': 'Asn', 'AAC': 'Asn', 'AAA': 'Lys', 'AAG': 'Lys', 'GAU': 'Asp', 'GAC': 'Asp', 'GAA': 'Glu', 'GAG': 'Glu', 'UGU': 'Cys', 'UGC': 'Cys', 'UGA': 'Stop', 'UGG': 'Trp', 'CGU': 'Arg', 'CGC': 'Arg', 'CGA': 'Arg', 'CGG': 'Arg', 'AGU': 'Ser', 'AGC': 'Ser', 'AGA': 'Arg', 'AGG': 'Arg', 'GGU': 'Gly', 'GGC': 'Gly', 'GGA': 'Gly', 'GGG': 'Gly', } # Token vocabulary (matches CodonFM config: vocab_size=69, pad_token_id=3) SPECIAL_TOKENS = {'[CLS]': 0, '[SEP]': 1, '[MASK]': 2, '[PAD]': 3, '[UNK]': 4} CODON_TO_ID = {codon: i + 5 for i, codon in enumerate(ALL_CODONS)} ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()} ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()}) VOCAB_SIZE = len(SPECIAL_TOKENS) + len(ALL_CODONS) # 5 + 64 = 69 assert VOCAB_SIZE == 69 PAD_TOKEN_ID = 3 CLS_TOKEN_ID = 0 SEP_TOKEN_ID = 1 MASK_TOKEN_ID = 2 UNK_TOKEN_ID = 4 def seq_to_codons(seq: str) -> list: """Split an mRNA/DNA sequence into codon triplets.""" seq = seq.upper().replace('T', 'U').strip() return [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)] def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict: """Tokenize an mRNA/DNA sequence into CodonFM token IDs.""" codons = seq_to_codons(seq) token_ids = [CLS_TOKEN_ID] for codon in codons[:max_length - 2]: token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID)) token_ids.append(SEP_TOKEN_ID) attention_mask = [1] * len(token_ids) return {'input_ids': token_ids, 'attention_mask': attention_mask} def validate_sequence(seq: str) -> dict: """Validate an mRNA/DNA sequence for CodonFM compatibility.""" seq_clean = seq.upper().replace('T', 'U').strip() issues = [] if len(seq_clean) == 0: issues.append("Empty sequence") if len(seq_clean) % 3 != 0: issues.append(f"Length {len(seq_clean)} not divisible by 3 (truncated to {len(seq_clean) - len(seq_clean) % 3})") invalid_chars = set(seq_clean) - {'A', 'U', 'G', 'C'} if invalid_chars: issues.append(f"Invalid characters: {invalid_chars}") codons = seq_to_codons(seq_clean) n_codons = len(codons) # Check for start codon starts_with_aug = codons[0] == 'AUG' if codons else False # Check for stop codons stop_codons = {'UAA', 'UAG', 'UGA'} internal_stops = [i for i, c in enumerate(codons[:-1]) if c in stop_codons] ends_with_stop = codons[-1] in stop_codons if codons else False if internal_stops: issues.append(f"Internal stop codons at positions: {internal_stops}") # Unknown codons unk_codons = [c for c in codons if c not in CODON_TO_ID] if unk_codons: issues.append(f"Unknown codons: {set(unk_codons)}") return { 'valid': len(issues) == 0, 'issues': issues, 'length_nt': len(seq_clean), 'length_codons': n_codons, 'starts_with_AUG': starts_with_aug, 'ends_with_stop': ends_with_stop, 'n_internal_stops': len(internal_stops), } # ============================================================ # 2. TRAINING DATASETS # ============================================================ TRAINING_DATASETS = { 'mogam-ai/CDS-BART-mRNA-stability': { 'description': 'iCodon vertebrate mRNA stability profiles (human, mouse, frog, fish)', 'source_paper': 'Diez et al. 2022, Scientific Reports "iCodon customizes gene expression based on the codon composition"', 'seq_col': 'seq', 'label_col': 'y', 'splits': {'train': 'train', 'val': 'val', 'test': 'test'}, 'label_meaning': 'mRNA half-life z-score (higher = more stable, meanβ‰ˆ0, stdβ‰ˆ1)', 'species': ['Human', 'Mouse', 'Xenopus (frog)', 'Zebrafish'], 'notes': 'RNA sequences (A,U,G,C). All divisible by 3. Subset of GleghornLab dataset.', }, 'GleghornLab/mrna_stability_other': { 'description': 'Extended multi-species mRNA stability data (superset of mogam-ai dataset)', 'source_paper': 'Li et al. 2024, Genome Research "CodonBERT large language model for mRNA vaccines"', 'seq_col': 'rna', 'label_col': 'labels', 'splits': {'train': 'train', 'val': 'valid', 'test': 'test'}, 'label_meaning': 'mRNA half-life z-score (higher = more stable)', 'species': ['Multiple vertebrate species'], 'notes': 'Has extra "seqs" column (protein-encoded, not used). Contains 1 outlier sequence of 3 nt. Superset of mogam-ai.', 'extra_col': 'seqs', }, } def download_training_data(export_dir=None): """Download and inspect training datasets from HuggingFace Hub.""" if load_dataset is None: print("ERROR: `datasets` library required. Run: pip install datasets") return None all_data = {} for repo_id, info in TRAINING_DATASETS.items(): print(f"\n{'='*60}") print(f"πŸ“¦ {repo_id}") print(f" {info['description']}") print(f"={'='*60}") ds = load_dataset(repo_id) for split_name, hf_split in info['splits'].items(): split_data = ds[hf_split] seqs = split_data[info['seq_col']] labels = split_data[info['label_col']] print(f"\n [{split_name}] {len(seqs)} samples") print(f" Seq lengths (nt): min={min(len(s) for s in seqs)}, " f"mean={np.mean([len(s) for s in seqs]):.0f}, " f"max={max(len(s) for s in seqs)}") print(f" Seq lengths (cod): min={min(len(s)//3 for s in seqs)}, " f"mean={np.mean([len(s)//3 for s in seqs]):.0f}, " f"max={max(len(s)//3 for s in seqs)}") labels_arr = np.array(labels) print(f" Labels: min={labels_arr.min():.3f}, mean={labels_arr.mean():.3f}, " f"std={labels_arr.std():.3f}, max={labels_arr.max():.3f}") all_data[repo_id] = ds if export_dir: export_training_data(all_data, export_dir) return all_data def preprocess_training_data(use_both_datasets=True, min_codons=3, max_codons=2046, remove_duplicates=True, deduplicate_across_splits=True): """ Preprocess training data: clean, filter, deduplicate, and combine. Steps: 1. Load both HF datasets 2. Use GleghornLab as primary (superset) OR combine both 3. Filter: remove sequences < min_codons or > max_codons codons 4. Filter: remove sequences with NaN labels 5. Filter: remove sequences with invalid characters 6. Deduplicate: remove exact sequence duplicates within each split 7. Deduplicate: ensure no train sequences appear in val/test (data leakage check) 8. Return clean {train, val, test} dictionaries Returns: dict with 'train', 'val', 'test' keys, each containing 'sequences' and 'labels' lists """ if load_dataset is None: raise ImportError("datasets library required: pip install datasets") print("Loading datasets...") if use_both_datasets: # Use GleghornLab (superset) β€” it contains ALL of mogam-ai plus extra data ds = load_dataset("GleghornLab/mrna_stability_other") raw_data = { 'train': {'seqs': ds['train']['rna'], 'labels': ds['train']['labels']}, 'val': {'seqs': ds['valid']['rna'], 'labels': ds['valid']['labels']}, 'test': {'seqs': ds['test']['rna'], 'labels': ds['test']['labels']}, } print(f" Using GleghornLab/mrna_stability_other (superset)") else: # Use mogam-ai only (smaller, cleaner) ds = load_dataset("mogam-ai/CDS-BART-mRNA-stability") raw_data = { 'train': {'seqs': ds['train']['seq'], 'labels': ds['train']['y']}, 'val': {'seqs': ds['val']['seq'], 'labels': ds['val']['y']}, 'test': {'seqs': ds['test']['seq'], 'labels': ds['test']['y']}, } print(f" Using mogam-ai/CDS-BART-mRNA-stability only") clean_data = {} total_removed = {'short': 0, 'long': 0, 'nan': 0, 'invalid': 0, 'duplicate': 0} for split in ['train', 'val', 'test']: seqs = raw_data[split]['seqs'] labels = raw_data[split]['labels'] orig_count = len(seqs) clean_seqs = [] clean_labels = [] seen = set() for seq, label in zip(seqs, labels): # Skip None/empty if seq is None or len(seq) == 0: total_removed['invalid'] += 1 continue # Normalize: uppercase, Tβ†’U seq = seq.upper().replace('T', 'U').strip() # Check NaN label if np.isnan(label): total_removed['nan'] += 1 continue # Check invalid characters if set(seq) - {'A', 'U', 'G', 'C'}: total_removed['invalid'] += 1 continue # Check length n_codons = len(seq) // 3 if n_codons < min_codons: total_removed['short'] += 1 continue if n_codons > max_codons: total_removed['long'] += 1 continue # Deduplicate within split if remove_duplicates: if seq in seen: total_removed['duplicate'] += 1 continue seen.add(seq) clean_seqs.append(seq) clean_labels.append(float(label)) clean_data[split] = {'sequences': clean_seqs, 'labels': clean_labels} print(f" [{split}] {orig_count} β†’ {len(clean_seqs)} samples " f"(removed {orig_count - len(clean_seqs)})") # Cross-split deduplication: check for train/test leakage if deduplicate_across_splits: test_seqs = set(clean_data['test']['sequences']) val_seqs = set(clean_data['val']['sequences']) leakage_test = sum(1 for s in clean_data['train']['sequences'] if s in test_seqs) leakage_val = sum(1 for s in clean_data['train']['sequences'] if s in val_seqs) val_test_overlap = len(val_seqs & test_seqs) print(f"\n Data leakage check:") print(f" Trainβ†’Test overlap: {leakage_test} sequences") print(f" Trainβ†’Val overlap: {leakage_val} sequences") print(f" Valβ†’Test overlap: {val_test_overlap} sequences") if leakage_test > 0 or leakage_val > 0: print(f" ⚠️ WARNING: Data leakage detected! Removing leaked sequences from train...") eval_seqs = test_seqs | val_seqs filtered_train = [(s, l) for s, l in zip(clean_data['train']['sequences'], clean_data['train']['labels']) if s not in eval_seqs] clean_data['train']['sequences'] = [x[0] for x in filtered_train] clean_data['train']['labels'] = [x[1] for x in filtered_train] print(f" Train after dedup: {len(clean_data['train']['sequences'])} samples") print(f"\n Removal summary: {total_removed}") print(f" Final sizes: train={len(clean_data['train']['sequences'])}, " f"val={len(clean_data['val']['sequences'])}, " f"test={len(clean_data['test']['sequences'])}") return clean_data def export_training_data(data, export_dir): """Export preprocessed data to CSV files.""" os.makedirs(export_dir, exist_ok=True) if isinstance(data, dict) and 'train' in data and 'sequences' in data.get('train', {}): # Already preprocessed format for split in ['train', 'val', 'test']: if split not in data: continue filepath = os.path.join(export_dir, f'{split}.csv') with open(filepath, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['sequence', 'stability_score']) for seq, label in zip(data[split]['sequences'], data[split]['labels']): writer.writerow([seq, label]) print(f" Exported {split}: {len(data[split]['sequences'])} rows β†’ {filepath}") else: print(" Export requires preprocessed data. Run preprocess_training_data() first.") # ============================================================ # 3. BENCHMARK DATASETS # ============================================================ CODONBERT_BASE_URL = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/master/benchmarks/CodonBERT/data/fine-tune" BENCHMARK_DATASETS = { 'stability': { 'url': f"{CODONBERT_BASE_URL}/mRNA_Stability.csv", 'filename': 'mRNA_Stability.csv', 'description': 'mRNA Stability (iCodon vertebrate mRNA half-life)', 'source': 'Diez et al. 2022, Scientific Reports', 'samples': 65356, 'seq_length': '3-3066 nt (1-1022 codons)', 'label': 'Half-life z-score (continuous, meanβ‰ˆ0, stdβ‰ˆ1)', 'metric': 'Spearman ρ', 'columns': 'sequence, value, dataset, split', 'species': 'Multi-vertebrate (human, mouse, frog, fish)', }, 'mrfp': { 'url': f"{CODONBERT_BASE_URL}/mRFP_Expression.csv", 'filename': 'mRFP_Expression.csv', 'description': 'mRFP Protein Expression in E. coli', 'source': 'Li et al. 2024, Genome Research (CodonBERT)', 'samples': 1459, 'seq_length': '678 nt (226 codons, fixed)', 'label': 'Fluorescence intensity (log scale, range 7.4-11.4)', 'metric': 'Spearman ρ', 'columns': 'sequence, value, dataset, split', 'species': 'E. coli (synthetic mRFP variants)', }, 'vaccine': { 'url': f"{CODONBERT_BASE_URL}/CoV_Vaccine_Degradation.csv", 'filename': 'CoV_Vaccine_Degradation.csv', 'description': 'SARS-CoV-2 mRNA Vaccine Degradation', 'source': 'CodonBERT benchmark (derived from Stanford OpenVaccine)', 'samples': 2400, 'seq_length': '81 nt (27 codons, fixed)', 'label': 'Degradation score (z-normalized, range -7.2 to 6.5)', 'metric': 'Spearman ρ', 'columns': 'sequence, value, dataset, split', 'species': 'Synthetic SARS-CoV-2 mRNA vaccine fragments', }, 'riboswitch': { 'url': f"{CODONBERT_BASE_URL}/Tc-Riboswitches.csv", 'filename': 'Tc-Riboswitches.csv', 'description': 'Tetracycline Riboswitch Activity', 'source': 'Groher et al. 2018 (via CodonBERT)', 'samples': 355, 'seq_length': '66-75 nt (22-25 codons)', 'label': 'Switching factor (continuous, range -0.3 to 3.1)', 'metric': 'Spearman ρ', 'columns': 'sequence, value, dataset, split', 'species': 'Synthetic tetracycline riboswitches', }, 'mlos': { 'url': f"{CODONBERT_BASE_URL}/MLOS.csv", 'filename': 'MLOS.csv', 'description': 'MLOS Flu Vaccine Antigen Expression', 'source': 'Ren et al. 2024 (HELM/MLOS)', 'samples': 167, 'seq_length': '~1700 nt (~567 codons)', 'label': 'Expression level (continuous, range 0.3-2.2)', 'metric': 'Spearman ρ', 'columns': 'cds, value (no split column β€” uses random 70/15/15)', 'species': 'Influenza haemagglutinin CDS variants', 'notes': 'No pre-defined splits. Column name is "cds" not "sequence".', }, } def download_benchmark_data(data_dir='./benchmark_data'): """Download all benchmark datasets.""" os.makedirs(data_dir, exist_ok=True) for task_name, info in BENCHMARK_DATASETS.items(): filepath = os.path.join(data_dir, info['filename']) if os.path.exists(filepath): size = os.path.getsize(filepath) print(f" βœ“ {info['filename']} already exists ({size/1024:.1f} KB)") else: print(f" ↓ Downloading {info['filename']}...") try: urllib.request.urlretrieve(info['url'], filepath) size = os.path.getsize(filepath) print(f" βœ“ Downloaded {info['filename']} ({size/1024:.1f} KB)") except Exception as e: print(f" βœ— Failed to download {info['filename']}: {e}") return data_dir # ============================================================ # 4. AUDIT # ============================================================ def audit_dataset(sequences, labels, name="dataset"): """Run a comprehensive audit on a list of sequences and labels.""" print(f"\n{'='*60}") print(f"AUDIT: {name} ({len(sequences)} sequences)") print(f"{'='*60}") if len(sequences) == 0: print(" (empty)") return # ---- Sequence stats ---- lengths_nt = [len(s) for s in sequences] lengths_codon = [len(s) // 3 for s in sequences] print(f"\n πŸ“ Sequence Lengths:") print(f" Nucleotides: min={min(lengths_nt)}, mean={np.mean(lengths_nt):.0f}, " f"median={np.median(lengths_nt):.0f}, max={max(lengths_nt)}") print(f" Codons: min={min(lengths_codon)}, mean={np.mean(lengths_codon):.0f}, " f"median={np.median(lengths_codon):.0f}, max={max(lengths_codon)}") # Length distribution buckets buckets = [0, 100, 300, 500, 1000, 2000, 3000, 10000] hist = np.histogram(lengths_codon, bins=buckets)[0] print(f" Codon length distribution:") for i, count in enumerate(hist): pct = 100 * count / len(sequences) bar = 'β–ˆ' * int(pct / 2) print(f" {buckets[i]:>5}-{buckets[i+1]:>5} codons: {count:>6} ({pct:>5.1f}%) {bar}") # Sequences > 2046 codons (CodonFM max) over_limit = sum(1 for c in lengths_codon if c > 2046) if over_limit > 0: print(f" ⚠️ {over_limit} sequences exceed CodonFM max (2046 codons) β€” will be truncated") # ---- Nucleotide composition ---- all_chars = Counter() for s in sequences: all_chars.update(s.upper()) total_bases = sum(all_chars.values()) print(f"\n 🧬 Nucleotide Composition:") for base in ['A', 'U', 'G', 'C']: count = all_chars.get(base, 0) pct = 100 * count / total_bases print(f" {base}: {count:>12,} ({pct:.1f}%)") unexpected = {k: v for k, v in all_chars.items() if k not in 'AUGC'} if unexpected: print(f" ⚠️ Unexpected characters: {unexpected}") # Not divisible by 3 not_div3 = sum(1 for s in sequences if len(s) % 3 != 0) if not_div3 > 0: print(f" ⚠️ {not_div3} sequences not divisible by 3") # ---- Codon usage ---- codon_counts = Counter() for s in sequences[:5000]: # sample for speed codons = seq_to_codons(s) codon_counts.update(codons) print(f"\n πŸ”€ Codon Usage (top 10 / bottom 10 from {min(5000, len(sequences))} seqs):") sorted_codons = codon_counts.most_common() for codon, count in sorted_codons[:10]: aa = CODON_TABLE.get(codon, '?') print(f" {codon} ({aa:>9s}): {count:>8,}") print(f" ...") for codon, count in sorted_codons[-10:]: aa = CODON_TABLE.get(codon, '?') print(f" {codon} ({aa:>9s}): {count:>8,}") # Start/stop codon analysis starts_with_aug = sum(1 for s in sequences if s[:3].upper().replace('T', 'U') == 'AUG') stop_codons = {'UAA', 'UAG', 'UGA'} ends_with_stop = sum(1 for s in sequences if seq_to_codons(s)[-1] in stop_codons) if sequences else 0 print(f"\n 🚦 Start/Stop Codons:") print(f" Starts with AUG: {starts_with_aug}/{len(sequences)} ({100*starts_with_aug/len(sequences):.1f}%)") print(f" Ends with stop: {ends_with_stop}/{len(sequences)} ({100*ends_with_stop/len(sequences):.1f}%)") # ---- Label stats ---- labels_arr = np.array(labels, dtype=float) nan_count = np.isnan(labels_arr).sum() labels_clean = labels_arr[~np.isnan(labels_arr)] print(f"\n πŸ“Š Label Distribution:") print(f" Count: {len(labels_arr)}, NaN: {nan_count}") if len(labels_clean) > 0: print(f" Min: {labels_clean.min():.4f}") print(f" Q1: {np.percentile(labels_clean, 25):.4f}") print(f" Median: {np.median(labels_clean):.4f}") print(f" Q3: {np.percentile(labels_clean, 75):.4f}") print(f" Max: {labels_clean.max():.4f}") print(f" Mean: {labels_clean.mean():.4f}") print(f" Std: {labels_clean.std():.4f}") print(f" Skew: {float(((labels_clean - labels_clean.mean()) ** 3).mean() / labels_clean.std() ** 3):.4f}") # ---- Duplicates ---- unique_seqs = len(set(sequences)) dup_count = len(sequences) - unique_seqs print(f"\n πŸ” Duplicates:") print(f" Unique sequences: {unique_seqs}") print(f" Duplicate sequences: {dup_count}") # ---- Outliers ---- if len(labels_clean) > 0: q1, q3 = np.percentile(labels_clean, [25, 75]) iqr = q3 - q1 lower = q1 - 3 * iqr upper = q3 + 3 * iqr outliers = np.sum((labels_clean < lower) | (labels_clean > upper)) print(f"\n ⚑ Outliers (>3 IQR):") print(f" Label outliers: {outliers}/{len(labels_clean)}") very_short = sum(1 for c in lengths_codon if c < 10) very_long = sum(1 for c in lengths_codon if c > 1000) print(f" Very short (<10 codons): {very_short}") print(f" Very long (>1000 codons): {very_long}") def run_full_audit(): """Run audit on all training and benchmark datasets.""" print("=" * 70) print("FULL DATASET AUDIT") print("=" * 70) # Training datasets print("\n\nπŸ“š TRAINING DATASETS") print("=" * 70) if load_dataset is not None: ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability") for split in ['train', 'val', 'test']: audit_dataset( ds1[split]['seq'], ds1[split]['y'], f"mogam-ai/CDS-BART-mRNA-stability [{split}]" ) ds2 = load_dataset("GleghornLab/mrna_stability_other") for split, hf_split in [('train', 'train'), ('val', 'valid'), ('test', 'test')]: audit_dataset( ds2[hf_split]['rna'], ds2[hf_split]['labels'], f"GleghornLab/mrna_stability_other [{split}]" ) # Cross-dataset analysis print("\n\nπŸ“Š CROSS-DATASET ANALYSIS") print("=" * 60) ds1_all = set(ds1['train']['seq']) | set(ds1['val']['seq']) | set(ds1['test']['seq']) ds2_all = set(ds2['train']['rna']) | set(ds2['valid']['rna']) | set(ds2['test']['rna']) print(f" mogam-ai total unique: {len(ds1_all)}") print(f" GleghornLab total unique: {len(ds2_all)}") print(f" Overlap: {len(ds1_all & ds2_all)}") print(f" mogam-ai βŠ‚ GleghornLab: {ds1_all.issubset(ds2_all)}") print(f" GleghornLab-only: {len(ds2_all - ds1_all)}") else: print(" Skipping (datasets library not installed)") # Benchmark datasets print("\n\nπŸ“Š BENCHMARK DATASETS") print("=" * 70) if pd is not None: data_dir = download_benchmark_data() for task_name, info in BENCHMARK_DATASETS.items(): filepath = os.path.join(data_dir, info['filename']) if not os.path.exists(filepath): continue df = pd.read_csv(filepath) df.columns = [c.lower().strip() for c in df.columns] seq_col = 'sequence' if 'sequence' in df.columns else 'cds' val_col = 'value' if seq_col in df.columns and val_col in df.columns: audit_dataset( df[seq_col].tolist(), df[val_col].tolist(), f"Benchmark: {task_name} ({info['filename']})" ) else: print(" Skipping (pandas not installed)") # ============================================================ # 5. VOCAB DISPLAY # ============================================================ def show_vocab(): """Display the full codon vocabulary with amino acid mapping.""" print("=" * 70) print("CodonFM Tokenizer Vocabulary (vocab_size=69)") print("=" * 70) print("\n SPECIAL TOKENS:") print(f" {'ID':>4} {'Token':<10} {'Description'}") print(f" {'─'*4} {'─'*10} {'─'*30}") descriptions = { '[CLS]': 'Classification token (prepended)', '[SEP]': 'Separator token (appended)', '[MASK]': 'Mask token (for MLM pretraining)', '[PAD]': 'Padding token (pad_token_id=3)', '[UNK]': 'Unknown token (for invalid codons)', } for token, tid in sorted(SPECIAL_TOKENS.items(), key=lambda x: x[1]): print(f" {tid:>4} {token:<10} {descriptions.get(token, '')}") print(f"\n CODON TOKENS (64 codons β†’ 20 amino acids + 3 stop):") print(f" {'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid':<12} " f"{'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid'}") print(f" {'─'*4} {'─'*6} {'─'*12} {'─'*4} {'─'*6} {'─'*12} " f"{'─'*4} {'─'*6} {'─'*12} {'─'*4} {'─'*6} {'─'*12}") items = sorted(CODON_TO_ID.items(), key=lambda x: x[1]) for i in range(0, len(items), 4): row_parts = [] for j in range(4): if i + j < len(items): codon, tid = items[i + j] aa = CODON_TABLE.get(codon, '?') row_parts.append(f" {tid:>4} {codon:<6} {aa:<12}") else: row_parts.append(f" {'':>4} {'':6} {'':12}") print("".join(row_parts)) print(f"\n TOKENIZATION EXAMPLE:") example = "AUGGCAGCCGAGACUCGG" codons = seq_to_codons(example) tokens = tokenize_mRNA(example) print(f" Input: {example}") print(f" Codons: {' '.join(codons)}") print(f" Token IDs: {tokens['input_ids']}") decoded = [ID_TO_CODON.get(t, '?') for t in tokens['input_ids']] print(f" Decoded: {' '.join(decoded)}") print(f"\n CONFIG (matches nvidia/NV-CodonFM-Encodon-80M-v1/config.json):") print(f" vocab_size: 69") print(f" pad_token_id: 3 ([PAD])") print(f" max_position_embeddings: 2046 codons (~6138 nt)") print(f" position_embedding_type: rotary (RoPE, ΞΈ=10000)") # ============================================================ # CLI # ============================================================ def main(): parser = argparse.ArgumentParser( description="Dataset setup for CodonFM-80M mRNA stability fine-tuning", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python data_setup.py --all # Download & audit everything python data_setup.py --training # Download training datasets python data_setup.py --training --export ./data # Download & export to CSV python data_setup.py --benchmark # Download benchmark datasets python data_setup.py --audit # Full audit of all datasets python data_setup.py --vocab # Show codon vocabulary python data_setup.py --preprocess # Preprocess & deduplicate """ ) parser.add_argument('--all', action='store_true', help='Download and audit everything') parser.add_argument('--training', action='store_true', help='Download training datasets from HF Hub') parser.add_argument('--benchmark', action='store_true', help='Download benchmark datasets from GitHub') parser.add_argument('--audit', action='store_true', help='Run full dataset audit') parser.add_argument('--preprocess', action='store_true', help='Preprocess training data (clean, deduplicate)') parser.add_argument('--vocab', action='store_true', help='Show codon vocabulary and tokenizer') parser.add_argument('--export', type=str, default=None, help='Export directory for preprocessed CSVs') parser.add_argument('--benchmark_dir', type=str, default='./benchmark_data', help='Directory for benchmark data') parser.add_argument('--use_both', action='store_true', default=True, help='Use both training datasets (default: True)') parser.add_argument('--mogam_only', action='store_true', help='Use only mogam-ai dataset (smaller, cleaner)') args = parser.parse_args() # Default: show help if not any([args.all, args.training, args.benchmark, args.audit, args.preprocess, args.vocab]): parser.print_help() return if args.vocab or args.all: show_vocab() if args.training or args.all: print("\n\nπŸ“¦ DOWNLOADING TRAINING DATASETS") print("=" * 60) download_training_data(export_dir=args.export) if args.benchmark or args.all: print("\n\nπŸ“¦ DOWNLOADING BENCHMARK DATASETS") print("=" * 60) download_benchmark_data(args.benchmark_dir) if args.preprocess or args.all: print("\n\nπŸ”§ PREPROCESSING TRAINING DATA") print("=" * 60) use_both = not args.mogam_only clean_data = preprocess_training_data(use_both_datasets=use_both) if args.export: export_training_data(clean_data, args.export) if args.audit or args.all: print("\n\nπŸ” RUNNING FULL AUDIT") run_full_audit() if __name__ == '__main__': main()