Spaces:
Runtime error
Runtime error
Upload 25 files
Browse files- graph/__init__.py +0 -0
- graph/__pycache__/__init__.cpython-311.pyc +0 -0
- graph/__pycache__/workflow.cpython-311.pyc +0 -0
- graph/workflow.py +102 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/vector_store.cpython-311.pyc +0 -0
- models/embedding.py +27 -0
- models/vector_store.py +72 -0
- tests/__init__.py +0 -0
- tests/__pycache__/__init__.cpython-311.pyc +0 -0
- tests/__pycache__/test_api_handler.cpython-311-pytest-8.3.5.pyc +0 -0
- tests/__pycache__/test_rag_agent.cpython-311-pytest-8.3.5.pyc +0 -0
- tests/__pycache__/test_workflow.cpython-311-pytest-8.3.5.pyc +0 -0
- tests/test_api_handler.py +79 -0
- tests/test_rag_agent.py +71 -0
- tests/test_workflow.py +78 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/api_handler.cpython-311.pyc +0 -0
- utils/__pycache__/document_loader.cpython-311.pyc +0 -0
- utils/__pycache__/evaluation.cpython-311.pyc +0 -0
- utils/api_handler.py +68 -0
- utils/document_loader.py +62 -0
- utils/evaluation.py +66 -0
graph/__init__.py
ADDED
|
File without changes
|
graph/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
graph/__pycache__/workflow.cpython-311.pyc
ADDED
|
Binary file (6.64 kB). View file
|
|
|
graph/workflow.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List, Literal, TypedDict, Annotated, Union
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from langgraph.graph import StateGraph, END
|
| 5 |
+
from agents.router_agent import RouterAgent
|
| 6 |
+
from agents.weather_agent import WeatherAgent
|
| 7 |
+
from agents.rag_agent import RAGAgent
|
| 8 |
+
from utils.evaluation import LangSmithEvaluator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WorkflowState(BaseModel):
|
| 12 |
+
"""State for the workflow graph"""
|
| 13 |
+
query: str = Field(description="The user's original query")
|
| 14 |
+
action: str = Field(description="The action to take: 'weather' or 'document'", default="")
|
| 15 |
+
context: List[Dict[str, Any]] = Field(description="Retrieved context (for document queries)", default=[])
|
| 16 |
+
weather_data: Dict[str, Any] = Field(description="Weather data (for weather queries)", default={})
|
| 17 |
+
city: str = Field(description="City for weather queries", default="")
|
| 18 |
+
response: str = Field(description="The final response to the user", default="")
|
| 19 |
+
evaluation: Dict[str, Any] = Field(description="Evaluation results", default={})
|
| 20 |
+
|
| 21 |
+
class LangGraphWorkflow:
|
| 22 |
+
"""LangGraph workflow for the AI pipeline"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.router_agent = RouterAgent()
|
| 26 |
+
self.weather_agent = WeatherAgent()
|
| 27 |
+
self.rag_agent = RAGAgent()
|
| 28 |
+
self.evaluator = LangSmithEvaluator()
|
| 29 |
+
|
| 30 |
+
# Build the workflow graph
|
| 31 |
+
self.workflow = self.build_workflow()
|
| 32 |
+
|
| 33 |
+
def route(self, state: WorkflowState) -> WorkflowState:
|
| 34 |
+
"""Route the query to the appropriate agent"""
|
| 35 |
+
action = self.router_agent.route_query(state.query)
|
| 36 |
+
return state.copy(update={"action": action})
|
| 37 |
+
|
| 38 |
+
def process_weather(self, state: WorkflowState) -> WorkflowState:
|
| 39 |
+
"""Process weather-related queries"""
|
| 40 |
+
weather_response = self.weather_agent.get_weather_response(state.query)
|
| 41 |
+
return state.copy(update={
|
| 42 |
+
"city": weather_response["city"],
|
| 43 |
+
"weather_data": weather_response["weather_data"],
|
| 44 |
+
"response": weather_response["response"]
|
| 45 |
+
})
|
| 46 |
+
|
| 47 |
+
def process_document(self, state: WorkflowState) -> WorkflowState:
|
| 48 |
+
"""Process document-related queries"""
|
| 49 |
+
rag_response = self.rag_agent.get_rag_response(state.query)
|
| 50 |
+
return state.copy(update={
|
| 51 |
+
"context": rag_response["context"],
|
| 52 |
+
"response": rag_response["response"]
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
def evaluate_response(self, state: WorkflowState) -> WorkflowState:
|
| 56 |
+
"""Evaluate the response using LangSmith"""
|
| 57 |
+
# For simplicity, we're only evaluating basic metrics here
|
| 58 |
+
evaluation = {
|
| 59 |
+
"query": state.query,
|
| 60 |
+
"response": state.response,
|
| 61 |
+
"action": state.action,
|
| 62 |
+
# Additional metrics would come from LangSmith in a real implementation
|
| 63 |
+
"confidence": 0.95 if state.context or state.weather_data else 0.7,
|
| 64 |
+
"latency": 1.2, # Example metric
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return state.copy(update={"evaluation": evaluation})
|
| 68 |
+
|
| 69 |
+
def build_workflow(self) -> StateGraph:
|
| 70 |
+
"""Build the LangGraph workflow"""
|
| 71 |
+
workflow = StateGraph(WorkflowState)
|
| 72 |
+
|
| 73 |
+
# Register nodes with names + actual methods
|
| 74 |
+
workflow.add_node("router", self.route) # Use callable (method) for logic
|
| 75 |
+
workflow.add_node("weather", self.process_weather) # Use callable
|
| 76 |
+
workflow.add_node("document", self.process_document) # Use callable
|
| 77 |
+
workflow.add_node("evaluate", self.evaluate_response) # Use callable
|
| 78 |
+
|
| 79 |
+
# Conditional edges — based on state.action
|
| 80 |
+
workflow.add_conditional_edges(
|
| 81 |
+
"router", # Source node
|
| 82 |
+
lambda state: state.action, # Condition function
|
| 83 |
+
{
|
| 84 |
+
"weather": "weather", # Condition -> Target node
|
| 85 |
+
"document": "document"
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
# Sequential steps
|
| 89 |
+
workflow.add_edge("weather", "evaluate") # Use node names
|
| 90 |
+
workflow.add_edge("document", "evaluate") # Use node names
|
| 91 |
+
workflow.add_edge("evaluate", END) # Use node name
|
| 92 |
+
|
| 93 |
+
# Set entry point
|
| 94 |
+
workflow.set_entry_point("router") # Use node name
|
| 95 |
+
|
| 96 |
+
return workflow.compile()
|
| 97 |
+
|
| 98 |
+
def invoke(self, query: str) -> Dict[str, Any]:
|
| 99 |
+
"""Invoke the workflow with a query"""
|
| 100 |
+
state = WorkflowState(query=query)
|
| 101 |
+
result = self.workflow.invoke(state)
|
| 102 |
+
return result
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
models/__pycache__/vector_store.cpython-311.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
models/embedding.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any
|
| 2 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 3 |
+
from langchain.schema import Document
|
| 4 |
+
import os
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EmbeddingModel:
|
| 12 |
+
"""Handles document embedding using Google's Gemini embedding models"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, api_key: str = GEMINI_API_KEY):
|
| 15 |
+
self.embeddings = GoogleGenerativeAIEmbeddings(
|
| 16 |
+
google_api_key=api_key,
|
| 17 |
+
model="models/text-embedding-004"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def embed_documents(self, documents: List[Document]) -> List[List[float]]:
|
| 21 |
+
"""Generate embeddings for a list of documents"""
|
| 22 |
+
texts = [doc.page_content for doc in documents]
|
| 23 |
+
return self.embeddings.embed_documents(texts)
|
| 24 |
+
|
| 25 |
+
def embed_query(self, query: str) -> List[float]:
|
| 26 |
+
"""Generate embedding for a query string"""
|
| 27 |
+
return self.embeddings.embed_query(query)
|
models/vector_store.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from langchain_community.vectorstores import Qdrant
|
| 4 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.http import models as rest
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME")
|
| 11 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 12 |
+
db_url = os.getenv("db_url")
|
| 13 |
+
db_api = os.getenv("db_api")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VectorStore:
|
| 17 |
+
"""Interface to the Qdrant vector database"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
collection_name: str = QDRANT_COLLECTION_NAME,
|
| 22 |
+
db_url: str = db_url,
|
| 23 |
+
db_api: int = db_api,
|
| 24 |
+
api_key: str = GEMINI_API_KEY
|
| 25 |
+
):
|
| 26 |
+
self.collection_name = collection_name
|
| 27 |
+
self.embeddings = GoogleGenerativeAIEmbeddings(
|
| 28 |
+
google_api_key=api_key,
|
| 29 |
+
model="models/text-embedding-004"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Initialize Qdrant client
|
| 33 |
+
self.client = QdrantClient( url=f"https://{db_url}",
|
| 34 |
+
api_key=db_api)
|
| 35 |
+
|
| 36 |
+
# Create collection if it doesn't exist
|
| 37 |
+
collections = self.client.get_collections().collections
|
| 38 |
+
collection_names = [collection.name for collection in collections]
|
| 39 |
+
|
| 40 |
+
if collection_name not in collection_names:
|
| 41 |
+
self.client.create_collection(
|
| 42 |
+
collection_name=collection_name,
|
| 43 |
+
vectors_config=rest.VectorParams(
|
| 44 |
+
size=768, # Gemini embedding dimension
|
| 45 |
+
distance=rest.Distance.COSINE
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Initialize Qdrant vectorstore
|
| 50 |
+
self.vectorstore = Qdrant(
|
| 51 |
+
client=self.client,
|
| 52 |
+
collection_name=collection_name,
|
| 53 |
+
embeddings=self.embeddings
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def add_documents(self, documents: List[Document]) -> bool:
|
| 57 |
+
"""Add documents to the vector store"""
|
| 58 |
+
try:
|
| 59 |
+
self.vectorstore.add_documents(documents)
|
| 60 |
+
return True
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error adding documents to vector store: {str(e)}")
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
| 66 |
+
"""Perform similarity search for a query"""
|
| 67 |
+
try:
|
| 68 |
+
return self.vectorstore.similarity_search(query, k=k)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error during similarity search: {str(e)}")
|
| 71 |
+
return []
|
| 72 |
+
|
tests/__init__.py
ADDED
|
File without changes
|
tests/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
tests/__pycache__/test_api_handler.cpython-311-pytest-8.3.5.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|
tests/__pycache__/test_rag_agent.cpython-311-pytest-8.3.5.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
tests/__pycache__/test_workflow.cpython-311-pytest-8.3.5.pyc
ADDED
|
Binary file (4.65 kB). View file
|
|
|
tests/test_api_handler.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
import json
|
| 4 |
+
from utils.api_handler import WeatherAPIHandler
|
| 5 |
+
from requests.exceptions import RequestException
|
| 6 |
+
from requests.exceptions import HTTPError
|
| 7 |
+
|
| 8 |
+
class TestWeatherAPIHandler(unittest.TestCase):
|
| 9 |
+
|
| 10 |
+
def setUp(self):
|
| 11 |
+
self.api_handler = WeatherAPIHandler(api_key="test_api_key")
|
| 12 |
+
|
| 13 |
+
# Sample successful response data
|
| 14 |
+
self.sample_response = {
|
| 15 |
+
"name": "London",
|
| 16 |
+
"sys": {"country": "GB"},
|
| 17 |
+
"main": {
|
| 18 |
+
"temp": 15.5,
|
| 19 |
+
"feels_like": 14.8,
|
| 20 |
+
"humidity": 76
|
| 21 |
+
},
|
| 22 |
+
"weather": [{"description": "scattered clouds"}],
|
| 23 |
+
"wind": {"speed": 3.6}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
@patch('requests.get')
|
| 27 |
+
def test_get_weather_success(self, mock_get):
|
| 28 |
+
# Configure mock
|
| 29 |
+
mock_response = MagicMock()
|
| 30 |
+
mock_response.status_code = 200
|
| 31 |
+
mock_response.json.return_value = self.sample_response
|
| 32 |
+
mock_get.return_value = mock_response
|
| 33 |
+
|
| 34 |
+
# Call the method
|
| 35 |
+
result = self.api_handler.get_weather("London")
|
| 36 |
+
|
| 37 |
+
# Assertions
|
| 38 |
+
self.assertEqual(result, self.sample_response)
|
| 39 |
+
mock_get.assert_called_once()
|
| 40 |
+
|
| 41 |
+
@patch('requests.get')
|
| 42 |
+
def test_get_weather_city_not_found(self, mock_get):
|
| 43 |
+
mock_response = MagicMock()
|
| 44 |
+
mock_response.status_code = 404
|
| 45 |
+
mock_response.raise_for_status.side_effect = HTTPError(response=mock_response)
|
| 46 |
+
mock_get.return_value = mock_response
|
| 47 |
+
|
| 48 |
+
result = self.api_handler.get_weather("NonExistentCity")
|
| 49 |
+
|
| 50 |
+
self.assertIn("error", result)
|
| 51 |
+
self.assertIn("NonExistentCity", result["error"])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@patch('requests.get')
|
| 55 |
+
def test_get_weather_connection_error(self, mock_get):
|
| 56 |
+
mock_get.side_effect = RequestException("Connection Error")
|
| 57 |
+
|
| 58 |
+
result = self.api_handler.get_weather("London")
|
| 59 |
+
|
| 60 |
+
self.assertIn("error", result)
|
| 61 |
+
self.assertIn("Connection Error", result["error"])
|
| 62 |
+
|
| 63 |
+
def test_format_weather_data_success(self):
|
| 64 |
+
# Call the method
|
| 65 |
+
formatted_result = self.api_handler.format_weather_data(self.sample_response)
|
| 66 |
+
|
| 67 |
+
# Assertions
|
| 68 |
+
self.assertIn("London", formatted_result)
|
| 69 |
+
self.assertIn("15.5°C", formatted_result)
|
| 70 |
+
self.assertIn("scattered clouds", formatted_result.lower())
|
| 71 |
+
|
| 72 |
+
def test_format_weather_data_error(self):
|
| 73 |
+
# Call the method with incomplete data
|
| 74 |
+
formatted_result = self.api_handler.format_weather_data({"error": "City not found"})
|
| 75 |
+
|
| 76 |
+
# Assertions
|
| 77 |
+
self.assertEqual(formatted_result, "City not found")
|
| 78 |
+
|
| 79 |
+
|
tests/test_rag_agent.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
from agents.rag_agent import RAGAgent
|
| 4 |
+
from langchain.schema import Document
|
| 5 |
+
|
| 6 |
+
class TestRAGAgent(unittest.TestCase):
|
| 7 |
+
|
| 8 |
+
def setUp(self):
|
| 9 |
+
# Create a mock for vector store
|
| 10 |
+
self.vector_store_patch = patch('agents.rag_agent.VectorStore')
|
| 11 |
+
self.mock_vector_store_class = self.vector_store_patch.start()
|
| 12 |
+
self.mock_vector_store = self.mock_vector_store_class.return_value
|
| 13 |
+
|
| 14 |
+
# Create a mock for LLM
|
| 15 |
+
self.llm_patch = patch('agents.rag_agent.ChatGoogleGenerativeAI')
|
| 16 |
+
self.mock_llm_class = self.llm_patch.start()
|
| 17 |
+
self.mock_llm = self.mock_llm_class.return_value
|
| 18 |
+
|
| 19 |
+
# Sample documents
|
| 20 |
+
self.sample_docs = [
|
| 21 |
+
Document(page_content="This is a test document about AI.", metadata={"source": "test1.pdf"}),
|
| 22 |
+
Document(page_content="LangChain is a framework for LLM applications.", metadata={"source": "test2.pdf"})
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Initialize agent
|
| 26 |
+
self.agent = RAGAgent(api_key="test_api_key")
|
| 27 |
+
|
| 28 |
+
def tearDown(self):
|
| 29 |
+
self.vector_store_patch.stop()
|
| 30 |
+
self.llm_patch.stop()
|
| 31 |
+
|
| 32 |
+
def test_retrieve_context(self):
|
| 33 |
+
# Configure mock
|
| 34 |
+
self.mock_vector_store.similarity_search.return_value = self.sample_docs
|
| 35 |
+
|
| 36 |
+
# Call the method
|
| 37 |
+
result = self.agent.retrieve_context("What is LangChain?")
|
| 38 |
+
|
| 39 |
+
# Assertions
|
| 40 |
+
self.assertEqual(result, self.sample_docs)
|
| 41 |
+
self.mock_vector_store.similarity_search.assert_called_once()
|
| 42 |
+
|
| 43 |
+
def test_get_rag_response_with_context(self):
|
| 44 |
+
# Mock similarity_search to return 2 documents
|
| 45 |
+
self.mock_vector_store.similarity_search.return_value = self.sample_docs
|
| 46 |
+
|
| 47 |
+
# Mock rag_chain
|
| 48 |
+
mock_chain = MagicMock()
|
| 49 |
+
mock_chain.invoke.return_value.content = "LangChain is a framework for building LLM applications."
|
| 50 |
+
self.agent.rag_chain = mock_chain
|
| 51 |
+
|
| 52 |
+
# Call the method
|
| 53 |
+
result = self.agent.get_rag_response("What is LangChain?")
|
| 54 |
+
|
| 55 |
+
# Assertions
|
| 56 |
+
self.assertEqual(result["response"], "LangChain is a framework for building LLM applications.")
|
| 57 |
+
self.assertEqual(len(result["context"]), 2)
|
| 58 |
+
self.assertEqual(result["context"][0]["page_content"], "This is a test document about AI.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_get_rag_response_no_context(self):
|
| 62 |
+
# Configure mock to return empty list
|
| 63 |
+
self.mock_vector_store.similarity_search.return_value = []
|
| 64 |
+
|
| 65 |
+
# Call the method
|
| 66 |
+
result = self.agent.get_rag_response("What is LangChain?")
|
| 67 |
+
|
| 68 |
+
# Assertions
|
| 69 |
+
self.assertEqual(len(result["context"]), 0)
|
| 70 |
+
self.assertIn("couldn't find any relevant information", result["response"])
|
| 71 |
+
|
tests/test_workflow.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
from graph.workflow import LangGraphWorkflow, WorkflowState
|
| 4 |
+
|
| 5 |
+
class TestLangGraphWorkflow(unittest.TestCase):
|
| 6 |
+
|
| 7 |
+
def setUp(self):
|
| 8 |
+
# Create mocks for agents
|
| 9 |
+
self.router_agent_patch = patch('graph.workflow.RouterAgent')
|
| 10 |
+
self.weather_agent_patch = patch('graph.workflow.WeatherAgent')
|
| 11 |
+
self.rag_agent_patch = patch('graph.workflow.RAGAgent')
|
| 12 |
+
self.evaluator_patch = patch('graph.workflow.LangSmithEvaluator')
|
| 13 |
+
|
| 14 |
+
self.mock_router_agent_class = self.router_agent_patch.start()
|
| 15 |
+
self.mock_weather_agent_class = self.weather_agent_patch.start()
|
| 16 |
+
self.mock_rag_agent_class = self.rag_agent_patch.start()
|
| 17 |
+
self.mock_evaluator_class = self.evaluator_patch.start()
|
| 18 |
+
|
| 19 |
+
self.mock_router_agent = self.mock_router_agent_class.return_value
|
| 20 |
+
self.mock_weather_agent = self.mock_weather_agent_class.return_value
|
| 21 |
+
self.mock_rag_agent = self.mock_rag_agent_class.return_value
|
| 22 |
+
self.mock_evaluator = self.mock_evaluator_class.return_value
|
| 23 |
+
|
| 24 |
+
# Initialize workflow
|
| 25 |
+
self.workflow = LangGraphWorkflow()
|
| 26 |
+
|
| 27 |
+
def tearDown(self):
|
| 28 |
+
self.router_agent_patch.stop()
|
| 29 |
+
self.weather_agent_patch.stop()
|
| 30 |
+
self.rag_agent_patch.stop()
|
| 31 |
+
self.evaluator_patch.stop()
|
| 32 |
+
|
| 33 |
+
def test_route_to_weather(self):
|
| 34 |
+
# Configure mock
|
| 35 |
+
self.mock_router_agent.route_query.return_value = "weather"
|
| 36 |
+
|
| 37 |
+
# Create state
|
| 38 |
+
state = WorkflowState(query="What's the weather in London?")
|
| 39 |
+
|
| 40 |
+
# Call the method
|
| 41 |
+
result = self.workflow.route(state)
|
| 42 |
+
|
| 43 |
+
# Assertions
|
| 44 |
+
self.assertEqual(result.action, "weather")
|
| 45 |
+
self.mock_router_agent.route_query.assert_called_once_with("What's the weather in London?")
|
| 46 |
+
|
| 47 |
+
def test_route_to_document(self):
|
| 48 |
+
# Configure mock
|
| 49 |
+
self.mock_router_agent.route_query.return_value = "document"
|
| 50 |
+
|
| 51 |
+
# Create state
|
| 52 |
+
state = WorkflowState(query="What is LangChain?")
|
| 53 |
+
|
| 54 |
+
# Call the method
|
| 55 |
+
result = self.workflow.route(state)
|
| 56 |
+
|
| 57 |
+
# Assertions
|
| 58 |
+
self.assertEqual(result.action, "document")
|
| 59 |
+
self.mock_router_agent.route_query.assert_called_once_with("What is LangChain?")
|
| 60 |
+
|
| 61 |
+
def test_process_weather(self):
|
| 62 |
+
# Configure mock
|
| 63 |
+
self.mock_weather_agent.get_weather_response.return_value = {
|
| 64 |
+
"city": "London",
|
| 65 |
+
"weather_data": {"temp": 15.5},
|
| 66 |
+
"response": "The weather in London is 15.5°C."
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Create state
|
| 70 |
+
state = WorkflowState(query="What's the weather in London?", action="weather")
|
| 71 |
+
|
| 72 |
+
# Call the method
|
| 73 |
+
result = self.workflow.process_weather(state)
|
| 74 |
+
|
| 75 |
+
# Assertions
|
| 76 |
+
self.assertEqual(result.city, "London")
|
| 77 |
+
self.assertEqual(result.weather_data, {"temp": 15.5})
|
| 78 |
+
self.assertEqual(result.response, "The weather in London is 15.5°C.")
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
utils/__pycache__/api_handler.cpython-311.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
utils/__pycache__/document_loader.cpython-311.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
utils/__pycache__/evaluation.cpython-311.pyc
ADDED
|
Binary file (3.24 kB). View file
|
|
|
utils/api_handler.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from typing import Dict, Any, Optional
|
| 3 |
+
import json
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY")
|
| 10 |
+
WEATHER_API_BASE_URL = "https://api.openweathermap.org/data/2.5/weather"
|
| 11 |
+
|
| 12 |
+
class WeatherAPIHandler:
|
| 13 |
+
"""Handler for the OpenWeatherMap API"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, api_key: str = OPENWEATHERMAP_API_KEY):
|
| 16 |
+
self.api_key = api_key
|
| 17 |
+
self.base_url = WEATHER_API_BASE_URL
|
| 18 |
+
|
| 19 |
+
def get_weather(self, city: str) -> Dict[str, Any]:
|
| 20 |
+
"""Fetch weather data for a given city"""
|
| 21 |
+
params = {
|
| 22 |
+
'q': city,
|
| 23 |
+
'appid': self.api_key,
|
| 24 |
+
'units': 'metric'
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
response = requests.get(self.base_url, params=params)
|
| 29 |
+
response.raise_for_status()
|
| 30 |
+
return response.json()
|
| 31 |
+
|
| 32 |
+
except requests.exceptions.HTTPError as e:
|
| 33 |
+
status_code = e.response.status_code if e.response else None
|
| 34 |
+
if status_code == 404:
|
| 35 |
+
return {"error": f"City {city} not found"}
|
| 36 |
+
return {"error": f"HTTP Error: {str(e)}"}
|
| 37 |
+
|
| 38 |
+
except requests.exceptions.RequestException as e:
|
| 39 |
+
return {"error": f"Request Error: {str(e)}"}
|
| 40 |
+
|
| 41 |
+
except json.JSONDecodeError:
|
| 42 |
+
return {"error": "Failed to parse API response"}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def format_weather_data(self, weather_data: Dict[str, Any]) -> str:
|
| 46 |
+
"""Format weather data into a readable string"""
|
| 47 |
+
if "error" in weather_data:
|
| 48 |
+
return weather_data["error"]
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
city = weather_data["name"]
|
| 52 |
+
country = weather_data["sys"]["country"]
|
| 53 |
+
temp = weather_data["main"]["temp"]
|
| 54 |
+
feels_like = weather_data["main"]["feels_like"]
|
| 55 |
+
humidity = weather_data["main"]["humidity"]
|
| 56 |
+
weather_desc = weather_data["weather"][0]["description"]
|
| 57 |
+
wind_speed = weather_data["wind"]["speed"]
|
| 58 |
+
|
| 59 |
+
formatted_result = f"""
|
| 60 |
+
Weather in {city}, {country}:
|
| 61 |
+
- Temperature: {temp}°C (Feels like: {feels_like}°C)
|
| 62 |
+
- Conditions: {weather_desc.capitalize()}
|
| 63 |
+
- Humidity: {humidity}%
|
| 64 |
+
- Wind Speed: {wind_speed} m/s
|
| 65 |
+
"""
|
| 66 |
+
return formatted_result
|
| 67 |
+
except KeyError:
|
| 68 |
+
return "Error formatting weather data: incomplete or invalid data received"
|
utils/document_loader.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DocumentLoader:
|
| 13 |
+
"""Handles loading and processing PDF documents"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, document_dir: str = "documents"):
|
| 16 |
+
self.document_dir = document_dir
|
| 17 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 18 |
+
chunk_size=1000,
|
| 19 |
+
chunk_overlap=200,
|
| 20 |
+
length_function=len,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Create documents directory if it doesn't exist
|
| 24 |
+
os.makedirs(document_dir, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
def load_pdf(self, file_path: str) -> List[Document]:
|
| 27 |
+
"""Load and split a PDF document into chunks"""
|
| 28 |
+
try:
|
| 29 |
+
loader = PyPDFLoader(file_path)
|
| 30 |
+
documents = loader.load()
|
| 31 |
+
return self.text_splitter.split_documents(documents)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Error loading PDF: {str(e)}")
|
| 34 |
+
return []
|
| 35 |
+
|
| 36 |
+
def save_uploaded_pdf(self, uploaded_file) -> str:
|
| 37 |
+
"""Save an uploaded PDF file with its original name and return its path"""
|
| 38 |
+
try:
|
| 39 |
+
# Make sure document_dir exists
|
| 40 |
+
os.makedirs(self.document_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# Sanitize the original filename to prevent path traversal or special characters
|
| 43 |
+
safe_filename = os.path.basename(uploaded_file.name)
|
| 44 |
+
save_path = os.path.join(self.document_dir, safe_filename)
|
| 45 |
+
|
| 46 |
+
# Save file content
|
| 47 |
+
with open(save_path, 'wb') as f:
|
| 48 |
+
f.write(uploaded_file.getvalue())
|
| 49 |
+
|
| 50 |
+
return save_path
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error saving uploaded PDF: {str(e)}")
|
| 53 |
+
return ""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_available_documents(self) -> List[str]:
|
| 57 |
+
"""Get a list of available PDF documents"""
|
| 58 |
+
try:
|
| 59 |
+
return [f for f in os.listdir(self.document_dir) if f.endswith('.pdf')]
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Error listing documents: {str(e)}")
|
| 62 |
+
return []
|
utils/evaluation.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
from langsmith import Client
|
| 3 |
+
from langchain.smith import RunEvalConfig
|
| 4 |
+
from langsmith.evaluation import run_evaluator
|
| 5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
|
| 11 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 12 |
+
|
| 13 |
+
class LangSmithEvaluator:
|
| 14 |
+
"""Handles evaluation using LangSmith"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, api_key: str = LANGSMITH_API_KEY):
|
| 17 |
+
self.client = Client(api_key=api_key)
|
| 18 |
+
self.evaluator_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash",google_api_key=GEMINI_API_KEY)
|
| 19 |
+
|
| 20 |
+
def evaluate_response(self, query: str, response: str, reference: str = None) -> Dict[str, Any]:
|
| 21 |
+
"""Evaluate an LLM response against a query and optional reference"""
|
| 22 |
+
eval_config = RunEvalConfig(
|
| 23 |
+
evaluators=[
|
| 24 |
+
"criteria",
|
| 25 |
+
"embedding_distance",
|
| 26 |
+
],
|
| 27 |
+
custom_evaluators=[
|
| 28 |
+
run_evaluator.RunEvalConfig(
|
| 29 |
+
evaluator="correctness",
|
| 30 |
+
llm=self.evaluator_llm
|
| 31 |
+
),
|
| 32 |
+
run_evaluator.RunEvalConfig(
|
| 33 |
+
evaluator="helpfulness",
|
| 34 |
+
llm=self.evaluator_llm
|
| 35 |
+
),
|
| 36 |
+
run_evaluator.RunEvalConfig(
|
| 37 |
+
evaluator="relevance",
|
| 38 |
+
llm=self.evaluator_llm
|
| 39 |
+
),
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# Create dataset with single example
|
| 45 |
+
dataset = self.client.create_dataset(
|
| 46 |
+
"evaluation_dataset",
|
| 47 |
+
description="Dataset for evaluation of LLM responses"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Add example
|
| 51 |
+
self.client.create_example(
|
| 52 |
+
inputs={"question": query},
|
| 53 |
+
outputs={"answer": response},
|
| 54 |
+
dataset_id=dataset.id
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Run evaluation
|
| 58 |
+
evaluation_results = self.client.run_evaluation(
|
| 59 |
+
dataset_id=dataset.id,
|
| 60 |
+
config=eval_config
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return evaluation_results
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Error during evaluation: {str(e)}")
|
| 66 |
+
return {"error": str(e)}
|