""" Phase 3: Continue Pretraining Demonstrates how to continue pretraining GeneMamba on your own data using masked LM objective. Usage: python examples/3_continue_pretraining.py """ import torch import numpy as np from torch.utils.data import Dataset from transformers import ( AutoModelForMaskedLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, ) class PretrainingDataset(Dataset): """ Dataset for pretraining/continued pretraining. Loads sequences and their lengths. """ def __init__(self, input_ids_list, max_length=2048): self.input_ids_list = input_ids_list self.max_length = max_length def __len__(self): return len(self.input_ids_list) def __getitem__(self, idx): input_ids = self.input_ids_list[idx] # Pad or truncate to max_length if len(input_ids) >= self.max_length: input_ids = input_ids[:self.max_length] else: input_ids = np.pad( input_ids, (0, self.max_length - len(input_ids)), constant_values=1 # Pad token ID ) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), } def create_mock_pretraining_data(n_sequences=5000, seq_len=2048): """Create mock single-cell sequences for pretraining.""" print("Creating mock pretraining dataset...") # Create ranked gene sequences # In practice, these would come from your scRNA-seq data sequences = [] for _ in range(n_sequences): # Random ranked sequence seq = np.random.randint(2, 25426, seq_len) sequences.append(seq) print(f"✓ Created {n_sequences} sequences of length {seq_len}") return sequences def main(): print("=" * 80) print("GeneMamba Phase 3: Continue Pretraining") print("=" * 80) # ============================================================ # Step 1: Load pretrained model for masked LM # ============================================================ print("\n[Step 1] Loading model for masked LM...") try: model = AutoModelForMaskedLM.from_pretrained( "GeneMamba-24l-512d", trust_remote_code=True, local_files_only=True, ) tokenizer = AutoTokenizer.from_pretrained( "GeneMamba-24l-512d", trust_remote_code=True, local_files_only=True, ) except Exception as e: print(f"Note: Could not load from hub ({e})") print("Using local initialization...") # Initialize locally from configuration_genemamba import GeneMambaConfig from modeling_genemamba import GeneMambaForMaskedLM config = GeneMambaConfig( vocab_size=25426, hidden_size=512, num_hidden_layers=24, ) model = GeneMambaForMaskedLM(config) tokenizer = None print(f"✓ Model loaded") print(f" - Architecture: {model.config.num_hidden_layers} layers, " f"hidden_size={model.config.hidden_size}") # ============================================================ # Step 2: Prepare pretraining data # ============================================================ print("\n[Step 2] Preparing pretraining dataset...") sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048) # Split train/eval train_size = int(0.9 * len(sequences)) train_sequences = sequences[:train_size] eval_sequences = sequences[train_size:] train_dataset = PretrainingDataset(train_sequences) eval_dataset = PretrainingDataset(eval_sequences) print(f"✓ Datasets created:") print(f" - Training: {len(train_dataset)} samples") print(f" - Evaluation: {len(eval_dataset)} samples") # ============================================================ # Step 3: Set up data collator for MLM # ============================================================ print("\n[Step 3] Setting up data collator...") if tokenizer is not None: data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=0.15, # Mask 15% of tokens ) else: # Custom collator if no tokenizer available class CustomDataCollator: def __call__(self, batch): input_ids = torch.stack([item["input_ids"] for item in batch]) # Create masked labels (for MLM loss) labels = input_ids.clone() mask = torch.rand(input_ids.shape) < 0.15 # Set input to [MASK] token (id=0) input_ids[mask] = 0 # Set labels to -100 where not masked (loss ignores these) labels[~mask] = -100 return {"input_ids": input_ids, "labels": labels} data_collator = CustomDataCollator() print(f"✓ Data collator ready (MLM probability: 0.15)") # ============================================================ # Step 4: Set up training arguments # ============================================================ print("\n[Step 4] Setting up training...") output_dir = "./pretrain_results" training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=2, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, weight_decay=0.01, warmup_steps=500, logging_steps=100, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", report_to="none", # Disable W&B seed=42, ) print(f"✓ Training config:") print(f" - Output dir: {output_dir}") print(f" - Epochs: {training_args.num_train_epochs}") print(f" - Batch size: {training_args.per_device_train_batch_size}") print(f" - Learning rate: {training_args.learning_rate}") print(f" - MLM masking: 15%") # ============================================================ # Step 5: Train # ============================================================ print("\n[Step 5] Starting continued pretraining...") trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, ) train_result = trainer.train() print(f"✓ Training complete!") print(f" - Final training loss: {train_result.training_loss:.4f}") # ============================================================ # Step 6: Evaluate # ============================================================ print("\n[Step 6] Evaluating on held-out set...") eval_results = trainer.evaluate() print(f"✓ Evaluation Results:") for metric, value in eval_results.items(): if isinstance(value, (int, float)): print(f" - {metric}: {value:.4f}") # ============================================================ # Step 7: Save model # ============================================================ print("\n[Step 7] Saving continued pretrained model...") save_dir = "./genemamba_continued_pretrain" model.save_pretrained(save_dir) if tokenizer is not None: tokenizer.save_pretrained(save_dir) print(f"✓ Model saved to '{save_dir}'") # ============================================================ # Step 8: Test model inference # ============================================================ print("\n[Step 8] Testing inference on masked input...") model.eval() # Create sample input with masked tokens sample_input = torch.randint(2, 25426, (1, 2048)) sample_input[0, :10] = 0 # Mask first 10 tokens with torch.no_grad(): outputs = model(sample_input) logits = outputs.logits predictions = torch.argmax(logits, dim=-1) print(f"✓ Sample predictions generated") print(f" - Input shape: {sample_input.shape}") print(f" - Output logits shape: {logits.shape}") print(f" - Top predicted genes (tokens): {predictions[0, :10].tolist()}") print("\n" + "=" * 80) print("Phase 3 Complete! Model ready for downstream tasks or further training.") print("=" * 80) return model, trainer if __name__ == "__main__": model, trainer = main()