""" Phase 2: Downstream Task - Fine-tune for Classification Demonstrates cell type annotation and other sequence classification tasks. Usage: python examples/2_finetune_classification.py """ import torch import numpy as np from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments class GeneExpressionDataset(Dataset): """ Simple dataset for gene expression classification. In practice, this would load from h5ad or other single-cell formats. """ def __init__(self, input_ids, labels, max_length=2048): self.input_ids = input_ids self.labels = labels self.max_length = max_length def __len__(self): return len(self.input_ids) def __getitem__(self, idx): input_id = self.input_ids[idx] label = self.labels[idx] return { "input_ids": torch.tensor(input_id, dtype=torch.long), "labels": torch.tensor(label, dtype=torch.long), } def create_mock_data(n_samples=1000, n_features=2048, n_classes=5): """Create mock single-cell data for demonstration.""" print("Creating mock dataset...") # Create random ranked gene sequences input_ids = np.random.randint(2, 25426, (n_samples, n_features)) # Create random labels (e.g., cell types) labels = np.random.randint(0, n_classes, n_samples) # Split into train/val/test train_size = int(0.7 * n_samples) val_size = int(0.15 * n_samples) train_ids = input_ids[:train_size] train_labels = labels[:train_size] val_ids = input_ids[train_size:train_size + val_size] val_labels = labels[train_size:train_size + val_size] test_ids = input_ids[train_size + val_size:] test_labels = labels[train_size + val_size:] print(f"✓ Dataset created:") print(f" - Train: {len(train_ids)} samples") print(f" - Val: {len(val_ids)} samples") print(f" - Test: {len(test_ids)} samples") print(f" - Classes: {n_classes}") return ( GeneExpressionDataset(train_ids, train_labels), GeneExpressionDataset(val_ids, val_labels), GeneExpressionDataset(test_ids, test_labels), ) def main(): print("=" * 80) print("GeneMamba Phase 2: Downstream Classification") print("=" * 80) # ============================================================ # Step 1: Load pretrained model with classification head # ============================================================ print("\n[Step 1] Loading pretrained model with classification head...") num_classes = 5 try: model = AutoModelForSequenceClassification.from_pretrained( "GeneMamba-24l-512d", num_labels=num_classes, trust_remote_code=True, local_files_only=True, ) except Exception as e: print(f"Note: Could not load from hub ({e})") print("Using local initialization...") # Initialize locally from configuration_genemamba import GeneMambaConfig from modeling_genemamba import GeneMambaForSequenceClassification config = GeneMambaConfig( vocab_size=25426, hidden_size=512, num_hidden_layers=24, num_labels=num_classes, ) model = GeneMambaForSequenceClassification(config) print(f"✓ Model loaded") print(f" - Classification head: input={model.config.hidden_size} → output={num_classes}") # ============================================================ # Step 2: Prepare data # ============================================================ print("\n[Step 2] Preparing dataset...") train_dataset, val_dataset, test_dataset = create_mock_data( n_samples=1000, n_features=2048, n_classes=num_classes, ) # ============================================================ # Step 3: Set up training arguments # ============================================================ print("\n[Step 3] Setting up training...") output_dir = "./classification_results" training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, weight_decay=0.01, warmup_steps=100, logging_steps=50, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", report_to="none", # Disable W&B logging seed=42, ) print(f"✓ Training config:") print(f" - Output dir: {output_dir}") print(f" - Epochs: {training_args.num_train_epochs}") print(f" - Batch size: {training_args.per_device_train_batch_size}") print(f" - Learning rate: {training_args.learning_rate}") # ============================================================ # Step 4: Train using Trainer # ============================================================ print("\n[Step 4] Training model...") from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score def compute_metrics(eval_pred): """Compute evaluation metrics.""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return { "accuracy": accuracy_score(labels, predictions), "f1": f1_score(labels, predictions, average="weighted", zero_division=0), "precision": precision_score(labels, predictions, average="weighted", zero_division=0), "recall": recall_score(labels, predictions, average="weighted", zero_division=0), } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, ) train_result = trainer.train() print(f"✓ Training complete!") print(f" - Final training loss: {train_result.training_loss:.4f}") # ============================================================ # Step 5: Evaluate on test set # ============================================================ print("\n[Step 5] Evaluating on test set...") test_results = trainer.evaluate(test_dataset) print(f"✓ Test Results:") for metric, value in test_results.items(): if isinstance(value, float): print(f" - {metric}: {value:.4f}") # ============================================================ # Step 6: Make predictions # ============================================================ print("\n[Step 6] Making predictions...") predictions = trainer.predict(test_dataset) predicted_classes = np.argmax(predictions.predictions, axis=1) print(f"✓ Predictions made:") print(f" - Predicted classes: {len(predicted_classes)} samples") print(f" - Class distribution: {np.bincount(predicted_classes)}") # ============================================================ # Step 7: Save model # ============================================================ print("\n[Step 7] Saving model...") save_dir = "./my_genemamba_classifier" model.save_pretrained(save_dir) print(f"✓ Model saved to '{save_dir}'") # ============================================================ # Step 8: Load and test saved model # ============================================================ print("\n[Step 8] Testing model reloading...") loaded_model = AutoModelForSequenceClassification.from_pretrained( save_dir, trust_remote_code=True, ) loaded_model.eval() # Test on a single batch with torch.no_grad(): sample_input = torch.randint(2, 25426, (1, 2048)) output = loaded_model(sample_input) logits = output.logits prediction = torch.argmax(logits, dim=1) print(f"✓ Loaded model test prediction: class {prediction.item()}") print("\n" + "=" * 80) print("Phase 2 Complete! Model ready for deployment.") print("=" * 80) return model, trainer if __name__ == "__main__": model, trainer = main()