Spaces:
Sleeping
Sleeping
| import gc, json, torch, gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import tiktoken | |
| from mingpt.model import GPT | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| REPO_ID = "to0ony/final-thesis-plotgen" | |
| state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")} | |
| def load_model(model_name): | |
| if state["model"] is not None and state["model_name"] == model_name: | |
| return state["model"] | |
| cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") | |
| mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name) | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| gcfg = GPT.get_default_config() | |
| gcfg.model_type = None | |
| gcfg.vocab_size = int(cfg["vocab_size"]) | |
| gcfg.block_size = int(cfg["block_size"]) | |
| gcfg.n_layer = int(cfg["n_layer"]) | |
| gcfg.n_head = int(cfg["n_head"]) | |
| gcfg.n_embd = int(cfg["n_embd"]) | |
| model = GPT(gcfg) | |
| sd = torch.load(mdl_path, map_location="cpu", weights_only=False) | |
| model.load_state_dict(sd["model_state_dict"], strict=True) | |
| model.to(DEVICE) | |
| model.eval() | |
| state["model"] = model | |
| state["model_name"] = model_name | |
| return model | |
| def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50): | |
| """Generiranje teksta iz prompta""" | |
| model = load_model(model_choice) | |
| enc = state["enc"] | |
| x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) | |
| y = model.generate( | |
| x, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k) if top_k > 0 else None, | |
| do_sample=True | |
| ) | |
| return enc.decode(y[0].tolist()) | |
| # Gradio UI | |
| with gr.Blocks(title="🎬 PlotGen") as demo: | |
| gr.Markdown("## 🎬 PlotGen\nUnesi prompt i generiraj radnju filma.") | |
| model_choice = gr.Dropdown( | |
| choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"], | |
| value="cmu-plots-model-enchanced.pt", | |
| label="Model" | |
| ) | |
| prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...") | |
| max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") | |
| top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)") | |
| btn = gr.Button("Generate") | |
| output = gr.Textbox(label="Output", lines=15) | |
| btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output) | |
| if __name__ == "__main__": | |
| demo.launch() | |