Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import GPT2Config | |
| from safetensors.torch import load_file | |
| from model import GPT2LMHeadModel | |
| # ---- LOAD YOUR MODEL ---- | |
| MODEL_REPO = "Hai929/The_GuageLLM_12M" | |
| config = GPT2Config.from_pretrained(MODEL_REPO) | |
| model = GPT2LMHeadModel(config) | |
| state = load_file("model.safetensors") | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| # ---- TOKENIZER (CHAR LEVEL) ---- | |
| def encode(text): | |
| return torch.tensor([[ord(c) % 256 for c in text]], dtype=torch.long) | |
| def decode(tokens): | |
| return "".join(chr(int(t)) for t in tokens) | |
| # ---- GENERATION ---- | |
| def chat(message, history): | |
| ids = encode(message) | |
| for _ in range(32): | |
| logits = model(ids).logits[:, -1, :] | |
| probs = torch.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, 1) | |
| ids = torch.cat([ids, next_token], dim=1) | |
| text = decode(ids[0]) | |
| return text.split(".")[0] + "." | |
| # ---- UI ---- | |
| gr.ChatInterface( | |
| fn=chat, | |
| title="GuageLLM", | |
| description="A small language model trained from scratch." | |
| ).launch() | |