| |
| """DeepGenopix Training Script β Phase 1: Biological Validation. |
| |
| Hyperparameter sweep script for TE family classification. |
| Supports baseline_v1, stride2_v1, stride8_v1, latent128_v1, latent768_v1, |
| layers4_v1, stem5_v1, stem7_v1 presets. |
| |
| Usage: |
| python scripts/train.py --preset baseline_v1 --data_dir /path/to/data |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import trackio |
| from torch.utils.data import DataLoader |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) |
|
|
| from deepgenopix.config import ( |
| BP_PER_TOKEN, |
| DIM_FEEDFORWARD, |
| D_MODEL, |
| DROPOUT, |
| NHEAD, |
| NUM_LAYERS, |
| STEM_KERNEL_SIZE, |
| COMPRESSOR_STRIDE, |
| ) |
| from deepgenopix.dataset import DeepGenopixDataset, dynamic_collate_fn |
| from deepgenopix.io_utils import load_json |
| from deepgenopix.model import DeepGenopixClassifier, count_parameters |
| from deepgenopix.trainer import train_model |
|
|
|
|
| |
| PRESETS = { |
| "baseline_v1": { |
| "description": "Stride 4 (48bp/token), 256d latent, stem kernel 3, 2 layers", |
| "compressor_stride": 4, |
| "d_model": 256, |
| "stem_kernel": 3, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| "stride2_v1": { |
| "description": "Stride 2 (24bp/token). High fidelity compression test.", |
| "compressor_stride": 2, |
| "d_model": 256, |
| "stem_kernel": 3, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| "stride8_v1": { |
| "description": "Stride 8 (96bp/token). Extreme abstraction test.", |
| "compressor_stride": 8, |
| "d_model": 256, |
| "stem_kernel": 3, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| "latent128_v1": { |
| "description": "128d latent vector. Compressed semantics test.", |
| "compressor_stride": 4, |
| "d_model": 128, |
| "stem_kernel": 3, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 512, |
| }, |
| "latent768_v1": { |
| "description": "768d latent vector. LLM-scale embedding test.", |
| "compressor_stride": 4, |
| "d_model": 768, |
| "stem_kernel": 3, |
| "num_layers": 2, |
| "nhead": 8, |
| "dim_feedforward": 3072, |
| }, |
| "layers4_v1": { |
| "description": "4-layer Transformer Encoder. Deep grammar test.", |
| "compressor_stride": 4, |
| "d_model": 256, |
| "stem_kernel": 3, |
| "num_layers": 4, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| "stem5_v1": { |
| "description": "Conv1d stem kernel 5. Wider retina (60bp field).", |
| "compressor_stride": 4, |
| "d_model": 256, |
| "stem_kernel": 5, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| "stem7_v1": { |
| "description": "Conv1d stem kernel 7. Widest retina (84bp field).", |
| "compressor_stride": 4, |
| "d_model": 256, |
| "stem_kernel": 7, |
| "num_layers": 2, |
| "nhead": 4, |
| "dim_feedforward": 1024, |
| }, |
| } |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="DeepGenopix Training") |
| parser.add_argument( |
| "--preset", |
| type=str, |
| default="baseline_v1", |
| choices=list(PRESETS.keys()), |
| help="Hyperparameter preset", |
| ) |
| parser.add_argument( |
| "--data_dir", |
| type=str, |
| default="/app/data/processed", |
| help="Directory with registry.csv, classes.json, te_visuals/", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="/app/data/output", |
| help="Output directory for checkpoints", |
| ) |
| parser.add_argument( |
| "--epochs", |
| type=int, |
| default=50, |
| help="Maximum training epochs", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=32, |
| help="Batch size", |
| ) |
| parser.add_argument( |
| "--lr", |
| type=float, |
| default=1e-4, |
| help="Learning rate", |
| ) |
| parser.add_argument( |
| "--weight_decay", |
| type=float, |
| default=1e-4, |
| help="Weight decay", |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=None, |
| help="Cap samples for dry run (None = all)", |
| ) |
| parser.add_argument( |
| "--num_workers", |
| type=int, |
| default=2, |
| help="DataLoader workers", |
| ) |
| parser.add_argument( |
| "--trackio_project", |
| type=str, |
| default="deepgenopix-v1", |
| help="Trackio project name", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| preset = PRESETS[args.preset] |
|
|
| print("=" * 70) |
| print(f"DeepGenopix Training β Preset: {args.preset}") |
| print(f" {preset['description']}") |
| print(f" Data: {args.data_dir}") |
| print(f" Output: {args.output_dir}") |
| print("=" * 70) |
|
|
| data_dir = Path(args.data_dir) |
| output_dir = Path(args.output_dir) / args.preset |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| classes_path = data_dir / "classes.json" |
| classes = load_json(classes_path) |
| num_classes = len(classes) |
| class_names = sorted(classes.keys(), key=lambda k: classes[k]) |
| print(f"\n[Main] {num_classes} classes loaded") |
|
|
| |
| registry_path = data_dir / "registry.csv" |
| visuals_dir = data_dir / "te_visuals" |
|
|
| train_dataset = DeepGenopixDataset( |
| registry_path, visuals_dir, split="train", max_samples=args.max_samples, |
| ) |
| val_dataset = DeepGenopixDataset( |
| registry_path, visuals_dir, split="val", max_samples=args.max_samples, |
| ) |
|
|
| print(f"[Main] Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples") |
|
|
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| collate_fn=dynamic_collate_fn, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| collate_fn=dynamic_collate_fn, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
|
|
| |
| model = DeepGenopixClassifier( |
| num_classes=num_classes, |
| stem_kernel=preset["stem_kernel"], |
| compressor_stride=preset["compressor_stride"], |
| d_model=preset["d_model"], |
| nhead=preset["nhead"], |
| num_layers=preset["num_layers"], |
| dim_feedforward=preset["dim_feedforward"], |
| dropout=DROPOUT, |
| ) |
|
|
| params = count_parameters(model) |
| bp_per_token = 12 * preset["compressor_stride"] |
|
|
| print(f"\n[Main] Model: {params['trainable']:,} params") |
| print(f"[Main] BP/token: {bp_per_token}") |
|
|
| |
| trackio.init( |
| project=args.trackio_project, |
| run_name=args.preset, |
| ) |
| trackio.log_params({ |
| "preset": args.preset, |
| "description": preset["description"], |
| "compressor_stride": preset["compressor_stride"], |
| "d_model": preset["d_model"], |
| "stem_kernel": preset["stem_kernel"], |
| "num_layers": preset["num_layers"], |
| "nhead": preset["nhead"], |
| "bp_per_token": bp_per_token, |
| "num_classes": num_classes, |
| "num_params": params["trainable"], |
| "batch_size": args.batch_size, |
| "lr": args.lr, |
| }) |
|
|
| |
| print(f"\n[Main] Starting training...") |
| model, metrics = train_model( |
| model=model, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| num_classes=num_classes, |
| class_names=class_names, |
| epochs=args.epochs, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| output_dir=output_dir, |
| trackio_project=args.trackio_project, |
| ) |
|
|
| |
| print(f"\n{'='*70}") |
| print(f"Training Complete: {args.preset}") |
| print(f" Best Val Acc: {metrics.best_val_acc:.4f}") |
| print(f" Best Val F1: {metrics.best_val_f1:.4f}") |
| print(f" Best Epoch: {metrics.best_epoch}") |
| print(f"{'='*70}") |
|
|
| trackio.alert( |
| f"Training Complete: {args.preset}", |
| f"Acc={metrics.best_val_acc:.4f}, F1={metrics.best_val_f1:.4f}, Epoch={metrics.best_epoch}", |
| level="success", |
| ) |
|
|
| |
| summary = { |
| "preset": args.preset, |
| **preset, |
| "best_val_acc": float(metrics.best_val_acc), |
| "best_val_f1": float(metrics.best_val_f1), |
| "best_epoch": metrics.best_epoch, |
| "num_params": params["trainable"], |
| "num_classes": num_classes, |
| "bp_per_token": bp_per_token, |
| } |
|
|
| with open(output_dir / "summary.json", "w") as fh: |
| json.dump(summary, fh, indent=2) |
|
|
| print(f"[Main] Summary saved: {output_dir / 'summary.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|