Spaces:
Sleeping
Sleeping
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage | |
| import logging | |
| import os | |
| import re | |
| from typing import List | |
| from args import Args, AgentPreset | |
| from llm_factory import LLMFactory | |
| class IAgent(): | |
| def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False): | |
| self.name = self._format_name(sys_prompt_filename) | |
| self.interface = agent_preset.get_interface() | |
| self.mock = (agent_preset.get_model_name() == "groot") | |
| self.tools = tools # <-- store tools for tool call execution | |
| # Load the system prompt from a file | |
| system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename) | |
| self.system_prompt = "" | |
| with open(system_prompt_path, "r") as file: | |
| self.system_prompt = file.read().strip() | |
| # Define LLM | |
| llm = LLMFactory.create(agent_preset) | |
| # Add tools | |
| if tools: | |
| self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls) | |
| else: | |
| self.model = llm | |
| def _format_name(sys_prompt_filename: str) -> str: | |
| # Remove file extension | |
| name_without_ext = os.path.splitext(sys_prompt_filename)[0] | |
| # Remove numbers and special characters from the beginning | |
| cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext) | |
| return cleaned_name | |
| def _bake_roles(messages: List[str]) -> List[AnyMessage]: | |
| """ | |
| Assigns roles to messages in reverse order: last message is HumanMessage, | |
| previous is AIMessage, and so on, alternating backwards. | |
| Args: | |
| messages (List[str]): List of message strings. | |
| Returns: | |
| List[AnyMessage]: List of messages wrapped with appropriate role classes. | |
| Raises: | |
| ValueError: If messages is empty. | |
| """ | |
| if not messages: | |
| raise ValueError("The list of messages cannot be empty !") | |
| messages_with_roles = [] | |
| total_messages = len(messages) | |
| for idx, msg in enumerate(messages): | |
| # Assign roles in reverse: last is Human, previous is AI, etc. | |
| reverse_idx = total_messages - idx - 1 | |
| if reverse_idx % 2 == 0: | |
| messages_with_roles.append(HumanMessage(content=msg)) | |
| else: | |
| messages_with_roles.append(AIMessage(content=msg)) | |
| return messages_with_roles | |
| def get_system_prompt(self) -> str: | |
| """ | |
| Retrieves the system prompt. | |
| Returns: | |
| str: The system prompt string. | |
| """ | |
| return self.system_prompt | |
| def _handle_tool_calls(self, tool_calls): | |
| """ | |
| Executes tool calls and returns their results as a string. | |
| """ | |
| tool_results = [] | |
| for call in tool_calls: | |
| tool_name = None | |
| tool_args = {} | |
| # Qwen-style: call['function']['name'], call['function']['arguments'] | |
| if "function" in call: | |
| tool_name = call["function"].get("name") | |
| import json | |
| try: | |
| tool_args = json.loads(call["function"].get("arguments", "{}")) | |
| except Exception: | |
| tool_args = {} | |
| # OpenAI-style: call['name'], call['args'] | |
| elif "name" in call and "args" in call: | |
| tool_name = call["name"] | |
| tool_args = call["args"] | |
| tool = next((t for t in self.tools if getattr(t, "name", None) == tool_name), None) | |
| if tool is not None: | |
| try: | |
| # Handle "__arg1" as positional argument for single-argument tools | |
| if isinstance(tool_args, dict) and len(tool_args) == 1 and "__arg1" in tool_args: | |
| result = tool.func(tool_args["__arg1"]) | |
| elif isinstance(tool_args, dict): | |
| result = tool.func(**tool_args) | |
| else: | |
| result = tool.func(tool_args) | |
| tool_results.append(f"[{tool_name}]: {result}") | |
| except Exception as e: | |
| tool_results.append(f"[{tool_name} ERROR]: {str(e)}") | |
| else: | |
| tool_results.append(f"[{tool_name} ERROR]: Tool not found") | |
| return "\n".join(tool_results) | |
| def query(self, messages: List[str]) -> str: | |
| """ | |
| Asynchronously queries the agent with a given question and returns the response. | |
| Args: | |
| question (str): The question to be sent to the agent. | |
| Returns: | |
| str: The response from the agent as a string. | |
| """ | |
| if Args.LOGGER is None: | |
| raise RuntimeError("LOGGER must be defined before querying the agent.") | |
| separator = "==============================" | |
| Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{messages}\nLAST INPUT:{messages[-1]}\n") | |
| if self.mock: | |
| response = str("I am GROOT !") | |
| Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
| return response | |
| system_prompt = self.get_system_prompt() | |
| # Disable thinking block for some models | |
| if Args.MiscParams.NO_THINK: | |
| messages[-1] += "\n/no_think" | |
| messages_with_roles = self._bake_roles(messages) | |
| conversation = [SystemMessage(content=system_prompt)] + messages_with_roles | |
| raw_output = self.model.invoke(conversation) | |
| # --- Unified output and tool call handling --- | |
| response = "" | |
| # 1. Handle tool calls if present | |
| tool_calls = getattr(raw_output, "additional_kwargs", {}).get("tool_calls", None) | |
| if tool_calls: | |
| Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' called tools !\n") | |
| response = self._handle_tool_calls(tool_calls) | |
| # 2. Otherwise, use standard LLM output if present | |
| elif hasattr(raw_output, "content") and raw_output.content: | |
| response = str(raw_output.content) | |
| # 3. Fallback: string conversion | |
| else: | |
| response = str(raw_output) | |
| Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
| return response | |