Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import time | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download, HfApi | |
| class TuneableAttentionMHA(nn.Module): | |
| def __init__(self, d: int, h: int, r: int): | |
| super().__init__() | |
| self.h, self.dk, self.r = h, d // h, r | |
| self.q = nn.Linear(d, d, bias=False) | |
| self.k = nn.Linear(d, d, bias=False) | |
| self.v = nn.Linear(d, d, bias=False) | |
| self.U = nn.Parameter(torch.randn(self.dk, r)) | |
| self.proj = nn.Linear(h * self.dk, d, bias=False) | |
| def forward(self, x, mask=None, kv_cache=None, use_cache=False): | |
| B, N, _ = x.shape | |
| q = (self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) | |
| k_new = (self.k(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) | |
| v_new = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2) | |
| if kv_cache is not None: | |
| k_cached, v_cached = kv_cache | |
| k = torch.cat([k_cached, k_new], dim=2) | |
| v = torch.cat([v_cached, v_new], dim=2) | |
| else: | |
| k, v = k_new, v_new | |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) | |
| if mask is not None: | |
| att = att + mask | |
| z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) | |
| if use_cache: | |
| return self.proj(z), (k, v) | |
| return self.proj(z) | |
| class Block(nn.Module): | |
| def __init__(self, d: int, h: int, r: int): | |
| super().__init__() | |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) | |
| self.mha = TuneableAttentionMHA(d, h, r) | |
| self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) | |
| def forward(self, x, mask, kv_cache=None, use_cache=False): | |
| if use_cache: | |
| attn_out, new_kv = self.mha(self.ln1(x), mask, kv_cache, use_cache=True) | |
| x = x + attn_out | |
| return x + self.ff(self.ln2(x)), new_kv | |
| else: | |
| x = x + self.mha(self.ln1(x), mask) | |
| return x + self.ff(self.ln2(x)) | |
| class Encoder(nn.Module): | |
| def __init__(self, vocab, d, layers, heads, rank): | |
| super().__init__() | |
| self.emb = nn.Embedding(vocab, d) | |
| self.blocks = nn.ModuleList([Block(d, heads, rank) for _ in range(layers)]) | |
| self.ln = nn.LayerNorm(d) | |
| def forward(self, ids, mask, kv_caches=None, use_cache=False): | |
| x = self.emb(ids) | |
| if use_cache: | |
| new_kvs = [] | |
| for i, blk in enumerate(self.blocks): | |
| kv = kv_caches[i] if kv_caches else None | |
| x, new_kv = blk(x, mask, kv, use_cache=True) | |
| new_kvs.append(new_kv) | |
| return self.ln(x), new_kvs | |
| for blk in self.blocks: | |
| x = blk(x, mask) | |
| return self.ln(x) | |
| class ARHead(nn.Module): | |
| def __init__(self, d, vocab): | |
| super().__init__() | |
| self.proj = nn.Linear(d, vocab) | |
| def forward(self, h): | |
| return self.proj(h) | |
| MODEL_REPO = "OpenTransformer/AGILLM-3-large" | |
| def get_latest_checkpoint(): | |
| api = HfApi() | |
| files = api.list_repo_files(MODEL_REPO, revision="main") | |
| ckpts = [f for f in files if f.endswith(".pt") and "step" in f] | |
| if not ckpts: | |
| raise ValueError("No checkpoints found in repo") | |
| ckpts.sort(key=lambda x: int(x.split("step")[1].split(".")[0]), reverse=True) | |
| return ckpts[0] | |
| def get_current_step_from_name(ckpt_name): | |
| return int(ckpt_name.split("step")[1].split(".")[0]) | |
| # Global model state | |
| model_state = { | |
| "core": None, | |
| "ar_head": None, | |
| "step": 0, | |
| "ckpt_name": None, | |
| "compiled": False, | |
| "tokenizer": None, | |
| "vocab": 0, | |
| } | |
| def load_model(ckpt_name=None, force_download=False): | |
| global model_state | |
| if ckpt_name is None: | |
| ckpt_name = get_latest_checkpoint() | |
| print(f"Loading checkpoint: {ckpt_name}") | |
| ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name, force_download=force_download) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| cfg = ckpt["cfg"] | |
| D, LAYERS, HEADS, RANK = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] | |
| step = ckpt.get("step", 0) | |
| core = Encoder(model_state["vocab"], D, LAYERS, HEADS, RANK) | |
| ar_head = ARHead(D, model_state["vocab"]) | |
| core.load_state_dict(ckpt["core"]) | |
| ar_head.load_state_dict(ckpt["ar"]) | |
| core.eval() | |
| ar_head.eval() | |
| # Try to compile | |
| compiled = False | |
| try: | |
| core = torch.compile(core, mode="reduce-overhead") | |
| ar_head = torch.compile(ar_head, mode="reduce-overhead") | |
| compiled = True | |
| print("Compilation successful!") | |
| except Exception as e: | |
| print(f"Compilation failed: {e}") | |
| # Warmup | |
| try: | |
| with torch.no_grad(): | |
| dummy = torch.randint(0, 1000, (1, 10)) | |
| mask = torch.zeros((1, 1, 10, 10)) | |
| _ = ar_head(core(dummy, mask)) | |
| except: | |
| pass | |
| model_state["core"] = core | |
| model_state["ar_head"] = ar_head | |
| model_state["step"] = step | |
| model_state["ckpt_name"] = ckpt_name | |
| model_state["compiled"] = compiled | |
| params = sum(p.numel() for p in core.parameters()) + sum(p.numel() for p in ar_head.parameters()) | |
| print(f"Loaded: {params:,} params @ step {step:,}") | |
| return step, ckpt_name | |
| def check_for_updates(): | |
| """Check if a newer checkpoint exists and load it if so.""" | |
| latest = get_latest_checkpoint() | |
| latest_step = get_current_step_from_name(latest) | |
| current_step = model_state["step"] | |
| if latest_step > current_step: | |
| # Force fresh download, bypass cache | |
| new_step, new_name = load_model(latest, force_download=True) | |
| return f"✅ Updated! Step {current_step:,} → {new_step:,}" | |
| else: | |
| return f"Already on latest (step {current_step:,})" | |
| # Initial load | |
| print("Loading tokenizer...") | |
| model_state["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_REPO, subfolder="tokenizer", trust_remote_code=True) | |
| model_state["vocab"] = len(model_state["tokenizer"]) | |
| print("Finding latest checkpoint...") | |
| load_model() | |
| def generate(prompt, max_tokens=50, temperature=0.8, top_p=0.9): | |
| tokenizer = model_state["tokenizer"] | |
| core = model_state["core"] | |
| ar_head = model_state["ar_head"] | |
| ids = tokenizer.encode(prompt, return_tensors="pt") | |
| start_time = time.time() | |
| tokens_generated = 0 | |
| prompt_len = ids.size(1) | |
| mask = torch.triu(torch.full((1, 1, prompt_len, prompt_len), float("-inf")), 1) | |
| h, kv_caches = core(ids, mask, use_cache=True) | |
| for _ in range(max_tokens): | |
| if ids.size(1) >= 2048: | |
| break | |
| logits = ar_head(h)[:, -1, :] / max(temperature, 0.01) | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) | |
| probs = F.softmax(sorted_logits, dim=-1) | |
| cumsum = torch.cumsum(probs, dim=-1) | |
| sorted_logits[cumsum - probs > top_p] = float('-inf') | |
| probs = F.softmax(sorted_logits, dim=-1) | |
| next_tok = sorted_idx.gather(-1, torch.multinomial(probs, 1)) | |
| ids = torch.cat([ids, next_tok], dim=1) | |
| tokens_generated += 1 | |
| if next_tok.item() == tokenizer.eos_token_id: | |
| break | |
| mask = torch.zeros((1, 1, 1, ids.size(1))) | |
| h, kv_caches = core(next_tok, mask, kv_caches, use_cache=True) | |
| elapsed = time.time() - start_time | |
| tok_s = tokens_generated / elapsed if elapsed > 0 else 0 | |
| text = tokenizer.decode(ids[0], skip_special_tokens=True) | |
| text = " ".join(text.split()) | |
| return text, tok_s, tokens_generated, elapsed | |
| def chat(message, history, max_tokens=50, temperature=0.8): | |
| if max_tokens is None: max_tokens = 50 | |
| if temperature is None: temperature = 0.8 | |
| prompt = "".join(f"User: {h[0]}\nAssistant: {h[1]}\n" for h in history) | |
| prompt += f"User: {message}\nAssistant:" | |
| response, tok_s, n_tok, elapsed = generate(prompt, int(max_tokens), float(temperature)) | |
| if "Assistant:" in response: | |
| response = response.split("Assistant:")[-1].strip() | |
| compiled_tag = " ⚡" if model_state["compiled"] else "" | |
| step = model_state["step"] | |
| return f"{response}\n\n---\n**{tok_s:.1f} tok/s**{compiled_tag} | {n_tok} tokens | {elapsed:.1f}s | step {step:,}" | |
| def get_status(): | |
| step = model_state["step"] | |
| compiled = "⚡ compiled" if model_state["compiled"] else "eager" | |
| return f"Step {step:,} | {compiled}" | |
| with gr.Blocks(title="AGILLM-3 Chat") as demo: | |
| gr.Markdown(f"# AGILLM-3 Large (698M)") | |
| gr.Markdown("KV-cached + torch.compile() | 0.5% pretrained | expect beautiful gibberish") | |
| with gr.Row(): | |
| status = gr.Textbox(value=get_status(), label="Model Status", interactive=False) | |
| refresh_btn = gr.Button("🔄 Check for Updates", variant="secondary") | |
| refresh_btn.click(fn=check_for_updates, outputs=status) | |
| chatbot = gr.ChatInterface( | |
| chat, | |
| additional_inputs=[ | |
| gr.Slider(10, 200, value=50, step=10, label="Max Tokens"), | |
| gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature"), | |
| ], | |
| examples=[["Hello!"], ["What is 2+2?"]], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |