GutoFonseca commited on
Commit
cf558d7
·
verified ·
1 Parent(s): 72bba2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -35
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # Modelos escolhidos
5
  model_name1 = "google/flan-t5-small"
6
  model_name2 = "google/flan-t5-base"
7
- arbitro_model_name = "google/flan-t5-base" # Usando um modelo maior para arbitrar
8
 
9
- # Carregar tokenizadores e modelos
10
  tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
11
  model1 = AutoModelForSeq2SeqLM.from_pretrained(model_name1)
12
 
@@ -16,65 +17,70 @@ model2 = AutoModelForSeq2SeqLM.from_pretrained(model_name2)
16
  tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
17
  model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
18
 
 
19
  def gerar_resposta(model, tokenizer, pergunta):
20
- # Formatar a pergunta para focar na capital
21
- if "capital" not in pergunta.lower():
22
- pergunta = f"Qual é a capital do {pergunta}?"
23
-
24
- inputs = tokenizer(pergunta, return_tensors="pt")
 
 
25
  outputs = model.generate(**inputs, max_length=20)
26
  resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  return resposta.strip()
28
 
 
 
 
 
 
 
29
  def arbitro(pergunta, resp1, resp2):
30
- # Melhorar o prompt para o árbitro
31
  prompt = (
32
- f"Decida qual resposta está mais correta para esta pergunta sobre geografia:\n"
33
- f"Pergunta: {pergunta}\n"
34
- f"Resposta 1: {resp1}\n"
35
- f"Resposta 2: {resp2}\n\n"
36
- f"Responda apenas com o número 1 ou 2, sem explicações."
37
  )
38
-
39
  inputs = tokenizer_arbitro(prompt, return_tensors="pt")
40
  outputs = model_arbitro.generate(**inputs, max_length=5)
41
  escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
42
 
43
- if escolha == "1":
44
- return resp1, "Modelo 1 (flan-t5-small)"
45
- elif escolha == "2":
46
  return resp2, "Modelo 2 (flan-t5-base)"
47
  else:
48
- # Se não conseguir decidir, escolhe a mais longa (heurística simples)
49
- return (resp1, "Modelo 1 (flan-t5-small)") if len(resp1) > len(resp2) else (resp2, "Modelo 2 (flan-t5-base)")
50
 
 
51
  def chatbot(pergunta):
52
- # Garantir que a pergunta está formatada corretamente
53
- pergunta = pergunta.strip().lower()
54
- if not pergunta.endswith("?"):
55
- pergunta += "?"
56
-
57
  resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
58
  resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
59
  resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
60
 
 
 
 
 
61
  return (
62
- f"Resposta selecionada: {resposta_final}\nModelo escolhido: {modelo_escolhido}",
63
- f"Modelo 1 (flan-t5-small): {resposta1}",
64
- f"Modelo 2 (flan-t5-base): {resposta2}"
65
  )
66
 
 
67
  iface = gr.Interface(
68
  fn=chatbot,
69
- inputs=gr.Textbox(label="Digite um país ou pergunta sobre capitais", placeholder="Ex: Brasil ou Qual a capital do Brasil?"),
70
  outputs=[
71
- gr.Textbox(label="Resposta Final"),
72
- gr.Textbox(label="Resposta do Modelo Pequeno"),
73
- gr.Textbox(label="Resposta do Modelo Base")
74
  ],
75
- title="Chatbot em Cascata - Perguntas sobre Capitais",
76
- description="Digite um país ou pergunta sobre capitais para comparar dois modelos de IA"
77
  )
78
 
79
- if __name__ == "__main__":
80
  iface.launch()
 
1
  import gradio as gr
2
+ import re
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Modelos
6
  model_name1 = "google/flan-t5-small"
7
  model_name2 = "google/flan-t5-base"
8
+ arbitro_model_name = "google/flan-t5-base"
9
 
10
+ # Carregar modelos e tokenizadores
11
  tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
12
  model1 = AutoModelForSeq2SeqLM.from_pretrained(model_name1)
13
 
 
17
  tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
18
  model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
19
 
20
+ # Geração da resposta de cada modelo com prompt reforçado
21
  def gerar_resposta(model, tokenizer, pergunta):
22
+ prompt = (
23
+ "Answer ONLY with the name of the capital city for the following question.\n"
24
+ "Do NOT answer with the country name.\n"
25
+ f"Question: {pergunta}\n"
26
+ "Answer:"
27
+ )
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
  outputs = model.generate(**inputs, max_length=20)
30
  resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
  return resposta.strip()
32
 
33
+ # Função para validar se parece uma capital
34
+ def eh_capital_valida(resposta):
35
+ # Considera resposta válida se tem 1 ou 2 palavras, só letras e espaços
36
+ return bool(re.match(r"^[A-Za-zÀ-ÿ\s]{2,40}$", resposta.strip()))
37
+
38
+ # Árbitro decide qual resposta é melhor
39
  def arbitro(pergunta, resp1, resp2):
 
40
  prompt = (
41
+ "You are a geography expert.\n"
42
+ f"Question: {pergunta}\n"
43
+ f"Answer 1: {resp1}\n"
44
+ f"Answer 2: {resp2}\n"
45
+ "Which answer is the correct capital? Reply only with 1 or 2."
46
  )
 
47
  inputs = tokenizer_arbitro(prompt, return_tensors="pt")
48
  outputs = model_arbitro.generate(**inputs, max_length=5)
49
  escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
50
 
51
+ if escolha == "2":
 
 
52
  return resp2, "Modelo 2 (flan-t5-base)"
53
  else:
54
+ return resp1, "Modelo 1 (flan-t5-small)"
 
55
 
56
+ # Função principal do chatbot
57
  def chatbot(pergunta):
 
 
 
 
 
58
  resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
59
  resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
60
  resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
61
 
62
+ # Validação
63
+ if not eh_capital_valida(resposta_final):
64
+ resposta_final = "Não consegui identificar a capital corretamente."
65
+
66
  return (
67
+ f"Resposta selecionada:\n{resposta_final}\n\nModelo escolhido:\n{modelo_escolhido}",
68
+ f"Resposta Modelo 1 (flan-t5-small):\n{resposta1}",
69
+ f"Resposta Modelo 2 (flan-t5-base):\n{resposta2}"
70
  )
71
 
72
+ # Interface Gradio
73
  iface = gr.Interface(
74
  fn=chatbot,
75
+ inputs=gr.Textbox(label="Digite uma pergunta sobre capitais"),
76
  outputs=[
77
+ gr.Textbox(label="Resposta selecionada"),
78
+ gr.Textbox(label="Resposta Modelo 1"),
79
+ gr.Textbox(label="Resposta Modelo 2")
80
  ],
81
+ title="Chatbot em Cascata - Perguntas sobre Capitais (sem listas)",
82
+ description="Insira uma pergunta como 'Qual é a capital da Alemanha?' e veja como os modelos escolhem a melhor resposta."
83
  )
84
 
85
+ if _name_ == "_main_":
86
  iface.launch()