mochaV2 / app.py
theguywhosucks's picture
Update app.py
69f32b6 verified
import json
import torch
from transformers import GPT2LMHeadModel, GPT2Config
import gradio as gr
# -----------------------------
# Load tokenizer manually
# -----------------------------
with open("vocab.json", "r") as f:
stoi = json.load(f)
itos = {i: s for s, i in stoi.items()}
def encode(text):
return [stoi.get(c, 0) for c in text]
def decode(ids):
return "".join([itos.get(i, "") for i in ids])
# -----------------------------
# Load model manually
# -----------------------------
with open("config.json") as f:
cfg = json.load(f)
config = GPT2Config(
vocab_size=cfg["vocab_size"],
n_positions=cfg["n_positions"],
n_ctx=cfg["n_ctx"],
n_embd=cfg["n_embd"],
n_layer=cfg["n_layer"],
n_head=cfg["n_head"],
activation_function=cfg["activation_function"]
)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu")) # your weights
model.eval()
# -----------------------------
# Generation
# -----------------------------
def complete_sentence(prompt, max_new_tokens=50):
ids = torch.tensor([encode(prompt)])
with torch.no_grad():
outputs = model.generate(ids, max_new_tokens=max_new_tokens, pad_token_id=config.eos_token_id)
return decode(outputs[0].tolist())
# -----------------------------
# Gradio app
# -----------------------------
gr.Interface(
fn=complete_sentence,
inputs=[gr.Textbox(label="Prompt"), gr.Slider(10, 200, value=50, step=10, label="Max tokens")],
outputs=gr.Textbox(label="Completed Text")
).launch()