sage-1b-space / app.py
itriedcoding's picture
Upload app.py with huggingface_hub
dc28463 verified
# Sage 1B Space - rebuilt
import gradio as gr
import torch
import torch.nn as nn
import math
import json
from huggingface_hub import hf_hub_download
REPO_ID = "itriedcoding/Sage-1B"
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=128):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._cos = None
self._sin = None
def get_cos_sin(self, x, seq_len=None):
seq_len = seq_len or x.size(1)
if self._cos is None or self._cos.size(-2) < seq_len:
t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)[None, None]
self._cos = emb.cos()
self._sin = emb.sin()
return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(x, c, s):
return (x * c) + (rotate_half(x) * s)
class Attention(nn.Module):
def __init__(self, h, nh, hd):
super().__init__()
self.h = h; self.nh = nh; self.hd = hd
self.q = nn.Linear(h, h, bias=False)
self.k = nn.Linear(h, h, bias=False)
self.v = nn.Linear(h, h, bias=False)
self.o = nn.Linear(h, h, bias=False)
def forward(self, x, cos, sin, mask):
B, T, _ = x.shape
q = self.q(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
v = self.v(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
a = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hd)
a = a + mask[:, :, :T, :T]
a = torch.nn.functional.softmax(a, dim=-1)
return self.o(a.matmul(v).transpose(1, 2).reshape(B, T, self.h))
class FF(nn.Module):
def __init__(self, h, i):
super().__init__()
self.g = nn.Linear(h, i, bias=False)
self.u = nn.Linear(h, i, bias=False)
self.d = nn.Linear(i, h, bias=False)
def forward(self, x):
return self.d(torch.nn.functional.silu(self.g(x)) * self.u(x))
class Block(nn.Module):
def __init__(self, h, nh, hd, i):
super().__init__()
self.an = nn.RMSNorm(h, eps=1e-6)
self.fn = nn.RMSNorm(h, eps=1e-6)
self.attn = Attention(h, nh, hd)
self.ff = FF(h, i)
def forward(self, x, c, s, m):
x = x + self.attn(self.an(x), c, s, m)
x = x + self.ff(self.fn(x))
return x
class Sage1B(nn.Module):
def __init__(self, cfg):
super().__init__()
self.embed = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
self.layers = nn.ModuleList([
Block(cfg["hidden_size"], cfg["num_attention_heads"],
cfg["head_dim"], cfg["intermediate_size"])
for _ in range(cfg["num_hidden_layers"])
])
self.norm = nn.RMSNorm(cfg["hidden_size"], eps=1e-6)
self.head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
self.rotary = RotaryEmbedding(cfg["head_dim"])
self.max_seq_len = cfg["max_position_embeddings"]
self.vocab_size = cfg["vocab_size"]
self.hidden_size = cfg["hidden_size"]
def forward(self, inp):
B, T = inp.shape
x = self.embed(inp) * math.sqrt(self.hidden_size)
cos, sin = self.rotary.get_cos_sin(x, T)
mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)[None, None]
for l in self.layers:
x = l(x, cos, sin, mask)
x = self.norm(x)
return self.head(x)
@torch.no_grad()
def generate(self, inp, max_new=50, temp=0.8, top_k=40):
self.eval()
for _ in range(max_new):
if inp.size(1) > self.max_seq_len:
inp = inp[:, -self.max_seq_len:]
logits = self.forward(inp)[:, -1, :] / temp
if top_k > 0:
vals = torch.topk(logits, top_k).values[:, -1:]
logits[logits < vals] = float("-inf")
probs = torch.nn.functional.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
inp = torch.cat([inp, nxt], dim=1)
if nxt.item() == 3:
break
return inp
from tokenizers import Tokenizer as Tk
print("Loading Sage 1B...")
cfg_p = hf_hub_download(REPO_ID, "config.json")
with open(cfg_p) as f:
cfg = json.load(f)
tok = Tk.from_file(hf_hub_download(REPO_ID, "tokenizer.json"))
model = Sage1B(cfg)
sd = torch.load(hf_hub_download(REPO_ID, "pytorch_model_state.bin"),
map_location="cpu", weights_only=True)
model.load_state_dict({k: v for k, v in sd.items() if "rotary" not in k}, strict=False)
model.eval()
print(f"Sage 1B loaded - {sum(p.numel() for p in model.parameters()):,} params")
def generate_text(prompt, max_length, temperature):
tokens = tok.encode(prompt).ids[:50]
inp = torch.tensor([[2] + tokens], dtype=torch.long)
out = model.generate(inp, max_new=int(max_length), temp=temperature, top_k=40)
return tok.decode(out[0].tolist(), skip_special_tokens=True)
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="Once upon a time"),
gr.Slider(10, 100, 30, step=1, label="Max Length"),
gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Sage 1B",
description="Custom 1.286B parameter language model from scratch.",
examples=[["Once upon a time", 30, 0.8],
["The story begins", 30, 0.8],
["In a world", 30, 0.8]],
)
if __name__ == "__main__":
demo.launch()