GutoFonseca commited on
Commit
ac83a33
·
verified ·
1 Parent(s): 71b45b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -105
app.py CHANGED
@@ -2,129 +2,143 @@ import gradio as gr
2
  import re
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- # Model and knowledge base setup
6
- MODELS = {
7
- "small": "google/flan-t5-small",
8
- "base": "google/flan-t5-base",
9
- "arbiter": "google/flan-t5-large"
10
  }
11
 
12
- # Load all models
13
- tokenizers = {name: AutoTokenizer.from_pretrained(model)
14
- for name, model in MODELS.items()}
15
- models = {name: AutoModelForSeq2SeqLM.from_pretrained(model)
16
- for name, model in MODELS.items()}
17
 
18
- # Enhanced capital database with common mistakes
19
- CAPITAL_DB = {
20
  'brazil': {
21
- 'correct': 'Brasília',
22
- 'common_errors': ['sao paulo', 'rio de janeiro', 'brazil']
23
  },
24
  'germany': {
25
- 'correct': 'Berlin',
26
- 'common_errors': ['munich', 'frankfurt']
27
  },
28
- # Add more countries as needed
 
 
 
 
29
  }
30
 
31
- def generate_response(model_name, question):
32
- """Improved response generation with examples"""
33
- prompt = f"""Act as a geography expert. Answer ONLY with the official capital name.
34
- Examples:
35
- Q: Capital of France? A: Paris
36
- Q: Brazil's capital? A: Brasília
37
- Q: Germany's capital? A: Berlin
38
- Q: {question}
39
  A:"""
40
-
41
- inputs = tokenizers[model_name](prompt, return_tensors="pt")
42
- outputs = models[model_name].generate(**inputs, max_length=20)
43
- return tokenizers[model_name].decode(outputs[0], skip_special_tokens=True).strip()
44
-
45
- def validate_and_correct(question, raw_answer):
46
- """Apply multiple correction layers"""
47
- question_lower = question.lower()
48
- answer_lower = raw_answer.lower()
49
-
50
- # First check if question is about a country we have in DB
51
- for country, data in CAPITAL_DB.items():
52
- if country in question_lower:
53
- # Check for common errors
54
- for error in data['common_errors']:
55
- if error in answer_lower:
56
- return data['correct']
57
-
58
- # If answer matches correct, use it
59
- if answer_lower == data['correct'].lower():
60
- return data['correct']
61
-
62
- # Final fallback to our known correct answer
63
- return data['correct']
64
-
65
- # For countries not in our DB, basic cleaning
66
- return raw_answer.title()
67
 
68
- def arbitrate(question, ans1, ans2):
69
- """Improved arbitration with validation priority"""
70
- corrected_1 = validate_and_correct(question, ans1)
71
- corrected_2 = validate_and_correct(question, ans2)
72
-
73
- # If one matches known correct, prefer it
74
- for country, data in CAPITAL_DB.items():
75
- if country in question.lower():
76
- if corrected_1 == data['correct']:
77
- return corrected_1, "Model 1 (validated)"
78
- if corrected_2 == data['correct']:
79
- return corrected_2, "Model 2 (validated)"
80
-
81
- # Fallback to arbiter model
82
- prompt = f"""As a geography professor, select the most likely correct capital:
83
- Question: {question}
84
- Option 1: {corrected_1}
85
- Option 2: {corrected_2}
86
- Respond ONLY with "1" or "2"."""
87
-
88
- inputs = tokenizers['arbiter'](prompt, return_tensors="pt")
89
- outputs = models['arbiter'].generate(**inputs, max_length=3)
90
- choice = tokenizers['arbiter'].decode(outputs[0], skip_special_tokens=True)
91
-
92
- return (corrected_1, "Model 1 (arbiter)") if choice.strip() == "1" else (corrected_2, "Model 2 (arbiter)")
93
 
94
- def chatbot(question):
95
- """Main processing pipeline"""
96
- # Generate responses
97
- ans1 = generate_response("small", question)
98
- ans2 = generate_response("base", question)
99
-
100
- # Get final answer
101
- final_ans, model_used = arbitrate(question, ans1, ans2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Format outputs
104
- outputs = [
105
- f"Selected Answer: {final_ans}\nChosen Model: {model_used}",
106
- f"Model 1 (small): {ans1}",
107
- f"Model 2 (base): {ans2}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ]
109
-
110
- # Special case formatting for Brazil
111
- if "brazil" in question.lower():
112
- outputs[0] = outputs[0].replace("Brasilia", "Brasília")
113
-
114
- return outputs
115
 
116
- # Gradio interface
117
- iface = gr.Interface(
118
  fn=chatbot,
119
- inputs=gr.Textbox(label="Ask about any country's capital", placeholder="What is the capital of Brazil?"),
120
  outputs=[
121
- gr.Textbox(label="Final Answer"),
122
- gr.Textbox(label="Model 1 Response"),
123
- gr.Textbox(label="Model 2 Response")
124
  ],
125
- title="🗺️ Capital City Expert (Guaranteed Correct for Brazil)",
126
- description="Now with 100% more Brasília! Try asking about Brazil, Germany, France..."
127
  )
128
 
129
  if __name__ == "__main__":
130
- iface.launch()
 
2
  import re
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Modelos utilizados
6
+ MODELOS = {
7
+ "primario": "google/flan-t5-small",
8
+ "secundario": "google/flan-t5-base",
9
+ "arbitro": "google/flan-t5-large"
10
  }
11
 
12
+ # Carregamento dos modelos e tokenizers
13
+ tokenizers = {nome: AutoTokenizer.from_pretrained(modelo)
14
+ for nome, modelo in MODELOS.items()}
15
+ modelos = {nome: AutoModelForSeq2SeqLM.from_pretrained(modelo)
16
+ for nome, modelo in MODELOS.items()}
17
 
18
+ # Base de capitais com erros comuns
19
+ BASE_CAPITAIS = {
20
  'brazil': {
21
+ 'correta': 'Brasília',
22
+ 'erros_comuns': ['sao paulo', 'rio de janeiro', 'brazil']
23
  },
24
  'germany': {
25
+ 'correta': 'Berlin',
26
+ 'erros_comuns': ['munich', 'frankfurt']
27
  },
28
+ 'france': {
29
+ 'correta': 'Paris',
30
+ 'erros_comuns': ['lyon', 'marseille']
31
+ },
32
+ # Adicione mais países conforme necessário
33
  }
34
 
35
+ def gerar_resposta(nome_modelo, pergunta):
36
+ """Gera resposta com o modelo especificado"""
37
+ prompt = f"""Aja como um especialista em geografia. Responda APENAS com o nome da capital oficial.
38
+ Exemplos:
39
+ Q: Capital da França? A: Paris
40
+ Q: Capital do Brasil? A: Brasília
41
+ Q: Capital da Alemanha? A: Berlin
42
+ Q: {pergunta}
43
  A:"""
44
+ entrada = tokenizers[nome_modelo](prompt, return_tensors="pt")
45
+ saida = modelos[nome_modelo].generate(**entrada, max_length=20)
46
+ return tokenizers[nome_modelo].decode(saida[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def validar_corrigir(pergunta, resposta_bruta):
49
+ """Valida e corrige a resposta com base na base de capitais"""
50
+ pergunta = pergunta.lower()
51
+ resposta = resposta_bruta.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ for pais, dados in BASE_CAPITAIS.items():
54
+ if pais in pergunta:
55
+ if resposta in dados['erros_comuns']:
56
+ return dados['correta']
57
+ if resposta == dados['correta'].lower():
58
+ return dados['correta']
59
+ return dados['correta']
60
+ return resposta_bruta.title()
61
+
62
+ def esta_confiante(resposta, pergunta):
63
+ """Avalia se a resposta pode ser considerada confiável"""
64
+ pergunta = pergunta.lower()
65
+ resposta = resposta.lower()
66
+ for pais, dados in BASE_CAPITAIS.items():
67
+ if pais in pergunta:
68
+ if resposta == dados['correta'].lower():
69
+ return True
70
+ if resposta in dados['erros_comuns']:
71
+ return False
72
+ return False
73
+
74
+ def arbitrar(pergunta, resposta1, resposta2):
75
+ """Usa o modelo árbitro para escolher a melhor resposta"""
76
+ corrigida1 = validar_corrigir(pergunta, resposta1)
77
+ corrigida2 = validar_corrigir(pergunta, resposta2)
78
+
79
+ for pais, dados in BASE_CAPITAIS.items():
80
+ if pais in pergunta.lower():
81
+ if corrigida1 == dados['correta']:
82
+ return corrigida1, "Modelo 1 (validado)"
83
+ if corrigida2 == dados['correta']:
84
+ return corrigida2, "Modelo 2 (validado)"
85
+
86
+ prompt = f"""Você é professor de geografia. Escolha a capital correta:
87
+ Pergunta: {pergunta}
88
+ Opção 1: {corrigida1}
89
+ Opção 2: {corrigida2}
90
+ Responda SOMENTE com "1" ou "2"."""
91
 
92
+ entrada = tokenizers['arbitro'](prompt, return_tensors="pt")
93
+ saida = modelos['arbitro'].generate(**entrada, max_length=3)
94
+ escolha = tokenizers['arbitro'].decode(saida[0], skip_special_tokens=True).strip()
95
+
96
+ if escolha == "1":
97
+ return corrigida1, "Modelo 1 (árbitro)"
98
+ else:
99
+ return corrigida2, "Modelo 2 (árbitro)"
100
+
101
+ def chatbot(pergunta):
102
+ """Pipeline em cascata para determinar a capital"""
103
+ resposta1 = gerar_resposta("primario", pergunta)
104
+ corrigida1 = validar_corrigir(pergunta, resposta1)
105
+
106
+ if corrigida1 == resposta1 and esta_confiante(corrigida1, pergunta):
107
+ return [
108
+ f"Resposta Selecionada: {corrigida1}\nModelo Escolhido: Modelo 1 (primário confiante)",
109
+ f"Modelo 1 (primário): {resposta1}",
110
+ f"Modelo 2 (secundário): Pulado"
111
+ ]
112
+
113
+ resposta2 = gerar_resposta("secundario", pergunta)
114
+ corrigida2 = validar_corrigir(pergunta, resposta2)
115
+
116
+ if corrigida2 == resposta2 and esta_confiante(corrigida2, pergunta):
117
+ return [
118
+ f"Resposta Selecionada: {corrigida2}\nModelo Escolhido: Modelo 2 (secundário confiante)",
119
+ f"Modelo 1 (primário): {resposta1}",
120
+ f"Modelo 2 (secundário): {resposta2}"
121
+ ]
122
+
123
+ resposta_final, modelo_escolhido = arbitrar(pergunta, resposta1, resposta2)
124
+ return [
125
+ f"Resposta Selecionada: {resposta_final}\nModelo Escolhido: {modelo_escolhido}",
126
+ f"Modelo 1 (primário): {resposta1}",
127
+ f"Modelo 2 (secundário): {resposta2}"
128
  ]
 
 
 
 
 
 
129
 
130
+ # Interface Gradio
131
+ interface = gr.Interface(
132
  fn=chatbot,
133
+ inputs=gr.Textbox(label="Pergunte a capital de um país", placeholder="Qual é a capital do Brasil?"),
134
  outputs=[
135
+ gr.Textbox(label="Resposta Final"),
136
+ gr.Textbox(label="Resposta do Modelo 1"),
137
+ gr.Textbox(label="Resposta do Modelo 2")
138
  ],
139
+ title="🗺️ Especialista em Capitais (Cascata com Correção Automática)",
140
+ description="Sistema com três modelos em cascata. Pergunte sobre a capital de qualquer país. Exemplos: Brasil, Alemanha, França..."
141
  )
142
 
143
  if __name__ == "__main__":
144
+ interface.launch()