hsg_rag_eea / src /rag /middleware.py
Pygmales
app sync
374588f
raw
history blame
6.06 kB
from datetime import datetime
from langchain.tools.tool_node import ToolCallRequest
from langchain.chat_models import BaseChatModel
from langchain.agents.middleware import (
ModelRequest,
ModelResponse,
wrap_model_call,
wrap_tool_call,
)
from langchain_core.messages import ToolMessage
from openai import (
BadRequestError,
OpenAIError,
InternalServerError,
NotFoundError,
RateLimitError,
)
from config import MAX_MODEL_RETRIES
from src.rag.utilclasses import AgentContext
from src.utils.logging import get_logger
model_logger = get_logger('chain_model_call')
tool_logger = get_logger('chain_tool_call')
class AgentChainMiddleware:
_tool_wrapper_middleware = None
_model_wrapper_middleware = None
@classmethod
def get_tool_wrapper(cls):
if cls._tool_wrapper_middleware:
return cls._tool_wrapper_middleware
cls._tool_wrapper_middleware = wrap_tool_call(cls._tool_call_wrapper)
tool_logger.info(f"Initialized tool call wrapper with call inspection")
return cls._tool_wrapper_middleware
@classmethod
def get_model_wrapper(cls):
if cls._model_wrapper_middleware:
return cls._model_wrapper_middleware
cls._model_wrapper_middleware = wrap_model_call(cls._model_call_wrapper)
model_logger.info(f"Initialized model call wrapper with maximum of {MAX_MODEL_RETRIES} retry attempts")
return cls._model_wrapper_middleware
@staticmethod
def _model_call_wrapper(request: ModelRequest, handler):
context: AgentContext = request.runtime.context
model: BaseChatModel = request.model
model_logger.info(f"{context.agent_name} is attempting to call model '{model.model_name}'...")
for attempt in range(1, MAX_MODEL_RETRIES+1):
try:
response: ModelResponse = handler(request)
model_logger.info(f"{context.agent_name} recieved response from model after {attempt} attempt{'s' if attempt > 1 else ''}")
result = response.result[0]
metadata = result.response_metadata
# Check if any errors occured during tool call execution.
# Some errors might be fatal, making the model unusable in the agent chain
if hasattr(result, 'invalid_tool_calls') and result.invalid_tool_calls:
for invalid_call in result.invalid_tool_calls:
fail_reason = invalid_call.get('error', 'Unknown').replace('\n', '')
model_logger.warning(f"Failed tool call: {invalid_call['name']}, error: {fail_reason}, retrying the call...")
if 'JSONDecodeError' in fail_reason:
model_logger.error(f"Model does not support current tool call architecture! Switching to the fallback model...")
raise Exception("Unsupported model")
elif not result.content and metadata['finish_reason'] != 'tool_calls':
model_logger.warning(f"Model returned an empty response, reason - {metadata['finish_reason']}! Retrying the call...")
else:
return response
except OpenAIError as e:
match e:
case InternalServerError():
model_logger.warning(f"[{e.code}] Internal difficulties on the provider side, retrying the call...")
case RateLimitError():
model_logger.warning(f"[{e.code}] Model is temporary rate limited, retrying the call...")
case NotFoundError():
model_logger.error(f"[{e.code}] Model cannot be used in the chain, reason: {e.body['message']}")
raise e
case BadRequestError():
model_logger.error(f"[400] Bad request: {e.body['message']}")
raise e
if attempt == MAX_MODEL_RETRIES:
model_logger.warning(f"Failed to recieve response from model '{model.model_name}' after {MAX_MODEL_RETRIES} attempt{'s' if attempt > 1 else ''}, reason: {e.body['message']}")
model_logger.info(f"Switching to the fallback model...")
raise e
except Exception as e:
model_logger.error(f"An error occured during model call (possibly backend side): {e}")
raise e
errormsg = f"{context.agent_name} failed to perform the model call due to unknown reason!"
model_logger.error(errormsg)
raise RuntimeError(errormsg)
@staticmethod
def _tool_call_wrapper(request: ToolCallRequest, handler):
context: AgentContext = request.runtime.context or AgentContext(agent_name="Agent")
tool_call = request.tool_call
tool_logger.info(f"{context.agent_name} is calling tool: {tool_call['name']} with tool call id {tool_call['id']}")
try:
response = handler(request)
tool_logger.info(f"Recieved response from tool call {tool_call['id']}")
if not response.content:
tool_logger.warning("Tool returned nothing! This might be an issue on the tool side.")
return response
except Exception as e:
tool_logger.error(f"Failed to use tool {tool_call['name']} with id {tool_call['id']}")
artifact = {
'error_type': type(e).__name__,
'error_message': str(e),
'tool_name': tool_call['name'],
'tool_args': tool_call['args'],
'timestamp': datetime.now().isoformat(),
}
import json
error_content = f"""Failed to use tool: {str(e)}
Error details:
{json.dumps(artifact, indent=2)}"""
return ToolMessage(
content=error_content,
tool_call_id=tool_call['id'],
artifact=artifact,
)