Spaces:
Sleeping
Sleeping
| import torch | |
| import logging | |
| import os | |
| from datetime import datetime | |
| # Global variables for data | |
| train_data = None | |
| val_data = None | |
| def setup_logging(log_dir="logs"): | |
| # Create logs directory if it doesn't exist | |
| os.makedirs(log_dir, exist_ok=True) | |
| # Create a timestamp for the log file | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| log_file = os.path.join(log_dir, f"training_{timestamp}.log") | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler(), # Also print to console | |
| ], | |
| ) | |
| logging.info(f"Logging to {log_file}") | |
| return logging.getLogger(__name__) | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters()) | |
| def get_batch(split, block_size, batch_size): | |
| data = train_data if split == "train" else val_data | |
| ix = torch.randint(len(data) - block_size, (batch_size,)) | |
| x = torch.stack([data[i : i + block_size] for i in ix]) | |
| y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) | |
| return x, y | |
| def prepare_data(text, tokenizer): | |
| """Prepare train and validation data""" | |
| global train_data, val_data | |
| # Encode the text | |
| data = torch.tensor(tokenizer.encode(text), dtype=torch.long) | |
| # Split into train and validation sets | |
| n = int(0.9 * len(data)) | |
| train_data = data[:n] | |
| val_data = data[n:] | |
| def generate(model, tokenizer, prompt, max_tokens, device): | |
| model.eval() | |
| tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device) | |
| block_size = model.config.block_size | |
| for _ in range(max_tokens): | |
| with torch.no_grad(): | |
| logits, _ = model(tokens[:, -block_size:]) | |
| logits = logits[:, -1, :] # / temperature | |
| probs = torch.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| tokens = torch.cat([tokens, next_token], dim=1) | |
| return tokenizer.decode(tokens[0].tolist())[len(prompt) :] | |