""" BERT-Thetis Colab Training Script ---------------------------------- Pretrain BERT-Thetis on WikiText-103 with Masked Language Modeling. In a cell above this in colab run this install here; and then begin the training. try: !pip uninstall -qy geometricvocab except: pass !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git Designed for Google Colab with: - Easy setup and installation - HuggingFace Hub integration - Memory-efficient training - Progress tracking and logging - Automatic checkpointing Author: AbstractPhil + Claude Sonnet 4.5 License: MIT """ import os import math import time from pathlib import Path from typing import Optional, Dict, Any from dataclasses import dataclass, field import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from datasets import load_dataset from transformers import AutoTokenizer from tqdm.auto import tqdm # Import BERT-Thetis from geovocab2.train.model.core.bert_thetis import ( ThetisConfig, ThetisForMaskedLM ) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Configuration # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ @dataclass class TrainingConfig: """Training configuration for Colab.""" # Model model_name: str = "bert-thetis-tiny-wikitext103" crystal_dim: int = 256 num_layers: int = 4 num_attention_heads: int = 4 intermediate_size: int = 1024 vocab_size: int = 30522 beatrix_levels: int = 16 max_position_embeddings: int = 512 # Dataset dataset_name: str = "wikitext" dataset_config: str = "wikitext-103-raw-v1" tokenizer_name: str = "bert-base-uncased" max_length: int = 128 mlm_probability: float = 0.15 # Training num_epochs: int = 10 batch_size: int = 64 gradient_accumulation_steps: int = 2 learning_rate: float = 5e-4 weight_decay: float = 0.01 warmup_ratio: float = 0.1 max_grad_norm: float = 1.0 # Hardware device: str = "cuda" if torch.cuda.is_available() else "cpu" num_workers: int = 2 pin_memory: bool = True mixed_precision: bool = True # Use AMP for faster training # Checkpointing save_steps: int = 1000 eval_steps: int = 500 logging_steps: int = 100 save_total_limit: int = 3 # HuggingFace Hub push_to_hub: bool = True hub_model_id: str = "AbstractPhil/bert-thetis-tiny-wikitext103" hub_token: Optional[str] = None # Will read from HF_TOKEN env var # Paths output_dir: str = "./thetis-outputs" cache_dir: str = "./cache" def __post_init__(self): """Setup paths and device.""" os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.cache_dir, exist_ok=True) # Get HF token from environment if not provided if self.hub_token is None: self.hub_token = os.environ.get("HF_TOKEN") print(f"🚢 BERT-Thetis Training Configuration") print(f" Device: {self.device}") print(f" Mixed Precision: {self.mixed_precision}") print(f" Model: {self.model_name}") print(f" Dataset: {self.dataset_name}/{self.dataset_config}") print(f" Output: {self.output_dir}") print(f" Push to Hub: {self.push_to_hub}") if self.push_to_hub: print(f" Hub Repo: {self.hub_model_id}") # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Dataset # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MaskedLMDataset(Dataset): """Dataset for Masked Language Modeling.""" def __init__( self, texts, tokenizer, max_length: int = 128, mlm_probability: float = 0.15 ): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length self.mlm_probability = mlm_probability def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize encoding = self.tokenizer( text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Create masked version labels = input_ids.clone() # Mask tokens probability_matrix = torch.full(labels.shape, self.mlm_probability) # Don't mask special tokens (pass the whole list, not individual tokens) special_tokens_mask = self.tokenizer.get_special_tokens_mask( labels.tolist(), already_has_special_tokens=True ) probability_matrix.masked_fill_( torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 ) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # Only compute loss on masked tokens # 80% of the time, replace with [MASK] indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices input_ids[indices_replaced] = self.tokenizer.mask_token_id # 10% of the time, replace with random token indices_random = ( torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced ) random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) input_ids[indices_random] = random_words[indices_random] # 10% of the time, keep original return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def prepare_datasets(config: TrainingConfig): """Load and prepare WikiText-103 datasets.""" print(f"\n📚 Loading {config.dataset_name}...") # Load dataset dataset = load_dataset( config.dataset_name, config.dataset_config, cache_dir=config.cache_dir ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( config.tokenizer_name, cache_dir=config.cache_dir ) # Filter out empty texts def is_valid(example): return len(example["text"].strip()) > 0 train_texts = [ex["text"] for ex in dataset["train"] if is_valid(ex)] val_texts = [ex["text"] for ex in dataset["validation"] if is_valid(ex)] print(f" Train samples: {len(train_texts):,}") print(f" Val samples: {len(val_texts):,}") # Create datasets train_dataset = MaskedLMDataset( train_texts, tokenizer, config.max_length, config.mlm_probability ) val_dataset = MaskedLMDataset( val_texts, tokenizer, config.max_length, config.mlm_probability ) return train_dataset, val_dataset, tokenizer # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Training Loop # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class ThetisTrainer: """Trainer for BERT-Thetis with MLM.""" def __init__( self, model: ThetisForMaskedLM, train_dataset: Dataset, val_dataset: Dataset, config: TrainingConfig ): self.model = model self.train_dataset = train_dataset self.val_dataset = val_dataset self.config = config # Move model to device self.model.to(config.device) # Data loaders self.train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=config.pin_memory ) self.val_loader = DataLoader( val_dataset, batch_size=config.batch_size * 2, # Larger batch for eval shuffle=False, num_workers=config.num_workers, pin_memory=config.pin_memory ) # Optimizer no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": config.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate) # Scheduler total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps warmup_steps = int(total_steps * config.warmup_ratio) self.scheduler = OneCycleLR( self.optimizer, max_lr=config.learning_rate, total_steps=total_steps, pct_start=config.warmup_ratio, anneal_strategy="cos" ) # Mixed precision self.scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and config.device == 'cuda' else None # Training state self.global_step = 0 self.epoch = 0 self.best_val_loss = float("inf") print(f"\n🎯 Training Setup") print(f" Total steps: {total_steps:,}") print(f" Warmup steps: {warmup_steps:,}") print(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") def train_epoch(self): """Train for one epoch.""" self.model.train() total_loss = 0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch + 1}") for step, batch in enumerate(progress_bar): # Move to device batch = {k: v.to(self.config.device) for k, v in batch.items()} # Forward pass with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): loss, _ = self.model( token_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] ) loss = loss / self.config.gradient_accumulation_steps # Backward pass if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() total_loss += loss.item() # Update weights if (step + 1) % self.config.gradient_accumulation_steps == 0: if self.scaler is not None: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.global_step += 1 # Update progress bar progress_bar.set_postfix({ "loss": f"{loss.item() * self.config.gradient_accumulation_steps:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}" }) # Logging if self.global_step % self.config.logging_steps == 0: avg_loss = total_loss / self.config.logging_steps print(f"\n Step {self.global_step}: loss={avg_loss:.4f}, lr={self.scheduler.get_last_lr()[0]:.2e}") total_loss = 0 # Evaluation if self.global_step % self.config.eval_steps == 0: val_loss = self.evaluate() print(f" Validation loss: {val_loss:.4f}") # Save best model if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.save_checkpoint("best") print(f" ✓ New best model saved!") self.model.train() # Save checkpoint if self.global_step % self.config.save_steps == 0: self.save_checkpoint(f"step-{self.global_step}") @torch.no_grad() def evaluate(self): """Evaluate on validation set.""" self.model.eval() total_loss = 0 total_steps = 0 for batch in tqdm(self.val_loader, desc="Evaluating", leave=False): batch = {k: v.to(self.config.device) for k, v in batch.items()} with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): loss, _ = self.model( token_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] ) total_loss += loss.item() total_steps += 1 return total_loss / total_steps def train(self): """Full training loop.""" print(f"\n🚀 Starting Training") print("=" * 70) start_time = time.time() for epoch in range(self.config.num_epochs): self.epoch = epoch print(f"\n📖 Epoch {epoch + 1}/{self.config.num_epochs}") self.train_epoch() # Epoch evaluation val_loss = self.evaluate() print(f"\n Epoch {epoch + 1} validation loss: {val_loss:.4f}") # Save epoch checkpoint self.save_checkpoint(f"epoch-{epoch + 1}") # Final evaluation final_val_loss = self.evaluate() print(f"\n✅ Training Complete!") print(f" Final validation loss: {final_val_loss:.4f}") print(f" Best validation loss: {self.best_val_loss:.4f}") print(f" Total time: {(time.time() - start_time) / 3600:.2f} hours") # Save final model self.save_checkpoint("final") # Push to hub if self.config.push_to_hub: self.push_to_hub() def save_checkpoint(self, name: str): """Save model checkpoint.""" output_dir = Path(self.config.output_dir) / name output_dir.mkdir(parents=True, exist_ok=True) # Save model torch.save(self.model.state_dict(), output_dir / "pytorch_model.bin") # Save config config_dict = { "crystal_dim": self.config.crystal_dim, "num_layers": self.config.num_layers, "num_attention_heads": self.config.num_attention_heads, "intermediate_size": self.config.intermediate_size, "vocab_size": self.config.vocab_size, "beatrix_levels": self.config.beatrix_levels, "max_position_embeddings": self.config.max_position_embeddings, } import json with open(output_dir / "config.json", "w") as f: json.dump(config_dict, f, indent=2) # Save training state state = { "global_step": self.global_step, "epoch": self.epoch, "best_val_loss": self.best_val_loss, } torch.save(state, output_dir / "training_state.pt") def push_to_hub(self): """Push model to HuggingFace Hub.""" if not self.config.hub_token: print("⚠️ No HuggingFace token found. Skipping push to hub.") return print(f"\n📤 Pushing to HuggingFace Hub: {self.config.hub_model_id}") try: from huggingface_hub import HfApi, create_repo api = HfApi(token=self.config.hub_token) # Create repo if it doesn't exist try: create_repo( repo_id=self.config.hub_model_id, token=self.config.hub_token, exist_ok=True ) except Exception as e: print(f" Repo creation: {e}") # Upload best checkpoint best_dir = Path(self.config.output_dir) / "best" if best_dir.exists(): api.upload_folder( folder_path=str(best_dir), repo_id=self.config.hub_model_id, token=self.config.hub_token ) print(f" ✓ Best model uploaded!") # Upload final checkpoint final_dir = Path(self.config.output_dir) / "final" if final_dir.exists(): api.upload_folder( folder_path=str(final_dir), repo_id=self.config.hub_model_id, path_in_repo="final", token=self.config.hub_token ) print(f" ✓ Final model uploaded!") except Exception as e: print(f"⚠️ Failed to push to hub: {e}") # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Main Entry Point # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ def main(): """Main training function.""" # Configuration config = TrainingConfig() # Prepare datasets train_dataset, val_dataset, tokenizer = prepare_datasets(config) # Create model print(f"\n🏗️ Creating BERT-Thetis model...") model_config = ThetisConfig( crystal_dim=config.crystal_dim, num_vertices=5, num_layers=config.num_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, vocab_size=config.vocab_size, beatrix_levels=config.beatrix_levels, max_position_embeddings=config.max_position_embeddings, ) model = ThetisForMaskedLM(model_config) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Total parameters: {total_params:,}") print(f" Trainable parameters: {trainable_params:,}") # Create trainer trainer = ThetisTrainer(model, train_dataset, val_dataset, config) # Train trainer.train() print("\n🎉 All done! BERT-Thetis is ready to sail!") if __name__ == "__main__": main()