backend-deploy / rag_agent_api /openrouter_agent.py
sanilahmed2019's picture
Redeploy backend
2ade705
"""
OpenRouter Agent module for the RAG Agent and API Layer system.
This module provides functionality for creating and managing an OpenRouter agent
that generates responses based on retrieved context.
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
import httpx
from .config import get_config
from .schemas import AgentContext, AgentResponse, SourceChunkSchema
from .utils import format_confidence_score
class OpenRouterAgent:
"""
A class to manage the OpenRouter agent for generating responses based on context.
"""
def __init__(self, model_name: str = "arcee-ai/trinity-mini:free"):
"""
Initialize the OpenRouter agent with configuration.
Args:
model_name: Name of the OpenRouter model to use (default: arcee-ai/trinity-mini:free)
"""
config = get_config()
api_key = config.openrouter_api_key
if not api_key:
raise ValueError("OPENROUTER_API_KEY environment variable not set")
self.api_key = api_key
self.model_name = model_name
self.base_url = "https://openrouter.ai/api/v1"
self.default_temperature = config.default_temperature
logging.info(f"OpenRouter agent initialized with model: {model_name}")
async def generate_response(self, context: AgentContext) -> AgentResponse:
"""
Generate a response based on the provided context.
Args:
context: AgentContext containing the query and retrieved context chunks
Returns:
AgentResponse with the generated answer and metadata
"""
# Check if retrieved context is empty (no chunks at all)
if not context.retrieved_chunks:
return AgentResponse(
raw_response="I could not find this information in the book.",
used_sources=[],
confidence_score=0.0,
is_valid=True,
validation_details="No context chunks retrieved from the database",
unsupported_claims=[]
)
# Check if context is insufficient (very short content)
total_context_length = sum(len(chunk.content) for chunk in context.retrieved_chunks)
if total_context_length < 10: # Much lower threshold, but still meaningful
return AgentResponse(
raw_response="I could not find this information in the book.",
used_sources=[],
confidence_score=0.0,
is_valid=True,
validation_details="No sufficient context provided to answer the question",
unsupported_claims=[]
)
try:
# Prepare the system message with instructions for grounding responses
system_message = self._create_system_message(context)
# Prepare the user message with the query
user_message = self._create_user_message(context)
# Prepare the payload for OpenRouter API
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
"temperature": context.source_policy if hasattr(context, 'temperature') else self.default_temperature,
"max_tokens": 1000
}
# Make the API call to OpenRouter
async with httpx.AsyncClient(timeout=30.0) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = await client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers
)
if response.status_code != 200:
logging.error(f"OpenRouter API error: {response.status_code} - {response.text}")
return AgentResponse(
raw_response="I could not find this information in the book.",
used_sources=[],
confidence_score=0.0,
is_valid=False,
validation_details=f"API error: {response.status_code}",
unsupported_claims=[]
)
response_data = response.json()
raw_response = response_data["choices"][0]["message"]["content"]
# If the response indicates no information was found, return the exact message
if "I could not find this information in the book" in raw_response:
return AgentResponse(
raw_response="I could not find this information in the book.",
used_sources=[],
confidence_score=0.0,
is_valid=True,
validation_details="No relevant information found in the provided context",
unsupported_claims=[]
)
# Determine which sources were used (this is a simplified approach)
used_sources = self._identify_used_sources(raw_response, context.retrieved_chunks)
# Calculate confidence score (based on similarity scores of used sources)
confidence_score = self._calculate_confidence_score(used_sources, context.retrieved_chunks)
# Validate that the response is grounded in the provided context
grounding_validation = self._validate_response_grounding(
raw_response, context.retrieved_chunks, context.query
)
# Create and return the agent response
agent_response = AgentResponse(
raw_response=raw_response,
used_sources=used_sources,
confidence_score=confidence_score,
is_valid=grounding_validation["is_valid"],
validation_details=grounding_validation["details"],
unsupported_claims=grounding_validation["unsupported_claims"]
)
logging.info(f"Agent response generated successfully. Confidence: {confidence_score:.2f}")
return agent_response
except Exception as e:
logging.error(f"Error generating response from OpenRouter agent: {e}", exc_info=True)
# Return the specific message when there's an error
return AgentResponse(
raw_response="I could not find this information in the book.",
used_sources=[],
confidence_score=0.0,
is_valid=False,
validation_details=f"Error generating response: {str(e)}",
unsupported_claims=[]
)
def _create_system_message(self, context: AgentContext) -> str:
"""
Create the system message that instructs the agent on how to behave.
Args:
context: AgentContext containing the query and retrieved context chunks
Returns:
Formatted system message string
"""
system_prompt = """You are a documentation-based assistant.
Answer ONLY using the provided context from the book
"Physical AI & Humanoid Robotics".
If the answer is not found, reply EXACTLY:
"I could not find this information in the book."""
return system_prompt
def _create_user_message(self, context: AgentContext) -> str:
"""
Create the user message containing the query.
Args:
context: AgentContext containing the query and retrieved context chunks
Returns:
Formatted user message string
"""
return f"""CONTEXT:
{self._format_context_chunks(context.retrieved_chunks)}
QUESTION:
{context.query}"""
def _format_context_chunks(self, chunks: List[SourceChunkSchema]) -> str:
"""
Format the context chunks for the prompt.
Args:
chunks: List of source chunks to format
Returns:
Formatted context string
"""
if not chunks:
return ""
formatted_chunks = []
for i, chunk in enumerate(chunks):
formatted_chunks.append(f"[Chunk {i+1}]\n{chunk.content}\n[/Chunk {i+1}]")
return "\n".join(formatted_chunks)
def _identify_used_sources(self, response: str, chunks: List[SourceChunkSchema]) -> List[str]:
"""
Identify which sources were likely used in the response.
This is a simplified approach - in a real implementation, you might use
more sophisticated techniques like semantic similarity.
Args:
response: The agent's response text
chunks: List of source chunks that were provided to the agent
Returns:
List of source IDs that were likely used
"""
used_sources = []
response_lower = response.lower()
for chunk in chunks:
# Check if any significant words from the chunk appear in the response
content_words = set(chunk.content.lower().split()[:20]) # Check first 20 words
response_words = set(response_lower.split())
# If there's significant overlap, consider this chunk as used
overlap = content_words.intersection(response_words)
if len(overlap) > 2: # Arbitrary threshold
used_sources.append(chunk.id)
# If no sources were identified, return all sources (conservative approach)
if not used_sources:
used_sources = [chunk.id for chunk in chunks]
return used_sources
def _calculate_confidence_score(self, used_sources: List[str], chunks: List[SourceChunkSchema]) -> float:
"""
Calculate a confidence score based on the quality of the used sources.
Args:
used_sources: List of source IDs that were used
chunks: List of all source chunks that were provided to the agent
Returns:
Confidence score between 0.0 and 1.0
"""
if not used_sources:
return 0.1 # Low confidence if no sources were used
# Calculate average similarity score of used sources
total_similarity = 0.0
used_count = 0
for chunk in chunks:
if chunk.id in used_sources:
total_similarity += chunk.similarity_score
used_count += 1
if used_count == 0:
return 0.1 # Low confidence if no matching chunks found
avg_similarity = total_similarity / used_count
# If similarity scores are very low (e.g., due to embedding issues),
# but we have content, still provide some confidence
if avg_similarity < 0.1 and len(used_sources) > 0:
# If we have relevant content but low similarity scores,
# it might be due to embedding issues, not lack of relevance
# So we'll set a minimum confidence if content exists
return 0.3 # Low but not zero confidence
else:
# Normalize the confidence score (adjust based on your requirements)
# Higher similarity scores contribute to higher confidence
confidence = avg_similarity
return format_confidence_score(confidence)
def _validate_response_grounding(self, response: str, chunks: List[SourceChunkSchema], query: str) -> Dict[str, Any]:
"""
Validate that the response is grounded in the provided context.
Args:
response: The agent's response text
chunks: List of source chunks that were provided to the agent
query: The original query
Returns:
Dictionary with validation results
"""
# Check if the response contains elements from the provided context
response_lower = response.lower()
context_text = " ".join([chunk.content.lower() for chunk in chunks])
# Simple heuristic: check if response contains significant terms from context
response_words = set(response_lower.split())
context_words = set(context_text.split())
# Calculate overlap between response and context
overlap = response_words.intersection(context_words)
total_response_words = len(response_words)
overlap_count = len(overlap)
# If less than 30% of response words come from context, flag as potentially ungrounded
is_grounded = True
unsupported_claims = []
if total_response_words > 0:
grounding_ratio = overlap_count / total_response_words
is_grounded = grounding_ratio >= 0.3 # At least 30% of words should come from context
# For now, we'll just return the basic validation
# In a more sophisticated implementation, you'd analyze the response more deeply
details = f"Response grounding validation completed. Context overlap ratio: {overlap_count/total_response_words if total_response_words > 0 else 0:.2f}"
return {
"is_valid": is_grounded,
"details": details,
"unsupported_claims": unsupported_claims
}
async def validate_response_quality(self, response: str, context: AgentContext) -> bool:
"""
Validate the quality of the agent's response.
Args:
response: The agent's response text
context: AgentContext containing the query and retrieved context chunks
Returns:
True if response meets quality standards, False otherwise
"""
# Check for common signs of poor quality responses
if not response or response.strip() == "":
logging.warning("Agent returned an empty response")
return False
# Check if response contains generic fallback phrases
lower_response = response.lower()
if "i don't know" in lower_response or "i don't have" in lower_response:
# This might be a valid response if there's no relevant context
if len(context.retrieved_chunks) == 0:
return True # Valid response if no context was provided
else:
# Check if the response is justified given the context
# For now, we'll consider it valid if it acknowledges the lack of relevant information
return True
# In a more sophisticated implementation, you'd validate against the context more rigorously
return True