Spaces:
Sleeping
Sleeping
File size: 6,063 Bytes
374588f 0a372e8 374588f 0a372e8 95cdb75 0a372e8 95cdb75 0a372e8 95cdb75 0a372e8 95cdb75 0a372e8 95cdb75 0a372e8 95cdb75 0a372e8 95cdb75 374588f 95cdb75 374588f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from datetime import datetime
from langchain.tools.tool_node import ToolCallRequest
from langchain.chat_models import BaseChatModel
from langchain.agents.middleware import (
ModelRequest,
ModelResponse,
wrap_model_call,
wrap_tool_call,
)
from langchain_core.messages import ToolMessage
from openai import (
BadRequestError,
OpenAIError,
InternalServerError,
NotFoundError,
RateLimitError,
)
from config import MAX_MODEL_RETRIES
from src.rag.utilclasses import AgentContext
from src.utils.logging import get_logger
model_logger = get_logger('chain_model_call')
tool_logger = get_logger('chain_tool_call')
class AgentChainMiddleware:
_tool_wrapper_middleware = None
_model_wrapper_middleware = None
@classmethod
def get_tool_wrapper(cls):
if cls._tool_wrapper_middleware:
return cls._tool_wrapper_middleware
cls._tool_wrapper_middleware = wrap_tool_call(cls._tool_call_wrapper)
tool_logger.info(f"Initialized tool call wrapper with call inspection")
return cls._tool_wrapper_middleware
@classmethod
def get_model_wrapper(cls):
if cls._model_wrapper_middleware:
return cls._model_wrapper_middleware
cls._model_wrapper_middleware = wrap_model_call(cls._model_call_wrapper)
model_logger.info(f"Initialized model call wrapper with maximum of {MAX_MODEL_RETRIES} retry attempts")
return cls._model_wrapper_middleware
@staticmethod
def _model_call_wrapper(request: ModelRequest, handler):
context: AgentContext = request.runtime.context
model: BaseChatModel = request.model
model_logger.info(f"{context.agent_name} is attempting to call model '{model.model_name}'...")
for attempt in range(1, MAX_MODEL_RETRIES+1):
try:
response: ModelResponse = handler(request)
model_logger.info(f"{context.agent_name} recieved response from model after {attempt} attempt{'s' if attempt > 1 else ''}")
result = response.result[0]
metadata = result.response_metadata
# Check if any errors occured during tool call execution.
# Some errors might be fatal, making the model unusable in the agent chain
if hasattr(result, 'invalid_tool_calls') and result.invalid_tool_calls:
for invalid_call in result.invalid_tool_calls:
fail_reason = invalid_call.get('error', 'Unknown').replace('\n', '')
model_logger.warning(f"Failed tool call: {invalid_call['name']}, error: {fail_reason}, retrying the call...")
if 'JSONDecodeError' in fail_reason:
model_logger.error(f"Model does not support current tool call architecture! Switching to the fallback model...")
raise Exception("Unsupported model")
elif not result.content and metadata['finish_reason'] != 'tool_calls':
model_logger.warning(f"Model returned an empty response, reason - {metadata['finish_reason']}! Retrying the call...")
else:
return response
except OpenAIError as e:
match e:
case InternalServerError():
model_logger.warning(f"[{e.code}] Internal difficulties on the provider side, retrying the call...")
case RateLimitError():
model_logger.warning(f"[{e.code}] Model is temporary rate limited, retrying the call...")
case NotFoundError():
model_logger.error(f"[{e.code}] Model cannot be used in the chain, reason: {e.body['message']}")
raise e
case BadRequestError():
model_logger.error(f"[400] Bad request: {e.body['message']}")
raise e
if attempt == MAX_MODEL_RETRIES:
model_logger.warning(f"Failed to recieve response from model '{model.model_name}' after {MAX_MODEL_RETRIES} attempt{'s' if attempt > 1 else ''}, reason: {e.body['message']}")
model_logger.info(f"Switching to the fallback model...")
raise e
except Exception as e:
model_logger.error(f"An error occured during model call (possibly backend side): {e}")
raise e
errormsg = f"{context.agent_name} failed to perform the model call due to unknown reason!"
model_logger.error(errormsg)
raise RuntimeError(errormsg)
@staticmethod
def _tool_call_wrapper(request: ToolCallRequest, handler):
context: AgentContext = request.runtime.context or AgentContext(agent_name="Agent")
tool_call = request.tool_call
tool_logger.info(f"{context.agent_name} is calling tool: {tool_call['name']} with tool call id {tool_call['id']}")
try:
response = handler(request)
tool_logger.info(f"Recieved response from tool call {tool_call['id']}")
if not response.content:
tool_logger.warning("Tool returned nothing! This might be an issue on the tool side.")
return response
except Exception as e:
tool_logger.error(f"Failed to use tool {tool_call['name']} with id {tool_call['id']}")
artifact = {
'error_type': type(e).__name__,
'error_message': str(e),
'tool_name': tool_call['name'],
'tool_args': tool_call['args'],
'timestamp': datetime.now().isoformat(),
}
import json
error_content = f"""Failed to use tool: {str(e)}
Error details:
{json.dumps(artifact, indent=2)}"""
return ToolMessage(
content=error_content,
tool_call_id=tool_call['id'],
artifact=artifact,
)
|