GutoFonseca commited on
Commit
72bba2a
·
verified ·
1 Parent(s): 5d07d27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  # Modelos escolhidos
5
  model_name1 = "google/flan-t5-small"
6
  model_name2 = "google/flan-t5-base"
7
- arbitro_model_name = "google/flan-t5-small" # Pode usar o mesmo ou outro
8
 
9
  # Carregar tokenizadores e modelos
10
  tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
@@ -17,21 +17,25 @@ tokenizer_arbitro = AutoTokenizer.from_pretrained(arbitro_model_name)
17
  model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
18
 
19
  def gerar_resposta(model, tokenizer, pergunta):
20
- # Prompt para focar na resposta correta da capital
21
- prompt = f"Responda com o nome da capital do país nesta pergunta: {pergunta}"
22
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
23
  outputs = model.generate(**inputs, max_length=20)
24
  resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  return resposta.strip()
26
 
27
  def arbitro(pergunta, resp1, resp2):
28
- # Prompt para o árbitro escolher a melhor resposta
29
  prompt = (
 
30
  f"Pergunta: {pergunta}\n"
31
  f"Resposta 1: {resp1}\n"
32
- f"Resposta 2: {resp2}\n"
33
- "Qual resposta é mais correta? Responda apenas 1 ou 2."
34
  )
 
35
  inputs = tokenizer_arbitro(prompt, return_tensors="pt")
36
  outputs = model_arbitro.generate(**inputs, max_length=5)
37
  escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
@@ -41,29 +45,35 @@ def arbitro(pergunta, resp1, resp2):
41
  elif escolha == "2":
42
  return resp2, "Modelo 2 (flan-t5-base)"
43
  else:
44
- # Caso o árbitro não 1 ou 2, escolhe a resposta 1 por padrão
45
- return resp1, "Modelo 1 (flan-t5-small)"
46
 
47
  def chatbot(pergunta):
 
 
 
 
 
48
  resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
49
  resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
50
  resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
51
 
52
  return (
53
- f"Resposta selecionada:\n{resposta_final}\n\nModelo escolhido:\n{modelo_escolhido}",
54
- f"Resposta Modelo 1 (flan-t5-small):\n{resposta1}",
55
- f"Resposta Modelo 2 (flan-t5-base):\n{resposta2}"
56
  )
57
 
58
  iface = gr.Interface(
59
  fn=chatbot,
60
- inputs=gr.Textbox(label="Digite uma pergunta sobre capitais"),
61
  outputs=[
62
- gr.Textbox(label="Resposta selecionada"),
63
- gr.Textbox(label="Resposta Modelo 1"),
64
- gr.Textbox(label="Resposta Modelo 2")
65
  ],
66
- title="Chatbot em Cascata - Perguntas sobre Capitais"
 
67
  )
68
 
69
  if __name__ == "__main__":
 
4
  # Modelos escolhidos
5
  model_name1 = "google/flan-t5-small"
6
  model_name2 = "google/flan-t5-base"
7
+ arbitro_model_name = "google/flan-t5-base" # Usando um modelo maior para arbitrar
8
 
9
  # Carregar tokenizadores e modelos
10
  tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
 
17
  model_arbitro = AutoModelForSeq2SeqLM.from_pretrained(arbitro_model_name)
18
 
19
  def gerar_resposta(model, tokenizer, pergunta):
20
+ # Formatar a pergunta para focar na capital
21
+ if "capital" not in pergunta.lower():
22
+ pergunta = f"Qual é a capital do {pergunta}?"
23
+
24
+ inputs = tokenizer(pergunta, return_tensors="pt")
25
  outputs = model.generate(**inputs, max_length=20)
26
  resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  return resposta.strip()
28
 
29
  def arbitro(pergunta, resp1, resp2):
30
+ # Melhorar o prompt para o árbitro
31
  prompt = (
32
+ f"Decida qual resposta está mais correta para esta pergunta sobre geografia:\n"
33
  f"Pergunta: {pergunta}\n"
34
  f"Resposta 1: {resp1}\n"
35
+ f"Resposta 2: {resp2}\n\n"
36
+ f"Responda apenas com o número 1 ou 2, sem explicações."
37
  )
38
+
39
  inputs = tokenizer_arbitro(prompt, return_tensors="pt")
40
  outputs = model_arbitro.generate(**inputs, max_length=5)
41
  escolha = tokenizer_arbitro.decode(outputs[0], skip_special_tokens=True).strip()
 
45
  elif escolha == "2":
46
  return resp2, "Modelo 2 (flan-t5-base)"
47
  else:
48
+ # Se não conseguir decidir, escolhe a mais longa (heurística simples)
49
+ return (resp1, "Modelo 1 (flan-t5-small)") if len(resp1) > len(resp2) else (resp2, "Modelo 2 (flan-t5-base)")
50
 
51
  def chatbot(pergunta):
52
+ # Garantir que a pergunta está formatada corretamente
53
+ pergunta = pergunta.strip().lower()
54
+ if not pergunta.endswith("?"):
55
+ pergunta += "?"
56
+
57
  resposta1 = gerar_resposta(model1, tokenizer1, pergunta)
58
  resposta2 = gerar_resposta(model2, tokenizer2, pergunta)
59
  resposta_final, modelo_escolhido = arbitro(pergunta, resposta1, resposta2)
60
 
61
  return (
62
+ f"Resposta selecionada: {resposta_final}\nModelo escolhido: {modelo_escolhido}",
63
+ f"Modelo 1 (flan-t5-small): {resposta1}",
64
+ f"Modelo 2 (flan-t5-base): {resposta2}"
65
  )
66
 
67
  iface = gr.Interface(
68
  fn=chatbot,
69
+ inputs=gr.Textbox(label="Digite um país ou pergunta sobre capitais", placeholder="Ex: Brasil ou Qual a capital do Brasil?"),
70
  outputs=[
71
+ gr.Textbox(label="Resposta Final"),
72
+ gr.Textbox(label="Resposta do Modelo Pequeno"),
73
+ gr.Textbox(label="Resposta do Modelo Base")
74
  ],
75
+ title="Chatbot em Cascata - Perguntas sobre Capitais",
76
+ description="Digite um país ou pergunta sobre capitais para comparar dois modelos de IA"
77
  )
78
 
79
  if __name__ == "__main__":