File size: 1,764 Bytes
c24eab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7de4ee6
c24eab6
0c256f7
0d1b49e
525c7c3
7de4ee6
c24eab6
2e05bfe
c24eab6
2e05bfe
c24eab6
 
 
 
 
 
2e05bfe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# import torch
# from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# import gradio as gr

# # Load your custom model and tokenizer
# model_name = "MiVaCod/mbart-neutralization"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# # Function to correct sentences
# def predict(sentence):
#     inputs = tokenizer.encode("correction: " + sentence, return_tensors="pt", max_length=512, truncation=True)
#     outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
#     corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     return corrected_sentence

# # Gradio Interface
# iface = gr.Interface(
#     fn=correct_sentence,
#     inputs="text",
#     outputs="text",
#     title="Sentence Correction",
#     description="Enter a sentence to be corrected:",
#     theme="compact"
# )

# # Launch the interface
# gr.Interface(fn=predict, inputs=gr.inputs.Textbox, outputs=gr.outputs.Textbox).launch(share=False)

from transformers import MBartForConditionalGeneration, MBart50Tokenizer
import gradio as grad

model_name = "MiVaCod/mbart-neutralization"
text2text_tkn= MBart50Tokenizer.from_pretrained(model_name)
mdl = MBartForConditionalGeneration.from_pretrained(model_name)

def text2text_paraphrase(sentence1):
     inp1 = "rte sentence1: "+sentence1
     enc = text2text_tkn(inp1, return_tensors="pt")
     tokens = mdl.generate(**enc)
     response=text2text_tkn.batch_decode(tokens)
     return response

sent1=grad.Textbox(lines=1, label="Frase misógina", placeholder="Introduce una frase misógina")
out=grad.Textbox(lines=1, label="Frase corregida")
grad.Interface(text2text_paraphrase, inputs=[sent1], outputs=out).launch()