| """ |
| Phase 2: Downstream Task - Fine-tune for Classification |
| Demonstrates cell type annotation and other sequence classification tasks. |
| |
| Usage: |
| python examples/2_finetune_classification.py |
| """ |
|
|
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments |
|
|
|
|
| class GeneExpressionDataset(Dataset): |
| """ |
| Simple dataset for gene expression classification. |
| In practice, this would load from h5ad or other single-cell formats. |
| """ |
| |
| def __init__(self, input_ids, labels, max_length=2048): |
| self.input_ids = input_ids |
| self.labels = labels |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.input_ids) |
| |
| def __getitem__(self, idx): |
| input_id = self.input_ids[idx] |
| label = self.labels[idx] |
| |
| return { |
| "input_ids": torch.tensor(input_id, dtype=torch.long), |
| "labels": torch.tensor(label, dtype=torch.long), |
| } |
|
|
|
|
| def create_mock_data(n_samples=1000, n_features=2048, n_classes=5): |
| """Create mock single-cell data for demonstration.""" |
| |
| print("Creating mock dataset...") |
| |
| |
| input_ids = np.random.randint(2, 25426, (n_samples, n_features)) |
| |
| |
| labels = np.random.randint(0, n_classes, n_samples) |
| |
| |
| train_size = int(0.7 * n_samples) |
| val_size = int(0.15 * n_samples) |
| |
| train_ids = input_ids[:train_size] |
| train_labels = labels[:train_size] |
| |
| val_ids = input_ids[train_size:train_size + val_size] |
| val_labels = labels[train_size:train_size + val_size] |
| |
| test_ids = input_ids[train_size + val_size:] |
| test_labels = labels[train_size + val_size:] |
| |
| print(f"✓ Dataset created:") |
| print(f" - Train: {len(train_ids)} samples") |
| print(f" - Val: {len(val_ids)} samples") |
| print(f" - Test: {len(test_ids)} samples") |
| print(f" - Classes: {n_classes}") |
| |
| return ( |
| GeneExpressionDataset(train_ids, train_labels), |
| GeneExpressionDataset(val_ids, val_labels), |
| GeneExpressionDataset(test_ids, test_labels), |
| ) |
|
|
|
|
| def main(): |
| print("=" * 80) |
| print("GeneMamba Phase 2: Downstream Classification") |
| print("=" * 80) |
| |
| |
| |
| |
| print("\n[Step 1] Loading pretrained model with classification head...") |
| |
| num_classes = 5 |
| |
| try: |
| model = AutoModelForSequenceClassification.from_pretrained( |
| "GeneMamba-24l-512d", |
| num_labels=num_classes, |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| except Exception as e: |
| print(f"Note: Could not load from hub ({e})") |
| print("Using local initialization...") |
| |
| |
| from configuration_genemamba import GeneMambaConfig |
| from modeling_genemamba import GeneMambaForSequenceClassification |
| |
| config = GeneMambaConfig( |
| vocab_size=25426, |
| hidden_size=512, |
| num_hidden_layers=24, |
| num_labels=num_classes, |
| ) |
| model = GeneMambaForSequenceClassification(config) |
| |
| print(f"✓ Model loaded") |
| print(f" - Classification head: input={model.config.hidden_size} → output={num_classes}") |
| |
| |
| |
| |
| print("\n[Step 2] Preparing dataset...") |
| |
| train_dataset, val_dataset, test_dataset = create_mock_data( |
| n_samples=1000, |
| n_features=2048, |
| n_classes=num_classes, |
| ) |
| |
| |
| |
| |
| print("\n[Step 3] Setting up training...") |
| |
| output_dir = "./classification_results" |
| |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=3, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=16, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| warmup_steps=100, |
| logging_steps=50, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="accuracy", |
| report_to="none", |
| seed=42, |
| ) |
| |
| print(f"✓ Training config:") |
| print(f" - Output dir: {output_dir}") |
| print(f" - Epochs: {training_args.num_train_epochs}") |
| print(f" - Batch size: {training_args.per_device_train_batch_size}") |
| print(f" - Learning rate: {training_args.learning_rate}") |
| |
| |
| |
| |
| print("\n[Step 4] Training model...") |
| |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score |
| |
| def compute_metrics(eval_pred): |
| """Compute evaluation metrics.""" |
| predictions, labels = eval_pred |
| predictions = np.argmax(predictions, axis=1) |
| |
| return { |
| "accuracy": accuracy_score(labels, predictions), |
| "f1": f1_score(labels, predictions, average="weighted", zero_division=0), |
| "precision": precision_score(labels, predictions, average="weighted", zero_division=0), |
| "recall": recall_score(labels, predictions, average="weighted", zero_division=0), |
| } |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| compute_metrics=compute_metrics, |
| ) |
| |
| train_result = trainer.train() |
| |
| print(f"✓ Training complete!") |
| print(f" - Final training loss: {train_result.training_loss:.4f}") |
| |
| |
| |
| |
| print("\n[Step 5] Evaluating on test set...") |
| |
| test_results = trainer.evaluate(test_dataset) |
| |
| print(f"✓ Test Results:") |
| for metric, value in test_results.items(): |
| if isinstance(value, float): |
| print(f" - {metric}: {value:.4f}") |
| |
| |
| |
| |
| print("\n[Step 6] Making predictions...") |
| |
| predictions = trainer.predict(test_dataset) |
| predicted_classes = np.argmax(predictions.predictions, axis=1) |
| |
| print(f"✓ Predictions made:") |
| print(f" - Predicted classes: {len(predicted_classes)} samples") |
| print(f" - Class distribution: {np.bincount(predicted_classes)}") |
| |
| |
| |
| |
| print("\n[Step 7] Saving model...") |
| |
| save_dir = "./my_genemamba_classifier" |
| model.save_pretrained(save_dir) |
| print(f"✓ Model saved to '{save_dir}'") |
| |
| |
| |
| |
| print("\n[Step 8] Testing model reloading...") |
| |
| loaded_model = AutoModelForSequenceClassification.from_pretrained( |
| save_dir, |
| trust_remote_code=True, |
| ) |
| loaded_model.eval() |
| |
| |
| with torch.no_grad(): |
| sample_input = torch.randint(2, 25426, (1, 2048)) |
| output = loaded_model(sample_input) |
| logits = output.logits |
| prediction = torch.argmax(logits, dim=1) |
| |
| print(f"✓ Loaded model test prediction: class {prediction.item()}") |
| |
| print("\n" + "=" * 80) |
| print("Phase 2 Complete! Model ready for deployment.") |
| print("=" * 80) |
| |
| return model, trainer |
|
|
|
|
| if __name__ == "__main__": |
| model, trainer = main() |
|
|