rubenml commited on
Commit
d861b0a
·
verified ·
1 Parent(s): 1f6d6c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -4,6 +4,9 @@ import requests
4
  import pandas as pd
5
  from transformers import pipeline
6
 
 
 
 
7
  # --- Constants ---
8
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
9
 
@@ -18,26 +21,54 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
18
  # --- Basic Agent Definition ---
19
  class GeneralAgent:
20
  def __init__(self):
21
- print("Initializing general QA agent with improved model...")
22
- # Usamos el modelo de RoBERTa entrenado con SQuAD2.0
23
- self.qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
 
24
 
25
  def __call__(self, question: str, context: str = None) -> str:
26
  """
27
- Process the question and return an answer based on the given context.
28
- Uses a prompt template to provide a clear structure for the answer.
29
  """
30
  if context is None:
31
  return "FINAL ANSWER: No context provided."
32
 
33
- try:
34
- result = self.qa_pipeline(question=question, context=context)
35
- answer = result["answer"]
36
- except Exception as e:
37
- print(f"Error during QA: {e}")
38
- answer = "Error processing question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- return f"FINAL ANSWER: {answer}"
 
 
41
 
42
 
43
 
 
4
  import pandas as pd
5
  from transformers import pipeline
6
 
7
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
8
+ import torch
9
+
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
 
 
21
  # --- Basic Agent Definition ---
22
  class GeneralAgent:
23
  def __init__(self):
24
+ print("Initializing GPT-2 based QA agent...")
25
+ # Cargar modelo y tokenizador de GPT-2
26
+ self.model = GPT2LMHeadModel.from_pretrained("gpt2")
27
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
28
 
29
  def __call__(self, question: str, context: str = None) -> str:
30
  """
31
+ Procesa la pregunta y genera una respuesta basada en el contexto proporcionado.
32
+ Usa un prompt específico para guiar la respuesta del modelo GPT-2.
33
  """
34
  if context is None:
35
  return "FINAL ANSWER: No context provided."
36
 
37
+ # Crear el prompt para el modelo GPT-2
38
+ prompt = f"""
39
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
40
+ Question: {question}
41
+ Context: {context}
42
+ """
43
+
44
+ # Tokenizar el prompt
45
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt")
46
+
47
+ # Generar la respuesta con GPT-2
48
+ outputs = self.model.generate(inputs, max_length=500, num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True)
49
+
50
+ # Decodificar la salida
51
+ answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+
53
+ # Extraer la respuesta final
54
+ final_answer = self._extract_final_answer(answer)
55
+
56
+ return f"FINAL ANSWER: {final_answer}"
57
+
58
+ def _extract_final_answer(self, answer: str) -> str:
59
+ """
60
+ Extrae la parte relevante de la respuesta generada por GPT-2 según el formato solicitado.
61
+ """
62
+ # Buscar la sección que comienza con "FINAL ANSWER:"
63
+ final_answer_start = "FINAL ANSWER:"
64
+ start_idx = answer.find(final_answer_start)
65
+
66
+ if start_idx == -1:
67
+ return "Error processing question."
68
 
69
+ # Extraer la respuesta que sigue a "FINAL ANSWER:"
70
+ final_answer = answer[start_idx + len(final_answer_start):].strip()
71
+ return final_answer.strip()
72
 
73
 
74