File size: 9,482 Bytes
b1b9230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


from langchain_core.output_parsers import PydanticOutputParser
from typing import Callable, Dict, List, Any
import time
import json
from groq_api import grok_get_llm_response, API_llama_get_llm_response, open_oss_get_llm_response, openai_get_llm_response, deepseekapi_get_llm_response
from local_templates import llama3_get_llm_response, mistral_get_llm_response, qwen_get_llm_response, deepseek_get_llm_response, grape_get_llm_response
import os
import re


max_steps = 15

base_dir = os.path.dirname(os.path.abspath(__file__))


def select_model(model_type: str):
    """Return the correct LLM response function for a given model_type."""

    mapping = {
        "groq_api": grok_get_llm_response,
        "llama_api": API_llama_get_llm_response,
        "oss_api": open_oss_get_llm_response,
        "openai_api": openai_get_llm_response,
        "deepseek_api": deepseekapi_get_llm_response,
        "llama3": llama3_get_llm_response,
        "mistral": mistral_get_llm_response,
        "qwen3": qwen_get_llm_response,
        "deepseek": deepseek_get_llm_response,
        "grape": grape_get_llm_response,
    }

    if model_type not in mapping:
        raise ValueError(f"Unknown model_type: {model_type}")

    return mapping[model_type]


def format_gaia_response(model_type, last_observation, question_out):

    get_llm_response = select_model(model_type)

    # Process Gaia 
    with open(base_dir+"/system_prompt_final.txt", "r") as f:
        final_sys_prompt = f.read()

    gaia_prompt = (
        f"{final_sys_prompt}\n\n"
        f"User Question:\n{question_out}\n\n"
        f"Last Observation:\n{last_observation}\n\n"
        "Please review user questions and the last obervation and respond with the correct answer, in the correct format. No extra text, just the answer."
    )

    final_answer_out = get_llm_response(final_sys_prompt, gaia_prompt, reasoning_format = 'hidden')

    return final_answer_out


class ImprovedAgent:
    def __init__(self, tools: Dict[str, Callable], model_type: str):
        self.tools = tools
        self.history = []
        self.get_llm_response = select_model(model_type)


        # Load system prompts from .txt files
        self.system_prompt_plan = self.load_prompt(base_dir+"/system_prompt_planning.txt")
        self.system_prompt_thought = self.load_prompt(base_dir+"/system_prompt_thought.txt")
        self.system_prompt_action = self.load_prompt(base_dir+"/system_prompt_action.txt")
        self.system_prompt_observe = self.load_prompt(base_dir+"/system_prompt_observe.txt")


    def load_prompt(self, filepath: str) -> str:
        with open(filepath, "r") as f:
            return f.read()

    def reset(self):
        self.history = []
    def strip_markdown_code_block(self, text: str) -> str:
        """
        Remove leading/trailing markdown code block markers like ```json or ```
        """
        # Remove leading ```json or ``` (case-insensitive, multiline-safe)
        text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
        # Remove trailing ```
        text = re.sub(r"\s*```$", "", text)
        return text.strip()
    def parse_json_response(self, response_text: str) -> Dict:
        """Attempt to parse LLM JSON response safely."""
        
        try:

            cleaned = self.strip_markdown_code_block(response_text.strip())
  
            json_text = self.extract_json_string(cleaned)

            json_text = json_text.replace("\\'", "'")
            #json_text = json_text.replace("\n", "\\n")

            return json.loads(json_text)

        except json.JSONDecodeError as e:
            print(f"[ERROR] JSON Parse Error: {e}")
            print(f"[DEBUG] Raw response: {response_text}")
            return {"error": f"Invalid JSON response: {str(e)}"}

    def extract_json_string(self, text: str) -> str:
        """Extract the first valid-looking JSON object from a string."""
        match = re.search(r'\{.*\}', text, re.DOTALL)
        return match.group(0) if match else text

    def build_prompt_from_history(self, query: str) -> str:
        return f"""User Query: {query}
History: {json.dumps(self.history, indent=2)}
"""

    def run(self, query: str):
        self.reset()

        # Step 1: Planning Agent
        planning_input = f"User Query: {query}"
        print("-----Stage Plan-----")
        
        plan_response = self.get_llm_response(self.system_prompt_plan, planning_input)
        print("-----Plan Text-----")
        print(plan_response)
        print("-------------------")
        print("-----Plan Parsed-----")
        parsed_plan = self.parse_json_response(plan_response)
        print(parsed_plan)
        print("---------------------")
        self.history.append(parsed_plan)

        current_input = self.build_prompt_from_history(query)

        for _ in range(max_steps):  # maximum 5 loops

            print(f"-----Itterantion {_}-----")
            # Step 2: Thought Agent
            print("-----Stage Thought-----")
            
            thought_response = self.get_llm_response(self.system_prompt_thought, current_input)
            print(thought_response)
            parsed_thought = self.parse_json_response(thought_response)
            print("-----Thought Parsed-----")
            print(parsed_thought)
            print("-----------------")
            self.history.append(parsed_thought)

            # Step 3: Action Agent
            if "thought" not in parsed_thought:
                return "[ERROR] Thought agent did not return 'thought'. Ending.", ""
            action_input = json.dumps({"thought": parsed_thought["thought"]})
            print("-----Stage Action-----")
            
            action_response_text = self.get_llm_response(self.system_prompt_action, action_input)

            # With this:
            try:
                # Handle <think> tags
                if '<think>' in action_response_text and '</think>' in action_response_text:
                    json_part = action_response_text.split('</think>')[1].strip()
                else:
                    json_part = action_response_text.strip()
                
                # Extract JSON
                import re
                json_match = re.search(r'\{.*\}', json_part)
                if json_match:
                    parsed_action = json.loads(json_match.group())
                else:
                    parsed_action = {'error': 'No JSON found in response'}
                    
            except Exception as e:
                parsed_action = {'error': f'JSON parsing failed: {str(e)}'}
            print(parsed_action)
            print("-----------------")
            self.history.append(parsed_action)

            # Step 4: Tool Execution
            tool_name = parsed_action.get("action")
            tool_args = parsed_action.get("action_input", {})
            # print("-----Tool Name-----")
            # print(tool_name)
            # print("-----Tool Args-----")
            # print(tool_args)
            # print("-----------------")
            if not tool_name or tool_name not in self.tools:
                observation = f"[ERROR] Invalid or missing tool: {tool_name}"
            else:
                try:
                    result = self.tools[tool_name](**tool_args)
                    observation = f"Tool `{tool_name}` executed successfully. Output: {result}"
                    print("-----Tool Observation OK-----")
                    print(observation)
                    print("-----------------")
                    
                except Exception as e:
                    observation = f"[ERROR] Tool `{tool_name}` execution failed: {str(e)}"
                    print("-----Tool Observation Fail-----")
                    print(observation)
                    print("-----------------")

            # Store the tool result explicitly in history
            self.history.append({
                "tool_name": tool_name,
                "tool_args": tool_args,
                #"tool_output": result if 'result' in locals() else None
            })

            # Step 5: Observation Agent
            
            observation_input = f"""User Query: {query}
                                    Plan: {json.dumps(self.history[0], indent=2)}
                                    History: {json.dumps(self.history, indent=2)}
                                    Tool Output: {observation}
                                    """         
            print("-----Stage Observe-----")
            observation_response_text = self.get_llm_response(self.system_prompt_observe, observation_input)

            print("-----Observation Parsed-----")
            parsed_observation = self.parse_json_response(observation_response_text)
            print(parsed_observation)
            print("-----------------")
            self.history.append(parsed_observation)

            # Step 6: Check for final answer
            if "final_answer" in parsed_observation:
                print(parsed_observation["final_answer"])
                #break
                return self.history, observation_response_text, parsed_observation["final_answer"]

            # Step 7: Update prompt for next loop
            current_input = self.build_prompt_from_history(query)

        print('ERROR LOOP LIMIT REACHED')
        return self.history, observation_response_text + "This is our last observation. Make your best estimation given the question.",  parsed_observation