Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
| 5 |
# Modelos
|
| 6 |
model_name1 = "google/flan-t5-small"
|
| 7 |
model_name2 = "google/flan-t5-base"
|
| 8 |
-
arbitro_model_name = "google/flan-t5-
|
| 9 |
|
| 10 |
# Carregar modelos e tokenizadores
|
| 11 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
|
@@ -17,11 +17,29 @@ model2 = AutoModelForSeq2SeqLM.from_pretrained(model_name2)
|
|
| 17 |
tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
|
| 18 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Geração da resposta de cada modelo com prompt reforçado
|
| 21 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 22 |
prompt = (
|
| 23 |
-
"
|
| 24 |
-
"
|
|
|
|
|
|
|
|
|
|
| 25 |
f"Question: {pergunta}\n"
|
| 26 |
"Answer:"
|
| 27 |
)
|
|
@@ -32,26 +50,56 @@ def gerar_resposta(model, tokenizer, pergunta):
|
|
| 32 |
|
| 33 |
# Função para validar se parece uma capital
|
| 34 |
def eh_capital_valida(resposta):
|
| 35 |
-
#
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Árbitro decide qual resposta é melhor
|
| 39 |
def arbitro(pergunta, resp1, resp2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
prompt = (
|
| 41 |
-
"You are a geography expert.\n"
|
| 42 |
f"Question: {pergunta}\n"
|
| 43 |
-
f"
|
| 44 |
-
f"
|
| 45 |
-
"Which
|
| 46 |
)
|
| 47 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 48 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 49 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
| 50 |
|
| 51 |
if escolha == "2":
|
| 52 |
-
return
|
| 53 |
else:
|
| 54 |
-
return
|
| 55 |
|
| 56 |
# Função principal do chatbot
|
| 57 |
def chatbot(pergunta):
|
|
@@ -59,7 +107,7 @@ def chatbot(pergunta):
|
|
| 59 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 60 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 61 |
|
| 62 |
-
# Validação
|
| 63 |
if not eh_capital_valida(resposta_final):
|
| 64 |
resposta_final = "Não consegui identificar a capital corretamente."
|
| 65 |
|
|
@@ -78,7 +126,7 @@ iface = gr.Interface(
|
|
| 78 |
gr.Textbox(label="Resposta Modelo 1"),
|
| 79 |
gr.Textbox(label="Resposta Modelo 2")
|
| 80 |
],
|
| 81 |
-
title="Chatbot em Cascata - Perguntas sobre Capitais (
|
| 82 |
description="Insira uma pergunta como 'Qual é a capital da Alemanha?' e veja como os modelos escolhem a melhor resposta."
|
| 83 |
)
|
| 84 |
|
|
|
|
| 5 |
# Modelos
|
| 6 |
model_name1 = "google/flan-t5-small"
|
| 7 |
model_name2 = "google/flan-t5-base"
|
| 8 |
+
arbitro_model_name = "google/flan-t5-large" # Modelo maior para arbitragem
|
| 9 |
|
| 10 |
# Carregar modelos e tokenizadores
|
| 11 |
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
|
|
|
|
| 17 |
tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
|
| 18 |
model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
|
| 19 |
|
| 20 |
+
# Lista de capitais conhecidas para validação
|
| 21 |
+
CAPITAIS_CONHECIDAS = {
|
| 22 |
+
'brasil': 'Brasília',
|
| 23 |
+
'alemanha': 'Berlim',
|
| 24 |
+
'frança': 'Paris',
|
| 25 |
+
'japão': 'Tóquio',
|
| 26 |
+
'itália': 'Roma',
|
| 27 |
+
'espanha': 'Madri',
|
| 28 |
+
'portugal': 'Lisboa',
|
| 29 |
+
'argentina': 'Buenos Aires',
|
| 30 |
+
'estados unidos': 'Washington',
|
| 31 |
+
'canadá': 'Ottawa',
|
| 32 |
+
# Adicione mais conforme necessário
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
# Geração da resposta de cada modelo com prompt reforçado
|
| 36 |
def gerar_resposta(model, tokenizer, pergunta):
|
| 37 |
prompt = (
|
| 38 |
+
"I will ask you about capital cities. Always respond with just the capital name.\n"
|
| 39 |
+
"Examples:\n"
|
| 40 |
+
"Question: What is the capital of France? Answer: Paris\n"
|
| 41 |
+
"Question: Capital of Japan? Answer: Tokyo\n"
|
| 42 |
+
"Question: Qual é a capital do Brasil? Answer: Brasília\n"
|
| 43 |
f"Question: {pergunta}\n"
|
| 44 |
"Answer:"
|
| 45 |
)
|
|
|
|
| 50 |
|
| 51 |
# Função para validar se parece uma capital
|
| 52 |
def eh_capital_valida(resposta):
|
| 53 |
+
# Verifica se está na lista de capitais conhecidas
|
| 54 |
+
resposta_limpa = resposta.strip().lower()
|
| 55 |
+
for capital in CAPITAIS_CONHECIDAS.values():
|
| 56 |
+
if capital.lower() == resposta_limpa:
|
| 57 |
+
return True
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
# Correção de respostas óbvias
|
| 61 |
+
def corrigir_resposta(pergunta, resposta):
|
| 62 |
+
pergunta_lower = pergunta.lower()
|
| 63 |
+
resposta_lower = resposta.lower()
|
| 64 |
+
|
| 65 |
+
# Verifica se a resposta é igual ao país mencionado na pergunta
|
| 66 |
+
for pais, capital in CAPITAIS_CONHECIDAS.items():
|
| 67 |
+
if pais in pergunta_lower and resposta_lower == pais:
|
| 68 |
+
return capital
|
| 69 |
+
|
| 70 |
+
# Correções específicas
|
| 71 |
+
if "brasil" in pergunta_lower and resposta_lower == "brasil":
|
| 72 |
+
return "Brasília"
|
| 73 |
+
|
| 74 |
+
return resposta
|
| 75 |
|
| 76 |
# Árbitro decide qual resposta é melhor
|
| 77 |
def arbitro(pergunta, resp1, resp2):
|
| 78 |
+
# Primeiro verifica se alguma resposta está na lista de capitais conhecidas
|
| 79 |
+
resp1_corrigida = corrigir_resposta(pergunta, resp1)
|
| 80 |
+
resp2_corrigida = corrigir_resposta(pergunta, resp2)
|
| 81 |
+
|
| 82 |
+
if eh_capital_valida(resp1_corrigida) and not eh_capital_valida(resp2_corrigida):
|
| 83 |
+
return resp1_corrigida, "Modelo 1 (corrigido)"
|
| 84 |
+
elif eh_capital_valida(resp2_corrigida) and not eh_capital_valida(resp1_corrigida):
|
| 85 |
+
return resp2_corrigida, "Modelo 2 (corrigido)"
|
| 86 |
+
|
| 87 |
+
# Se ambas ou nenhuma for válida, usa o árbitro
|
| 88 |
prompt = (
|
| 89 |
+
"You are a geography expert. Choose the correct capital city.\n"
|
| 90 |
f"Question: {pergunta}\n"
|
| 91 |
+
f"Option 1: {resp1_corrigida}\n"
|
| 92 |
+
f"Option 2: {resp2_corrigida}\n"
|
| 93 |
+
"Which option is the correct capital? Reply only with 1 or 2."
|
| 94 |
)
|
| 95 |
inputs = tokenizer_arbitro(prompt, return_tensors="pt")
|
| 96 |
outputs = model_arbitro.generate(**inputs, max_length=5)
|
| 97 |
escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
|
| 98 |
|
| 99 |
if escolha == "2":
|
| 100 |
+
return resp2_corrigida, "Modelo 2 (flan-t5-base)"
|
| 101 |
else:
|
| 102 |
+
return resp1_corrigida, "Modelo 1 (flan-t5-small)"
|
| 103 |
|
| 104 |
# Função principal do chatbot
|
| 105 |
def chatbot(pergunta):
|
|
|
|
| 107 |
resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
|
| 108 |
resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
|
| 109 |
|
| 110 |
+
# Validação final
|
| 111 |
if not eh_capital_valida(resposta_final):
|
| 112 |
resposta_final = "Não consegui identificar a capital corretamente."
|
| 113 |
|
|
|
|
| 126 |
gr.Textbox(label="Resposta Modelo 1"),
|
| 127 |
gr.Textbox(label="Resposta Modelo 2")
|
| 128 |
],
|
| 129 |
+
title="Chatbot em Cascata - Perguntas sobre Capitais (Melhorado)",
|
| 130 |
description="Insira uma pergunta como 'Qual é a capital da Alemanha?' e veja como os modelos escolhem a melhor resposta."
|
| 131 |
)
|
| 132 |
|