File size: 4,609 Bytes
79078fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import shutil

import torch

from .config import SmallGPTConfig
from .model import SmallGPTModel
from .tokenizer import WordTokenizer
from .trainer import create_model_and_tokenizer, set_seed, train_model


class SmallGPTService:
    def __init__(self, config: SmallGPTConfig):
        self.config = config
        torch.set_num_threads(max(1, self.config.cpu_threads))
        self.model = None
        self.tokenizer = None

    def generate(self, prompt: str, max_new_tokens: int, temperature: float, top_k: int):
        clean_prompt = prompt or "User: hello\nAssistant:"
        self._ensure_ready()
        encoded = self.tokenizer.encode(clean_prompt, add_bos=True)
        idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
        self.model.eval()

        with torch.inference_mode():
            output = self.model.generate(
                idx=idx,
                max_new_tokens=max_new_tokens,
                eos_id=self.tokenizer.eos_id,
                temperature=temperature,
                top_k=top_k,
            )

        text = self.tokenizer.decode(output[0].tolist())
        status = (
            f"Generated with small GPT Python. "
            f"Architecture=causal transformer, Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}."
        )
        return text, status

    def train(self, extra_text: str, steps: int):
        steps = max(1, steps)
        checkpoint_exists = self.config.checkpoint_path.exists()
        training_text = extra_text or ""

        if checkpoint_exists:
            self._load_or_initialize(extra_text="")

        model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text)
        if checkpoint_exists and self.model is not None and self.tokenizer is not None:
            if tokenizer.stoi == self.tokenizer.stoi:
                model.load_state_dict(self.model.state_dict())

        losses = train_model(model, encoded, self.config, steps)
        self.model = model
        self.tokenizer = tokenizer
        self._save_checkpoint(extra_text=training_text)

        return (
            f"small GPT training finished.\n"
            f"Steps: {steps}\n"
            f"Start Loss: {losses[0]:.4f}\n"
            f"End Loss: {losses[-1]:.4f}\n"
            f"Checkpoint: {self.config.checkpoint_path}"
        )

    def reset(self):
        checkpoint_dir = self.config.checkpoint_path.parent
        if checkpoint_dir.exists():
            shutil.rmtree(checkpoint_dir)
        self.model = None
        self.tokenizer = None
        return "small GPT reset complete. Next train or generate call will rebuild from scratch."

    def _ensure_ready(self):
        if self.model is not None and self.tokenizer is not None:
            return
        self._load_or_initialize(extra_text="")

    def _load_or_initialize(self, extra_text: str):
        checkpoint = self.config.checkpoint_path
        if checkpoint.exists():
            state = torch.load(checkpoint, map_location="cpu")
            self.tokenizer = WordTokenizer.from_state_dict(state["tokenizer"])
            self.model = SmallGPTModel(
                vocab_size=state["config"]["vocab_size"],
                block_size=state["config"]["block_size"],
                d_model=state["config"]["d_model"],
                n_heads=state["config"]["n_heads"],
                n_layers=state["config"]["n_layers"],
                dropout=state["config"]["dropout"],
            )
            self.model.load_state_dict(state["model"])
            self.model.eval()
            return

        set_seed(self.config.seed)
        self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text)
        train_model(self.model, encoded, self.config, self.config.bootstrap_steps)
        self._save_checkpoint(extra_text=extra_text)

    def _save_checkpoint(self, extra_text: str):
        checkpoint = self.config.checkpoint_path
        checkpoint.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "model": self.model.state_dict(),
                "tokenizer": self.tokenizer.state_dict(),
                "config": {
                    "vocab_size": self.tokenizer.vocab_size,
                    "block_size": self.config.block_size,
                    "d_model": self.config.d_model,
                    "n_heads": self.config.n_heads,
                    "n_layers": self.config.n_layers,
                    "dropout": self.config.dropout,
                    "extra_text": extra_text,
                },
            },
            checkpoint,
        )