ChatBot / app.py
GutoFonseca's picture
Update app.py
ac83a33 verified
import gradio as gr
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Modelos utilizados
MODELOS = {
"primario": "google/flan-t5-small",
"secundario": "google/flan-t5-base",
"arbitro": "google/flan-t5-large"
}
# Carregamento dos modelos e tokenizers
tokenizers = {nome: AutoTokenizer.from_pretrained(modelo)
for nome, modelo in MODELOS.items()}
modelos = {nome: AutoModelForSeq2SeqLM.from_pretrained(modelo)
for nome, modelo in MODELOS.items()}
# Base de capitais com erros comuns
BASE_CAPITAIS = {
'brazil': {
'correta': 'Brasília',
'erros_comuns': ['sao paulo', 'rio de janeiro', 'brazil']
},
'germany': {
'correta': 'Berlin',
'erros_comuns': ['munich', 'frankfurt']
},
'france': {
'correta': 'Paris',
'erros_comuns': ['lyon', 'marseille']
},
# Adicione mais países conforme necessário
}
def gerar_resposta(nome_modelo, pergunta):
"""Gera resposta com o modelo especificado"""
prompt = f"""Aja como um especialista em geografia. Responda APENAS com o nome da capital oficial.
Exemplos:
Q: Capital da França? A: Paris
Q: Capital do Brasil? A: Brasília
Q: Capital da Alemanha? A: Berlin
Q: {pergunta}
A:"""
entrada = tokenizers[nome_modelo](prompt, return_tensors="pt")
saida = modelos[nome_modelo].generate(**entrada, max_length=20)
return tokenizers[nome_modelo].decode(saida[0], skip_special_tokens=True).strip()
def validar_corrigir(pergunta, resposta_bruta):
"""Valida e corrige a resposta com base na base de capitais"""
pergunta = pergunta.lower()
resposta = resposta_bruta.lower()
for pais, dados in BASE_CAPITAIS.items():
if pais in pergunta:
if resposta in dados['erros_comuns']:
return dados['correta']
if resposta == dados['correta'].lower():
return dados['correta']
return dados['correta']
return resposta_bruta.title()
def esta_confiante(resposta, pergunta):
"""Avalia se a resposta pode ser considerada confiável"""
pergunta = pergunta.lower()
resposta = resposta.lower()
for pais, dados in BASE_CAPITAIS.items():
if pais in pergunta:
if resposta == dados['correta'].lower():
return True
if resposta in dados['erros_comuns']:
return False
return False
def arbitrar(pergunta, resposta1, resposta2):
"""Usa o modelo árbitro para escolher a melhor resposta"""
corrigida1 = validar_corrigir(pergunta, resposta1)
corrigida2 = validar_corrigir(pergunta, resposta2)
for pais, dados in BASE_CAPITAIS.items():
if pais in pergunta.lower():
if corrigida1 == dados['correta']:
return corrigida1, "Modelo 1 (validado)"
if corrigida2 == dados['correta']:
return corrigida2, "Modelo 2 (validado)"
prompt = f"""Você é professor de geografia. Escolha a capital correta:
Pergunta: {pergunta}
Opção 1: {corrigida1}
Opção 2: {corrigida2}
Responda SOMENTE com "1" ou "2"."""
entrada = tokenizers['arbitro'](prompt, return_tensors="pt")
saida = modelos['arbitro'].generate(**entrada, max_length=3)
escolha = tokenizers['arbitro'].decode(saida[0], skip_special_tokens=True).strip()
if escolha == "1":
return corrigida1, "Modelo 1 (árbitro)"
else:
return corrigida2, "Modelo 2 (árbitro)"
def chatbot(pergunta):
"""Pipeline em cascata para determinar a capital"""
resposta1 = gerar_resposta("primario", pergunta)
corrigida1 = validar_corrigir(pergunta, resposta1)
if corrigida1 == resposta1 and esta_confiante(corrigida1, pergunta):
return [
f"Resposta Selecionada: {corrigida1}\nModelo Escolhido: Modelo 1 (primário confiante)",
f"Modelo 1 (primário): {resposta1}",
f"Modelo 2 (secundário): Pulado"
]
resposta2 = gerar_resposta("secundario", pergunta)
corrigida2 = validar_corrigir(pergunta, resposta2)
if corrigida2 == resposta2 and esta_confiante(corrigida2, pergunta):
return [
f"Resposta Selecionada: {corrigida2}\nModelo Escolhido: Modelo 2 (secundário confiante)",
f"Modelo 1 (primário): {resposta1}",
f"Modelo 2 (secundário): {resposta2}"
]
resposta_final, modelo_escolhido = arbitrar(pergunta, resposta1, resposta2)
return [
f"Resposta Selecionada: {resposta_final}\nModelo Escolhido: {modelo_escolhido}",
f"Modelo 1 (primário): {resposta1}",
f"Modelo 2 (secundário): {resposta2}"
]
# Interface Gradio
interface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(label="Pergunte a capital de um país", placeholder="Qual é a capital do Brasil?"),
outputs=[
gr.Textbox(label="Resposta Final"),
gr.Textbox(label="Resposta do Modelo 1"),
gr.Textbox(label="Resposta do Modelo 2")
],
title="🗺️ Especialista em Capitais (Cascata com Correção Automática)",
description="Sistema com três modelos em cascata. Pergunte sobre a capital de qualquer país. Exemplos: Brasil, Alemanha, França..."
)
if __name__ == "__main__":
interface.launch()