Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import pandas as pd
|
|
| 7 |
from langchain_core.messages import HumanMessage
|
| 8 |
#from agent import build_graph
|
| 9 |
from agent_simple import build_graph
|
| 10 |
-
|
| 11 |
|
| 12 |
|
| 13 |
# (Keep Constants as is)
|
|
@@ -24,14 +24,51 @@ class BasicAgent:
|
|
| 24 |
print("BasicAgent initialized.")
|
| 25 |
self.graph = build_graph()
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def __call__(self, question: str) -> str:
|
| 28 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 29 |
# Wrap the question in a HumanMessage from langchain_core
|
| 30 |
messages = [HumanMessage(content=question)]
|
| 31 |
messages = self.graph.invoke({"messages": messages})
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def run_and_submit_all(nb_questions: int, profile: gr.OAuthProfile | None):
|
| 37 |
"""
|
|
|
|
| 7 |
from langchain_core.messages import HumanMessage
|
| 8 |
#from agent import build_graph
|
| 9 |
from agent_simple import build_graph
|
| 10 |
+
import re
|
| 11 |
|
| 12 |
|
| 13 |
# (Keep Constants as is)
|
|
|
|
| 24 |
print("BasicAgent initialized.")
|
| 25 |
self.graph = build_graph()
|
| 26 |
|
| 27 |
+
def extract_final_answer(self, response: str) -> str:
|
| 28 |
+
"""Extract the final answer from the agent's response."""
|
| 29 |
+
# Look for patterns like "Answer: X" or "**Answer:** X"
|
| 30 |
+
answer_patterns = [
|
| 31 |
+
r'\*\*Answer:\*\*\s*(.+?)(?:\n|$)',
|
| 32 |
+
r'Answer:\s*(.+?)(?:\n|$)',
|
| 33 |
+
r'\*\*(.+?)\*\*\s*(?:studio albums?|albums?)',
|
| 34 |
+
r'(?:the answer is|therefore|thus|so)\s*[:\-]?\s*\*\*?(\d+)\*\*?',
|
| 35 |
+
r'(?:count is|total (?:would be|is))\s*\*\*?(\d+)\*\*?',
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
for pattern in answer_patterns:
|
| 39 |
+
match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
|
| 40 |
+
if match:
|
| 41 |
+
answer = match.group(1).strip()
|
| 42 |
+
# Clean up common formatting
|
| 43 |
+
answer = re.sub(r'\*\*', '', answer)
|
| 44 |
+
answer = re.sub(r'^"|"$', '', answer) # Remove quotes
|
| 45 |
+
return answer
|
| 46 |
+
|
| 47 |
+
# If no pattern matches, try to extract the last meaningful sentence
|
| 48 |
+
sentences = response.split('.')
|
| 49 |
+
for sentence in reversed(sentences):
|
| 50 |
+
sentence = sentence.strip()
|
| 51 |
+
if sentence and len(sentence) < 100: # Reasonable answer length
|
| 52 |
+
# Look for numbers or short phrases
|
| 53 |
+
if re.search(r'\d+|[a-zA-Z]{1,20}', sentence):
|
| 54 |
+
return sentence
|
| 55 |
+
|
| 56 |
+
# Fallback: return the last 50 characters, cleaned up
|
| 57 |
+
fallback = response.strip()[-50:].strip()
|
| 58 |
+
return fallback
|
| 59 |
+
|
| 60 |
def __call__(self, question: str) -> str:
|
| 61 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 62 |
# Wrap the question in a HumanMessage from langchain_core
|
| 63 |
messages = [HumanMessage(content=question)]
|
| 64 |
messages = self.graph.invoke({"messages": messages})
|
| 65 |
+
full_response = messages['messages'][-1].content
|
| 66 |
+
|
| 67 |
+
# Extract just the final answer
|
| 68 |
+
final_answer = self.extract_final_answer(full_response)
|
| 69 |
+
print(f"Extracted answer: {final_answer}")
|
| 70 |
+
|
| 71 |
+
return final_answer
|
| 72 |
|
| 73 |
def run_and_submit_all(nb_questions: int, profile: gr.OAuthProfile | None):
|
| 74 |
"""
|