Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Modelos utilizados | |
| MODELOS = { | |
| "primario": "google/flan-t5-small", | |
| "secundario": "google/flan-t5-base", | |
| "arbitro": "google/flan-t5-large" | |
| } | |
| # Carregamento dos modelos e tokenizers | |
| tokenizers = {nome: AutoTokenizer.from_pretrained(modelo) | |
| for nome, modelo in MODELOS.items()} | |
| modelos = {nome: AutoModelForSeq2SeqLM.from_pretrained(modelo) | |
| for nome, modelo in MODELOS.items()} | |
| # Base de capitais com erros comuns | |
| BASE_CAPITAIS = { | |
| 'brazil': { | |
| 'correta': 'Brasília', | |
| 'erros_comuns': ['sao paulo', 'rio de janeiro', 'brazil'] | |
| }, | |
| 'germany': { | |
| 'correta': 'Berlin', | |
| 'erros_comuns': ['munich', 'frankfurt'] | |
| }, | |
| 'france': { | |
| 'correta': 'Paris', | |
| 'erros_comuns': ['lyon', 'marseille'] | |
| }, | |
| # Adicione mais países conforme necessário | |
| } | |
| def gerar_resposta(nome_modelo, pergunta): | |
| """Gera resposta com o modelo especificado""" | |
| prompt = f"""Aja como um especialista em geografia. Responda APENAS com o nome da capital oficial. | |
| Exemplos: | |
| Q: Capital da França? A: Paris | |
| Q: Capital do Brasil? A: Brasília | |
| Q: Capital da Alemanha? A: Berlin | |
| Q: {pergunta} | |
| A:""" | |
| entrada = tokenizers[nome_modelo](prompt, return_tensors="pt") | |
| saida = modelos[nome_modelo].generate(**entrada, max_length=20) | |
| return tokenizers[nome_modelo].decode(saida[0], skip_special_tokens=True).strip() | |
| def validar_corrigir(pergunta, resposta_bruta): | |
| """Valida e corrige a resposta com base na base de capitais""" | |
| pergunta = pergunta.lower() | |
| resposta = resposta_bruta.lower() | |
| for pais, dados in BASE_CAPITAIS.items(): | |
| if pais in pergunta: | |
| if resposta in dados['erros_comuns']: | |
| return dados['correta'] | |
| if resposta == dados['correta'].lower(): | |
| return dados['correta'] | |
| return dados['correta'] | |
| return resposta_bruta.title() | |
| def esta_confiante(resposta, pergunta): | |
| """Avalia se a resposta pode ser considerada confiável""" | |
| pergunta = pergunta.lower() | |
| resposta = resposta.lower() | |
| for pais, dados in BASE_CAPITAIS.items(): | |
| if pais in pergunta: | |
| if resposta == dados['correta'].lower(): | |
| return True | |
| if resposta in dados['erros_comuns']: | |
| return False | |
| return False | |
| def arbitrar(pergunta, resposta1, resposta2): | |
| """Usa o modelo árbitro para escolher a melhor resposta""" | |
| corrigida1 = validar_corrigir(pergunta, resposta1) | |
| corrigida2 = validar_corrigir(pergunta, resposta2) | |
| for pais, dados in BASE_CAPITAIS.items(): | |
| if pais in pergunta.lower(): | |
| if corrigida1 == dados['correta']: | |
| return corrigida1, "Modelo 1 (validado)" | |
| if corrigida2 == dados['correta']: | |
| return corrigida2, "Modelo 2 (validado)" | |
| prompt = f"""Você é professor de geografia. Escolha a capital correta: | |
| Pergunta: {pergunta} | |
| Opção 1: {corrigida1} | |
| Opção 2: {corrigida2} | |
| Responda SOMENTE com "1" ou "2".""" | |
| entrada = tokenizers['arbitro'](prompt, return_tensors="pt") | |
| saida = modelos['arbitro'].generate(**entrada, max_length=3) | |
| escolha = tokenizers['arbitro'].decode(saida[0], skip_special_tokens=True).strip() | |
| if escolha == "1": | |
| return corrigida1, "Modelo 1 (árbitro)" | |
| else: | |
| return corrigida2, "Modelo 2 (árbitro)" | |
| def chatbot(pergunta): | |
| """Pipeline em cascata para determinar a capital""" | |
| resposta1 = gerar_resposta("primario", pergunta) | |
| corrigida1 = validar_corrigir(pergunta, resposta1) | |
| if corrigida1 == resposta1 and esta_confiante(corrigida1, pergunta): | |
| return [ | |
| f"Resposta Selecionada: {corrigida1}\nModelo Escolhido: Modelo 1 (primário confiante)", | |
| f"Modelo 1 (primário): {resposta1}", | |
| f"Modelo 2 (secundário): Pulado" | |
| ] | |
| resposta2 = gerar_resposta("secundario", pergunta) | |
| corrigida2 = validar_corrigir(pergunta, resposta2) | |
| if corrigida2 == resposta2 and esta_confiante(corrigida2, pergunta): | |
| return [ | |
| f"Resposta Selecionada: {corrigida2}\nModelo Escolhido: Modelo 2 (secundário confiante)", | |
| f"Modelo 1 (primário): {resposta1}", | |
| f"Modelo 2 (secundário): {resposta2}" | |
| ] | |
| resposta_final, modelo_escolhido = arbitrar(pergunta, resposta1, resposta2) | |
| return [ | |
| f"Resposta Selecionada: {resposta_final}\nModelo Escolhido: {modelo_escolhido}", | |
| f"Modelo 1 (primário): {resposta1}", | |
| f"Modelo 2 (secundário): {resposta2}" | |
| ] | |
| # Interface Gradio | |
| interface = gr.Interface( | |
| fn=chatbot, | |
| inputs=gr.Textbox(label="Pergunte a capital de um país", placeholder="Qual é a capital do Brasil?"), | |
| outputs=[ | |
| gr.Textbox(label="Resposta Final"), | |
| gr.Textbox(label="Resposta do Modelo 1"), | |
| gr.Textbox(label="Resposta do Modelo 2") | |
| ], | |
| title="🗺️ Especialista em Capitais (Cascata com Correção Automática)", | |
| description="Sistema com três modelos em cascata. Pergunte sobre a capital de qualquer país. Exemplos: Brasil, Alemanha, França..." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |