manshu-init / app.py
Yuchen Li
Initial Space app
f132163
import sys
from pathlib import Path
import gradio as gr
import torch
import yaml
from huggingface_hub import snapshot_download
MODEL_REPO_ID = "LiManshu/nextShakespeare"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_MODEL = None
_TOKENIZER = None
_MODEL_CFG = None
def _load_yaml(path: Path):
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def _ensure_loaded():
global _MODEL, _TOKENIZER, _MODEL_CFG
if _MODEL is not None and _TOKENIZER is not None:
return
local_dir = snapshot_download(
repo_id=MODEL_REPO_ID,
repo_type="model",
allow_patterns=[
"llm/**",
"configs/model.yaml",
"data/vocab/char_vocab.json",
"checkpoints/best_model.pt",
],
)
repo_root = Path(local_dir)
sys.path.insert(0, str(repo_root))
from llm.data.tokenizer import CharTokenizer
from llm.inference.generate import greedy_decode, sample_decode
from llm.model.transformer import Transformer
from llm.utils.checkpoint import load_model_only
model_cfg = _load_yaml(repo_root / "configs" / "model.yaml")
tokenizer = CharTokenizer(vocab_path=str(repo_root / "data" / "vocab" / "char_vocab.json"))
model_cfg["vocab_size"] = tokenizer.vocab_size
model = Transformer(model_cfg)
load_model_only(model, str(repo_root / "checkpoints" / "best_model.pt"))
model.to(DEVICE)
model.eval()
_MODEL = model
_TOKENIZER = tokenizer
_MODEL_CFG = {
"hidden_size": model_cfg.get("hidden_size"),
"num_hidden_layers": model_cfg.get("num_hidden_layers"),
"vocab_size": tokenizer.vocab_size,
"device": str(DEVICE),
}
_MODEL.greedy_decode = greedy_decode
_MODEL.sample_decode = sample_decode
def generate(prompt, max_length, temperature, top_k, top_p):
_ensure_loaded()
text = prompt or ""
input_ids = _TOKENIZER.encode(text)
if not input_ids:
input_ids = [0]
input_ids = torch.tensor([input_ids], dtype=torch.long)
with torch.no_grad():
if float(temperature) == 0.0:
generated_ids = _MODEL.greedy_decode(
_MODEL,
input_ids,
max_length=int(max_length),
device=DEVICE,
)
else:
generated_ids = _MODEL.sample_decode(
_MODEL,
input_ids,
max_length=int(max_length),
temperature=float(temperature),
top_k=int(top_k) if int(top_k) > 0 else None,
top_p=float(top_p) if float(top_p) > 0 else None,
device=DEVICE,
)
output_text = _TOKENIZER.decode(generated_ids[0])
info = (
f"repo={MODEL_REPO_ID} | device={_MODEL_CFG['device']} | "
f"layers={_MODEL_CFG['num_hidden_layers']} | hidden={_MODEL_CFG['hidden_size']}"
)
return output_text, info
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", value="First Citizen:\n", lines=8),
gr.Slider(minimum=1, maximum=400, value=200, step=1, label="max_length"),
gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="temperature"),
gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="top_p"),
],
outputs=[
gr.Textbox(label="Generated Text", lines=14),
gr.Textbox(label="Model Info"),
],
title="manshu-init",
description="Online demo for LiManshu/nextShakespeare",
)
if __name__ == "__main__":
demo.launch()