Spaces:
Runtime error
Runtime error
returned classic generating
Browse files
app.py
CHANGED
|
@@ -39,20 +39,21 @@ def load_model():
|
|
| 39 |
|
| 40 |
|
| 41 |
@torch.inference_mode()
|
| 42 |
-
def
|
| 43 |
-
"""
|
| 44 |
model = load_model()
|
| 45 |
enc = state["enc"]
|
|
|
|
| 46 |
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
|
|
|
|
| 47 |
y = model.generate(
|
| 48 |
x,
|
| 49 |
max_new_tokens=int(max_new_tokens),
|
| 50 |
temperature=float(temperature),
|
| 51 |
top_k=int(top_k) if top_k > 0 else None
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
yield enc.decode(y[:i+1])
|
| 56 |
|
| 57 |
# Gradio UI
|
| 58 |
with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
|
|
@@ -63,9 +64,9 @@ with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
|
|
| 63 |
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
|
| 64 |
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
|
| 65 |
btn = gr.Button("Generate")
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
btn.click(generate_stream, [prompt, max_new_tokens, temperature, top_k], output)
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|
| 71 |
demo.launch()
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
@torch.inference_mode()
|
| 42 |
+
def generate(prompt, max_new_tokens=200, temperature=0.9, top_k=50):
|
| 43 |
+
"""Generiranje teksta iz prompta"""
|
| 44 |
model = load_model()
|
| 45 |
enc = state["enc"]
|
| 46 |
+
|
| 47 |
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
|
| 48 |
+
|
| 49 |
y = model.generate(
|
| 50 |
x,
|
| 51 |
max_new_tokens=int(max_new_tokens),
|
| 52 |
temperature=float(temperature),
|
| 53 |
top_k=int(top_k) if top_k > 0 else None
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return enc.decode(y[0].tolist())
|
|
|
|
| 57 |
|
| 58 |
# Gradio UI
|
| 59 |
with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
|
|
|
|
| 64 |
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
|
| 65 |
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
|
| 66 |
btn = gr.Button("Generate")
|
| 67 |
+
output = gr.Textbox(label="Output", lines=15)
|
| 68 |
|
| 69 |
+
btn.click(generate, [prompt, max_new_tokens, temperature, top_k], output)
|
|
|
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|
| 72 |
demo.launch()
|