deepgenopix / scripts /train.py
vedatonuryilmaz's picture
Upload scripts/train.py
c83c966 verified
#!/usr/bin/env python3
"""DeepGenopix Training Script β€” Phase 1: Biological Validation.
Hyperparameter sweep script for TE family classification.
Supports baseline_v1, stride2_v1, stride8_v1, latent128_v1, latent768_v1,
layers4_v1, stem5_v1, stem7_v1 presets.
Usage:
python scripts/train.py --preset baseline_v1 --data_dir /path/to/data
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import trackio
from torch.utils.data import DataLoader
# Add src to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from deepgenopix.config import (
BP_PER_TOKEN,
DIM_FEEDFORWARD,
D_MODEL,
DROPOUT,
NHEAD,
NUM_LAYERS,
STEM_KERNEL_SIZE,
COMPRESSOR_STRIDE,
)
from deepgenopix.dataset import DeepGenopixDataset, dynamic_collate_fn
from deepgenopix.io_utils import load_json
from deepgenopix.model import DeepGenopixClassifier, count_parameters
from deepgenopix.trainer import train_model
# ── Preset Definitions ────────────────────────────────────────────────────────
PRESETS = {
"baseline_v1": {
"description": "Stride 4 (48bp/token), 256d latent, stem kernel 3, 2 layers",
"compressor_stride": 4,
"d_model": 256,
"stem_kernel": 3,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 1024,
},
"stride2_v1": {
"description": "Stride 2 (24bp/token). High fidelity compression test.",
"compressor_stride": 2,
"d_model": 256,
"stem_kernel": 3,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 1024,
},
"stride8_v1": {
"description": "Stride 8 (96bp/token). Extreme abstraction test.",
"compressor_stride": 8,
"d_model": 256,
"stem_kernel": 3,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 1024,
},
"latent128_v1": {
"description": "128d latent vector. Compressed semantics test.",
"compressor_stride": 4,
"d_model": 128,
"stem_kernel": 3,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 512,
},
"latent768_v1": {
"description": "768d latent vector. LLM-scale embedding test.",
"compressor_stride": 4,
"d_model": 768,
"stem_kernel": 3,
"num_layers": 2,
"nhead": 8,
"dim_feedforward": 3072,
},
"layers4_v1": {
"description": "4-layer Transformer Encoder. Deep grammar test.",
"compressor_stride": 4,
"d_model": 256,
"stem_kernel": 3,
"num_layers": 4,
"nhead": 4,
"dim_feedforward": 1024,
},
"stem5_v1": {
"description": "Conv1d stem kernel 5. Wider retina (60bp field).",
"compressor_stride": 4,
"d_model": 256,
"stem_kernel": 5,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 1024,
},
"stem7_v1": {
"description": "Conv1d stem kernel 7. Widest retina (84bp field).",
"compressor_stride": 4,
"d_model": 256,
"stem_kernel": 7,
"num_layers": 2,
"nhead": 4,
"dim_feedforward": 1024,
},
}
def parse_args():
parser = argparse.ArgumentParser(description="DeepGenopix Training")
parser.add_argument(
"--preset",
type=str,
default="baseline_v1",
choices=list(PRESETS.keys()),
help="Hyperparameter preset",
)
parser.add_argument(
"--data_dir",
type=str,
default="/app/data/processed",
help="Directory with registry.csv, classes.json, te_visuals/",
)
parser.add_argument(
"--output_dir",
type=str,
default="/app/data/output",
help="Output directory for checkpoints",
)
parser.add_argument(
"--epochs",
type=int,
default=50,
help="Maximum training epochs",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size",
)
parser.add_argument(
"--lr",
type=float,
default=1e-4,
help="Learning rate",
)
parser.add_argument(
"--weight_decay",
type=float,
default=1e-4,
help="Weight decay",
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Cap samples for dry run (None = all)",
)
parser.add_argument(
"--num_workers",
type=int,
default=2,
help="DataLoader workers",
)
parser.add_argument(
"--trackio_project",
type=str,
default="deepgenopix-v1",
help="Trackio project name",
)
return parser.parse_args()
def main():
args = parse_args()
preset = PRESETS[args.preset]
print("=" * 70)
print(f"DeepGenopix Training β€” Preset: {args.preset}")
print(f" {preset['description']}")
print(f" Data: {args.data_dir}")
print(f" Output: {args.output_dir}")
print("=" * 70)
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir) / args.preset
os.makedirs(output_dir, exist_ok=True)
# ── Load class mapping ─────────────────────────────────────────────────
classes_path = data_dir / "classes.json"
classes = load_json(classes_path)
num_classes = len(classes)
class_names = sorted(classes.keys(), key=lambda k: classes[k])
print(f"\n[Main] {num_classes} classes loaded")
# ── Create datasets ────────────────────────────────────────────────────
registry_path = data_dir / "registry.csv"
visuals_dir = data_dir / "te_visuals"
train_dataset = DeepGenopixDataset(
registry_path, visuals_dir, split="train", max_samples=args.max_samples,
)
val_dataset = DeepGenopixDataset(
registry_path, visuals_dir, split="val", max_samples=args.max_samples,
)
print(f"[Main] Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
# ── DataLoaders ────────────────────────────────────────────────────────
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=dynamic_collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=dynamic_collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
# ── Build model ────────────────────────────────────────────────────────
model = DeepGenopixClassifier(
num_classes=num_classes,
stem_kernel=preset["stem_kernel"],
compressor_stride=preset["compressor_stride"],
d_model=preset["d_model"],
nhead=preset["nhead"],
num_layers=preset["num_layers"],
dim_feedforward=preset["dim_feedforward"],
dropout=DROPOUT,
)
params = count_parameters(model)
bp_per_token = 12 * preset["compressor_stride"]
print(f"\n[Main] Model: {params['trainable']:,} params")
print(f"[Main] BP/token: {bp_per_token}")
# ── Trackio init ───────────────────────────────────────────────────────
trackio.init(
project=args.trackio_project,
run_name=args.preset,
)
trackio.log_params({
"preset": args.preset,
"description": preset["description"],
"compressor_stride": preset["compressor_stride"],
"d_model": preset["d_model"],
"stem_kernel": preset["stem_kernel"],
"num_layers": preset["num_layers"],
"nhead": preset["nhead"],
"bp_per_token": bp_per_token,
"num_classes": num_classes,
"num_params": params["trainable"],
"batch_size": args.batch_size,
"lr": args.lr,
})
# ── Train ──────────────────────────────────────────────────────────────
print(f"\n[Main] Starting training...")
model, metrics = train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
num_classes=num_classes,
class_names=class_names,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
output_dir=output_dir,
trackio_project=args.trackio_project,
)
# ── Final summary ──────────────────────────────────────────────────────
print(f"\n{'='*70}")
print(f"Training Complete: {args.preset}")
print(f" Best Val Acc: {metrics.best_val_acc:.4f}")
print(f" Best Val F1: {metrics.best_val_f1:.4f}")
print(f" Best Epoch: {metrics.best_epoch}")
print(f"{'='*70}")
trackio.alert(
f"Training Complete: {args.preset}",
f"Acc={metrics.best_val_acc:.4f}, F1={metrics.best_val_f1:.4f}, Epoch={metrics.best_epoch}",
level="success",
)
# ── Save final artifacts ───────────────────────────────────────────────
summary = {
"preset": args.preset,
**preset,
"best_val_acc": float(metrics.best_val_acc),
"best_val_f1": float(metrics.best_val_f1),
"best_epoch": metrics.best_epoch,
"num_params": params["trainable"],
"num_classes": num_classes,
"bp_per_token": bp_per_token,
}
with open(output_dir / "summary.json", "w") as fh:
json.dump(summary, fh, indent=2)
print(f"[Main] Summary saved: {output_dir / 'summary.json'}")
if __name__ == "__main__":
main()