to0ony's picture
updated name
21430f1 verified
import gc, json, torch, gradio as gr
from huggingface_hub import hf_hub_download
import tiktoken
from mingpt.model import GPT
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "to0ony/final-thesis-plotgen"
state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")}
def load_model(model_name):
if state["model"] is not None and state["model_name"] == model_name:
return state["model"]
cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
gcfg = GPT.get_default_config()
gcfg.model_type = None
gcfg.vocab_size = int(cfg["vocab_size"])
gcfg.block_size = int(cfg["block_size"])
gcfg.n_layer = int(cfg["n_layer"])
gcfg.n_head = int(cfg["n_head"])
gcfg.n_embd = int(cfg["n_embd"])
model = GPT(gcfg)
sd = torch.load(mdl_path, map_location="cpu", weights_only=False)
model.load_state_dict(sd["model_state_dict"], strict=True)
model.to(DEVICE)
model.eval()
state["model"] = model
state["model_name"] = model_name
return model
@torch.inference_mode()
def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50):
"""Generiranje teksta iz prompta"""
model = load_model(model_choice)
enc = state["enc"]
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
y = model.generate(
x,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k) if top_k > 0 else None,
do_sample=True
)
return enc.decode(y[0].tolist())
# Gradio UI
with gr.Blocks(title="🎬 PlotGen") as demo:
gr.Markdown("## 🎬 PlotGen\nUnesi prompt i generiraj radnju filma.")
model_choice = gr.Dropdown(
choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"],
value="cmu-plots-model-enchanced.pt",
label="Model"
)
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
btn = gr.Button("Generate")
output = gr.Textbox(label="Output", lines=15)
btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output)
if __name__ == "__main__":
demo.launch()