to0ony commited on
Commit
79f1c31
·
1 Parent(s): 8a6a35c

implemented model choice dropdown

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -7,21 +7,20 @@ from mingpt.model import GPT
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
  REPO_ID = "to0ony/final-thesis-plotgen"
9
 
10
- state = {"model": None, "enc": tiktoken.get_encoding("gpt2")}
11
 
12
- def load_model():
13
- if state["model"] is not None:
14
  return state["model"]
15
 
16
  cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
17
- mdl_path = hf_hub_download(repo_id=REPO_ID, filename="cmu-plots-model.pt")
18
 
19
  with open(cfg_path, "r", encoding="utf-8") as f:
20
  cfg = json.load(f)
21
 
22
  gcfg = GPT.get_default_config()
23
  gcfg.model_type = None
24
-
25
  gcfg.vocab_size = int(cfg["vocab_size"])
26
  gcfg.block_size = int(cfg["block_size"])
27
  gcfg.n_layer = int(cfg["n_layer"])
@@ -35,13 +34,14 @@ def load_model():
35
  model.eval()
36
 
37
  state["model"] = model
 
38
  return model
39
 
40
 
41
  @torch.inference_mode()
42
- def generate(prompt, max_new_tokens=200, temperature=0.7, top_k=50):
43
  """Generiranje teksta iz prompta"""
44
- model = load_model()
45
  enc = state["enc"]
46
 
47
  x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
@@ -56,10 +56,16 @@ def generate(prompt, max_new_tokens=200, temperature=0.7, top_k=50):
56
 
57
  return enc.decode(y[0].tolist())
58
 
 
59
  # Gradio UI
60
  with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
61
  gr.Markdown("## 🎬 Film Plot Generator\nUnesi prompt i generiraj radnju filma.")
62
 
 
 
 
 
 
63
  prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
64
  max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
65
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
@@ -67,7 +73,7 @@ with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
67
  btn = gr.Button("Generate")
68
  output = gr.Textbox(label="Output", lines=15)
69
 
70
- btn.click(generate, [prompt, max_new_tokens, temperature, top_k], output)
71
 
72
  if __name__ == "__main__":
73
  demo.launch()
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
  REPO_ID = "to0ony/final-thesis-plotgen"
9
 
10
+ state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")}
11
 
12
+ def load_model(model_name):
13
+ if state["model"] is not None and state["model_name"] == model_name:
14
  return state["model"]
15
 
16
  cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
17
+ mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name)
18
 
19
  with open(cfg_path, "r", encoding="utf-8") as f:
20
  cfg = json.load(f)
21
 
22
  gcfg = GPT.get_default_config()
23
  gcfg.model_type = None
 
24
  gcfg.vocab_size = int(cfg["vocab_size"])
25
  gcfg.block_size = int(cfg["block_size"])
26
  gcfg.n_layer = int(cfg["n_layer"])
 
34
  model.eval()
35
 
36
  state["model"] = model
37
+ state["model_name"] = model_name
38
  return model
39
 
40
 
41
  @torch.inference_mode()
42
+ def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50):
43
  """Generiranje teksta iz prompta"""
44
+ model = load_model(model_choice)
45
  enc = state["enc"]
46
 
47
  x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
 
56
 
57
  return enc.decode(y[0].tolist())
58
 
59
+
60
  # Gradio UI
61
  with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
62
  gr.Markdown("## 🎬 Film Plot Generator\nUnesi prompt i generiraj radnju filma.")
63
 
64
+ model_choice = gr.Dropdown(
65
+ choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"],
66
+ value="cmu-plots-model.pt",
67
+ label="Model"
68
+ )
69
  prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
70
  max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
71
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
 
73
  btn = gr.Button("Generate")
74
  output = gr.Textbox(label="Output", lines=15)
75
 
76
+ btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output)
77
 
78
  if __name__ == "__main__":
79
  demo.launch()