Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from langchain.tools.tool_node import ToolCallRequest | |
| from langchain.chat_models import BaseChatModel | |
| from langchain.agents.middleware import ( | |
| ModelRequest, | |
| ModelResponse, | |
| wrap_model_call, | |
| wrap_tool_call, | |
| ) | |
| from langchain_core.messages import ToolMessage | |
| from openai import ( | |
| BadRequestError, | |
| OpenAIError, | |
| InternalServerError, | |
| NotFoundError, | |
| RateLimitError, | |
| ) | |
| from config import MAX_MODEL_RETRIES | |
| from src.rag.utilclasses import AgentContext | |
| from src.utils.logging import get_logger | |
| model_logger = get_logger('chain_model_call') | |
| tool_logger = get_logger('chain_tool_call') | |
| class AgentChainMiddleware: | |
| _tool_wrapper_middleware = None | |
| _model_wrapper_middleware = None | |
| def get_tool_wrapper(cls): | |
| if cls._tool_wrapper_middleware: | |
| return cls._tool_wrapper_middleware | |
| cls._tool_wrapper_middleware = wrap_tool_call(cls._tool_call_wrapper) | |
| tool_logger.info(f"Initialized tool call wrapper with call inspection") | |
| return cls._tool_wrapper_middleware | |
| def get_model_wrapper(cls): | |
| if cls._model_wrapper_middleware: | |
| return cls._model_wrapper_middleware | |
| cls._model_wrapper_middleware = wrap_model_call(cls._model_call_wrapper) | |
| model_logger.info(f"Initialized model call wrapper with maximum of {MAX_MODEL_RETRIES} retry attempts") | |
| return cls._model_wrapper_middleware | |
| def _model_call_wrapper(request: ModelRequest, handler): | |
| context: AgentContext = request.runtime.context | |
| model: BaseChatModel = request.model | |
| model_logger.info(f"{context.agent_name} is attempting to call model '{model.model_name}'...") | |
| for attempt in range(1, MAX_MODEL_RETRIES+1): | |
| try: | |
| response: ModelResponse = handler(request) | |
| model_logger.info(f"{context.agent_name} recieved response from model after {attempt} attempt{'s' if attempt > 1 else ''}") | |
| result = response.result[0] | |
| metadata = result.response_metadata | |
| # Check if any errors occured during tool call execution. | |
| # Some errors might be fatal, making the model unusable in the agent chain | |
| if hasattr(result, 'invalid_tool_calls') and result.invalid_tool_calls: | |
| for invalid_call in result.invalid_tool_calls: | |
| fail_reason = invalid_call.get('error', 'Unknown').replace('\n', '') | |
| model_logger.warning(f"Failed tool call: {invalid_call['name']}, error: {fail_reason}, retrying the call...") | |
| if 'JSONDecodeError' in fail_reason: | |
| model_logger.error(f"Model does not support current tool call architecture! Switching to the fallback model...") | |
| raise Exception("Unsupported model") | |
| elif not result.content and metadata['finish_reason'] != 'tool_calls': | |
| model_logger.warning(f"Model returned an empty response, reason - {metadata['finish_reason']}! Retrying the call...") | |
| else: | |
| return response | |
| except OpenAIError as e: | |
| match e: | |
| case InternalServerError(): | |
| model_logger.warning(f"[{e.code}] Internal difficulties on the provider side, retrying the call...") | |
| case RateLimitError(): | |
| model_logger.warning(f"[{e.code}] Model is temporary rate limited, retrying the call...") | |
| case NotFoundError(): | |
| model_logger.error(f"[{e.code}] Model cannot be used in the chain, reason: {e.body['message']}") | |
| raise e | |
| case BadRequestError(): | |
| model_logger.error(f"[400] Bad request: {e.body['message']}") | |
| raise e | |
| if attempt == MAX_MODEL_RETRIES: | |
| model_logger.warning(f"Failed to recieve response from model '{model.model_name}' after {MAX_MODEL_RETRIES} attempt{'s' if attempt > 1 else ''}, reason: {e.body['message']}") | |
| model_logger.info(f"Switching to the fallback model...") | |
| raise e | |
| except Exception as e: | |
| model_logger.error(f"An error occured during model call (possibly backend side): {e}") | |
| raise e | |
| errormsg = f"{context.agent_name} failed to perform the model call due to unknown reason!" | |
| model_logger.error(errormsg) | |
| raise RuntimeError(errormsg) | |
| def _tool_call_wrapper(request: ToolCallRequest, handler): | |
| context: AgentContext = request.runtime.context or AgentContext(agent_name="Agent") | |
| tool_call = request.tool_call | |
| tool_logger.info(f"{context.agent_name} is calling tool: {tool_call['name']} with tool call id {tool_call['id']}") | |
| try: | |
| response = handler(request) | |
| tool_logger.info(f"Recieved response from tool call {tool_call['id']}") | |
| if not response.content: | |
| tool_logger.warning("Tool returned nothing! This might be an issue on the tool side.") | |
| return response | |
| except Exception as e: | |
| tool_logger.error(f"Failed to use tool {tool_call['name']} with id {tool_call['id']}") | |
| artifact = { | |
| 'error_type': type(e).__name__, | |
| 'error_message': str(e), | |
| 'tool_name': tool_call['name'], | |
| 'tool_args': tool_call['args'], | |
| 'timestamp': datetime.now().isoformat(), | |
| } | |
| import json | |
| error_content = f"""Failed to use tool: {str(e)} | |
| Error details: | |
| {json.dumps(artifact, indent=2)}""" | |
| return ToolMessage( | |
| content=error_content, | |
| tool_call_id=tool_call['id'], | |
| artifact=artifact, | |
| ) | |