File size: 4,580 Bytes
740c342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
import shutil

import torch

from .config import LLMConfig
from .trainer import create_model_and_tokenizer, set_seed, train_model
from .model import TinyTransformerLM
from .tokenizer import CharTokenizer


class LocalMiniLLMService:
    def __init__(self, config: LLMConfig):
        self.config = config
        torch.set_num_threads(max(1, self.config.cpu_threads))
        self.model = None
        self.tokenizer = None

    def generate(self, prompt: str, max_new_tokens: int, temperature: float, top_k: int):
        clean_prompt = prompt or "User: hello\nAssistant:"
        self._ensure_ready()
        encoded = self.tokenizer.encode(clean_prompt)
        if not encoded:
            encoded = self.tokenizer.encode("User: hello\nAssistant:")
        idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
        self.model.eval()
        with torch.inference_mode():
            out = self.model.generate(
                idx,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k,
            )
        text = self.tokenizer.decode(out[0].tolist())
        status = (
            f"Generated text with local tiny transformer. "
            f"Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}, Hidden={self.config.d_model}."
        )
        return text, status

    def train(self, extra_text: str, steps: int):
        steps = max(1, steps)
        training_text = extra_text or ""
        checkpoint_exists = self.config.checkpoint_path.exists()

        if checkpoint_exists:
            self._load_or_initialize(extra_text="")

        model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text)
        if checkpoint_exists and self.model is not None and self.tokenizer is not None:
            if tokenizer.stoi == self.tokenizer.stoi:
                model.load_state_dict(self.model.state_dict())

        losses = train_model(model, encoded, self.config, steps)
        self.model = model
        self.tokenizer = tokenizer
        self._save_checkpoint(extra_text=training_text)

        return (
            f"Training finished.\n"
            f"Steps: {steps}\n"
            f"Start Loss: {losses[0]:.4f}\n"
            f"End Loss: {losses[-1]:.4f}\n"
            f"Checkpoint: {self.config.checkpoint_path}"
        )

    def reset(self):
        checkpoint_dir = self.config.checkpoint_path.parent
        if checkpoint_dir.exists():
            shutil.rmtree(checkpoint_dir)
        self.model = None
        self.tokenizer = None
        return "Model reset. Next generate/train call will rebuild from scratch."

    def _ensure_ready(self):
        if self.model is not None and self.tokenizer is not None:
            return
        self._load_or_initialize(extra_text="")

    def _load_or_initialize(self, extra_text: str):
        checkpoint = self.config.checkpoint_path
        if checkpoint.exists():
            state = torch.load(checkpoint, map_location="cpu")
            self.tokenizer = CharTokenizer.from_state_dict(state["tokenizer"])
            self.model = TinyTransformerLM(
                vocab_size=state["config"]["vocab_size"],
                block_size=state["config"]["block_size"],
                d_model=state["config"]["d_model"],
                n_heads=state["config"]["n_heads"],
                n_layers=state["config"]["n_layers"],
                dropout=state["config"]["dropout"],
            )
            self.model.load_state_dict(state["model"])
            self.model.eval()
            return

        set_seed(self.config.seed)
        self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text)
        train_model(self.model, encoded, self.config, self.config.bootstrap_steps)
        self._save_checkpoint(extra_text=extra_text)

    def _save_checkpoint(self, extra_text: str):
        checkpoint = self.config.checkpoint_path
        checkpoint.parent.mkdir(parents=True, exist_ok=True)
        state = {
            "model": self.model.state_dict(),
            "tokenizer": self.tokenizer.state_dict(),
            "config": {
                "vocab_size": self.tokenizer.vocab_size,
                "block_size": self.config.block_size,
                "d_model": self.config.d_model,
                "n_heads": self.config.n_heads,
                "n_layers": self.config.n_layers,
                "dropout": self.config.dropout,
                "extra_text": extra_text,
            },
        }
        torch.save(state, checkpoint)