Spaces:
Sleeping
Sleeping
| import shutil | |
| import torch | |
| from .config import SmallGPTConfig | |
| from .model import SmallGPTModel | |
| from .tokenizer import WordTokenizer | |
| from .trainer import create_model_and_tokenizer, set_seed, train_model | |
| class SmallGPTService: | |
| def __init__(self, config: SmallGPTConfig): | |
| self.config = config | |
| torch.set_num_threads(max(1, self.config.cpu_threads)) | |
| self.model = None | |
| self.tokenizer = None | |
| def generate(self, prompt: str, max_new_tokens: int, temperature: float, top_k: int): | |
| clean_prompt = prompt or "User: hello\nAssistant:" | |
| self._ensure_ready() | |
| encoded = self.tokenizer.encode(clean_prompt, add_bos=True) | |
| idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0) | |
| self.model.eval() | |
| with torch.inference_mode(): | |
| output = self.model.generate( | |
| idx=idx, | |
| max_new_tokens=max_new_tokens, | |
| eos_id=self.tokenizer.eos_id, | |
| temperature=temperature, | |
| top_k=top_k, | |
| ) | |
| text = self.tokenizer.decode(output[0].tolist()) | |
| status = ( | |
| f"Generated with small GPT Python. " | |
| f"Architecture=causal transformer, Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}." | |
| ) | |
| return text, status | |
| def train(self, extra_text: str, steps: int): | |
| steps = max(1, steps) | |
| checkpoint_exists = self.config.checkpoint_path.exists() | |
| training_text = extra_text or "" | |
| if checkpoint_exists: | |
| self._load_or_initialize(extra_text="") | |
| model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text) | |
| if checkpoint_exists and self.model is not None and self.tokenizer is not None: | |
| if tokenizer.stoi == self.tokenizer.stoi: | |
| model.load_state_dict(self.model.state_dict()) | |
| losses = train_model(model, encoded, self.config, steps) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self._save_checkpoint(extra_text=training_text) | |
| return ( | |
| f"small GPT training finished.\n" | |
| f"Steps: {steps}\n" | |
| f"Start Loss: {losses[0]:.4f}\n" | |
| f"End Loss: {losses[-1]:.4f}\n" | |
| f"Checkpoint: {self.config.checkpoint_path}" | |
| ) | |
| def reset(self): | |
| checkpoint_dir = self.config.checkpoint_path.parent | |
| if checkpoint_dir.exists(): | |
| shutil.rmtree(checkpoint_dir) | |
| self.model = None | |
| self.tokenizer = None | |
| return "small GPT reset complete. Next train or generate call will rebuild from scratch." | |
| def _ensure_ready(self): | |
| if self.model is not None and self.tokenizer is not None: | |
| return | |
| self._load_or_initialize(extra_text="") | |
| def _load_or_initialize(self, extra_text: str): | |
| checkpoint = self.config.checkpoint_path | |
| if checkpoint.exists(): | |
| state = torch.load(checkpoint, map_location="cpu") | |
| self.tokenizer = WordTokenizer.from_state_dict(state["tokenizer"]) | |
| self.model = SmallGPTModel( | |
| vocab_size=state["config"]["vocab_size"], | |
| block_size=state["config"]["block_size"], | |
| d_model=state["config"]["d_model"], | |
| n_heads=state["config"]["n_heads"], | |
| n_layers=state["config"]["n_layers"], | |
| dropout=state["config"]["dropout"], | |
| ) | |
| self.model.load_state_dict(state["model"]) | |
| self.model.eval() | |
| return | |
| set_seed(self.config.seed) | |
| self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text) | |
| train_model(self.model, encoded, self.config, self.config.bootstrap_steps) | |
| self._save_checkpoint(extra_text=extra_text) | |
| def _save_checkpoint(self, extra_text: str): | |
| checkpoint = self.config.checkpoint_path | |
| checkpoint.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "model": self.model.state_dict(), | |
| "tokenizer": self.tokenizer.state_dict(), | |
| "config": { | |
| "vocab_size": self.tokenizer.vocab_size, | |
| "block_size": self.config.block_size, | |
| "d_model": self.config.d_model, | |
| "n_heads": self.config.n_heads, | |
| "n_layers": self.config.n_layers, | |
| "dropout": self.config.dropout, | |
| "extra_text": extra_text, | |
| }, | |
| }, | |
| checkpoint, | |
| ) | |