|
|
""" |
|
|
BERT-Thetis Colab Training Script |
|
|
---------------------------------- |
|
|
Pretrain BERT-Thetis on WikiText-103 with Masked Language Modeling. |
|
|
|
|
|
In a cell above this in colab run this install here; and then begin the training. |
|
|
|
|
|
try: |
|
|
!pip uninstall -qy geometricvocab |
|
|
except: |
|
|
pass |
|
|
|
|
|
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
|
|
|
|
|
|
|
|
Designed for Google Colab with: |
|
|
- Easy setup and installation |
|
|
- HuggingFace Hub integration |
|
|
- Memory-efficient training |
|
|
- Progress tracking and logging |
|
|
- Automatic checkpointing |
|
|
|
|
|
Author: AbstractPhil + Claude Sonnet 4.5 |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
from geovocab2.train.model.core.bert_thetis import ( |
|
|
ThetisConfig, |
|
|
ThetisForMaskedLM |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Training configuration for Colab.""" |
|
|
|
|
|
|
|
|
model_name: str = "bert-thetis-tiny-wikitext103" |
|
|
crystal_dim: int = 256 |
|
|
num_layers: int = 4 |
|
|
num_attention_heads: int = 4 |
|
|
intermediate_size: int = 1024 |
|
|
vocab_size: int = 30522 |
|
|
beatrix_levels: int = 16 |
|
|
max_position_embeddings: int = 512 |
|
|
|
|
|
|
|
|
dataset_name: str = "wikitext" |
|
|
dataset_config: str = "wikitext-103-raw-v1" |
|
|
tokenizer_name: str = "bert-base-uncased" |
|
|
max_length: int = 128 |
|
|
mlm_probability: float = 0.15 |
|
|
|
|
|
|
|
|
num_epochs: int = 10 |
|
|
batch_size: int = 64 |
|
|
gradient_accumulation_steps: int = 2 |
|
|
learning_rate: float = 5e-4 |
|
|
weight_decay: float = 0.01 |
|
|
warmup_ratio: float = 0.1 |
|
|
max_grad_norm: float = 1.0 |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
num_workers: int = 2 |
|
|
pin_memory: bool = True |
|
|
mixed_precision: bool = True |
|
|
|
|
|
|
|
|
save_steps: int = 1000 |
|
|
eval_steps: int = 500 |
|
|
logging_steps: int = 100 |
|
|
save_total_limit: int = 3 |
|
|
|
|
|
|
|
|
push_to_hub: bool = True |
|
|
hub_model_id: str = "AbstractPhil/bert-thetis-tiny-wikitext103" |
|
|
hub_token: Optional[str] = None |
|
|
|
|
|
|
|
|
output_dir: str = "./thetis-outputs" |
|
|
cache_dir: str = "./cache" |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Setup paths and device.""" |
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if self.hub_token is None: |
|
|
self.hub_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
print(f"π’ BERT-Thetis Training Configuration") |
|
|
print(f" Device: {self.device}") |
|
|
print(f" Mixed Precision: {self.mixed_precision}") |
|
|
print(f" Model: {self.model_name}") |
|
|
print(f" Dataset: {self.dataset_name}/{self.dataset_config}") |
|
|
print(f" Output: {self.output_dir}") |
|
|
print(f" Push to Hub: {self.push_to_hub}") |
|
|
if self.push_to_hub: |
|
|
print(f" Hub Repo: {self.hub_model_id}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskedLMDataset(Dataset): |
|
|
"""Dataset for Masked Language Modeling.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
texts, |
|
|
tokenizer, |
|
|
max_length: int = 128, |
|
|
mlm_probability: float = 0.15 |
|
|
): |
|
|
self.texts = texts |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.mlm_probability = mlm_probability |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.texts) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
text = self.texts[idx] |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = encoding["input_ids"].squeeze(0) |
|
|
attention_mask = encoding["attention_mask"].squeeze(0) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
|
|
|
|
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability) |
|
|
|
|
|
|
|
|
special_tokens_mask = self.tokenizer.get_special_tokens_mask( |
|
|
labels.tolist(), already_has_special_tokens=True |
|
|
) |
|
|
probability_matrix.masked_fill_( |
|
|
torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 |
|
|
) |
|
|
|
|
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
|
labels[~masked_indices] = -100 |
|
|
|
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
|
|
input_ids[indices_replaced] = self.tokenizer.mask_token_id |
|
|
|
|
|
|
|
|
indices_random = ( |
|
|
torch.bernoulli(torch.full(labels.shape, 0.5)).bool() |
|
|
& masked_indices |
|
|
& ~indices_replaced |
|
|
) |
|
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
|
|
input_ids[indices_random] = random_words[indices_random] |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels |
|
|
} |
|
|
|
|
|
|
|
|
def prepare_datasets(config: TrainingConfig): |
|
|
"""Load and prepare WikiText-103 datasets.""" |
|
|
print(f"\nπ Loading {config.dataset_name}...") |
|
|
|
|
|
|
|
|
dataset = load_dataset( |
|
|
config.dataset_name, |
|
|
config.dataset_config, |
|
|
cache_dir=config.cache_dir |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
config.tokenizer_name, |
|
|
cache_dir=config.cache_dir |
|
|
) |
|
|
|
|
|
|
|
|
def is_valid(example): |
|
|
return len(example["text"].strip()) > 0 |
|
|
|
|
|
train_texts = [ex["text"] for ex in dataset["train"] if is_valid(ex)] |
|
|
val_texts = [ex["text"] for ex in dataset["validation"] if is_valid(ex)] |
|
|
|
|
|
print(f" Train samples: {len(train_texts):,}") |
|
|
print(f" Val samples: {len(val_texts):,}") |
|
|
|
|
|
|
|
|
train_dataset = MaskedLMDataset( |
|
|
train_texts, |
|
|
tokenizer, |
|
|
config.max_length, |
|
|
config.mlm_probability |
|
|
) |
|
|
|
|
|
val_dataset = MaskedLMDataset( |
|
|
val_texts, |
|
|
tokenizer, |
|
|
config.max_length, |
|
|
config.mlm_probability |
|
|
) |
|
|
|
|
|
return train_dataset, val_dataset, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ThetisTrainer: |
|
|
"""Trainer for BERT-Thetis with MLM.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: ThetisForMaskedLM, |
|
|
train_dataset: Dataset, |
|
|
val_dataset: Dataset, |
|
|
config: TrainingConfig |
|
|
): |
|
|
self.model = model |
|
|
self.train_dataset = train_dataset |
|
|
self.val_dataset = val_dataset |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.model.to(config.device) |
|
|
|
|
|
|
|
|
self.train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory |
|
|
) |
|
|
|
|
|
self.val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.batch_size * 2, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory |
|
|
) |
|
|
|
|
|
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
"weight_decay": config.weight_decay, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
"weight_decay": 0.0, |
|
|
}, |
|
|
] |
|
|
|
|
|
self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate) |
|
|
|
|
|
|
|
|
total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps |
|
|
warmup_steps = int(total_steps * config.warmup_ratio) |
|
|
|
|
|
self.scheduler = OneCycleLR( |
|
|
self.optimizer, |
|
|
max_lr=config.learning_rate, |
|
|
total_steps=total_steps, |
|
|
pct_start=config.warmup_ratio, |
|
|
anneal_strategy="cos" |
|
|
) |
|
|
|
|
|
|
|
|
self.scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and config.device == 'cuda' else None |
|
|
|
|
|
|
|
|
self.global_step = 0 |
|
|
self.epoch = 0 |
|
|
self.best_val_loss = float("inf") |
|
|
|
|
|
print(f"\nπ― Training Setup") |
|
|
print(f" Total steps: {total_steps:,}") |
|
|
print(f" Warmup steps: {warmup_steps:,}") |
|
|
print(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") |
|
|
|
|
|
def train_epoch(self): |
|
|
"""Train for one epoch.""" |
|
|
self.model.train() |
|
|
total_loss = 0 |
|
|
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch + 1}") |
|
|
|
|
|
for step, batch in enumerate(progress_bar): |
|
|
|
|
|
batch = {k: v.to(self.config.device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): |
|
|
loss, _ = self.model( |
|
|
token_ids=batch["input_ids"], |
|
|
attention_mask=batch["attention_mask"], |
|
|
labels=batch["labels"] |
|
|
) |
|
|
loss = loss / self.config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0: |
|
|
if self.scaler is not None: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
|
|
self.optimizer.step() |
|
|
|
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
self.global_step += 1 |
|
|
|
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
"loss": f"{loss.item() * self.config.gradient_accumulation_steps:.4f}", |
|
|
"lr": f"{self.scheduler.get_last_lr()[0]:.2e}" |
|
|
}) |
|
|
|
|
|
|
|
|
if self.global_step % self.config.logging_steps == 0: |
|
|
avg_loss = total_loss / self.config.logging_steps |
|
|
print(f"\n Step {self.global_step}: loss={avg_loss:.4f}, lr={self.scheduler.get_last_lr()[0]:.2e}") |
|
|
total_loss = 0 |
|
|
|
|
|
|
|
|
if self.global_step % self.config.eval_steps == 0: |
|
|
val_loss = self.evaluate() |
|
|
print(f" Validation loss: {val_loss:.4f}") |
|
|
|
|
|
|
|
|
if val_loss < self.best_val_loss: |
|
|
self.best_val_loss = val_loss |
|
|
self.save_checkpoint("best") |
|
|
print(f" β New best model saved!") |
|
|
|
|
|
self.model.train() |
|
|
|
|
|
|
|
|
if self.global_step % self.config.save_steps == 0: |
|
|
self.save_checkpoint(f"step-{self.global_step}") |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self): |
|
|
"""Evaluate on validation set.""" |
|
|
self.model.eval() |
|
|
total_loss = 0 |
|
|
total_steps = 0 |
|
|
|
|
|
for batch in tqdm(self.val_loader, desc="Evaluating", leave=False): |
|
|
batch = {k: v.to(self.config.device) for k, v in batch.items()} |
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): |
|
|
loss, _ = self.model( |
|
|
token_ids=batch["input_ids"], |
|
|
attention_mask=batch["attention_mask"], |
|
|
labels=batch["labels"] |
|
|
) |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_steps += 1 |
|
|
|
|
|
return total_loss / total_steps |
|
|
|
|
|
def train(self): |
|
|
"""Full training loop.""" |
|
|
print(f"\nπ Starting Training") |
|
|
print("=" * 70) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for epoch in range(self.config.num_epochs): |
|
|
self.epoch = epoch |
|
|
print(f"\nπ Epoch {epoch + 1}/{self.config.num_epochs}") |
|
|
|
|
|
self.train_epoch() |
|
|
|
|
|
|
|
|
val_loss = self.evaluate() |
|
|
print(f"\n Epoch {epoch + 1} validation loss: {val_loss:.4f}") |
|
|
|
|
|
|
|
|
self.save_checkpoint(f"epoch-{epoch + 1}") |
|
|
|
|
|
|
|
|
final_val_loss = self.evaluate() |
|
|
print(f"\nβ
Training Complete!") |
|
|
print(f" Final validation loss: {final_val_loss:.4f}") |
|
|
print(f" Best validation loss: {self.best_val_loss:.4f}") |
|
|
print(f" Total time: {(time.time() - start_time) / 3600:.2f} hours") |
|
|
|
|
|
|
|
|
self.save_checkpoint("final") |
|
|
|
|
|
|
|
|
if self.config.push_to_hub: |
|
|
self.push_to_hub() |
|
|
|
|
|
def save_checkpoint(self, name: str): |
|
|
"""Save model checkpoint.""" |
|
|
output_dir = Path(self.config.output_dir) / name |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(self.model.state_dict(), output_dir / "pytorch_model.bin") |
|
|
|
|
|
|
|
|
config_dict = { |
|
|
"crystal_dim": self.config.crystal_dim, |
|
|
"num_layers": self.config.num_layers, |
|
|
"num_attention_heads": self.config.num_attention_heads, |
|
|
"intermediate_size": self.config.intermediate_size, |
|
|
"vocab_size": self.config.vocab_size, |
|
|
"beatrix_levels": self.config.beatrix_levels, |
|
|
"max_position_embeddings": self.config.max_position_embeddings, |
|
|
} |
|
|
|
|
|
import json |
|
|
with open(output_dir / "config.json", "w") as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
|
|
|
|
|
|
state = { |
|
|
"global_step": self.global_step, |
|
|
"epoch": self.epoch, |
|
|
"best_val_loss": self.best_val_loss, |
|
|
} |
|
|
torch.save(state, output_dir / "training_state.pt") |
|
|
|
|
|
def push_to_hub(self): |
|
|
"""Push model to HuggingFace Hub.""" |
|
|
if not self.config.hub_token: |
|
|
print("β οΈ No HuggingFace token found. Skipping push to hub.") |
|
|
return |
|
|
|
|
|
print(f"\nπ€ Pushing to HuggingFace Hub: {self.config.hub_model_id}") |
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
|
|
|
api = HfApi(token=self.config.hub_token) |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo( |
|
|
repo_id=self.config.hub_model_id, |
|
|
token=self.config.hub_token, |
|
|
exist_ok=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f" Repo creation: {e}") |
|
|
|
|
|
|
|
|
best_dir = Path(self.config.output_dir) / "best" |
|
|
if best_dir.exists(): |
|
|
api.upload_folder( |
|
|
folder_path=str(best_dir), |
|
|
repo_id=self.config.hub_model_id, |
|
|
token=self.config.hub_token |
|
|
) |
|
|
print(f" β Best model uploaded!") |
|
|
|
|
|
|
|
|
final_dir = Path(self.config.output_dir) / "final" |
|
|
if final_dir.exists(): |
|
|
api.upload_folder( |
|
|
folder_path=str(final_dir), |
|
|
repo_id=self.config.hub_model_id, |
|
|
path_in_repo="final", |
|
|
token=self.config.hub_token |
|
|
) |
|
|
print(f" β Final model uploaded!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to push to hub: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
|
|
|
config = TrainingConfig() |
|
|
|
|
|
|
|
|
train_dataset, val_dataset, tokenizer = prepare_datasets(config) |
|
|
|
|
|
|
|
|
print(f"\nποΈ Creating BERT-Thetis model...") |
|
|
model_config = ThetisConfig( |
|
|
crystal_dim=config.crystal_dim, |
|
|
num_vertices=5, |
|
|
num_layers=config.num_layers, |
|
|
num_attention_heads=config.num_attention_heads, |
|
|
intermediate_size=config.intermediate_size, |
|
|
vocab_size=config.vocab_size, |
|
|
beatrix_levels=config.beatrix_levels, |
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
) |
|
|
|
|
|
model = ThetisForMaskedLM(model_config) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
print(f" Total parameters: {total_params:,}") |
|
|
print(f" Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
trainer = ThetisTrainer(model, train_dataset, val_dataset, config) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
print("\nπ All done! BERT-Thetis is ready to sail!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |