GuageLLM-Web / app.py
Hai929's picture
Update app.py
3fc4867 verified
import torch
import gradio as gr
from transformers import GPT2Config
from safetensors.torch import load_file
from model import GPT2LMHeadModel
# ---- LOAD YOUR MODEL ----
MODEL_REPO = "Hai929/The_GuageLLM_12M"
config = GPT2Config.from_pretrained(MODEL_REPO)
model = GPT2LMHeadModel(config)
state = load_file("model.safetensors")
model.load_state_dict(state, strict=False)
model.eval()
# ---- TOKENIZER (CHAR LEVEL) ----
def encode(text):
return torch.tensor([[ord(c) % 256 for c in text]], dtype=torch.long)
def decode(tokens):
return "".join(chr(int(t)) for t in tokens)
# ---- GENERATION ----
@torch.no_grad()
def chat(message, history):
ids = encode(message)
for _ in range(32):
logits = model(ids).logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
ids = torch.cat([ids, next_token], dim=1)
text = decode(ids[0])
return text.split(".")[0] + "."
# ---- UI ----
gr.ChatInterface(
fn=chat,
title="GuageLLM",
description="A small language model trained from scratch."
).launch()