GutoFonseca commited on
Commit
5ee2042
·
verified ·
1 Parent(s): 54c470c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -59
app.py CHANGED
@@ -1,64 +1,144 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ],
 
 
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import re
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ # Modelos utilizados
6
+ MODELOS = {
7
+ "primario": "google/flan-t5-small",
8
+ "secundario": "google/flan-t5-base",
9
+ "arbitro": "google/flan-t5-large"
10
+ }
11
+
12
+ # Carregamento dos modelos e tokenizers
13
+ tokenizers = {nome: AutoTokenizer.from_pretrained(modelo)
14
+ for nome, modelo in MODELOS.items()}
15
+ modelos = {nome: AutoModelForSeq2SeqLM.from_pretrained(modelo)
16
+ for nome, modelo in MODELOS.items()}
17
+
18
+ # Base de capitais com erros comuns
19
+ BASE_CAPITAIS = {
20
+ 'brazil': {
21
+ 'correta': 'Brasília',
22
+ 'erros_comuns': ['sao paulo', 'rio de janeiro', 'brazil']
23
+ },
24
+ 'germany': {
25
+ 'correta': 'Berlin',
26
+ 'erros_comuns': ['munich', 'frankfurt']
27
+ },
28
+ 'france': {
29
+ 'correta': 'Paris',
30
+ 'erros_comuns': ['lyon', 'marseille']
31
+ },
32
+ # Adicione mais países conforme necessário
33
+ }
34
+
35
+ def gerar_resposta(nome_modelo, pergunta):
36
+ """Gera resposta com o modelo especificado"""
37
+ prompt = f"""Aja como um especialista em geografia. Responda APENAS com o nome da capital oficial.
38
+ Exemplos:
39
+ Q: Capital da França? A: Paris
40
+ Q: Capital do Brasil? A: Brasília
41
+ Q: Capital da Alemanha? A: Berlin
42
+ Q: {pergunta}
43
+ A:"""
44
+ entrada = tokenizers[nome_modelo](prompt, return_tensors="pt")
45
+ saida = modelos[nome_modelo].generate(**entrada, max_length=20)
46
+ return tokenizers[nome_modelo].decode(saida[0], skip_special_tokens=True).strip()
47
+
48
+ def validar_corrigir(pergunta, resposta_bruta):
49
+ """Valida e corrige a resposta com base na base de capitais"""
50
+ pergunta = pergunta.lower()
51
+ resposta = resposta_bruta.lower()
52
+
53
+ for pais, dados in BASE_CAPITAIS.items():
54
+ if pais in pergunta:
55
+ if resposta in dados['erros_comuns']:
56
+ return dados['correta']
57
+ if resposta == dados['correta'].lower():
58
+ return dados['correta']
59
+ return dados['correta']
60
+ return resposta_bruta.title()
61
+
62
+ def esta_confiante(resposta, pergunta):
63
+ """Avalia se a resposta pode ser considerada confiável"""
64
+ pergunta = pergunta.lower()
65
+ resposta = resposta.lower()
66
+ for pais, dados in BASE_CAPITAIS.items():
67
+ if pais in pergunta:
68
+ if resposta == dados['correta'].lower():
69
+ return True
70
+ if resposta in dados['erros_comuns']:
71
+ return False
72
+ return False
73
+
74
+ def arbitrar(pergunta, resposta1, resposta2):
75
+ """Usa o modelo árbitro para escolher a melhor resposta"""
76
+ corrigida1 = validar_corrigir(pergunta, resposta1)
77
+ corrigida2 = validar_corrigir(pergunta, resposta2)
78
+
79
+ for pais, dados in BASE_CAPITAIS.items():
80
+ if pais in pergunta.lower():
81
+ if corrigida1 == dados['correta']:
82
+ return corrigida1, "Modelo 1 (validado)"
83
+ if corrigida2 == dados['correta']:
84
+ return corrigida2, "Modelo 2 (validado)"
85
+
86
+ prompt = f"""Você é professor de geografia. Escolha a capital correta:
87
+ Pergunta: {pergunta}
88
+ Opção 1: {corrigida1}
89
+ Opção 2: {corrigida2}
90
+ Responda SOMENTE com "1" ou "2"."""
91
+
92
+ entrada = tokenizers['arbitro'](prompt, return_tensors="pt")
93
+ saida = modelos['arbitro'].generate(**entrada, max_length=3)
94
+ escolha = tokenizers['arbitro'].decode(saida[0], skip_special_tokens=True).strip()
95
+
96
+ if escolha == "1":
97
+ return corrigida1, "Modelo 1 (árbitro)"
98
+ else:
99
+ return corrigida2, "Modelo 2 (árbitro)"
100
+
101
+ def chatbot(pergunta):
102
+ """Pipeline em cascata para determinar a capital"""
103
+ resposta1 = gerar_resposta("primario", pergunta)
104
+ corrigida1 = validar_corrigir(pergunta, resposta1)
105
+
106
+ if corrigida1 == resposta1 and esta_confiante(corrigida1, pergunta):
107
+ return [
108
+ f"Resposta Selecionada: {corrigida1}\nModelo Escolhido: Modelo 1 (primário confiante)",
109
+ f"Modelo 1 (primário): {resposta1}",
110
+ f"Modelo 2 (secundário): Pulado"
111
+ ]
112
+
113
+ resposta2 = gerar_resposta("secundario", pergunta)
114
+ corrigida2 = validar_corrigir(pergunta, resposta2)
115
+
116
+ if corrigida2 == resposta2 and esta_confiante(corrigida2, pergunta):
117
+ return [
118
+ f"Resposta Selecionada: {corrigida2}\nModelo Escolhido: Modelo 2 (secundário confiante)",
119
+ f"Modelo 1 (primário): {resposta1}",
120
+ f"Modelo 2 (secundário): {resposta2}"
121
+ ]
122
+
123
+ resposta_final, modelo_escolhido = arbitrar(pergunta, resposta1, resposta2)
124
+ return [
125
+ f"Resposta Selecionada: {resposta_final}\nModelo Escolhido: {modelo_escolhido}",
126
+ f"Modelo 1 (primário): {resposta1}",
127
+ f"Modelo 2 (secundário): {resposta2}"
128
+ ]
129
+
130
+ # Interface Gradio
131
+ interface = gr.Interface(
132
+ fn=chatbot,
133
+ inputs=gr.Textbox(label="Pergunte a capital de um país", placeholder="Qual é a capital do Brasil?"),
134
+ outputs=[
135
+ gr.Textbox(label="Resposta Final"),
136
+ gr.Textbox(label="Resposta do Modelo 1"),
137
+ gr.Textbox(label="Resposta do Modelo 2")
138
  ],
139
+ title="🗺️ Especialista em Capitais (Cascata com Correção Automática)",
140
+ description="Sistema com três modelos em cascata. Pergunte sobre a capital de qualquer país. Exemplos: Brasil, Alemanha, França..."
141
  )
142
 
 
143
  if __name__ == "__main__":
144
+ interface.launch()