import shutil import torch from .config import SmallGPTConfig from .model import SmallGPTModel from .tokenizer import WordTokenizer from .trainer import create_model_and_tokenizer, set_seed, train_model class SmallGPTService: def __init__(self, config: SmallGPTConfig): self.config = config torch.set_num_threads(max(1, self.config.cpu_threads)) self.model = None self.tokenizer = None def generate(self, prompt: str, max_new_tokens: int, temperature: float, top_k: int): clean_prompt = prompt or "User: hello\nAssistant:" self._ensure_ready() encoded = self.tokenizer.encode(clean_prompt, add_bos=True) idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0) self.model.eval() with torch.inference_mode(): output = self.model.generate( idx=idx, max_new_tokens=max_new_tokens, eos_id=self.tokenizer.eos_id, temperature=temperature, top_k=top_k, ) text = self.tokenizer.decode(output[0].tolist()) status = ( f"Generated with small GPT Python. " f"Architecture=causal transformer, Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}." ) return text, status def train(self, extra_text: str, steps: int): steps = max(1, steps) checkpoint_exists = self.config.checkpoint_path.exists() training_text = extra_text or "" if checkpoint_exists: self._load_or_initialize(extra_text="") model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text) if checkpoint_exists and self.model is not None and self.tokenizer is not None: if tokenizer.stoi == self.tokenizer.stoi: model.load_state_dict(self.model.state_dict()) losses = train_model(model, encoded, self.config, steps) self.model = model self.tokenizer = tokenizer self._save_checkpoint(extra_text=training_text) return ( f"small GPT training finished.\n" f"Steps: {steps}\n" f"Start Loss: {losses[0]:.4f}\n" f"End Loss: {losses[-1]:.4f}\n" f"Checkpoint: {self.config.checkpoint_path}" ) def reset(self): checkpoint_dir = self.config.checkpoint_path.parent if checkpoint_dir.exists(): shutil.rmtree(checkpoint_dir) self.model = None self.tokenizer = None return "small GPT reset complete. Next train or generate call will rebuild from scratch." def _ensure_ready(self): if self.model is not None and self.tokenizer is not None: return self._load_or_initialize(extra_text="") def _load_or_initialize(self, extra_text: str): checkpoint = self.config.checkpoint_path if checkpoint.exists(): state = torch.load(checkpoint, map_location="cpu") self.tokenizer = WordTokenizer.from_state_dict(state["tokenizer"]) self.model = SmallGPTModel( vocab_size=state["config"]["vocab_size"], block_size=state["config"]["block_size"], d_model=state["config"]["d_model"], n_heads=state["config"]["n_heads"], n_layers=state["config"]["n_layers"], dropout=state["config"]["dropout"], ) self.model.load_state_dict(state["model"]) self.model.eval() return set_seed(self.config.seed) self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text) train_model(self.model, encoded, self.config, self.config.bootstrap_steps) self._save_checkpoint(extra_text=extra_text) def _save_checkpoint(self, extra_text: str): checkpoint = self.config.checkpoint_path checkpoint.parent.mkdir(parents=True, exist_ok=True) torch.save( { "model": self.model.state_dict(), "tokenizer": self.tokenizer.state_dict(), "config": { "vocab_size": self.tokenizer.vocab_size, "block_size": self.config.block_size, "d_model": self.config.d_model, "n_heads": self.config.n_heads, "n_layers": self.config.n_layers, "dropout": self.config.dropout, "extra_text": extra_text, }, }, checkpoint, )