import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import math import time from transformers import AutoTokenizer from huggingface_hub import hf_hub_download, HfApi class TuneableAttentionMHA(nn.Module): def __init__(self, d: int, h: int, r: int): super().__init__() self.h, self.dk, self.r = h, d // h, r self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, d, bias=False) self.v = nn.Linear(d, d, bias=False) self.U = nn.Parameter(torch.randn(self.dk, r)) self.proj = nn.Linear(h * self.dk, d, bias=False) def forward(self, x, mask=None, kv_cache=None, use_cache=False): B, N, _ = x.shape q = (self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) k_new = (self.k(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) v_new = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2) if kv_cache is not None: k_cached, v_cached = kv_cache k = torch.cat([k_cached, k_new], dim=2) v = torch.cat([v_cached, v_new], dim=2) else: k, v = k_new, v_new att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if mask is not None: att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) if use_cache: return self.proj(z), (k, v) return self.proj(z) class Block(nn.Module): def __init__(self, d: int, h: int, r: int): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) self.mha = TuneableAttentionMHA(d, h, r) self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) def forward(self, x, mask, kv_cache=None, use_cache=False): if use_cache: attn_out, new_kv = self.mha(self.ln1(x), mask, kv_cache, use_cache=True) x = x + attn_out return x + self.ff(self.ln2(x)), new_kv else: x = x + self.mha(self.ln1(x), mask) return x + self.ff(self.ln2(x)) class Encoder(nn.Module): def __init__(self, vocab, d, layers, heads, rank): super().__init__() self.emb = nn.Embedding(vocab, d) self.blocks = nn.ModuleList([Block(d, heads, rank) for _ in range(layers)]) self.ln = nn.LayerNorm(d) def forward(self, ids, mask, kv_caches=None, use_cache=False): x = self.emb(ids) if use_cache: new_kvs = [] for i, blk in enumerate(self.blocks): kv = kv_caches[i] if kv_caches else None x, new_kv = blk(x, mask, kv, use_cache=True) new_kvs.append(new_kv) return self.ln(x), new_kvs for blk in self.blocks: x = blk(x, mask) return self.ln(x) class ARHead(nn.Module): def __init__(self, d, vocab): super().__init__() self.proj = nn.Linear(d, vocab) def forward(self, h): return self.proj(h) MODEL_REPO = "OpenTransformer/AGILLM-3-large" def get_latest_checkpoint(): api = HfApi() files = api.list_repo_files(MODEL_REPO, revision="main") ckpts = [f for f in files if f.endswith(".pt") and "step" in f] if not ckpts: raise ValueError("No checkpoints found in repo") ckpts.sort(key=lambda x: int(x.split("step")[1].split(".")[0]), reverse=True) return ckpts[0] def get_current_step_from_name(ckpt_name): return int(ckpt_name.split("step")[1].split(".")[0]) # Global model state model_state = { "core": None, "ar_head": None, "step": 0, "ckpt_name": None, "compiled": False, "tokenizer": None, "vocab": 0, } def load_model(ckpt_name=None, force_download=False): global model_state if ckpt_name is None: ckpt_name = get_latest_checkpoint() print(f"Loading checkpoint: {ckpt_name}") ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name, force_download=force_download) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) cfg = ckpt["cfg"] D, LAYERS, HEADS, RANK = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] step = ckpt.get("step", 0) core = Encoder(model_state["vocab"], D, LAYERS, HEADS, RANK) ar_head = ARHead(D, model_state["vocab"]) core.load_state_dict(ckpt["core"]) ar_head.load_state_dict(ckpt["ar"]) core.eval() ar_head.eval() # Try to compile compiled = False try: core = torch.compile(core, mode="reduce-overhead") ar_head = torch.compile(ar_head, mode="reduce-overhead") compiled = True print("Compilation successful!") except Exception as e: print(f"Compilation failed: {e}") # Warmup try: with torch.no_grad(): dummy = torch.randint(0, 1000, (1, 10)) mask = torch.zeros((1, 1, 10, 10)) _ = ar_head(core(dummy, mask)) except: pass model_state["core"] = core model_state["ar_head"] = ar_head model_state["step"] = step model_state["ckpt_name"] = ckpt_name model_state["compiled"] = compiled params = sum(p.numel() for p in core.parameters()) + sum(p.numel() for p in ar_head.parameters()) print(f"Loaded: {params:,} params @ step {step:,}") return step, ckpt_name def check_for_updates(): """Check if a newer checkpoint exists and load it if so.""" latest = get_latest_checkpoint() latest_step = get_current_step_from_name(latest) current_step = model_state["step"] if latest_step > current_step: # Force fresh download, bypass cache new_step, new_name = load_model(latest, force_download=True) return f"✅ Updated! Step {current_step:,} → {new_step:,}" else: return f"Already on latest (step {current_step:,})" # Initial load print("Loading tokenizer...") model_state["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_REPO, subfolder="tokenizer", trust_remote_code=True) model_state["vocab"] = len(model_state["tokenizer"]) print("Finding latest checkpoint...") load_model() @torch.no_grad() def generate(prompt, max_tokens=50, temperature=0.8, top_p=0.9): tokenizer = model_state["tokenizer"] core = model_state["core"] ar_head = model_state["ar_head"] ids = tokenizer.encode(prompt, return_tensors="pt") start_time = time.time() tokens_generated = 0 prompt_len = ids.size(1) mask = torch.triu(torch.full((1, 1, prompt_len, prompt_len), float("-inf")), 1) h, kv_caches = core(ids, mask, use_cache=True) for _ in range(max_tokens): if ids.size(1) >= 2048: break logits = ar_head(h)[:, -1, :] / max(temperature, 0.01) sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs = F.softmax(sorted_logits, dim=-1) cumsum = torch.cumsum(probs, dim=-1) sorted_logits[cumsum - probs > top_p] = float('-inf') probs = F.softmax(sorted_logits, dim=-1) next_tok = sorted_idx.gather(-1, torch.multinomial(probs, 1)) ids = torch.cat([ids, next_tok], dim=1) tokens_generated += 1 if next_tok.item() == tokenizer.eos_token_id: break mask = torch.zeros((1, 1, 1, ids.size(1))) h, kv_caches = core(next_tok, mask, kv_caches, use_cache=True) elapsed = time.time() - start_time tok_s = tokens_generated / elapsed if elapsed > 0 else 0 text = tokenizer.decode(ids[0], skip_special_tokens=True) text = " ".join(text.split()) return text, tok_s, tokens_generated, elapsed def chat(message, history, max_tokens=50, temperature=0.8): if max_tokens is None: max_tokens = 50 if temperature is None: temperature = 0.8 prompt = "".join(f"User: {h[0]}\nAssistant: {h[1]}\n" for h in history) prompt += f"User: {message}\nAssistant:" response, tok_s, n_tok, elapsed = generate(prompt, int(max_tokens), float(temperature)) if "Assistant:" in response: response = response.split("Assistant:")[-1].strip() compiled_tag = " ⚡" if model_state["compiled"] else "" step = model_state["step"] return f"{response}\n\n---\n**{tok_s:.1f} tok/s**{compiled_tag} | {n_tok} tokens | {elapsed:.1f}s | step {step:,}" def get_status(): step = model_state["step"] compiled = "⚡ compiled" if model_state["compiled"] else "eager" return f"Step {step:,} | {compiled}" with gr.Blocks(title="AGILLM-3 Chat") as demo: gr.Markdown(f"# AGILLM-3 Large (698M)") gr.Markdown("KV-cached + torch.compile() | 0.5% pretrained | expect beautiful gibberish") with gr.Row(): status = gr.Textbox(value=get_status(), label="Model Status", interactive=False) refresh_btn = gr.Button("🔄 Check for Updates", variant="secondary") refresh_btn.click(fn=check_for_updates, outputs=status) chatbot = gr.ChatInterface( chat, additional_inputs=[ gr.Slider(10, 200, value=50, step=10, label="Max Tokens"), gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature"), ], examples=[["Hello!"], ["What is 2+2?"]], cache_examples=False, ) if __name__ == "__main__": demo.launch()