| """ |
| data_setup.py β Download, preprocess, audit, and prepare all datasets for |
| CodonFM-80M mRNA stability fine-tuning and benchmarking. |
| |
| Usage: |
| # Download and preprocess everything |
| python data_setup.py --all |
| |
| # Download only training datasets |
| python data_setup.py --training |
| |
| # Download only benchmark datasets |
| python data_setup.py --benchmark |
| |
| # Audit datasets (inspect stats, find issues) |
| python data_setup.py --audit |
| |
| # Export preprocessed training data to local files |
| python data_setup.py --training --export ./processed_data |
| |
| # Show codon vocabulary and tokenizer details |
| python data_setup.py --vocab |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import csv |
| import urllib.request |
| from collections import Counter |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| try: |
| import pandas as pd |
| except ImportError: |
| pd = None |
|
|
| try: |
| from datasets import load_dataset |
| except ImportError: |
| load_dataset = None |
|
|
|
|
| |
| |
| |
|
|
| RNA_BASES = ['A', 'U', 'G', 'C'] |
| ALL_CODONS = [b1 + b2 + b3 for b1 in RNA_BASES for b2 in RNA_BASES for b3 in RNA_BASES] |
|
|
| |
| CODON_TABLE = { |
| 'UUU': 'Phe', 'UUC': 'Phe', 'UUA': 'Leu', 'UUG': 'Leu', |
| 'CUU': 'Leu', 'CUC': 'Leu', 'CUA': 'Leu', 'CUG': 'Leu', |
| 'AUU': 'Ile', 'AUC': 'Ile', 'AUA': 'Ile', 'AUG': 'Met/Start', |
| 'GUU': 'Val', 'GUC': 'Val', 'GUA': 'Val', 'GUG': 'Val', |
| 'UCU': 'Ser', 'UCC': 'Ser', 'UCA': 'Ser', 'UCG': 'Ser', |
| 'CCU': 'Pro', 'CCC': 'Pro', 'CCA': 'Pro', 'CCG': 'Pro', |
| 'ACU': 'Thr', 'ACC': 'Thr', 'ACA': 'Thr', 'ACG': 'Thr', |
| 'GCU': 'Ala', 'GCC': 'Ala', 'GCA': 'Ala', 'GCG': 'Ala', |
| 'UAU': 'Tyr', 'UAC': 'Tyr', 'UAA': 'Stop', 'UAG': 'Stop', |
| 'CAU': 'His', 'CAC': 'His', 'CAA': 'Gln', 'CAG': 'Gln', |
| 'AAU': 'Asn', 'AAC': 'Asn', 'AAA': 'Lys', 'AAG': 'Lys', |
| 'GAU': 'Asp', 'GAC': 'Asp', 'GAA': 'Glu', 'GAG': 'Glu', |
| 'UGU': 'Cys', 'UGC': 'Cys', 'UGA': 'Stop', 'UGG': 'Trp', |
| 'CGU': 'Arg', 'CGC': 'Arg', 'CGA': 'Arg', 'CGG': 'Arg', |
| 'AGU': 'Ser', 'AGC': 'Ser', 'AGA': 'Arg', 'AGG': 'Arg', |
| 'GGU': 'Gly', 'GGC': 'Gly', 'GGA': 'Gly', 'GGG': 'Gly', |
| } |
|
|
| |
| SPECIAL_TOKENS = {'[CLS]': 0, '[SEP]': 1, '[MASK]': 2, '[PAD]': 3, '[UNK]': 4} |
| CODON_TO_ID = {codon: i + 5 for i, codon in enumerate(ALL_CODONS)} |
| ID_TO_CODON = {v: k for k, v in CODON_TO_ID.items()} |
| ID_TO_CODON.update({v: k for k, v in SPECIAL_TOKENS.items()}) |
|
|
| VOCAB_SIZE = len(SPECIAL_TOKENS) + len(ALL_CODONS) |
| assert VOCAB_SIZE == 69 |
|
|
| PAD_TOKEN_ID = 3 |
| CLS_TOKEN_ID = 0 |
| SEP_TOKEN_ID = 1 |
| MASK_TOKEN_ID = 2 |
| UNK_TOKEN_ID = 4 |
|
|
|
|
| def seq_to_codons(seq: str) -> list: |
| """Split an mRNA/DNA sequence into codon triplets.""" |
| seq = seq.upper().replace('T', 'U').strip() |
| return [seq[i:i+3] for i in range(0, len(seq) - len(seq) % 3, 3)] |
|
|
|
|
| def tokenize_mRNA(seq: str, max_length: int = 2046) -> dict: |
| """Tokenize an mRNA/DNA sequence into CodonFM token IDs.""" |
| codons = seq_to_codons(seq) |
| token_ids = [CLS_TOKEN_ID] |
| for codon in codons[:max_length - 2]: |
| token_ids.append(CODON_TO_ID.get(codon, UNK_TOKEN_ID)) |
| token_ids.append(SEP_TOKEN_ID) |
| attention_mask = [1] * len(token_ids) |
| return {'input_ids': token_ids, 'attention_mask': attention_mask} |
|
|
|
|
| def validate_sequence(seq: str) -> dict: |
| """Validate an mRNA/DNA sequence for CodonFM compatibility.""" |
| seq_clean = seq.upper().replace('T', 'U').strip() |
| issues = [] |
|
|
| if len(seq_clean) == 0: |
| issues.append("Empty sequence") |
| if len(seq_clean) % 3 != 0: |
| issues.append(f"Length {len(seq_clean)} not divisible by 3 (truncated to {len(seq_clean) - len(seq_clean) % 3})") |
|
|
| invalid_chars = set(seq_clean) - {'A', 'U', 'G', 'C'} |
| if invalid_chars: |
| issues.append(f"Invalid characters: {invalid_chars}") |
|
|
| codons = seq_to_codons(seq_clean) |
| n_codons = len(codons) |
|
|
| |
| starts_with_aug = codons[0] == 'AUG' if codons else False |
|
|
| |
| stop_codons = {'UAA', 'UAG', 'UGA'} |
| internal_stops = [i for i, c in enumerate(codons[:-1]) if c in stop_codons] |
| ends_with_stop = codons[-1] in stop_codons if codons else False |
|
|
| if internal_stops: |
| issues.append(f"Internal stop codons at positions: {internal_stops}") |
|
|
| |
| unk_codons = [c for c in codons if c not in CODON_TO_ID] |
| if unk_codons: |
| issues.append(f"Unknown codons: {set(unk_codons)}") |
|
|
| return { |
| 'valid': len(issues) == 0, |
| 'issues': issues, |
| 'length_nt': len(seq_clean), |
| 'length_codons': n_codons, |
| 'starts_with_AUG': starts_with_aug, |
| 'ends_with_stop': ends_with_stop, |
| 'n_internal_stops': len(internal_stops), |
| } |
|
|
|
|
| |
| |
| |
|
|
| TRAINING_DATASETS = { |
| 'mogam-ai/CDS-BART-mRNA-stability': { |
| 'description': 'iCodon vertebrate mRNA stability profiles (human, mouse, frog, fish)', |
| 'source_paper': 'Diez et al. 2022, Scientific Reports "iCodon customizes gene expression based on the codon composition"', |
| 'seq_col': 'seq', |
| 'label_col': 'y', |
| 'splits': {'train': 'train', 'val': 'val', 'test': 'test'}, |
| 'label_meaning': 'mRNA half-life z-score (higher = more stable, meanβ0, stdβ1)', |
| 'species': ['Human', 'Mouse', 'Xenopus (frog)', 'Zebrafish'], |
| 'notes': 'RNA sequences (A,U,G,C). All divisible by 3. Subset of GleghornLab dataset.', |
| }, |
| 'GleghornLab/mrna_stability_other': { |
| 'description': 'Extended multi-species mRNA stability data (superset of mogam-ai dataset)', |
| 'source_paper': 'Li et al. 2024, Genome Research "CodonBERT large language model for mRNA vaccines"', |
| 'seq_col': 'rna', |
| 'label_col': 'labels', |
| 'splits': {'train': 'train', 'val': 'valid', 'test': 'test'}, |
| 'label_meaning': 'mRNA half-life z-score (higher = more stable)', |
| 'species': ['Multiple vertebrate species'], |
| 'notes': 'Has extra "seqs" column (protein-encoded, not used). Contains 1 outlier sequence of 3 nt. Superset of mogam-ai.', |
| 'extra_col': 'seqs', |
| }, |
| } |
|
|
|
|
| def download_training_data(export_dir=None): |
| """Download and inspect training datasets from HuggingFace Hub.""" |
| if load_dataset is None: |
| print("ERROR: `datasets` library required. Run: pip install datasets") |
| return None |
|
|
| all_data = {} |
|
|
| for repo_id, info in TRAINING_DATASETS.items(): |
| print(f"\n{'='*60}") |
| print(f"π¦ {repo_id}") |
| print(f" {info['description']}") |
| print(f"={'='*60}") |
|
|
| ds = load_dataset(repo_id) |
|
|
| for split_name, hf_split in info['splits'].items(): |
| split_data = ds[hf_split] |
| seqs = split_data[info['seq_col']] |
| labels = split_data[info['label_col']] |
|
|
| print(f"\n [{split_name}] {len(seqs)} samples") |
| print(f" Seq lengths (nt): min={min(len(s) for s in seqs)}, " |
| f"mean={np.mean([len(s) for s in seqs]):.0f}, " |
| f"max={max(len(s) for s in seqs)}") |
| print(f" Seq lengths (cod): min={min(len(s)//3 for s in seqs)}, " |
| f"mean={np.mean([len(s)//3 for s in seqs]):.0f}, " |
| f"max={max(len(s)//3 for s in seqs)}") |
| labels_arr = np.array(labels) |
| print(f" Labels: min={labels_arr.min():.3f}, mean={labels_arr.mean():.3f}, " |
| f"std={labels_arr.std():.3f}, max={labels_arr.max():.3f}") |
|
|
| all_data[repo_id] = ds |
|
|
| if export_dir: |
| export_training_data(all_data, export_dir) |
|
|
| return all_data |
|
|
|
|
| def preprocess_training_data(use_both_datasets=True, min_codons=3, max_codons=2046, |
| remove_duplicates=True, deduplicate_across_splits=True): |
| """ |
| Preprocess training data: clean, filter, deduplicate, and combine. |
| |
| Steps: |
| 1. Load both HF datasets |
| 2. Use GleghornLab as primary (superset) OR combine both |
| 3. Filter: remove sequences < min_codons or > max_codons codons |
| 4. Filter: remove sequences with NaN labels |
| 5. Filter: remove sequences with invalid characters |
| 6. Deduplicate: remove exact sequence duplicates within each split |
| 7. Deduplicate: ensure no train sequences appear in val/test (data leakage check) |
| 8. Return clean {train, val, test} dictionaries |
| |
| Returns: |
| dict with 'train', 'val', 'test' keys, each containing 'sequences' and 'labels' lists |
| """ |
| if load_dataset is None: |
| raise ImportError("datasets library required: pip install datasets") |
|
|
| print("Loading datasets...") |
|
|
| if use_both_datasets: |
| |
| ds = load_dataset("GleghornLab/mrna_stability_other") |
| raw_data = { |
| 'train': {'seqs': ds['train']['rna'], 'labels': ds['train']['labels']}, |
| 'val': {'seqs': ds['valid']['rna'], 'labels': ds['valid']['labels']}, |
| 'test': {'seqs': ds['test']['rna'], 'labels': ds['test']['labels']}, |
| } |
| print(f" Using GleghornLab/mrna_stability_other (superset)") |
| else: |
| |
| ds = load_dataset("mogam-ai/CDS-BART-mRNA-stability") |
| raw_data = { |
| 'train': {'seqs': ds['train']['seq'], 'labels': ds['train']['y']}, |
| 'val': {'seqs': ds['val']['seq'], 'labels': ds['val']['y']}, |
| 'test': {'seqs': ds['test']['seq'], 'labels': ds['test']['y']}, |
| } |
| print(f" Using mogam-ai/CDS-BART-mRNA-stability only") |
|
|
| clean_data = {} |
| total_removed = {'short': 0, 'long': 0, 'nan': 0, 'invalid': 0, 'duplicate': 0} |
|
|
| for split in ['train', 'val', 'test']: |
| seqs = raw_data[split]['seqs'] |
| labels = raw_data[split]['labels'] |
| orig_count = len(seqs) |
|
|
| clean_seqs = [] |
| clean_labels = [] |
| seen = set() |
|
|
| for seq, label in zip(seqs, labels): |
| |
| if seq is None or len(seq) == 0: |
| total_removed['invalid'] += 1 |
| continue |
|
|
| |
| seq = seq.upper().replace('T', 'U').strip() |
|
|
| |
| if np.isnan(label): |
| total_removed['nan'] += 1 |
| continue |
|
|
| |
| if set(seq) - {'A', 'U', 'G', 'C'}: |
| total_removed['invalid'] += 1 |
| continue |
|
|
| |
| n_codons = len(seq) // 3 |
| if n_codons < min_codons: |
| total_removed['short'] += 1 |
| continue |
| if n_codons > max_codons: |
| total_removed['long'] += 1 |
| continue |
|
|
| |
| if remove_duplicates: |
| if seq in seen: |
| total_removed['duplicate'] += 1 |
| continue |
| seen.add(seq) |
|
|
| clean_seqs.append(seq) |
| clean_labels.append(float(label)) |
|
|
| clean_data[split] = {'sequences': clean_seqs, 'labels': clean_labels} |
| print(f" [{split}] {orig_count} β {len(clean_seqs)} samples " |
| f"(removed {orig_count - len(clean_seqs)})") |
|
|
| |
| if deduplicate_across_splits: |
| test_seqs = set(clean_data['test']['sequences']) |
| val_seqs = set(clean_data['val']['sequences']) |
|
|
| leakage_test = sum(1 for s in clean_data['train']['sequences'] if s in test_seqs) |
| leakage_val = sum(1 for s in clean_data['train']['sequences'] if s in val_seqs) |
| val_test_overlap = len(val_seqs & test_seqs) |
|
|
| print(f"\n Data leakage check:") |
| print(f" TrainβTest overlap: {leakage_test} sequences") |
| print(f" TrainβVal overlap: {leakage_val} sequences") |
| print(f" ValβTest overlap: {val_test_overlap} sequences") |
|
|
| if leakage_test > 0 or leakage_val > 0: |
| print(f" β οΈ WARNING: Data leakage detected! Removing leaked sequences from train...") |
| eval_seqs = test_seqs | val_seqs |
| filtered_train = [(s, l) for s, l in |
| zip(clean_data['train']['sequences'], clean_data['train']['labels']) |
| if s not in eval_seqs] |
| clean_data['train']['sequences'] = [x[0] for x in filtered_train] |
| clean_data['train']['labels'] = [x[1] for x in filtered_train] |
| print(f" Train after dedup: {len(clean_data['train']['sequences'])} samples") |
|
|
| print(f"\n Removal summary: {total_removed}") |
| print(f" Final sizes: train={len(clean_data['train']['sequences'])}, " |
| f"val={len(clean_data['val']['sequences'])}, " |
| f"test={len(clean_data['test']['sequences'])}") |
|
|
| return clean_data |
|
|
|
|
| def export_training_data(data, export_dir): |
| """Export preprocessed data to CSV files.""" |
| os.makedirs(export_dir, exist_ok=True) |
|
|
| if isinstance(data, dict) and 'train' in data and 'sequences' in data.get('train', {}): |
| |
| for split in ['train', 'val', 'test']: |
| if split not in data: |
| continue |
| filepath = os.path.join(export_dir, f'{split}.csv') |
| with open(filepath, 'w', newline='') as f: |
| writer = csv.writer(f) |
| writer.writerow(['sequence', 'stability_score']) |
| for seq, label in zip(data[split]['sequences'], data[split]['labels']): |
| writer.writerow([seq, label]) |
| print(f" Exported {split}: {len(data[split]['sequences'])} rows β {filepath}") |
| else: |
| print(" Export requires preprocessed data. Run preprocess_training_data() first.") |
|
|
|
|
| |
| |
| |
|
|
| CODONBERT_BASE_URL = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/master/benchmarks/CodonBERT/data/fine-tune" |
|
|
| BENCHMARK_DATASETS = { |
| 'stability': { |
| 'url': f"{CODONBERT_BASE_URL}/mRNA_Stability.csv", |
| 'filename': 'mRNA_Stability.csv', |
| 'description': 'mRNA Stability (iCodon vertebrate mRNA half-life)', |
| 'source': 'Diez et al. 2022, Scientific Reports', |
| 'samples': 65356, |
| 'seq_length': '3-3066 nt (1-1022 codons)', |
| 'label': 'Half-life z-score (continuous, meanβ0, stdβ1)', |
| 'metric': 'Spearman Ο', |
| 'columns': 'sequence, value, dataset, split', |
| 'species': 'Multi-vertebrate (human, mouse, frog, fish)', |
| }, |
| 'mrfp': { |
| 'url': f"{CODONBERT_BASE_URL}/mRFP_Expression.csv", |
| 'filename': 'mRFP_Expression.csv', |
| 'description': 'mRFP Protein Expression in E. coli', |
| 'source': 'Li et al. 2024, Genome Research (CodonBERT)', |
| 'samples': 1459, |
| 'seq_length': '678 nt (226 codons, fixed)', |
| 'label': 'Fluorescence intensity (log scale, range 7.4-11.4)', |
| 'metric': 'Spearman Ο', |
| 'columns': 'sequence, value, dataset, split', |
| 'species': 'E. coli (synthetic mRFP variants)', |
| }, |
| 'vaccine': { |
| 'url': f"{CODONBERT_BASE_URL}/CoV_Vaccine_Degradation.csv", |
| 'filename': 'CoV_Vaccine_Degradation.csv', |
| 'description': 'SARS-CoV-2 mRNA Vaccine Degradation', |
| 'source': 'CodonBERT benchmark (derived from Stanford OpenVaccine)', |
| 'samples': 2400, |
| 'seq_length': '81 nt (27 codons, fixed)', |
| 'label': 'Degradation score (z-normalized, range -7.2 to 6.5)', |
| 'metric': 'Spearman Ο', |
| 'columns': 'sequence, value, dataset, split', |
| 'species': 'Synthetic SARS-CoV-2 mRNA vaccine fragments', |
| }, |
| 'riboswitch': { |
| 'url': f"{CODONBERT_BASE_URL}/Tc-Riboswitches.csv", |
| 'filename': 'Tc-Riboswitches.csv', |
| 'description': 'Tetracycline Riboswitch Activity', |
| 'source': 'Groher et al. 2018 (via CodonBERT)', |
| 'samples': 355, |
| 'seq_length': '66-75 nt (22-25 codons)', |
| 'label': 'Switching factor (continuous, range -0.3 to 3.1)', |
| 'metric': 'Spearman Ο', |
| 'columns': 'sequence, value, dataset, split', |
| 'species': 'Synthetic tetracycline riboswitches', |
| }, |
| 'mlos': { |
| 'url': f"{CODONBERT_BASE_URL}/MLOS.csv", |
| 'filename': 'MLOS.csv', |
| 'description': 'MLOS Flu Vaccine Antigen Expression', |
| 'source': 'Ren et al. 2024 (HELM/MLOS)', |
| 'samples': 167, |
| 'seq_length': '~1700 nt (~567 codons)', |
| 'label': 'Expression level (continuous, range 0.3-2.2)', |
| 'metric': 'Spearman Ο', |
| 'columns': 'cds, value (no split column β uses random 70/15/15)', |
| 'species': 'Influenza haemagglutinin CDS variants', |
| 'notes': 'No pre-defined splits. Column name is "cds" not "sequence".', |
| }, |
| } |
|
|
|
|
| def download_benchmark_data(data_dir='./benchmark_data'): |
| """Download all benchmark datasets.""" |
| os.makedirs(data_dir, exist_ok=True) |
|
|
| for task_name, info in BENCHMARK_DATASETS.items(): |
| filepath = os.path.join(data_dir, info['filename']) |
| if os.path.exists(filepath): |
| size = os.path.getsize(filepath) |
| print(f" β {info['filename']} already exists ({size/1024:.1f} KB)") |
| else: |
| print(f" β Downloading {info['filename']}...") |
| try: |
| urllib.request.urlretrieve(info['url'], filepath) |
| size = os.path.getsize(filepath) |
| print(f" β Downloaded {info['filename']} ({size/1024:.1f} KB)") |
| except Exception as e: |
| print(f" β Failed to download {info['filename']}: {e}") |
|
|
| return data_dir |
|
|
|
|
| |
| |
| |
|
|
| def audit_dataset(sequences, labels, name="dataset"): |
| """Run a comprehensive audit on a list of sequences and labels.""" |
| print(f"\n{'='*60}") |
| print(f"AUDIT: {name} ({len(sequences)} sequences)") |
| print(f"{'='*60}") |
|
|
| if len(sequences) == 0: |
| print(" (empty)") |
| return |
|
|
| |
| lengths_nt = [len(s) for s in sequences] |
| lengths_codon = [len(s) // 3 for s in sequences] |
|
|
| print(f"\n π Sequence Lengths:") |
| print(f" Nucleotides: min={min(lengths_nt)}, mean={np.mean(lengths_nt):.0f}, " |
| f"median={np.median(lengths_nt):.0f}, max={max(lengths_nt)}") |
| print(f" Codons: min={min(lengths_codon)}, mean={np.mean(lengths_codon):.0f}, " |
| f"median={np.median(lengths_codon):.0f}, max={max(lengths_codon)}") |
|
|
| |
| buckets = [0, 100, 300, 500, 1000, 2000, 3000, 10000] |
| hist = np.histogram(lengths_codon, bins=buckets)[0] |
| print(f" Codon length distribution:") |
| for i, count in enumerate(hist): |
| pct = 100 * count / len(sequences) |
| bar = 'β' * int(pct / 2) |
| print(f" {buckets[i]:>5}-{buckets[i+1]:>5} codons: {count:>6} ({pct:>5.1f}%) {bar}") |
|
|
| |
| over_limit = sum(1 for c in lengths_codon if c > 2046) |
| if over_limit > 0: |
| print(f" β οΈ {over_limit} sequences exceed CodonFM max (2046 codons) β will be truncated") |
|
|
| |
| all_chars = Counter() |
| for s in sequences: |
| all_chars.update(s.upper()) |
| total_bases = sum(all_chars.values()) |
| print(f"\n 𧬠Nucleotide Composition:") |
| for base in ['A', 'U', 'G', 'C']: |
| count = all_chars.get(base, 0) |
| pct = 100 * count / total_bases |
| print(f" {base}: {count:>12,} ({pct:.1f}%)") |
| unexpected = {k: v for k, v in all_chars.items() if k not in 'AUGC'} |
| if unexpected: |
| print(f" β οΈ Unexpected characters: {unexpected}") |
|
|
| |
| not_div3 = sum(1 for s in sequences if len(s) % 3 != 0) |
| if not_div3 > 0: |
| print(f" β οΈ {not_div3} sequences not divisible by 3") |
|
|
| |
| codon_counts = Counter() |
| for s in sequences[:5000]: |
| codons = seq_to_codons(s) |
| codon_counts.update(codons) |
|
|
| print(f"\n π€ Codon Usage (top 10 / bottom 10 from {min(5000, len(sequences))} seqs):") |
| sorted_codons = codon_counts.most_common() |
| for codon, count in sorted_codons[:10]: |
| aa = CODON_TABLE.get(codon, '?') |
| print(f" {codon} ({aa:>9s}): {count:>8,}") |
| print(f" ...") |
| for codon, count in sorted_codons[-10:]: |
| aa = CODON_TABLE.get(codon, '?') |
| print(f" {codon} ({aa:>9s}): {count:>8,}") |
|
|
| |
| starts_with_aug = sum(1 for s in sequences if s[:3].upper().replace('T', 'U') == 'AUG') |
| stop_codons = {'UAA', 'UAG', 'UGA'} |
| ends_with_stop = sum(1 for s in sequences |
| if seq_to_codons(s)[-1] in stop_codons) if sequences else 0 |
| print(f"\n π¦ Start/Stop Codons:") |
| print(f" Starts with AUG: {starts_with_aug}/{len(sequences)} ({100*starts_with_aug/len(sequences):.1f}%)") |
| print(f" Ends with stop: {ends_with_stop}/{len(sequences)} ({100*ends_with_stop/len(sequences):.1f}%)") |
|
|
| |
| labels_arr = np.array(labels, dtype=float) |
| nan_count = np.isnan(labels_arr).sum() |
| labels_clean = labels_arr[~np.isnan(labels_arr)] |
|
|
| print(f"\n π Label Distribution:") |
| print(f" Count: {len(labels_arr)}, NaN: {nan_count}") |
| if len(labels_clean) > 0: |
| print(f" Min: {labels_clean.min():.4f}") |
| print(f" Q1: {np.percentile(labels_clean, 25):.4f}") |
| print(f" Median: {np.median(labels_clean):.4f}") |
| print(f" Q3: {np.percentile(labels_clean, 75):.4f}") |
| print(f" Max: {labels_clean.max():.4f}") |
| print(f" Mean: {labels_clean.mean():.4f}") |
| print(f" Std: {labels_clean.std():.4f}") |
| print(f" Skew: {float(((labels_clean - labels_clean.mean()) ** 3).mean() / labels_clean.std() ** 3):.4f}") |
|
|
| |
| unique_seqs = len(set(sequences)) |
| dup_count = len(sequences) - unique_seqs |
| print(f"\n π Duplicates:") |
| print(f" Unique sequences: {unique_seqs}") |
| print(f" Duplicate sequences: {dup_count}") |
|
|
| |
| if len(labels_clean) > 0: |
| q1, q3 = np.percentile(labels_clean, [25, 75]) |
| iqr = q3 - q1 |
| lower = q1 - 3 * iqr |
| upper = q3 + 3 * iqr |
| outliers = np.sum((labels_clean < lower) | (labels_clean > upper)) |
| print(f"\n β‘ Outliers (>3 IQR):") |
| print(f" Label outliers: {outliers}/{len(labels_clean)}") |
|
|
| very_short = sum(1 for c in lengths_codon if c < 10) |
| very_long = sum(1 for c in lengths_codon if c > 1000) |
| print(f" Very short (<10 codons): {very_short}") |
| print(f" Very long (>1000 codons): {very_long}") |
|
|
|
|
| def run_full_audit(): |
| """Run audit on all training and benchmark datasets.""" |
| print("=" * 70) |
| print("FULL DATASET AUDIT") |
| print("=" * 70) |
|
|
| |
| print("\n\nπ TRAINING DATASETS") |
| print("=" * 70) |
|
|
| if load_dataset is not None: |
| ds1 = load_dataset("mogam-ai/CDS-BART-mRNA-stability") |
| for split in ['train', 'val', 'test']: |
| audit_dataset( |
| ds1[split]['seq'], ds1[split]['y'], |
| f"mogam-ai/CDS-BART-mRNA-stability [{split}]" |
| ) |
|
|
| ds2 = load_dataset("GleghornLab/mrna_stability_other") |
| for split, hf_split in [('train', 'train'), ('val', 'valid'), ('test', 'test')]: |
| audit_dataset( |
| ds2[hf_split]['rna'], ds2[hf_split]['labels'], |
| f"GleghornLab/mrna_stability_other [{split}]" |
| ) |
|
|
| |
| print("\n\nπ CROSS-DATASET ANALYSIS") |
| print("=" * 60) |
| ds1_all = set(ds1['train']['seq']) | set(ds1['val']['seq']) | set(ds1['test']['seq']) |
| ds2_all = set(ds2['train']['rna']) | set(ds2['valid']['rna']) | set(ds2['test']['rna']) |
| print(f" mogam-ai total unique: {len(ds1_all)}") |
| print(f" GleghornLab total unique: {len(ds2_all)}") |
| print(f" Overlap: {len(ds1_all & ds2_all)}") |
| print(f" mogam-ai β GleghornLab: {ds1_all.issubset(ds2_all)}") |
| print(f" GleghornLab-only: {len(ds2_all - ds1_all)}") |
| else: |
| print(" Skipping (datasets library not installed)") |
|
|
| |
| print("\n\nπ BENCHMARK DATASETS") |
| print("=" * 70) |
|
|
| if pd is not None: |
| data_dir = download_benchmark_data() |
| for task_name, info in BENCHMARK_DATASETS.items(): |
| filepath = os.path.join(data_dir, info['filename']) |
| if not os.path.exists(filepath): |
| continue |
| df = pd.read_csv(filepath) |
| df.columns = [c.lower().strip() for c in df.columns] |
|
|
| seq_col = 'sequence' if 'sequence' in df.columns else 'cds' |
| val_col = 'value' |
|
|
| if seq_col in df.columns and val_col in df.columns: |
| audit_dataset( |
| df[seq_col].tolist(), df[val_col].tolist(), |
| f"Benchmark: {task_name} ({info['filename']})" |
| ) |
| else: |
| print(" Skipping (pandas not installed)") |
|
|
|
|
| |
| |
| |
|
|
| def show_vocab(): |
| """Display the full codon vocabulary with amino acid mapping.""" |
| print("=" * 70) |
| print("CodonFM Tokenizer Vocabulary (vocab_size=69)") |
| print("=" * 70) |
|
|
| print("\n SPECIAL TOKENS:") |
| print(f" {'ID':>4} {'Token':<10} {'Description'}") |
| print(f" {'β'*4} {'β'*10} {'β'*30}") |
| descriptions = { |
| '[CLS]': 'Classification token (prepended)', |
| '[SEP]': 'Separator token (appended)', |
| '[MASK]': 'Mask token (for MLM pretraining)', |
| '[PAD]': 'Padding token (pad_token_id=3)', |
| '[UNK]': 'Unknown token (for invalid codons)', |
| } |
| for token, tid in sorted(SPECIAL_TOKENS.items(), key=lambda x: x[1]): |
| print(f" {tid:>4} {token:<10} {descriptions.get(token, '')}") |
|
|
| print(f"\n CODON TOKENS (64 codons β 20 amino acids + 3 stop):") |
| print(f" {'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid':<12} " |
| f"{'ID':>4} {'Codon':<6} {'Amino Acid':<12} {'ID':>4} {'Codon':<6} {'Amino Acid'}") |
| print(f" {'β'*4} {'β'*6} {'β'*12} {'β'*4} {'β'*6} {'β'*12} " |
| f"{'β'*4} {'β'*6} {'β'*12} {'β'*4} {'β'*6} {'β'*12}") |
|
|
| items = sorted(CODON_TO_ID.items(), key=lambda x: x[1]) |
| for i in range(0, len(items), 4): |
| row_parts = [] |
| for j in range(4): |
| if i + j < len(items): |
| codon, tid = items[i + j] |
| aa = CODON_TABLE.get(codon, '?') |
| row_parts.append(f" {tid:>4} {codon:<6} {aa:<12}") |
| else: |
| row_parts.append(f" {'':>4} {'':6} {'':12}") |
| print("".join(row_parts)) |
|
|
| print(f"\n TOKENIZATION EXAMPLE:") |
| example = "AUGGCAGCCGAGACUCGG" |
| codons = seq_to_codons(example) |
| tokens = tokenize_mRNA(example) |
| print(f" Input: {example}") |
| print(f" Codons: {' '.join(codons)}") |
| print(f" Token IDs: {tokens['input_ids']}") |
| decoded = [ID_TO_CODON.get(t, '?') for t in tokens['input_ids']] |
| print(f" Decoded: {' '.join(decoded)}") |
|
|
| print(f"\n CONFIG (matches nvidia/NV-CodonFM-Encodon-80M-v1/config.json):") |
| print(f" vocab_size: 69") |
| print(f" pad_token_id: 3 ([PAD])") |
| print(f" max_position_embeddings: 2046 codons (~6138 nt)") |
| print(f" position_embedding_type: rotary (RoPE, ΞΈ=10000)") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Dataset setup for CodonFM-80M mRNA stability fine-tuning", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| python data_setup.py --all # Download & audit everything |
| python data_setup.py --training # Download training datasets |
| python data_setup.py --training --export ./data # Download & export to CSV |
| python data_setup.py --benchmark # Download benchmark datasets |
| python data_setup.py --audit # Full audit of all datasets |
| python data_setup.py --vocab # Show codon vocabulary |
| python data_setup.py --preprocess # Preprocess & deduplicate |
| """ |
| ) |
|
|
| parser.add_argument('--all', action='store_true', help='Download and audit everything') |
| parser.add_argument('--training', action='store_true', help='Download training datasets from HF Hub') |
| parser.add_argument('--benchmark', action='store_true', help='Download benchmark datasets from GitHub') |
| parser.add_argument('--audit', action='store_true', help='Run full dataset audit') |
| parser.add_argument('--preprocess', action='store_true', help='Preprocess training data (clean, deduplicate)') |
| parser.add_argument('--vocab', action='store_true', help='Show codon vocabulary and tokenizer') |
| parser.add_argument('--export', type=str, default=None, help='Export directory for preprocessed CSVs') |
| parser.add_argument('--benchmark_dir', type=str, default='./benchmark_data', |
| help='Directory for benchmark data') |
| parser.add_argument('--use_both', action='store_true', default=True, |
| help='Use both training datasets (default: True)') |
| parser.add_argument('--mogam_only', action='store_true', |
| help='Use only mogam-ai dataset (smaller, cleaner)') |
|
|
| args = parser.parse_args() |
|
|
| |
| if not any([args.all, args.training, args.benchmark, args.audit, args.preprocess, args.vocab]): |
| parser.print_help() |
| return |
|
|
| if args.vocab or args.all: |
| show_vocab() |
|
|
| if args.training or args.all: |
| print("\n\nπ¦ DOWNLOADING TRAINING DATASETS") |
| print("=" * 60) |
| download_training_data(export_dir=args.export) |
|
|
| if args.benchmark or args.all: |
| print("\n\nπ¦ DOWNLOADING BENCHMARK DATASETS") |
| print("=" * 60) |
| download_benchmark_data(args.benchmark_dir) |
|
|
| if args.preprocess or args.all: |
| print("\n\nπ§ PREPROCESSING TRAINING DATA") |
| print("=" * 60) |
| use_both = not args.mogam_only |
| clean_data = preprocess_training_data(use_both_datasets=use_both) |
| if args.export: |
| export_training_data(clean_data, args.export) |
|
|
| if args.audit or args.all: |
| print("\n\nπ RUNNING FULL AUDIT") |
| run_full_audit() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|