GeneMamba / examples /2_finetune_classification.py
mineself2016's picture
Upload GeneMamba model
54cd552 verified
"""
Phase 2: Downstream Task - Fine-tune for Classification
Demonstrates cell type annotation and other sequence classification tasks.
Usage:
python examples/2_finetune_classification.py
"""
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
class GeneExpressionDataset(Dataset):
"""
Simple dataset for gene expression classification.
In practice, this would load from h5ad or other single-cell formats.
"""
def __init__(self, input_ids, labels, max_length=2048):
self.input_ids = input_ids
self.labels = labels
self.max_length = max_length
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
input_id = self.input_ids[idx]
label = self.labels[idx]
return {
"input_ids": torch.tensor(input_id, dtype=torch.long),
"labels": torch.tensor(label, dtype=torch.long),
}
def create_mock_data(n_samples=1000, n_features=2048, n_classes=5):
"""Create mock single-cell data for demonstration."""
print("Creating mock dataset...")
# Create random ranked gene sequences
input_ids = np.random.randint(2, 25426, (n_samples, n_features))
# Create random labels (e.g., cell types)
labels = np.random.randint(0, n_classes, n_samples)
# Split into train/val/test
train_size = int(0.7 * n_samples)
val_size = int(0.15 * n_samples)
train_ids = input_ids[:train_size]
train_labels = labels[:train_size]
val_ids = input_ids[train_size:train_size + val_size]
val_labels = labels[train_size:train_size + val_size]
test_ids = input_ids[train_size + val_size:]
test_labels = labels[train_size + val_size:]
print(f"✓ Dataset created:")
print(f" - Train: {len(train_ids)} samples")
print(f" - Val: {len(val_ids)} samples")
print(f" - Test: {len(test_ids)} samples")
print(f" - Classes: {n_classes}")
return (
GeneExpressionDataset(train_ids, train_labels),
GeneExpressionDataset(val_ids, val_labels),
GeneExpressionDataset(test_ids, test_labels),
)
def main():
print("=" * 80)
print("GeneMamba Phase 2: Downstream Classification")
print("=" * 80)
# ============================================================
# Step 1: Load pretrained model with classification head
# ============================================================
print("\n[Step 1] Loading pretrained model with classification head...")
num_classes = 5
try:
model = AutoModelForSequenceClassification.from_pretrained(
"GeneMamba-24l-512d",
num_labels=num_classes,
trust_remote_code=True,
local_files_only=True,
)
except Exception as e:
print(f"Note: Could not load from hub ({e})")
print("Using local initialization...")
# Initialize locally
from configuration_genemamba import GeneMambaConfig
from modeling_genemamba import GeneMambaForSequenceClassification
config = GeneMambaConfig(
vocab_size=25426,
hidden_size=512,
num_hidden_layers=24,
num_labels=num_classes,
)
model = GeneMambaForSequenceClassification(config)
print(f"✓ Model loaded")
print(f" - Classification head: input={model.config.hidden_size} → output={num_classes}")
# ============================================================
# Step 2: Prepare data
# ============================================================
print("\n[Step 2] Preparing dataset...")
train_dataset, val_dataset, test_dataset = create_mock_data(
n_samples=1000,
n_features=2048,
n_classes=num_classes,
)
# ============================================================
# Step 3: Set up training arguments
# ============================================================
print("\n[Step 3] Setting up training...")
output_dir = "./classification_results"
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
weight_decay=0.01,
warmup_steps=100,
logging_steps=50,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="none", # Disable W&B logging
seed=42,
)
print(f"✓ Training config:")
print(f" - Output dir: {output_dir}")
print(f" - Epochs: {training_args.num_train_epochs}")
print(f" - Batch size: {training_args.per_device_train_batch_size}")
print(f" - Learning rate: {training_args.learning_rate}")
# ============================================================
# Step 4: Train using Trainer
# ============================================================
print("\n[Step 4] Training model...")
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
def compute_metrics(eval_pred):
"""Compute evaluation metrics."""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return {
"accuracy": accuracy_score(labels, predictions),
"f1": f1_score(labels, predictions, average="weighted", zero_division=0),
"precision": precision_score(labels, predictions, average="weighted", zero_division=0),
"recall": recall_score(labels, predictions, average="weighted", zero_division=0),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
)
train_result = trainer.train()
print(f"✓ Training complete!")
print(f" - Final training loss: {train_result.training_loss:.4f}")
# ============================================================
# Step 5: Evaluate on test set
# ============================================================
print("\n[Step 5] Evaluating on test set...")
test_results = trainer.evaluate(test_dataset)
print(f"✓ Test Results:")
for metric, value in test_results.items():
if isinstance(value, float):
print(f" - {metric}: {value:.4f}")
# ============================================================
# Step 6: Make predictions
# ============================================================
print("\n[Step 6] Making predictions...")
predictions = trainer.predict(test_dataset)
predicted_classes = np.argmax(predictions.predictions, axis=1)
print(f"✓ Predictions made:")
print(f" - Predicted classes: {len(predicted_classes)} samples")
print(f" - Class distribution: {np.bincount(predicted_classes)}")
# ============================================================
# Step 7: Save model
# ============================================================
print("\n[Step 7] Saving model...")
save_dir = "./my_genemamba_classifier"
model.save_pretrained(save_dir)
print(f"✓ Model saved to '{save_dir}'")
# ============================================================
# Step 8: Load and test saved model
# ============================================================
print("\n[Step 8] Testing model reloading...")
loaded_model = AutoModelForSequenceClassification.from_pretrained(
save_dir,
trust_remote_code=True,
)
loaded_model.eval()
# Test on a single batch
with torch.no_grad():
sample_input = torch.randint(2, 25426, (1, 2048))
output = loaded_model(sample_input)
logits = output.logits
prediction = torch.argmax(logits, dim=1)
print(f"✓ Loaded model test prediction: class {prediction.item()}")
print("\n" + "=" * 80)
print("Phase 2 Complete! Model ready for deployment.")
print("=" * 80)
return model, trainer
if __name__ == "__main__":
model, trainer = main()