Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -98,23 +98,17 @@ class GPT(nn.Module):
|
|
| 98 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 99 |
return logits, loss
|
| 100 |
|
| 101 |
-
# Load model
|
| 102 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 103 |
-
print(f"Loading model on {device}...")
|
| 104 |
config = GPTConfig()
|
| 105 |
model = GPT(config)
|
| 106 |
model_path = os.path.join("models", "best_model.pt")
|
| 107 |
-
|
| 108 |
-
#model.load_state_dict(checkpoint['model_state_dict'])
|
| 109 |
-
|
| 110 |
-
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
|
| 111 |
model.to(device)
|
| 112 |
model.eval()
|
| 113 |
-
|
| 114 |
enc = tiktoken.get_encoding('gpt2')
|
| 115 |
|
| 116 |
-
print(f"✅ Model loaded!")
|
| 117 |
-
|
| 118 |
|
| 119 |
def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature: float = 1.0):
|
| 120 |
tokens = enc.encode(prompt)
|
|
@@ -134,15 +128,48 @@ def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature
|
|
| 134 |
out_tokens = x[0].tolist()
|
| 135 |
return enc.decode(out_tokens)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
with gr.Blocks() as demo:
|
| 138 |
-
gr.Markdown("#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
with gr.Row():
|
| 140 |
inp = gr.Textbox(lines=3, placeholder="Enter prompt here...", label="Prompt")
|
| 141 |
-
out = gr.Textbox(lines=10, label="Generated")
|
|
|
|
| 142 |
with gr.Row():
|
| 143 |
max_tokens = gr.Slider(1, 200, value=30, step=1, label="Max new tokens")
|
| 144 |
topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
|
| 145 |
temp = gr.Slider(0.01, 2.0, value=1.0, step=0.01, label="Temperature")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
btn = gr.Button("Generate")
|
| 147 |
btn.click(fn=generate, inputs=[inp, max_tokens, topk, temp], outputs=out)
|
| 148 |
|
|
|
|
| 98 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 99 |
return logits, loss
|
| 100 |
|
| 101 |
+
# Load model and tokenizer
|
| 102 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 103 |
config = GPTConfig()
|
| 104 |
model = GPT(config)
|
| 105 |
model_path = os.path.join("models", "best_model.pt")
|
| 106 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
|
|
|
|
|
|
| 107 |
model.to(device)
|
| 108 |
model.eval()
|
|
|
|
| 109 |
enc = tiktoken.get_encoding('gpt2')
|
| 110 |
|
| 111 |
+
print(f"✅ Model loaded on {device}!")
|
|
|
|
| 112 |
|
| 113 |
def generate(prompt: str, max_new_tokens: int = 30, top_k: int = 50, temperature: float = 1.0):
|
| 114 |
tokens = enc.encode(prompt)
|
|
|
|
| 128 |
out_tokens = x[0].tolist()
|
| 129 |
return enc.decode(out_tokens)
|
| 130 |
|
| 131 |
+
# Example prompts for dropdown
|
| 132 |
+
example_prompts = [
|
| 133 |
+
"To be, or not to be, that is the question:",
|
| 134 |
+
"O Romeo, Romeo! wherefore art thou Romeo?",
|
| 135 |
+
"Once more unto the breach, dear friends, once more;",
|
| 136 |
+
"All the world's a stage,",
|
| 137 |
+
"The lady doth protest too much, methinks."
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
with gr.Blocks() as demo:
|
| 141 |
+
gr.Markdown("# GPT-2 (124M) Shakespeare Text Generator")
|
| 142 |
+
gr.Markdown(
|
| 143 |
+
"GPT-2 (124M) model trained from scratch on Shakespeare's works. "
|
| 144 |
+
"Start with a prompt and generate Shakespearean-style text!"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
with gr.Row():
|
| 148 |
inp = gr.Textbox(lines=3, placeholder="Enter prompt here...", label="Prompt")
|
| 149 |
+
out = gr.Textbox(lines=10, label="Generated Text")
|
| 150 |
+
|
| 151 |
with gr.Row():
|
| 152 |
max_tokens = gr.Slider(1, 200, value=30, step=1, label="Max new tokens")
|
| 153 |
topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
|
| 154 |
temp = gr.Slider(0.01, 2.0, value=1.0, step=0.01, label="Temperature")
|
| 155 |
+
|
| 156 |
+
with gr.Row():
|
| 157 |
+
example_dropdown = gr.Dropdown(
|
| 158 |
+
choices=example_prompts,
|
| 159 |
+
label="Choose example prompt",
|
| 160 |
+
interactive=True
|
| 161 |
+
)
|
| 162 |
+
clear_btn = gr.Button("Clear output")
|
| 163 |
+
|
| 164 |
+
def use_example(prompt):
|
| 165 |
+
return prompt
|
| 166 |
+
|
| 167 |
+
def clear_output():
|
| 168 |
+
return ""
|
| 169 |
+
|
| 170 |
+
example_dropdown.change(fn=use_example, inputs=example_dropdown, outputs=inp)
|
| 171 |
+
clear_btn.click(fn=clear_output, inputs=[], outputs=out)
|
| 172 |
+
|
| 173 |
btn = gr.Button("Generate")
|
| 174 |
btn.click(fn=generate, inputs=[inp, max_tokens, topk, temp], outputs=out)
|
| 175 |
|