File size: 5,946 Bytes
64cea89
223832a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc28463
223832a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Sage 1B Space - rebuilt
import gradio as gr
import torch
import torch.nn as nn
import math
import json
from huggingface_hub import hf_hub_download

REPO_ID = "itriedcoding/Sage-1B"

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=128):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        self._cos = None
        self._sin = None
    def get_cos_sin(self, x, seq_len=None):
        seq_len = seq_len or x.size(1)
        if self._cos is None or self._cos.size(-2) < seq_len:
            t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)[None, None]
            self._cos = emb.cos()
            self._sin = emb.sin()
        return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :]

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary(x, c, s):
    return (x * c) + (rotate_half(x) * s)

class Attention(nn.Module):
    def __init__(self, h, nh, hd):
        super().__init__()
        self.h = h; self.nh = nh; self.hd = hd
        self.q = nn.Linear(h, h, bias=False)
        self.k = nn.Linear(h, h, bias=False)
        self.v = nn.Linear(h, h, bias=False)
        self.o = nn.Linear(h, h, bias=False)
    def forward(self, x, cos, sin, mask):
        B, T, _ = x.shape
        q = self.q(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
        k = self.k(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
        v = self.v(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
        q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
        a = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hd)
        a = a + mask[:, :, :T, :T]
        a = torch.nn.functional.softmax(a, dim=-1)
        return self.o(a.matmul(v).transpose(1, 2).reshape(B, T, self.h))

class FF(nn.Module):
    def __init__(self, h, i):
        super().__init__()
        self.g = nn.Linear(h, i, bias=False)
        self.u = nn.Linear(h, i, bias=False)
        self.d = nn.Linear(i, h, bias=False)
    def forward(self, x):
        return self.d(torch.nn.functional.silu(self.g(x)) * self.u(x))

class Block(nn.Module):
    def __init__(self, h, nh, hd, i):
        super().__init__()
        self.an = nn.RMSNorm(h, eps=1e-6)
        self.fn = nn.RMSNorm(h, eps=1e-6)
        self.attn = Attention(h, nh, hd)
        self.ff = FF(h, i)
    def forward(self, x, c, s, m):
        x = x + self.attn(self.an(x), c, s, m)
        x = x + self.ff(self.fn(x))
        return x

class Sage1B(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.embed = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
        self.layers = nn.ModuleList([
            Block(cfg["hidden_size"], cfg["num_attention_heads"],
                  cfg["head_dim"], cfg["intermediate_size"])
            for _ in range(cfg["num_hidden_layers"])
        ])
        self.norm = nn.RMSNorm(cfg["hidden_size"], eps=1e-6)
        self.head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
        self.rotary = RotaryEmbedding(cfg["head_dim"])
        self.max_seq_len = cfg["max_position_embeddings"]
        self.vocab_size = cfg["vocab_size"]
        self.hidden_size = cfg["hidden_size"]

    def forward(self, inp):
        B, T = inp.shape
        x = self.embed(inp) * math.sqrt(self.hidden_size)
        cos, sin = self.rotary.get_cos_sin(x, T)
        mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)[None, None]
        for l in self.layers:
            x = l(x, cos, sin, mask)
        x = self.norm(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, inp, max_new=50, temp=0.8, top_k=40):
        self.eval()
        for _ in range(max_new):
            if inp.size(1) > self.max_seq_len:
                inp = inp[:, -self.max_seq_len:]
            logits = self.forward(inp)[:, -1, :] / temp
            if top_k > 0:
                vals = torch.topk(logits, top_k).values[:, -1:]
                logits[logits < vals] = float("-inf")
            probs = torch.nn.functional.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, 1)
            inp = torch.cat([inp, nxt], dim=1)
            if nxt.item() == 3:
                break
        return inp

from tokenizers import Tokenizer as Tk
print("Loading Sage 1B...")

cfg_p = hf_hub_download(REPO_ID, "config.json")
with open(cfg_p) as f:
    cfg = json.load(f)

tok = Tk.from_file(hf_hub_download(REPO_ID, "tokenizer.json"))

model = Sage1B(cfg)
sd = torch.load(hf_hub_download(REPO_ID, "pytorch_model_state.bin"),
                map_location="cpu", weights_only=True)
model.load_state_dict({k: v for k, v in sd.items() if "rotary" not in k}, strict=False)
model.eval()
print(f"Sage 1B loaded - {sum(p.numel() for p in model.parameters()):,} params")

def generate_text(prompt, max_length, temperature):
    tokens = tok.encode(prompt).ids[:50]
    inp = torch.tensor([[2] + tokens], dtype=torch.long)
    out = model.generate(inp, max_new=int(max_length), temp=temperature, top_k=40)
    return tok.decode(out[0].tolist(), skip_special_tokens=True)

demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="Once upon a time"),
        gr.Slider(10, 100, 30, step=1, label="Max Length"),
        gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature"),
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Sage 1B",
    description="Custom 1.286B parameter language model from scratch.",
    examples=[["Once upon a time", 30, 0.8],
              ["The story begins", 30, 0.8],
              ["In a world", 30, 0.8]],
)

if __name__ == "__main__":
    demo.launch()