import torch import gradio as gr from transformers import GPT2Config from safetensors.torch import load_file from model import GPT2LMHeadModel # ---- LOAD YOUR MODEL ---- MODEL_REPO = "Hai929/The_GuageLLM_12M" config = GPT2Config.from_pretrained(MODEL_REPO) model = GPT2LMHeadModel(config) state = load_file("model.safetensors") model.load_state_dict(state, strict=False) model.eval() # ---- TOKENIZER (CHAR LEVEL) ---- def encode(text): return torch.tensor([[ord(c) % 256 for c in text]], dtype=torch.long) def decode(tokens): return "".join(chr(int(t)) for t in tokens) # ---- GENERATION ---- @torch.no_grad() def chat(message, history): ids = encode(message) for _ in range(32): logits = model(ids).logits[:, -1, :] probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) ids = torch.cat([ids, next_token], dim=1) text = decode(ids[0]) return text.split(".")[0] + "." # ---- UI ---- gr.ChatInterface( fn=chat, title="GuageLLM", description="A small language model trained from scratch." ).launch()