Spaces:
Sleeping
Sleeping
| # Sage 1B Space - rebuilt | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "itriedcoding/Sage-1B" | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim, max_seq_len=128): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self.max_seq_len = max_seq_len | |
| self._cos = None | |
| self._sin = None | |
| def get_cos_sin(self, x, seq_len=None): | |
| seq_len = seq_len or x.size(1) | |
| if self._cos is None or self._cos.size(-2) < seq_len: | |
| t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1)[None, None] | |
| self._cos = emb.cos() | |
| self._sin = emb.sin() | |
| return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :] | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary(x, c, s): | |
| return (x * c) + (rotate_half(x) * s) | |
| class Attention(nn.Module): | |
| def __init__(self, h, nh, hd): | |
| super().__init__() | |
| self.h = h; self.nh = nh; self.hd = hd | |
| self.q = nn.Linear(h, h, bias=False) | |
| self.k = nn.Linear(h, h, bias=False) | |
| self.v = nn.Linear(h, h, bias=False) | |
| self.o = nn.Linear(h, h, bias=False) | |
| def forward(self, x, cos, sin, mask): | |
| B, T, _ = x.shape | |
| q = self.q(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) | |
| k = self.k(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) | |
| v = self.v(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) | |
| q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin) | |
| a = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hd) | |
| a = a + mask[:, :, :T, :T] | |
| a = torch.nn.functional.softmax(a, dim=-1) | |
| return self.o(a.matmul(v).transpose(1, 2).reshape(B, T, self.h)) | |
| class FF(nn.Module): | |
| def __init__(self, h, i): | |
| super().__init__() | |
| self.g = nn.Linear(h, i, bias=False) | |
| self.u = nn.Linear(h, i, bias=False) | |
| self.d = nn.Linear(i, h, bias=False) | |
| def forward(self, x): | |
| return self.d(torch.nn.functional.silu(self.g(x)) * self.u(x)) | |
| class Block(nn.Module): | |
| def __init__(self, h, nh, hd, i): | |
| super().__init__() | |
| self.an = nn.RMSNorm(h, eps=1e-6) | |
| self.fn = nn.RMSNorm(h, eps=1e-6) | |
| self.attn = Attention(h, nh, hd) | |
| self.ff = FF(h, i) | |
| def forward(self, x, c, s, m): | |
| x = x + self.attn(self.an(x), c, s, m) | |
| x = x + self.ff(self.fn(x)) | |
| return x | |
| class Sage1B(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.embed = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"]) | |
| self.layers = nn.ModuleList([ | |
| Block(cfg["hidden_size"], cfg["num_attention_heads"], | |
| cfg["head_dim"], cfg["intermediate_size"]) | |
| for _ in range(cfg["num_hidden_layers"]) | |
| ]) | |
| self.norm = nn.RMSNorm(cfg["hidden_size"], eps=1e-6) | |
| self.head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False) | |
| self.rotary = RotaryEmbedding(cfg["head_dim"]) | |
| self.max_seq_len = cfg["max_position_embeddings"] | |
| self.vocab_size = cfg["vocab_size"] | |
| self.hidden_size = cfg["hidden_size"] | |
| def forward(self, inp): | |
| B, T = inp.shape | |
| x = self.embed(inp) * math.sqrt(self.hidden_size) | |
| cos, sin = self.rotary.get_cos_sin(x, T) | |
| mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)[None, None] | |
| for l in self.layers: | |
| x = l(x, cos, sin, mask) | |
| x = self.norm(x) | |
| return self.head(x) | |
| def generate(self, inp, max_new=50, temp=0.8, top_k=40): | |
| self.eval() | |
| for _ in range(max_new): | |
| if inp.size(1) > self.max_seq_len: | |
| inp = inp[:, -self.max_seq_len:] | |
| logits = self.forward(inp)[:, -1, :] / temp | |
| if top_k > 0: | |
| vals = torch.topk(logits, top_k).values[:, -1:] | |
| logits[logits < vals] = float("-inf") | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| nxt = torch.multinomial(probs, 1) | |
| inp = torch.cat([inp, nxt], dim=1) | |
| if nxt.item() == 3: | |
| break | |
| return inp | |
| from tokenizers import Tokenizer as Tk | |
| print("Loading Sage 1B...") | |
| cfg_p = hf_hub_download(REPO_ID, "config.json") | |
| with open(cfg_p) as f: | |
| cfg = json.load(f) | |
| tok = Tk.from_file(hf_hub_download(REPO_ID, "tokenizer.json")) | |
| model = Sage1B(cfg) | |
| sd = torch.load(hf_hub_download(REPO_ID, "pytorch_model_state.bin"), | |
| map_location="cpu", weights_only=True) | |
| model.load_state_dict({k: v for k, v in sd.items() if "rotary" not in k}, strict=False) | |
| model.eval() | |
| print(f"Sage 1B loaded - {sum(p.numel() for p in model.parameters()):,} params") | |
| def generate_text(prompt, max_length, temperature): | |
| tokens = tok.encode(prompt).ids[:50] | |
| inp = torch.tensor([[2] + tokens], dtype=torch.long) | |
| out = model.generate(inp, max_new=int(max_length), temp=temperature, top_k=40) | |
| return tok.decode(out[0].tolist(), skip_special_tokens=True) | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", value="Once upon a time"), | |
| gr.Slider(10, 100, 30, step=1, label="Max Length"), | |
| gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature"), | |
| ], | |
| outputs=gr.Textbox(label="Generated Text"), | |
| title="Sage 1B", | |
| description="Custom 1.286B parameter language model from scratch.", | |
| examples=[["Once upon a time", 30, 0.8], | |
| ["The story begins", 30, 0.8], | |
| ["In a world", 30, 0.8]], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |