import gradio as gr import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Modelos utilizados MODELOS = { "primario": "google/flan-t5-small", "secundario": "google/flan-t5-base", "arbitro": "google/flan-t5-large" } # Carregamento dos modelos e tokenizers tokenizers = {nome: AutoTokenizer.from_pretrained(modelo) for nome, modelo in MODELOS.items()} modelos = {nome: AutoModelForSeq2SeqLM.from_pretrained(modelo) for nome, modelo in MODELOS.items()} # Base de capitais com erros comuns BASE_CAPITAIS = { 'brazil': { 'correta': 'Brasília', 'erros_comuns': ['sao paulo', 'rio de janeiro', 'brazil'] }, 'germany': { 'correta': 'Berlin', 'erros_comuns': ['munich', 'frankfurt'] }, 'france': { 'correta': 'Paris', 'erros_comuns': ['lyon', 'marseille'] }, # Adicione mais países conforme necessário } def gerar_resposta(nome_modelo, pergunta): """Gera resposta com o modelo especificado""" prompt = f"""Aja como um especialista em geografia. Responda APENAS com o nome da capital oficial. Exemplos: Q: Capital da França? A: Paris Q: Capital do Brasil? A: Brasília Q: Capital da Alemanha? A: Berlin Q: {pergunta} A:""" entrada = tokenizers[nome_modelo](prompt, return_tensors="pt") saida = modelos[nome_modelo].generate(**entrada, max_length=20) return tokenizers[nome_modelo].decode(saida[0], skip_special_tokens=True).strip() def validar_corrigir(pergunta, resposta_bruta): """Valida e corrige a resposta com base na base de capitais""" pergunta = pergunta.lower() resposta = resposta_bruta.lower() for pais, dados in BASE_CAPITAIS.items(): if pais in pergunta: if resposta in dados['erros_comuns']: return dados['correta'] if resposta == dados['correta'].lower(): return dados['correta'] return dados['correta'] return resposta_bruta.title() def esta_confiante(resposta, pergunta): """Avalia se a resposta pode ser considerada confiável""" pergunta = pergunta.lower() resposta = resposta.lower() for pais, dados in BASE_CAPITAIS.items(): if pais in pergunta: if resposta == dados['correta'].lower(): return True if resposta in dados['erros_comuns']: return False return False def arbitrar(pergunta, resposta1, resposta2): """Usa o modelo árbitro para escolher a melhor resposta""" corrigida1 = validar_corrigir(pergunta, resposta1) corrigida2 = validar_corrigir(pergunta, resposta2) for pais, dados in BASE_CAPITAIS.items(): if pais in pergunta.lower(): if corrigida1 == dados['correta']: return corrigida1, "Modelo 1 (validado)" if corrigida2 == dados['correta']: return corrigida2, "Modelo 2 (validado)" prompt = f"""Você é professor de geografia. Escolha a capital correta: Pergunta: {pergunta} Opção 1: {corrigida1} Opção 2: {corrigida2} Responda SOMENTE com "1" ou "2".""" entrada = tokenizers['arbitro'](prompt, return_tensors="pt") saida = modelos['arbitro'].generate(**entrada, max_length=3) escolha = tokenizers['arbitro'].decode(saida[0], skip_special_tokens=True).strip() if escolha == "1": return corrigida1, "Modelo 1 (árbitro)" else: return corrigida2, "Modelo 2 (árbitro)" def chatbot(pergunta): """Pipeline em cascata para determinar a capital""" resposta1 = gerar_resposta("primario", pergunta) corrigida1 = validar_corrigir(pergunta, resposta1) if corrigida1 == resposta1 and esta_confiante(corrigida1, pergunta): return [ f"Resposta Selecionada: {corrigida1}\nModelo Escolhido: Modelo 1 (primário confiante)", f"Modelo 1 (primário): {resposta1}", f"Modelo 2 (secundário): Pulado" ] resposta2 = gerar_resposta("secundario", pergunta) corrigida2 = validar_corrigir(pergunta, resposta2) if corrigida2 == resposta2 and esta_confiante(corrigida2, pergunta): return [ f"Resposta Selecionada: {corrigida2}\nModelo Escolhido: Modelo 2 (secundário confiante)", f"Modelo 1 (primário): {resposta1}", f"Modelo 2 (secundário): {resposta2}" ] resposta_final, modelo_escolhido = arbitrar(pergunta, resposta1, resposta2) return [ f"Resposta Selecionada: {resposta_final}\nModelo Escolhido: {modelo_escolhido}", f"Modelo 1 (primário): {resposta1}", f"Modelo 2 (secundário): {resposta2}" ] # Interface Gradio interface = gr.Interface( fn=chatbot, inputs=gr.Textbox(label="Pergunte a capital de um país", placeholder="Qual é a capital do Brasil?"), outputs=[ gr.Textbox(label="Resposta Final"), gr.Textbox(label="Resposta do Modelo 1"), gr.Textbox(label="Resposta do Modelo 2") ], title="🗺️ Especialista em Capitais (Cascata com Correção Automática)", description="Sistema com três modelos em cascata. Pergunte sobre a capital de qualquer país. Exemplos: Brasil, Alemanha, França..." ) if __name__ == "__main__": interface.launch()