Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
| 4 |
# Modelos escolhidos
|
| 5 |
model_name1 = "google/flan-t5-small"
|
| 6 |
model_name2 = "google/flan-t5-base"
|
| 7 |
-
arbitro_model_name = "google/flan-t5-
|
| 8 |
|
| 9 |
# Carregar tokenizadores e modelos
|
| 10 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
|
@@ -17,21 +17,25 @@ tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
|
|
| 17 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 18 |
|
| 19 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
outputs = model.generate(**inputs, max_length=20)
|
| 24 |
resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 25 |
return resposta.strip()
|
| 26 |
|
| 27 |
def arbitro(pergunta, resp1, resp2):
|
| 28 |
-
#
|
| 29 |
prompt = (
|
|
|
|
| 30 |
f"Pergunta: {pergunta}\n"
|
| 31 |
f"Resposta 1: {resp1}\n"
|
| 32 |
-
f"Resposta 2: {resp2}\n"
|
| 33 |
-
"
|
| 34 |
)
|
|
|
|
| 35 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 36 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 37 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
|
@@ -41,29 +45,35 @@ def arbitro(pergunta, resp1, resp2):
|
|
| 41 |
elif escolha == "2":
|
| 42 |
return resp2, "Modelo 2 (flan-t5-base)"
|
| 43 |
else:
|
| 44 |
-
#
|
| 45 |
-
return resp1, "Modelo 1 (flan-t5-small)"
|
| 46 |
|
| 47 |
def chatbot(pergunta):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
|
| 49 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 50 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 51 |
|
| 52 |
return (
|
| 53 |
-
f"Resposta selecionada:
|
| 54 |
-
f"
|
| 55 |
-
f"
|
| 56 |
)
|
| 57 |
|
| 58 |
iface = gr.Interface(
|
| 59 |
fn=chatbot,
|
| 60 |
-
inputs=gr.Textbox(label="Digite
|
| 61 |
outputs=[
|
| 62 |
-
gr.Textbox(label="Resposta
|
| 63 |
-
gr.Textbox(label="Resposta Modelo
|
| 64 |
-
gr.Textbox(label="Resposta Modelo
|
| 65 |
],
|
| 66 |
-
title="Chatbot em Cascata - Perguntas sobre Capitais"
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
if __name__ == "__main__":
|
|
|
|
| 4 |
# Modelos escolhidos
|
| 5 |
model_name1 = "google/flan-t5-small"
|
| 6 |
model_name2 = "google/flan-t5-base"
|
| 7 |
+
arbitro_model_name = "google/flan-t5-base" # Usando um modelo maior para arbitrar
|
| 8 |
|
| 9 |
# Carregar tokenizadores e modelos
|
| 10 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
|
|
|
| 17 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 18 |
|
| 19 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 20 |
+
# Formatar a pergunta para focar na capital
|
| 21 |
+
if "capital" not in pergunta.lower():
|
| 22 |
+
pergunta = f"Qual é a capital do {pergunta}?"
|
| 23 |
+
|
| 24 |
+
inputs = tokenizer(pergunta, return_tensors="pt")
|
| 25 |
outputs = model.generate(**inputs, max_length=20)
|
| 26 |
resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 27 |
return resposta.strip()
|
| 28 |
|
| 29 |
def arbitro(pergunta, resp1, resp2):
|
| 30 |
+
# Melhorar o prompt para o árbitro
|
| 31 |
prompt = (
|
| 32 |
+
f"Decida qual resposta está mais correta para esta pergunta sobre geografia:\n"
|
| 33 |
f"Pergunta: {pergunta}\n"
|
| 34 |
f"Resposta 1: {resp1}\n"
|
| 35 |
+
f"Resposta 2: {resp2}\n\n"
|
| 36 |
+
f"Responda apenas com o número 1 ou 2, sem explicações."
|
| 37 |
)
|
| 38 |
+
|
| 39 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 40 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 41 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
|
|
|
| 45 |
elif escolha == "2":
|
| 46 |
return resp2, "Modelo 2 (flan-t5-base)"
|
| 47 |
else:
|
| 48 |
+
# Se não conseguir decidir, escolhe a mais longa (heurística simples)
|
| 49 |
+
return (resp1, "Modelo 1 (flan-t5-small)") if len(resp1) > len(resp2) else (resp2, "Modelo 2 (flan-t5-base)")
|
| 50 |
|
| 51 |
def chatbot(pergunta):
|
| 52 |
+
# Garantir que a pergunta está formatada corretamente
|
| 53 |
+
pergunta = pergunta.strip().lower()
|
| 54 |
+
if not pergunta.endswith("?"):
|
| 55 |
+
pergunta += "?"
|
| 56 |
+
|
| 57 |
resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
|
| 58 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 59 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 60 |
|
| 61 |
return (
|
| 62 |
+
f"Resposta selecionada: {resposta_final}\nModelo escolhido: {modelo_escolhido}",
|
| 63 |
+
f"Modelo 1 (flan-t5-small): {resposta1}",
|
| 64 |
+
f"Modelo 2 (flan-t5-base): {resposta2}"
|
| 65 |
)
|
| 66 |
|
| 67 |
iface = gr.Interface(
|
| 68 |
fn=chatbot,
|
| 69 |
+
inputs=gr.Textbox(label="Digite um país ou pergunta sobre capitais", placeholder="Ex: Brasil ou Qual a capital do Brasil?"),
|
| 70 |
outputs=[
|
| 71 |
+
gr.Textbox(label="Resposta Final"),
|
| 72 |
+
gr.Textbox(label="Resposta do Modelo Pequeno"),
|
| 73 |
+
gr.Textbox(label="Resposta do Modelo Base")
|
| 74 |
],
|
| 75 |
+
title="Chatbot em Cascata - Perguntas sobre Capitais",
|
| 76 |
+
description="Digite um país ou pergunta sobre capitais para comparar dois modelos de IA"
|
| 77 |
)
|
| 78 |
|
| 79 |
if __name__ == "__main__":
|