Spaces:
Sleeping
Sleeping
File size: 2,627 Bytes
aff8dfb 8fccddb aff8dfb 8fccddb 79f1c31 8fccddb 79f1c31 8fccddb 79f1c31 8fccddb 08585a9 8fccddb b4fee78 520c6fd 8fccddb 79f1c31 8fccddb aff8dfb 8fccddb 79f1c31 656652b 79f1c31 8fccddb 656652b 8fccddb 656652b 747470f 8a6a35c 656652b 8fccddb 79f1c31 8fccddb 21430f1 8fccddb 79f1c31 56fc36d 79f1c31 8fccddb 56fc36d 8fccddb 656652b 8fccddb 79f1c31 8fccddb aff8dfb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | import gc, json, torch, gradio as gr
from huggingface_hub import hf_hub_download
import tiktoken
from mingpt.model import GPT
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "to0ony/final-thesis-plotgen"
state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")}
def load_model(model_name):
if state["model"] is not None and state["model_name"] == model_name:
return state["model"]
cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
gcfg = GPT.get_default_config()
gcfg.model_type = None
gcfg.vocab_size = int(cfg["vocab_size"])
gcfg.block_size = int(cfg["block_size"])
gcfg.n_layer = int(cfg["n_layer"])
gcfg.n_head = int(cfg["n_head"])
gcfg.n_embd = int(cfg["n_embd"])
model = GPT(gcfg)
sd = torch.load(mdl_path, map_location="cpu", weights_only=False)
model.load_state_dict(sd["model_state_dict"], strict=True)
model.to(DEVICE)
model.eval()
state["model"] = model
state["model_name"] = model_name
return model
@torch.inference_mode()
def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50):
"""Generiranje teksta iz prompta"""
model = load_model(model_choice)
enc = state["enc"]
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
y = model.generate(
x,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k) if top_k > 0 else None,
do_sample=True
)
return enc.decode(y[0].tolist())
# Gradio UI
with gr.Blocks(title="🎬 PlotGen") as demo:
gr.Markdown("## 🎬 PlotGen\nUnesi prompt i generiraj radnju filma.")
model_choice = gr.Dropdown(
choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"],
value="cmu-plots-model-enchanced.pt",
label="Model"
)
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
btn = gr.Button("Generate")
output = gr.Textbox(label="Output", lines=15)
btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output)
if __name__ == "__main__":
demo.launch()
|