| |
|
| |
|
| | from langchain_core.output_parsers import PydanticOutputParser |
| | from typing import Callable, Dict, List, Any |
| | import time |
| | import json |
| | from groq_api import grok_get_llm_response, API_llama_get_llm_response, open_oss_get_llm_response, openai_get_llm_response, deepseekapi_get_llm_response |
| | from local_templates import llama3_get_llm_response, mistral_get_llm_response, qwen_get_llm_response, deepseek_get_llm_response, grape_get_llm_response |
| | import os |
| | import re |
| |
|
| |
|
| | max_steps = 15 |
| |
|
| | base_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| |
|
| | def select_model(model_type: str): |
| | """Return the correct LLM response function for a given model_type.""" |
| |
|
| | mapping = { |
| | "groq_api": grok_get_llm_response, |
| | "llama_api": API_llama_get_llm_response, |
| | "oss_api": open_oss_get_llm_response, |
| | "openai_api": openai_get_llm_response, |
| | "deepseek_api": deepseekapi_get_llm_response, |
| | "llama3": llama3_get_llm_response, |
| | "mistral": mistral_get_llm_response, |
| | "qwen3": qwen_get_llm_response, |
| | "deepseek": deepseek_get_llm_response, |
| | "grape": grape_get_llm_response, |
| | } |
| |
|
| | if model_type not in mapping: |
| | raise ValueError(f"Unknown model_type: {model_type}") |
| |
|
| | return mapping[model_type] |
| |
|
| |
|
| | def format_gaia_response(model_type, last_observation, question_out): |
| |
|
| | get_llm_response = select_model(model_type) |
| |
|
| | |
| | with open(base_dir+"/system_prompt_final.txt", "r") as f: |
| | final_sys_prompt = f.read() |
| |
|
| | gaia_prompt = ( |
| | f"{final_sys_prompt}\n\n" |
| | f"User Question:\n{question_out}\n\n" |
| | f"Last Observation:\n{last_observation}\n\n" |
| | "Please review user questions and the last obervation and respond with the correct answer, in the correct format. No extra text, just the answer." |
| | ) |
| |
|
| | final_answer_out = get_llm_response(final_sys_prompt, gaia_prompt, reasoning_format = 'hidden') |
| |
|
| | return final_answer_out |
| |
|
| |
|
| | class ImprovedAgent: |
| | def __init__(self, tools: Dict[str, Callable], model_type: str): |
| | self.tools = tools |
| | self.history = [] |
| | self.get_llm_response = select_model(model_type) |
| |
|
| |
|
| | |
| | self.system_prompt_plan = self.load_prompt(base_dir+"/system_prompt_planning.txt") |
| | self.system_prompt_thought = self.load_prompt(base_dir+"/system_prompt_thought.txt") |
| | self.system_prompt_action = self.load_prompt(base_dir+"/system_prompt_action.txt") |
| | self.system_prompt_observe = self.load_prompt(base_dir+"/system_prompt_observe.txt") |
| |
|
| |
|
| | def load_prompt(self, filepath: str) -> str: |
| | with open(filepath, "r") as f: |
| | return f.read() |
| |
|
| | def reset(self): |
| | self.history = [] |
| | def strip_markdown_code_block(self, text: str) -> str: |
| | """ |
| | Remove leading/trailing markdown code block markers like ```json or ``` |
| | """ |
| | |
| | text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE) |
| | |
| | text = re.sub(r"\s*```$", "", text) |
| | return text.strip() |
| | def parse_json_response(self, response_text: str) -> Dict: |
| | """Attempt to parse LLM JSON response safely.""" |
| | |
| | try: |
| |
|
| | cleaned = self.strip_markdown_code_block(response_text.strip()) |
| | |
| | json_text = self.extract_json_string(cleaned) |
| |
|
| | json_text = json_text.replace("\\'", "'") |
| | |
| |
|
| | return json.loads(json_text) |
| |
|
| | except json.JSONDecodeError as e: |
| | print(f"[ERROR] JSON Parse Error: {e}") |
| | print(f"[DEBUG] Raw response: {response_text}") |
| | return {"error": f"Invalid JSON response: {str(e)}"} |
| |
|
| | def extract_json_string(self, text: str) -> str: |
| | """Extract the first valid-looking JSON object from a string.""" |
| | match = re.search(r'\{.*\}', text, re.DOTALL) |
| | return match.group(0) if match else text |
| |
|
| | def build_prompt_from_history(self, query: str) -> str: |
| | return f"""User Query: {query} |
| | History: {json.dumps(self.history, indent=2)} |
| | """ |
| |
|
| | def run(self, query: str): |
| | self.reset() |
| |
|
| | |
| | planning_input = f"User Query: {query}" |
| | print("-----Stage Plan-----") |
| | |
| | plan_response = self.get_llm_response(self.system_prompt_plan, planning_input) |
| | print("-----Plan Text-----") |
| | print(plan_response) |
| | print("-------------------") |
| | print("-----Plan Parsed-----") |
| | parsed_plan = self.parse_json_response(plan_response) |
| | print(parsed_plan) |
| | print("---------------------") |
| | self.history.append(parsed_plan) |
| |
|
| | current_input = self.build_prompt_from_history(query) |
| |
|
| | for _ in range(max_steps): |
| |
|
| | print(f"-----Itterantion {_}-----") |
| | |
| | print("-----Stage Thought-----") |
| | |
| | thought_response = self.get_llm_response(self.system_prompt_thought, current_input) |
| | print(thought_response) |
| | parsed_thought = self.parse_json_response(thought_response) |
| | print("-----Thought Parsed-----") |
| | print(parsed_thought) |
| | print("-----------------") |
| | self.history.append(parsed_thought) |
| |
|
| | |
| | if "thought" not in parsed_thought: |
| | return "[ERROR] Thought agent did not return 'thought'. Ending.", "" |
| | action_input = json.dumps({"thought": parsed_thought["thought"]}) |
| | print("-----Stage Action-----") |
| | |
| | action_response_text = self.get_llm_response(self.system_prompt_action, action_input) |
| |
|
| | |
| | try: |
| | |
| | if '<think>' in action_response_text and '</think>' in action_response_text: |
| | json_part = action_response_text.split('</think>')[1].strip() |
| | else: |
| | json_part = action_response_text.strip() |
| | |
| | |
| | import re |
| | json_match = re.search(r'\{.*\}', json_part) |
| | if json_match: |
| | parsed_action = json.loads(json_match.group()) |
| | else: |
| | parsed_action = {'error': 'No JSON found in response'} |
| | |
| | except Exception as e: |
| | parsed_action = {'error': f'JSON parsing failed: {str(e)}'} |
| | print(parsed_action) |
| | print("-----------------") |
| | self.history.append(parsed_action) |
| |
|
| | |
| | tool_name = parsed_action.get("action") |
| | tool_args = parsed_action.get("action_input", {}) |
| | |
| | |
| | |
| | |
| | |
| | if not tool_name or tool_name not in self.tools: |
| | observation = f"[ERROR] Invalid or missing tool: {tool_name}" |
| | else: |
| | try: |
| | result = self.tools[tool_name](**tool_args) |
| | observation = f"Tool `{tool_name}` executed successfully. Output: {result}" |
| | print("-----Tool Observation OK-----") |
| | print(observation) |
| | print("-----------------") |
| | |
| | except Exception as e: |
| | observation = f"[ERROR] Tool `{tool_name}` execution failed: {str(e)}" |
| | print("-----Tool Observation Fail-----") |
| | print(observation) |
| | print("-----------------") |
| |
|
| | |
| | self.history.append({ |
| | "tool_name": tool_name, |
| | "tool_args": tool_args, |
| | |
| | }) |
| |
|
| | |
| | |
| | observation_input = f"""User Query: {query} |
| | Plan: {json.dumps(self.history[0], indent=2)} |
| | History: {json.dumps(self.history, indent=2)} |
| | Tool Output: {observation} |
| | """ |
| | print("-----Stage Observe-----") |
| | observation_response_text = self.get_llm_response(self.system_prompt_observe, observation_input) |
| |
|
| | print("-----Observation Parsed-----") |
| | parsed_observation = self.parse_json_response(observation_response_text) |
| | print(parsed_observation) |
| | print("-----------------") |
| | self.history.append(parsed_observation) |
| |
|
| | |
| | if "final_answer" in parsed_observation: |
| | print(parsed_observation["final_answer"]) |
| | |
| | return self.history, observation_response_text, parsed_observation["final_answer"] |
| |
|
| | |
| | current_input = self.build_prompt_from_history(query) |
| |
|
| | print('ERROR LOOP LIMIT REACHED') |
| | return self.history, observation_response_text + "This is our last observation. Make your best estimation given the question.", parsed_observation |