Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,6 +34,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
| 34 |
from langchain_community.vectorstores import FAISS
|
| 35 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 36 |
from langchain_community.tools import DuckDuckGoSearchRun
|
|
|
|
| 37 |
|
| 38 |
# =============================================================================
|
| 39 |
# CONFIGURATION
|
|
@@ -42,6 +43,40 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
| 42 |
MAX_TURNS = 20
|
| 43 |
MAX_MESSAGE_LENGTH = 8000
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# =============================================================================
|
| 46 |
# ASR INITIALIZATION
|
| 47 |
# =============================================================================
|
|
@@ -343,57 +378,96 @@ class ScrapeInput(BaseModel):
|
|
| 343 |
@tool(args_schema=ScrapeInput)
|
| 344 |
def scrape_and_retrieve(url: str, query: str) -> str:
|
| 345 |
"""
|
| 346 |
-
Scrapes a webpage,
|
|
|
|
| 347 |
"""
|
| 348 |
if not (url.lower().startswith(('http://', 'https://'))):
|
| 349 |
return f"Error: Invalid URL. Must start with http:// or https://. Got: '{url}'"
|
| 350 |
-
if not query:
|
| 351 |
return "Error: A query is required to search the page content."
|
| 352 |
|
| 353 |
-
#
|
| 354 |
-
if
|
| 355 |
-
|
|
|
|
| 356 |
|
| 357 |
-
print(f"--- Calling RAG Scraper: {url} for query: {query} ---")
|
| 358 |
|
| 359 |
try:
|
|
|
|
| 360 |
headers = {
|
| 361 |
-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 362 |
}
|
|
|
|
| 363 |
response = requests.get(url, headers=headers, timeout=20)
|
| 364 |
response.raise_for_status()
|
| 365 |
|
|
|
|
| 366 |
soup = BeautifulSoup(response.text, 'html.parser')
|
| 367 |
-
|
|
|
|
|
|
|
| 368 |
tag.extract()
|
| 369 |
|
| 370 |
-
|
|
|
|
|
|
|
| 371 |
if not main_content:
|
| 372 |
return "Error: Could not find main content on the page."
|
| 373 |
|
|
|
|
| 374 |
text = main_content.get_text(separator='\n', strip=True)
|
| 375 |
-
text = '\n'.join(chunk for chunk in (line.strip() for line in text.splitlines()) if chunk)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
|
|
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
retriever = db.as_retriever(search_kwargs={"k": 5})
|
| 386 |
retrieved_docs = retriever.invoke(query)
|
| 387 |
|
| 388 |
if not retrieved_docs:
|
| 389 |
-
return "
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
|
|
|
|
|
|
| 394 |
except Exception as e:
|
| 395 |
tb_str = traceback.format_exc()
|
| 396 |
-
return f"Error
|
| 397 |
|
| 398 |
|
| 399 |
class FinalAnswerInput(BaseModel):
|
|
@@ -423,7 +497,7 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
|
|
| 423 |
"""
|
| 424 |
Parses malformed tool call strings from an LLM response.
|
| 425 |
"""
|
| 426 |
-
print(f"Original LLM content for fallback parsing:\n---\n{content}\n---")
|
| 427 |
tool_name = None
|
| 428 |
tool_input = None
|
| 429 |
cleaned_str = None
|
|
@@ -513,7 +587,7 @@ class AgentState(TypedDict):
|
|
| 513 |
|
| 514 |
# =============================================================================
|
| 515 |
# CONDITIONAL EDGE FUNCTION
|
| 516 |
-
|
| 517 |
def should_continue(state: AgentState):
|
| 518 |
"""
|
| 519 |
Decide whether to continue, call tools, or end.
|
|
@@ -562,26 +636,8 @@ class BasicAgent:
|
|
| 562 |
self.tools = defined_tools
|
| 563 |
|
| 564 |
# Initialize RAG Components
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
self.embeddings = HuggingFaceEmbeddings(
|
| 568 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 569 |
-
model_kwargs={'device': 'cpu'}
|
| 570 |
-
)
|
| 571 |
-
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 572 |
-
chunk_size=1000,
|
| 573 |
-
chunk_overlap=200
|
| 574 |
-
)
|
| 575 |
-
|
| 576 |
-
# Attach to scraper tool
|
| 577 |
-
scrape_and_retrieve.embeddings = self.embeddings
|
| 578 |
-
scrape_and_retrieve.text_splitter = self.text_splitter
|
| 579 |
-
|
| 580 |
-
print("✅ RAG components initialized.")
|
| 581 |
-
except Exception as e:
|
| 582 |
-
print(f"⚠️ Warning: Could not initialize RAG components. Error: {e}")
|
| 583 |
-
self.embeddings = None
|
| 584 |
-
self.text_splitter = None
|
| 585 |
|
| 586 |
# Build tool descriptions
|
| 587 |
tool_desc_list = []
|
|
@@ -613,35 +669,24 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
|
|
| 613 |
**CRITICAL RULES:**
|
| 614 |
- **TOOL USE:** You MUST use tools to find the answer. Do NOT use your own knowledge.
|
| 615 |
- **FINAL ANSWER:** When you have the answer, use final_answer_tool. The 'answer' argument must be the answer ONLY (e.g., "42", "red, blue, green").
|
| 616 |
-
- **
|
| 617 |
-
{{"name": "tool_name", "arguments": {{"key": "value"}}}}
|
| 618 |
-
|
| 619 |
-
**EXAMPLE: CODE INTERPRETER**
|
| 620 |
-
{{"name": "code_interpreter", "arguments": {{"code": "print(1 + 1)"}}}}
|
| 621 |
-
|
| 622 |
-
**EXAMPLE: FINAL ANSWER**
|
| 623 |
-
{{"name": "final_answer_tool", "arguments": {{"answer": "28"}}}}
|
| 624 |
|
| 625 |
**TOOLS:**
|
| 626 |
{tool_descriptions}
|
| 627 |
|
| 628 |
-
**REMEMBER:** One step at a time. Use tools.
|
| 629 |
"""
|
| 630 |
|
| 631 |
print("Initializing Groq LLM...")
|
| 632 |
try:
|
|
|
|
| 633 |
self.llm_with_tools = ChatGroq(
|
| 634 |
temperature=0,
|
| 635 |
groq_api_key=GROQ_API_KEY,
|
| 636 |
-
model_name="
|
| 637 |
max_tokens=4096,
|
| 638 |
timeout=60
|
| 639 |
-
).bind_tools(
|
| 640 |
-
self.tools,
|
| 641 |
-
# This setting forces the model to call one of the bound tools.
|
| 642 |
-
# 'auto' is the default, but 'any' is stricter for an agent.
|
| 643 |
-
tool_choice="any"
|
| 644 |
-
)
|
| 645 |
print("✅ Main LLM (llama-3.3-70b-versatile with tools) initialized.")
|
| 646 |
|
| 647 |
except Exception as e:
|
|
@@ -656,7 +701,7 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
|
|
| 656 |
print('='*60)
|
| 657 |
|
| 658 |
if current_turn > MAX_TURNS:
|
| 659 |
-
return {"messages": [SystemMessage(content="Max turns reached.")]}
|
| 660 |
|
| 661 |
max_retries = 3
|
| 662 |
ai_message = None
|
|
@@ -691,7 +736,7 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
|
|
| 691 |
|
| 692 |
# Tool Node
|
| 693 |
tool_node = ToolNode(self.tools)
|
| 694 |
-
|
| 695 |
# Build Graph
|
| 696 |
print("Building Single-Agent graph...")
|
| 697 |
graph_builder = StateGraph(AgentState)
|
|
|
|
| 34 |
from langchain_community.vectorstores import FAISS
|
| 35 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 36 |
from langchain_community.tools import DuckDuckGoSearchRun
|
| 37 |
+
from langchain.docstore.document import Document
|
| 38 |
|
| 39 |
# =============================================================================
|
| 40 |
# CONFIGURATION
|
|
|
|
| 43 |
MAX_TURNS = 20
|
| 44 |
MAX_MESSAGE_LENGTH = 8000
|
| 45 |
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# GLOBAL RAG COMPONENTS (Initialize once)
|
| 48 |
+
# =============================================================================
|
| 49 |
+
global_embeddings = None
|
| 50 |
+
global_text_splitter = None
|
| 51 |
+
|
| 52 |
+
def initialize_rag_components():
|
| 53 |
+
"""Initialize RAG components globally."""
|
| 54 |
+
global global_embeddings, global_text_splitter
|
| 55 |
+
|
| 56 |
+
if global_embeddings is None:
|
| 57 |
+
print("Initializing RAG embeddings...")
|
| 58 |
+
try:
|
| 59 |
+
global_embeddings = HuggingFaceEmbeddings(
|
| 60 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 61 |
+
model_kwargs={'device': 'cpu'}
|
| 62 |
+
)
|
| 63 |
+
print("✅ Embeddings initialized.")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"⚠️ Failed to initialize embeddings: {e}")
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
if global_text_splitter is None:
|
| 69 |
+
print("Initializing text splitter...")
|
| 70 |
+
global_text_splitter = RecursiveCharacterTextSplitter(
|
| 71 |
+
chunk_size=1000,
|
| 72 |
+
chunk_overlap=200,
|
| 73 |
+
length_function=len,
|
| 74 |
+
separators=["\n\n", "\n", ". ", " ", ""]
|
| 75 |
+
)
|
| 76 |
+
print("✅ Text splitter initialized.")
|
| 77 |
+
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
# =============================================================================
|
| 81 |
# ASR INITIALIZATION
|
| 82 |
# =============================================================================
|
|
|
|
| 378 |
@tool(args_schema=ScrapeInput)
|
| 379 |
def scrape_and_retrieve(url: str, query: str) -> str:
|
| 380 |
"""
|
| 381 |
+
Scrapes a webpage, embeds its content using RAG, and retrieves relevant sections based on the query.
|
| 382 |
+
Use this to extract specific information from web pages.
|
| 383 |
"""
|
| 384 |
if not (url.lower().startswith(('http://', 'https://'))):
|
| 385 |
return f"Error: Invalid URL. Must start with http:// or https://. Got: '{url}'"
|
| 386 |
+
if not query or not query.strip():
|
| 387 |
return "Error: A query is required to search the page content."
|
| 388 |
|
| 389 |
+
# Check if RAG components are initialized
|
| 390 |
+
if global_embeddings is None or global_text_splitter is None:
|
| 391 |
+
if not initialize_rag_components():
|
| 392 |
+
return "Error: RAG components could not be initialized."
|
| 393 |
|
| 394 |
+
print(f"--- Calling RAG Scraper: {url} for query: '{query}' ---")
|
| 395 |
|
| 396 |
try:
|
| 397 |
+
# Fetch the webpage
|
| 398 |
headers = {
|
| 399 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 400 |
}
|
| 401 |
+
print(f"Fetching URL: {url}")
|
| 402 |
response = requests.get(url, headers=headers, timeout=20)
|
| 403 |
response.raise_for_status()
|
| 404 |
|
| 405 |
+
# Parse HTML
|
| 406 |
soup = BeautifulSoup(response.text, 'html.parser')
|
| 407 |
+
|
| 408 |
+
# Remove unwanted tags
|
| 409 |
+
for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe", "noscript"]):
|
| 410 |
tag.extract()
|
| 411 |
|
| 412 |
+
# Try to find main content
|
| 413 |
+
main_content = soup.find('main') or soup.find('article') or soup.find('div', class_=re.compile('content|main|article', re.I)) or soup.body
|
| 414 |
+
|
| 415 |
if not main_content:
|
| 416 |
return "Error: Could not find main content on the page."
|
| 417 |
|
| 418 |
+
# Extract text
|
| 419 |
text = main_content.get_text(separator='\n', strip=True)
|
|
|
|
| 420 |
|
| 421 |
+
# Clean up text - remove extra whitespace and empty lines
|
| 422 |
+
lines = [line.strip() for line in text.splitlines()]
|
| 423 |
+
text = '\n'.join(line for line in lines if line)
|
| 424 |
|
| 425 |
+
if not text or len(text) < 50:
|
| 426 |
+
return f"Error: Scraped content was too short or empty (length: {len(text)})."
|
| 427 |
+
|
| 428 |
+
print(f"Scraped text length: {len(text)} characters")
|
| 429 |
+
|
| 430 |
+
# Split text into chunks
|
| 431 |
+
chunks = global_text_splitter.split_text(text)
|
| 432 |
+
|
| 433 |
+
if not chunks:
|
| 434 |
+
return "Error: Text could not be split into chunks."
|
| 435 |
+
|
| 436 |
+
print(f"Created {len(chunks)} chunks")
|
| 437 |
+
|
| 438 |
+
# Create Document objects
|
| 439 |
+
docs = [Document(page_content=chunk, metadata={"source": url}) for chunk in chunks]
|
| 440 |
+
|
| 441 |
+
# Create FAISS vector store
|
| 442 |
+
print("Creating embeddings and vector store...")
|
| 443 |
+
db = FAISS.from_documents(docs, global_embeddings)
|
| 444 |
+
|
| 445 |
+
# Retrieve relevant chunks
|
| 446 |
+
print(f"Searching for: '{query}'")
|
| 447 |
retriever = db.as_retriever(search_kwargs={"k": 5})
|
| 448 |
retrieved_docs = retriever.invoke(query)
|
| 449 |
|
| 450 |
if not retrieved_docs:
|
| 451 |
+
return f"No relevant information found on {url} for query: '{query}'\n\nThe page was successfully scraped but doesn't seem to contain information matching your query."
|
| 452 |
+
|
| 453 |
+
print(f"Retrieved {len(retrieved_docs)} relevant chunks")
|
| 454 |
+
|
| 455 |
+
# Combine retrieved chunks
|
| 456 |
+
context_parts = []
|
| 457 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 458 |
+
context_parts.append(f"[Chunk {i}]\n{doc.page_content}")
|
| 459 |
+
|
| 460 |
+
context = "\n\n---\n\n".join(context_parts)
|
| 461 |
+
|
| 462 |
+
result = f"Successfully retrieved relevant information from {url}\n\nQuery: {query}\n\n{context}"
|
| 463 |
+
|
| 464 |
+
return truncate_if_needed(result)
|
| 465 |
|
| 466 |
+
except requests.RequestException as e:
|
| 467 |
+
return f"Error fetching URL {url}: {str(e)}\n\nThe website may be blocking requests or may be temporarily unavailable."
|
| 468 |
except Exception as e:
|
| 469 |
tb_str = traceback.format_exc()
|
| 470 |
+
return f"Error processing {url}: {str(e)}\n\nDetails:\n{tb_str}"
|
| 471 |
|
| 472 |
|
| 473 |
class FinalAnswerInput(BaseModel):
|
|
|
|
| 497 |
"""
|
| 498 |
Parses malformed tool call strings from an LLM response.
|
| 499 |
"""
|
| 500 |
+
print(f"Original LLM content for fallback parsing:\n---\n{content[:500]}\n---")
|
| 501 |
tool_name = None
|
| 502 |
tool_input = None
|
| 503 |
cleaned_str = None
|
|
|
|
| 587 |
|
| 588 |
# =============================================================================
|
| 589 |
# CONDITIONAL EDGE FUNCTION
|
| 590 |
+
# =============================================================================
|
| 591 |
def should_continue(state: AgentState):
|
| 592 |
"""
|
| 593 |
Decide whether to continue, call tools, or end.
|
|
|
|
| 636 |
self.tools = defined_tools
|
| 637 |
|
| 638 |
# Initialize RAG Components
|
| 639 |
+
if not initialize_rag_components():
|
| 640 |
+
print("⚠️ Warning: RAG components failed to initialize. scrape_and_retrieve may not work.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
# Build tool descriptions
|
| 643 |
tool_desc_list = []
|
|
|
|
| 669 |
**CRITICAL RULES:**
|
| 670 |
- **TOOL USE:** You MUST use tools to find the answer. Do NOT use your own knowledge.
|
| 671 |
- **FINAL ANSWER:** When you have the answer, use final_answer_tool. The 'answer' argument must be the answer ONLY (e.g., "42", "red, blue, green").
|
| 672 |
+
- **NO CONVERSATIONAL TEXT:** Never add phrases like "The answer is" or "Based on the information". Just the answer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
|
| 674 |
**TOOLS:**
|
| 675 |
{tool_descriptions}
|
| 676 |
|
| 677 |
+
**REMEMBER:** One step at a time. Use tools. Call final_answer_tool when done.
|
| 678 |
"""
|
| 679 |
|
| 680 |
print("Initializing Groq LLM...")
|
| 681 |
try:
|
| 682 |
+
# Changed from tool_choice="any" to "auto" for better flexibility
|
| 683 |
self.llm_with_tools = ChatGroq(
|
| 684 |
temperature=0,
|
| 685 |
groq_api_key=GROQ_API_KEY,
|
| 686 |
+
model_name="llama-3.3-70b-versatile",
|
| 687 |
max_tokens=4096,
|
| 688 |
timeout=60
|
| 689 |
+
).bind_tools(self.tools, tool_choice="auto")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
print("✅ Main LLM (llama-3.3-70b-versatile with tools) initialized.")
|
| 691 |
|
| 692 |
except Exception as e:
|
|
|
|
| 701 |
print('='*60)
|
| 702 |
|
| 703 |
if current_turn > MAX_TURNS:
|
| 704 |
+
return {"messages": [SystemMessage(content="Max turns reached.")], "turn": current_turn}
|
| 705 |
|
| 706 |
max_retries = 3
|
| 707 |
ai_message = None
|
|
|
|
| 736 |
|
| 737 |
# Tool Node
|
| 738 |
tool_node = ToolNode(self.tools)
|
| 739 |
+
|
| 740 |
# Build Graph
|
| 741 |
print("Building Single-Agent graph...")
|
| 742 |
graph_builder = StateGraph(AgentState)
|