# Sage 1B Space - rebuilt import gradio as gr import torch import torch.nn as nn import math import json from huggingface_hub import hf_hub_download REPO_ID = "itriedcoding/Sage-1B" class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=128): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.max_seq_len = max_seq_len self._cos = None self._sin = None def get_cos_sin(self, x, seq_len=None): seq_len = seq_len or x.size(1) if self._cos is None or self._cos.size(-2) < seq_len: t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1)[None, None] self._cos = emb.cos() self._sin = emb.sin() return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary(x, c, s): return (x * c) + (rotate_half(x) * s) class Attention(nn.Module): def __init__(self, h, nh, hd): super().__init__() self.h = h; self.nh = nh; self.hd = hd self.q = nn.Linear(h, h, bias=False) self.k = nn.Linear(h, h, bias=False) self.v = nn.Linear(h, h, bias=False) self.o = nn.Linear(h, h, bias=False) def forward(self, x, cos, sin, mask): B, T, _ = x.shape q = self.q(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) k = self.k(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) v = self.v(x).reshape(B, T, self.nh, self.hd).transpose(1, 2) q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin) a = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hd) a = a + mask[:, :, :T, :T] a = torch.nn.functional.softmax(a, dim=-1) return self.o(a.matmul(v).transpose(1, 2).reshape(B, T, self.h)) class FF(nn.Module): def __init__(self, h, i): super().__init__() self.g = nn.Linear(h, i, bias=False) self.u = nn.Linear(h, i, bias=False) self.d = nn.Linear(i, h, bias=False) def forward(self, x): return self.d(torch.nn.functional.silu(self.g(x)) * self.u(x)) class Block(nn.Module): def __init__(self, h, nh, hd, i): super().__init__() self.an = nn.RMSNorm(h, eps=1e-6) self.fn = nn.RMSNorm(h, eps=1e-6) self.attn = Attention(h, nh, hd) self.ff = FF(h, i) def forward(self, x, c, s, m): x = x + self.attn(self.an(x), c, s, m) x = x + self.ff(self.fn(x)) return x class Sage1B(nn.Module): def __init__(self, cfg): super().__init__() self.embed = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"]) self.layers = nn.ModuleList([ Block(cfg["hidden_size"], cfg["num_attention_heads"], cfg["head_dim"], cfg["intermediate_size"]) for _ in range(cfg["num_hidden_layers"]) ]) self.norm = nn.RMSNorm(cfg["hidden_size"], eps=1e-6) self.head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False) self.rotary = RotaryEmbedding(cfg["head_dim"]) self.max_seq_len = cfg["max_position_embeddings"] self.vocab_size = cfg["vocab_size"] self.hidden_size = cfg["hidden_size"] def forward(self, inp): B, T = inp.shape x = self.embed(inp) * math.sqrt(self.hidden_size) cos, sin = self.rotary.get_cos_sin(x, T) mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)[None, None] for l in self.layers: x = l(x, cos, sin, mask) x = self.norm(x) return self.head(x) @torch.no_grad() def generate(self, inp, max_new=50, temp=0.8, top_k=40): self.eval() for _ in range(max_new): if inp.size(1) > self.max_seq_len: inp = inp[:, -self.max_seq_len:] logits = self.forward(inp)[:, -1, :] / temp if top_k > 0: vals = torch.topk(logits, top_k).values[:, -1:] logits[logits < vals] = float("-inf") probs = torch.nn.functional.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) inp = torch.cat([inp, nxt], dim=1) if nxt.item() == 3: break return inp from tokenizers import Tokenizer as Tk print("Loading Sage 1B...") cfg_p = hf_hub_download(REPO_ID, "config.json") with open(cfg_p) as f: cfg = json.load(f) tok = Tk.from_file(hf_hub_download(REPO_ID, "tokenizer.json")) model = Sage1B(cfg) sd = torch.load(hf_hub_download(REPO_ID, "pytorch_model_state.bin"), map_location="cpu", weights_only=True) model.load_state_dict({k: v for k, v in sd.items() if "rotary" not in k}, strict=False) model.eval() print(f"Sage 1B loaded - {sum(p.numel() for p in model.parameters()):,} params") def generate_text(prompt, max_length, temperature): tokens = tok.encode(prompt).ids[:50] inp = torch.tensor([[2] + tokens], dtype=torch.long) out = model.generate(inp, max_new=int(max_length), temp=temperature, top_k=40) return tok.decode(out[0].tolist(), skip_special_tokens=True) demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Prompt", value="Once upon a time"), gr.Slider(10, 100, 30, step=1, label="Max Length"), gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature"), ], outputs=gr.Textbox(label="Generated Text"), title="Sage 1B", description="Custom 1.286B parameter language model from scratch.", examples=[["Once upon a time", 30, 0.8], ["The story begins", 30, 0.8], ["In a world", 30, 0.8]], ) if __name__ == "__main__": demo.launch()