File size: 2,627 Bytes
aff8dfb
8fccddb
 
 
aff8dfb
8fccddb
 
 
 
79f1c31
8fccddb
79f1c31
 
8fccddb
 
 
79f1c31
8fccddb
 
 
 
 
08585a9
 
 
 
 
 
8fccddb
 
b4fee78
520c6fd
8fccddb
 
 
 
79f1c31
8fccddb
 
aff8dfb
8fccddb
79f1c31
656652b
79f1c31
8fccddb
656652b
8fccddb
656652b
747470f
 
 
 
8a6a35c
 
656652b
 
 
8fccddb
79f1c31
8fccddb
21430f1
 
8fccddb
79f1c31
 
56fc36d
79f1c31
 
8fccddb
 
56fc36d
8fccddb
 
656652b
8fccddb
79f1c31
8fccddb
 
aff8dfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gc, json, torch, gradio as gr
from huggingface_hub import hf_hub_download
import tiktoken

from mingpt.model import GPT 

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "to0ony/final-thesis-plotgen"

state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")}

def load_model(model_name):
    if state["model"] is not None and state["model_name"] == model_name:
        return state["model"]

    cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
    mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name)

    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = json.load(f)

    gcfg = GPT.get_default_config()
    gcfg.model_type = None
    gcfg.vocab_size = int(cfg["vocab_size"])
    gcfg.block_size = int(cfg["block_size"])
    gcfg.n_layer    = int(cfg["n_layer"])
    gcfg.n_head     = int(cfg["n_head"])
    gcfg.n_embd     = int(cfg["n_embd"])

    model = GPT(gcfg)
    sd = torch.load(mdl_path, map_location="cpu", weights_only=False)
    model.load_state_dict(sd["model_state_dict"], strict=True)
    model.to(DEVICE)
    model.eval()

    state["model"] = model
    state["model_name"] = model_name
    return model


@torch.inference_mode()
def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50):
    """Generiranje teksta iz prompta"""
    model = load_model(model_choice)
    enc = state["enc"]

    x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)

    y = model.generate(
        x,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_k=int(top_k) if top_k > 0 else None,
        do_sample=True
    )

    return enc.decode(y[0].tolist())


# Gradio UI
with gr.Blocks(title="🎬 PlotGen") as demo:
    gr.Markdown("## 🎬 PlotGen\nUnesi prompt i generiraj radnju filma.")

    model_choice = gr.Dropdown(
        choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"],
        value="cmu-plots-model-enchanced.pt",
        label="Model"
    )
    prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
    max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
    temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
    top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
    btn = gr.Button("Generate")
    output = gr.Textbox(label="Output", lines=15)

    btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output)

if __name__ == "__main__":
    demo.launch()