File size: 6,063 Bytes
374588f
0a372e8
 
 
 
 
 
 
 
 
374588f
0a372e8
95cdb75
0a372e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95cdb75
0a372e8
95cdb75
0a372e8
 
 
 
 
 
 
 
 
95cdb75
 
0a372e8
 
 
 
 
 
 
 
 
 
 
95cdb75
 
 
0a372e8
 
 
 
 
95cdb75
 
 
 
 
 
 
0a372e8
 
 
 
 
 
 
95cdb75
374588f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95cdb75
374588f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from datetime import datetime
from langchain.tools.tool_node import ToolCallRequest
from langchain.chat_models import BaseChatModel
from langchain.agents.middleware import (
        ModelRequest,
        ModelResponse,

        wrap_model_call,
        wrap_tool_call,
)
from langchain_core.messages import ToolMessage
from openai import (
    BadRequestError,
    OpenAIError, 
    InternalServerError, 
    NotFoundError, 
    RateLimitError,
)

from config import MAX_MODEL_RETRIES
from src.rag.utilclasses import AgentContext
from src.utils.logging import get_logger

model_logger = get_logger('chain_model_call')
tool_logger  = get_logger('chain_tool_call')

class AgentChainMiddleware:
    _tool_wrapper_middleware = None 
    _model_wrapper_middleware = None


    @classmethod 
    def get_tool_wrapper(cls):
        if cls._tool_wrapper_middleware:
            return cls._tool_wrapper_middleware

        cls._tool_wrapper_middleware = wrap_tool_call(cls._tool_call_wrapper)
        tool_logger.info(f"Initialized tool call wrapper with call inspection")
        return cls._tool_wrapper_middleware

    
    @classmethod
    def get_model_wrapper(cls):
        if cls._model_wrapper_middleware:
            return cls._model_wrapper_middleware

        cls._model_wrapper_middleware = wrap_model_call(cls._model_call_wrapper)
        model_logger.info(f"Initialized model call wrapper with maximum of {MAX_MODEL_RETRIES} retry attempts")
        return cls._model_wrapper_middleware


    @staticmethod
    def _model_call_wrapper(request: ModelRequest, handler):
        context: AgentContext  = request.runtime.context
        model:   BaseChatModel = request.model
        model_logger.info(f"{context.agent_name} is attempting to call model '{model.model_name}'...")
        for attempt in range(1, MAX_MODEL_RETRIES+1):
            try:
                response: ModelResponse = handler(request)
                model_logger.info(f"{context.agent_name} recieved response from model after {attempt} attempt{'s' if attempt > 1 else ''}")
                result = response.result[0]
                metadata = result.response_metadata
                # Check if any errors occured during tool call execution.
                # Some errors might be fatal, making the model unusable in the agent chain
                if hasattr(result, 'invalid_tool_calls') and result.invalid_tool_calls:
                    for invalid_call in result.invalid_tool_calls:
                        fail_reason = invalid_call.get('error', 'Unknown').replace('\n', '')
                        model_logger.warning(f"Failed tool call: {invalid_call['name']}, error: {fail_reason}, retrying the call...")
                        if 'JSONDecodeError' in fail_reason:
                            model_logger.error(f"Model does not support current tool call architecture! Switching to the fallback model...")
                            raise Exception("Unsupported model") 
                elif not result.content and metadata['finish_reason'] != 'tool_calls':
                    model_logger.warning(f"Model returned an empty response, reason - {metadata['finish_reason']}! Retrying the call...")
                else:
                    return response
            except OpenAIError as e:
                match e:
                    case InternalServerError():
                        model_logger.warning(f"[{e.code}] Internal difficulties on the provider side, retrying the call...")
                    case RateLimitError():
                        model_logger.warning(f"[{e.code}] Model is temporary rate limited, retrying the call...")
                    case NotFoundError():
                        model_logger.error(f"[{e.code}] Model cannot be used in the chain, reason: {e.body['message']}")
                        raise e
                    case BadRequestError():
                        model_logger.error(f"[400] Bad request: {e.body['message']}")
                        raise e

                if attempt == MAX_MODEL_RETRIES:
                    model_logger.warning(f"Failed to recieve response from model '{model.model_name}' after {MAX_MODEL_RETRIES} attempt{'s' if attempt > 1 else ''}, reason: {e.body['message']}")
                    model_logger.info(f"Switching to the fallback model...")
                    raise e
            except Exception as e:
                model_logger.error(f"An error occured during model call (possibly backend side): {e}")
                raise e
        
        errormsg = f"{context.agent_name} failed to perform the model call due to unknown reason!"
        model_logger.error(errormsg)
        raise RuntimeError(errormsg)


    @staticmethod
    def _tool_call_wrapper(request: ToolCallRequest, handler):
        context: AgentContext = request.runtime.context or AgentContext(agent_name="Agent")
        
        tool_call = request.tool_call
        tool_logger.info(f"{context.agent_name} is calling tool: {tool_call['name']} with tool call id {tool_call['id']}")
        try:
            response = handler(request)
            tool_logger.info(f"Recieved response from tool call {tool_call['id']}") 
            if not response.content:
                tool_logger.warning("Tool returned nothing! This might be an issue on the tool side.")
            return response       
        except Exception as e:
            tool_logger.error(f"Failed to use tool {tool_call['name']} with id {tool_call['id']}")
            artifact = {
                'error_type': type(e).__name__,
                'error_message': str(e),
                'tool_name': tool_call['name'],
                'tool_args': tool_call['args'],
                'timestamp': datetime.now().isoformat(),
            }

            import json
            error_content = f"""Failed to use tool: {str(e)}

Error details:
{json.dumps(artifact, indent=2)}"""

            return ToolMessage(
                content=error_content,
                tool_call_id=tool_call['id'],
                artifact=artifact,
            )