Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
|
| 4 |
-
# Modelos
|
| 5 |
model_name1 = "google/flan-t5-small"
|
| 6 |
model_name2 = "google/flan-t5-base"
|
| 7 |
-
arbitro_model_name = "google/flan-t5-base"
|
| 8 |
|
| 9 |
-
# Carregar
|
| 10 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
| 11 |
model1 = AutoModelForSeq2SeqLM.from_pretrained(model_name1)
|
| 12 |
|
|
@@ -16,65 +17,70 @@ model2 = AutoModelForSeq2SeqLM.from_pretrained(model_name2)
|
|
| 16 |
tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
|
| 17 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 18 |
|
|
|
|
| 19 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
outputs = model.generate(**inputs, max_length=20)
|
| 26 |
resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 27 |
return resposta.strip()
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def arbitro(pergunta, resp1, resp2):
|
| 30 |
-
# Melhorar o prompt para o árbitro
|
| 31 |
prompt = (
|
| 32 |
-
|
| 33 |
-
f"
|
| 34 |
-
f"
|
| 35 |
-
f"
|
| 36 |
-
|
| 37 |
)
|
| 38 |
-
|
| 39 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 40 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 41 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
| 42 |
|
| 43 |
-
if escolha == "
|
| 44 |
-
return resp1, "Modelo 1 (flan-t5-small)"
|
| 45 |
-
elif escolha == "2":
|
| 46 |
return resp2, "Modelo 2 (flan-t5-base)"
|
| 47 |
else:
|
| 48 |
-
|
| 49 |
-
return (resp1, "Modelo 1 (flan-t5-small)") if len(resp1) > len(resp2) else (resp2, "Modelo 2 (flan-t5-base)")
|
| 50 |
|
|
|
|
| 51 |
def chatbot(pergunta):
|
| 52 |
-
# Garantir que a pergunta está formatada corretamente
|
| 53 |
-
pergunta = pergunta.strip().lower()
|
| 54 |
-
if not pergunta.endswith("?"):
|
| 55 |
-
pergunta += "?"
|
| 56 |
-
|
| 57 |
resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
|
| 58 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 59 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
return (
|
| 62 |
-
f"Resposta selecionada:
|
| 63 |
-
f"Modelo 1 (flan-t5-small):
|
| 64 |
-
f"Modelo 2 (flan-t5-base):
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
iface = gr.Interface(
|
| 68 |
fn=chatbot,
|
| 69 |
-
inputs=gr.Textbox(label="Digite
|
| 70 |
outputs=[
|
| 71 |
-
gr.Textbox(label="Resposta
|
| 72 |
-
gr.Textbox(label="Resposta
|
| 73 |
-
gr.Textbox(label="Resposta
|
| 74 |
],
|
| 75 |
-
title="Chatbot em Cascata - Perguntas sobre Capitais",
|
| 76 |
-
description="
|
| 77 |
)
|
| 78 |
|
| 79 |
-
if
|
| 80 |
iface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import re
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
|
| 5 |
+
# Modelos
|
| 6 |
model_name1 = "google/flan-t5-small"
|
| 7 |
model_name2 = "google/flan-t5-base"
|
| 8 |
+
arbitro_model_name = "google/flan-t5-base"
|
| 9 |
|
| 10 |
+
# Carregar modelos e tokenizadores
|
| 11 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
| 12 |
model1 = AutoModelForSeq2SeqLM.from_pretrained(model_name1)
|
| 13 |
|
|
|
|
| 17 |
tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
|
| 18 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 19 |
|
| 20 |
+
# Geração da resposta de cada modelo com prompt reforçado
|
| 21 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 22 |
+
prompt = (
|
| 23 |
+
"Answer ONLY with the name of the capital city for the following question.\n"
|
| 24 |
+
"Do NOT answer with the country name.\n"
|
| 25 |
+
f"Question: {pergunta}\n"
|
| 26 |
+
"Answer:"
|
| 27 |
+
)
|
| 28 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 29 |
outputs = model.generate(**inputs, max_length=20)
|
| 30 |
resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 31 |
return resposta.strip()
|
| 32 |
|
| 33 |
+
# Função para validar se parece uma capital
|
| 34 |
+
def eh_capital_valida(resposta):
|
| 35 |
+
# Considera resposta válida se tem 1 ou 2 palavras, só letras e espaços
|
| 36 |
+
return bool(re.match(r"^[A-Za-zÀ-ÿ\s]{2,40}$", resposta.strip()))
|
| 37 |
+
|
| 38 |
+
# Árbitro decide qual resposta é melhor
|
| 39 |
def arbitro(pergunta, resp1, resp2):
|
|
|
|
| 40 |
prompt = (
|
| 41 |
+
"You are a geography expert.\n"
|
| 42 |
+
f"Question: {pergunta}\n"
|
| 43 |
+
f"Answer 1: {resp1}\n"
|
| 44 |
+
f"Answer 2: {resp2}\n"
|
| 45 |
+
"Which answer is the correct capital? Reply only with 1 or 2."
|
| 46 |
)
|
|
|
|
| 47 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 48 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 49 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
| 50 |
|
| 51 |
+
if escolha == "2":
|
|
|
|
|
|
|
| 52 |
return resp2, "Modelo 2 (flan-t5-base)"
|
| 53 |
else:
|
| 54 |
+
return resp1, "Modelo 1 (flan-t5-small)"
|
|
|
|
| 55 |
|
| 56 |
+
# Função principal do chatbot
|
| 57 |
def chatbot(pergunta):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
|
| 59 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 60 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 61 |
|
| 62 |
+
# Validação
|
| 63 |
+
if not eh_capital_valida(resposta_final):
|
| 64 |
+
resposta_final = "Não consegui identificar a capital corretamente."
|
| 65 |
+
|
| 66 |
return (
|
| 67 |
+
f"Resposta selecionada:\n{resposta_final}\n\nModelo escolhido:\n{modelo_escolhido}",
|
| 68 |
+
f"Resposta Modelo 1 (flan-t5-small):\n{resposta1}",
|
| 69 |
+
f"Resposta Modelo 2 (flan-t5-base):\n{resposta2}"
|
| 70 |
)
|
| 71 |
|
| 72 |
+
# Interface Gradio
|
| 73 |
iface = gr.Interface(
|
| 74 |
fn=chatbot,
|
| 75 |
+
inputs=gr.Textbox(label="Digite uma pergunta sobre capitais"),
|
| 76 |
outputs=[
|
| 77 |
+
gr.Textbox(label="Resposta selecionada"),
|
| 78 |
+
gr.Textbox(label="Resposta Modelo 1"),
|
| 79 |
+
gr.Textbox(label="Resposta Modelo 2")
|
| 80 |
],
|
| 81 |
+
title="Chatbot em Cascata - Perguntas sobre Capitais (sem listas)",
|
| 82 |
+
description="Insira uma pergunta como 'Qual é a capital da Alemanha?' e veja como os modelos escolhem a melhor resposta."
|
| 83 |
)
|
| 84 |
|
| 85 |
+
if _name_ == "_main_":
|
| 86 |
iface.launch()
|