Buckets:
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from pytorch_optimizer import SOAP | |
| import nltk | |
| from nltk.corpus import gutenberg | |
| from src.tokenizer import CharTokenizer | |
| from src.model import TinyReasonerModel | |
| import os | |
| class GutenbergDataset(Dataset): | |
| def __init__(self, tokenizer, seq_len=256, max_chars=1000000): | |
| self.tokenizer = tokenizer | |
| self.seq_len = seq_len | |
| # Load Gutenberg text | |
| text = "" | |
| files = gutenberg.fileids()[:5] # Use first 5 books for variety | |
| for fileid in files: | |
| text += gutenberg.raw(fileid) | |
| if len(text) > max_chars: | |
| text = text[:max_chars] | |
| break | |
| self.tokens = tokenizer.encode(text) | |
| def __len__(self): | |
| return (len(self.tokens) - 1) // self.seq_len | |
| def __getitem__(self, idx): | |
| start = idx * self.seq_len | |
| end = start + self.seq_len | |
| x = torch.tensor(self.tokens[start:end]).long() | |
| y = torch.tensor(self.tokens[start+1:end+1]).long() | |
| return x, y | |
| def train(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| tokenizer = CharTokenizer() | |
| dataset = GutenbergDataset(tokenizer) | |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | |
| model = TinyReasonerModel(tokenizer.vocab_size).to(device) | |
| # Configure optimizers as requested: | |
| # SOAP/SPlus for recurrent layers, Adam (or diagonal) for embeddings. | |
| # SOAP with max_precond_size=1 is like Adam. | |
| # Separate parameters | |
| embedding_params = list(model.embedding.parameters()) | |
| other_params = [p for n, p in model.named_parameters() if "embedding" not in n] | |
| param_groups = [ | |
| {"params": other_params}, # Default SOAP | |
| {"params": embedding_params, "max_precond_size": 1} # Adam-like behavior in SOAP | |
| ] | |
| optimizer = SOAP(param_groups, lr=1e-3) | |
| criterion = nn.CrossEntropyLoss() | |
| num_epochs = 1 | |
| model.train() | |
| for epoch in range(num_epochs): | |
| total_loss = 0 | |
| for i, (x, y) in enumerate(dataloader): | |
| x, y = x.to(device), y.to(device) | |
| optimizer.zero_grad() | |
| logits, _ = model(x) | |
| loss = criterion(logits.view(-1, tokenizer.vocab_size), y.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| if i % 10 == 0: | |
| print(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.4f}") | |
| # For verification dry-run | |
| if os.environ.get("DRY_RUN") == "1" and i >= 5: | |
| break | |
| avg_loss = total_loss / (i + 1) | |
| print(f"Epoch {epoch} complete. Avg Loss: {avg_loss:.4f}") | |
| # Save model | |
| os.makedirs("models", exist_ok=True) | |
| torch.save(model.state_dict(), "models/pretrained.pt") | |
| print("Model saved to models/pretrained.pt") | |
| if __name__ == "__main__": | |
| train() | |
Xet Storage Details
- Size:
- 2.97 kB
- Xet hash:
- bc3d0218e3b8935099d753dbab2363d0827ad0fad419210cd23314c5e54f5b1b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.