Spaces:
Sleeping
Sleeping
INFINA-RD
refactor: DI infrastructure, service decomposition, repository helpers, test suite
e066621 | from __future__ import annotations | |
| import os | |
| from typing import List, Dict | |
| from dotenv import load_dotenv | |
| try: # langchain>=1.0 moved globals helpers under langchain_core | |
| from langchain.globals import set_verbose, set_debug | |
| except ImportError: | |
| from langchain_core.globals import set_verbose, set_debug | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.constants import END | |
| from langgraph.graph import StateGraph | |
| from langgraph.prebuilt import create_react_agent | |
| from agent.prompts import cli_system_prompt | |
| from agent.states import AgentConfig, AgentGraphState, ModelBackend | |
| from agent.tools import ( | |
| write_file, | |
| read_file, | |
| get_current_directory, | |
| list_files, | |
| print_tree, | |
| search_files, | |
| summarize_project, | |
| edit_file, | |
| delete_file, | |
| run_cmd, | |
| init_project_root, | |
| ) | |
| load_dotenv() | |
| DEBUG_FLAG = os.getenv("AGENT_DEBUG", "false").lower() in {"1", "true", "yes", "on"} | |
| set_debug(DEBUG_FLAG) | |
| set_verbose(DEBUG_FLAG) | |
| def _build_llm(config: AgentConfig): | |
| """Instantiate the requested backend with sensible defaults.""" | |
| if config.backend == ModelBackend.GEMINI: | |
| model = config.model or os.getenv("GEMINI_MODEL", "gemini-2.0-flash") | |
| return ChatGoogleGenerativeAI(model=model, temperature=config.temperature) | |
| if config.backend == ModelBackend.OPENROUTER: | |
| model = config.model or os.getenv("OPENROUTER_MODEL", "openrouter/meta-llama/llama-3.1-8b-instruct") | |
| base_url = os.getenv("OPENROUTER_API_URL", "https://openrouter.ai/api/v1") | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENROUTER_API_KEY is not set in the environment.") | |
| return ChatOpenAI(model=model, base_url=base_url, api_key=api_key, temperature=config.temperature) | |
| # Default to Groq backend | |
| model = config.model or os.getenv("GROQ_MODEL", "openai/gpt-oss-120b") | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY is not set in the environment.") | |
| return ChatGroq(model=model, temperature=config.temperature) | |
| class AgentRunner: | |
| """Gemini/Qwen-style CLI agent that keeps conversation state across turns.""" | |
| def __init__(self, config: AgentConfig): | |
| self.config = config | |
| self.project_directory = init_project_root(config.project_directory) | |
| self.project_summary: str | None = None | |
| self.messages: List[Dict[str, str]] = [] | |
| self.llm = _build_llm(config) | |
| self.tools = [ | |
| read_file, | |
| write_file, | |
| edit_file, | |
| delete_file, | |
| list_files, | |
| print_tree, | |
| search_files, | |
| summarize_project, | |
| get_current_directory, | |
| run_cmd, | |
| ] | |
| self.react_agent = create_react_agent(self.llm, self.tools) | |
| self.graph = self._build_graph() | |
| def _build_graph(self): | |
| workflow = StateGraph(AgentGraphState) | |
| workflow.add_node("bootstrap", self._bootstrap_node) | |
| workflow.add_node("react", self._react_node) | |
| workflow.set_entry_point("bootstrap") | |
| workflow.add_edge("bootstrap", "react") | |
| workflow.add_edge("react", END) | |
| return workflow.compile() | |
| def _bootstrap_node(self, state: AgentGraphState) -> AgentGraphState: | |
| """ | |
| Ensures project root + summary exist, injects system prompt, and appends the | |
| current user prompt to the rolling conversation. | |
| """ | |
| project_root = init_project_root(state["project_directory"]) | |
| refresh_requested = state.get("refresh_context", False) | |
| summary = state.get("project_summary") | |
| if self.config.auto_context and (refresh_requested or not summary): | |
| summary = summarize_project.run(".") | |
| messages = list(state.get("messages") or []) | |
| system_prompt = cli_system_prompt(summary) | |
| if not messages or messages[0].get("role") != "system": | |
| messages.insert(0, {"role": "system", "content": system_prompt}) | |
| else: | |
| messages[0] = {"role": "system", "content": system_prompt} | |
| pending = state.get("pending_user_message") | |
| if pending: | |
| messages.append({"role": "user", "content": pending}) | |
| return { | |
| "project_directory": project_root, | |
| "project_summary": summary, | |
| "messages": messages, | |
| "pending_user_message": None, | |
| "refresh_context": False, | |
| } | |
| def _react_node(self, state: AgentGraphState) -> AgentGraphState: | |
| """Delegates to the LangGraph ReAct agent to decide tool calls.""" | |
| result = self.react_agent.invoke({"messages": state["messages"]}) | |
| return { | |
| "messages": result["messages"], | |
| "project_summary": state.get("project_summary"), | |
| "project_directory": state.get("project_directory"), | |
| } | |
| def invoke(self, user_prompt: str, *, refresh_context: bool = False, clear_history: bool = False) -> AgentGraphState: | |
| """ | |
| Runs a single user turn through the agent. | |
| Returns the updated LangGraph state (including messages for streaming / history). | |
| """ | |
| if clear_history: | |
| self.reset_history() | |
| initial_state: AgentGraphState = { | |
| "project_directory": self.project_directory, | |
| "project_summary": None if refresh_context else self.project_summary, | |
| "messages": self.messages, | |
| "pending_user_message": user_prompt, | |
| "refresh_context": refresh_context, | |
| } | |
| final_state = self.graph.invoke(initial_state, config={"recursion_limit": self.config.recursion_limit}) | |
| self.project_summary = final_state.get("project_summary", self.project_summary) | |
| self.messages = final_state.get("messages", []) | |
| return final_state | |
| def reset_history(self): | |
| """Drops the in-memory conversation.""" | |
| self.messages = [] | |
| def refresh_summary(self) -> str: | |
| """Forces a fresh project summary and returns it.""" | |
| self.project_summary = summarize_project.run(".") | |
| return self.project_summary | |
| def conversation_history(self) -> List[Dict[str, str]]: | |
| """Returns an immutable copy of the chat history.""" | |
| return list(self.messages) | |
| # Convenience helper for legacy imports | |
| def agent_factory(config: AgentConfig) -> AgentRunner: | |
| return AgentRunner(config) | |