sidharthg commited on
Commit
471660d
·
verified ·
1 Parent(s): 723f068

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -11
app.py CHANGED
@@ -98,23 +98,17 @@ class GPT(nn.Module):
98
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
99
  return logits, loss
100
 
101
- # Load model
102
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
103
- print(f"Loading model on {device}...")
104
  config = GPTConfig()
105
  model = GPT(config)
106
  model_path = os.path.join("models", "best_model.pt")
107
- #checkpoint = torch.load(model_path, map_location=device, weights_only=False)
108
- #model.load_state_dict(checkpoint['model_state_dict'])
109
-
110
- model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
111
  model.to(device)
112
  model.eval()
113
-
114
  enc = tiktoken.get_encoding('gpt2')
115
 
116
- print(f"✅ Model loaded!")
117
-
118
 
119
  def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature: float = 1.0):
120
  tokens = enc.encode(prompt)
@@ -134,15 +128,48 @@ def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature
134
  out_tokens = x[0].tolist()
135
  return enc.decode(out_tokens)
136
 
 
 
 
 
 
 
 
 
 
137
  with gr.Blocks() as demo:
138
- gr.Markdown("# GPT2-Space")
 
 
 
 
 
139
  with gr.Row():
140
  inp = gr.Textbox(lines=3, placeholder="Enter prompt here...", label="Prompt")
141
- out = gr.Textbox(lines=10, label="Generated")
 
142
  with gr.Row():
143
  max_tokens = gr.Slider(1, 200, value=30, step=1, label="Max new tokens")
144
  topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
145
  temp = gr.Slider(0.01, 2.0, value=1.0, step=0.01, label="Temperature")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  btn = gr.Button("Generate")
147
  btn.click(fn=generate, inputs=[inp, max_tokens, topk, temp], outputs=out)
148
 
 
98
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
99
  return logits, loss
100
 
101
+ # Load model and tokenizer
102
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
103
  config = GPTConfig()
104
  model = GPT(config)
105
  model_path = os.path.join("models", "best_model.pt")
106
+ model.load_state_dict(torch.load(model_path, map_location=device))
 
 
 
107
  model.to(device)
108
  model.eval()
 
109
  enc = tiktoken.get_encoding('gpt2')
110
 
111
+ print(f"✅ Model loaded on {device}!")
 
112
 
113
  def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature: float = 1.0):
114
  tokens = enc.encode(prompt)
 
128
  out_tokens = x[0].tolist()
129
  return enc.decode(out_tokens)
130
 
131
+ # Example prompts for dropdown
132
+ example_prompts = [
133
+ "To be, or not to be, that is the question:",
134
+ "O Romeo, Romeo! wherefore art thou Romeo?",
135
+ "Once more unto the breach, dear friends, once more;",
136
+ "All the world's a stage,",
137
+ "The lady doth protest too much, methinks."
138
+ ]
139
+
140
  with gr.Blocks() as demo:
141
+ gr.Markdown("# GPT-2 (124M) Shakespeare Text Generator")
142
+ gr.Markdown(
143
+ "GPT-2 (124M) model trained from scratch on Shakespeare's works. "
144
+ "Start with a prompt and generate Shakespearean-style text!"
145
+ )
146
+
147
  with gr.Row():
148
  inp = gr.Textbox(lines=3, placeholder="Enter prompt here...", label="Prompt")
149
+ out = gr.Textbox(lines=10, label="Generated Text")
150
+
151
  with gr.Row():
152
  max_tokens = gr.Slider(1, 200, value=30, step=1, label="Max new tokens")
153
  topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
154
  temp = gr.Slider(0.01, 2.0, value=1.0, step=0.01, label="Temperature")
155
+
156
+ with gr.Row():
157
+ example_dropdown = gr.Dropdown(
158
+ choices=example_prompts,
159
+ label="Choose example prompt",
160
+ interactive=True
161
+ )
162
+ clear_btn = gr.Button("Clear output")
163
+
164
+ def use_example(prompt):
165
+ return prompt
166
+
167
+ def clear_output():
168
+ return ""
169
+
170
+ example_dropdown.change(fn=use_example, inputs=example_dropdown, outputs=inp)
171
+ clear_btn.click(fn=clear_output, inputs=[], outputs=out)
172
+
173
  btn = gr.Button("Generate")
174
  btn.click(fn=generate, inputs=[inp, max_tokens, topk, temp], outputs=out)
175