Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
from torch.nn import functional as F
|
| 5 |
import seaborn
|
|
@@ -60,7 +60,9 @@ def generate(model_name, text):
|
|
| 60 |
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
|
| 61 |
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
|
| 62 |
output = tokenizer.decode(outputs[0])
|
| 63 |
-
return ".".join(output.split(".")[:-1]) + "."
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
output_text = gr.outputs.Textbox()
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
|
| 2 |
+
import re
|
| 3 |
import gradio as gr
|
| 4 |
from torch.nn import functional as F
|
| 5 |
import seaborn
|
|
|
|
| 60 |
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
|
| 61 |
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
|
| 62 |
output = tokenizer.decode(outputs[0])
|
| 63 |
+
#return ".".join(output.split(".")[:-1]) + "."
|
| 64 |
+
sent = ".".join(output.split(".")[:-1]) + "."
|
| 65 |
+
return re.match(r'<pad> ([^<>]*)', sent).group(1)
|
| 66 |
|
| 67 |
|
| 68 |
output_text = gr.outputs.Textbox()
|