Spaces:
Sleeping
Sleeping
File size: 5,946 Bytes
64cea89 223832a dc28463 223832a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | # Sage 1B Space - rebuilt
import gradio as gr
import torch
import torch.nn as nn
import math
import json
from huggingface_hub import hf_hub_download
REPO_ID = "itriedcoding/Sage-1B"
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=128):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._cos = None
self._sin = None
def get_cos_sin(self, x, seq_len=None):
seq_len = seq_len or x.size(1)
if self._cos is None or self._cos.size(-2) < seq_len:
t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)[None, None]
self._cos = emb.cos()
self._sin = emb.sin()
return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(x, c, s):
return (x * c) + (rotate_half(x) * s)
class Attention(nn.Module):
def __init__(self, h, nh, hd):
super().__init__()
self.h = h; self.nh = nh; self.hd = hd
self.q = nn.Linear(h, h, bias=False)
self.k = nn.Linear(h, h, bias=False)
self.v = nn.Linear(h, h, bias=False)
self.o = nn.Linear(h, h, bias=False)
def forward(self, x, cos, sin, mask):
B, T, _ = x.shape
q = self.q(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
v = self.v(x).reshape(B, T, self.nh, self.hd).transpose(1, 2)
q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
a = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hd)
a = a + mask[:, :, :T, :T]
a = torch.nn.functional.softmax(a, dim=-1)
return self.o(a.matmul(v).transpose(1, 2).reshape(B, T, self.h))
class FF(nn.Module):
def __init__(self, h, i):
super().__init__()
self.g = nn.Linear(h, i, bias=False)
self.u = nn.Linear(h, i, bias=False)
self.d = nn.Linear(i, h, bias=False)
def forward(self, x):
return self.d(torch.nn.functional.silu(self.g(x)) * self.u(x))
class Block(nn.Module):
def __init__(self, h, nh, hd, i):
super().__init__()
self.an = nn.RMSNorm(h, eps=1e-6)
self.fn = nn.RMSNorm(h, eps=1e-6)
self.attn = Attention(h, nh, hd)
self.ff = FF(h, i)
def forward(self, x, c, s, m):
x = x + self.attn(self.an(x), c, s, m)
x = x + self.ff(self.fn(x))
return x
class Sage1B(nn.Module):
def __init__(self, cfg):
super().__init__()
self.embed = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
self.layers = nn.ModuleList([
Block(cfg["hidden_size"], cfg["num_attention_heads"],
cfg["head_dim"], cfg["intermediate_size"])
for _ in range(cfg["num_hidden_layers"])
])
self.norm = nn.RMSNorm(cfg["hidden_size"], eps=1e-6)
self.head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
self.rotary = RotaryEmbedding(cfg["head_dim"])
self.max_seq_len = cfg["max_position_embeddings"]
self.vocab_size = cfg["vocab_size"]
self.hidden_size = cfg["hidden_size"]
def forward(self, inp):
B, T = inp.shape
x = self.embed(inp) * math.sqrt(self.hidden_size)
cos, sin = self.rotary.get_cos_sin(x, T)
mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)[None, None]
for l in self.layers:
x = l(x, cos, sin, mask)
x = self.norm(x)
return self.head(x)
@torch.no_grad()
def generate(self, inp, max_new=50, temp=0.8, top_k=40):
self.eval()
for _ in range(max_new):
if inp.size(1) > self.max_seq_len:
inp = inp[:, -self.max_seq_len:]
logits = self.forward(inp)[:, -1, :] / temp
if top_k > 0:
vals = torch.topk(logits, top_k).values[:, -1:]
logits[logits < vals] = float("-inf")
probs = torch.nn.functional.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
inp = torch.cat([inp, nxt], dim=1)
if nxt.item() == 3:
break
return inp
from tokenizers import Tokenizer as Tk
print("Loading Sage 1B...")
cfg_p = hf_hub_download(REPO_ID, "config.json")
with open(cfg_p) as f:
cfg = json.load(f)
tok = Tk.from_file(hf_hub_download(REPO_ID, "tokenizer.json"))
model = Sage1B(cfg)
sd = torch.load(hf_hub_download(REPO_ID, "pytorch_model_state.bin"),
map_location="cpu", weights_only=True)
model.load_state_dict({k: v for k, v in sd.items() if "rotary" not in k}, strict=False)
model.eval()
print(f"Sage 1B loaded - {sum(p.numel() for p in model.parameters()):,} params")
def generate_text(prompt, max_length, temperature):
tokens = tok.encode(prompt).ids[:50]
inp = torch.tensor([[2] + tokens], dtype=torch.long)
out = model.generate(inp, max_new=int(max_length), temp=temperature, top_k=40)
return tok.decode(out[0].tolist(), skip_special_tokens=True)
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="Once upon a time"),
gr.Slider(10, 100, 30, step=1, label="Max Length"),
gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Sage 1B",
description="Custom 1.286B parameter language model from scratch.",
examples=[["Once upon a time", 30, 0.8],
["The story begins", 30, 0.8],
["In a world", 30, 0.8]],
)
if __name__ == "__main__":
demo.launch()
|