Update app.py
Browse files
app.py
CHANGED
|
@@ -37,7 +37,7 @@ m.eval()
|
|
| 37 |
#print(
|
| 38 |
# "Model with {:.2f}M parameters".format(sum(p.numel() for p in m.parameters()) / 1e6)
|
| 39 |
#)
|
| 40 |
-
def model_generate(text):
|
| 41 |
# generate some output based on the context
|
| 42 |
#context = torch.tensor(np.array(encode("Hello! My name is ", tokenizer)))
|
| 43 |
#context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
|
|
@@ -47,7 +47,7 @@ def model_generate(text):
|
|
| 47 |
context = torch.from_numpy(context_np)
|
| 48 |
#print(context)
|
| 49 |
|
| 50 |
-
return decode(enc_sec=m.generate(idx=context, max_new_tokens=
|
| 51 |
|
| 52 |
-
iface = gr.Interface(fn=model_generate, inputs="text", outputs="text")
|
| 53 |
iface.launch()
|
|
|
|
| 37 |
#print(
|
| 38 |
# "Model with {:.2f}M parameters".format(sum(p.numel() for p in m.parameters()) / 1e6)
|
| 39 |
#)
|
| 40 |
+
def model_generate(text, number):
|
| 41 |
# generate some output based on the context
|
| 42 |
#context = torch.tensor(np.array(encode("Hello! My name is ", tokenizer)))
|
| 43 |
#context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
|
|
|
|
| 47 |
context = torch.from_numpy(context_np)
|
| 48 |
#print(context)
|
| 49 |
|
| 50 |
+
return decode(enc_sec=m.generate(idx=context, max_new_tokens=number, block_size=BLOCK_SIZE)[0], tokenizer=tokenizer)
|
| 51 |
|
| 52 |
+
iface = gr.Interface(fn=model_generate, inputs=["text", gr.Slider(10, 1000)], outputs="text")
|
| 53 |
iface.launch()
|