Spaces:
Running
Running
| """ | |
| NeuraPrompt Agent v9.0.1 — Unlimited Power Update + Groq Fallback | |
| ----------------------------------------------------------------- | |
| Updates over v9.0: | |
| * Groq fallback for _call_llm — when OpenRouter returns 429 / 449 / 503 / | |
| 502 / 504, or any "LLM unavailable / rate limit / overloaded" body, the | |
| agent automatically retries the SAME messages against Groq using the same | |
| JSON protocol ({thought, action, input}). The rest of stream_agent is | |
| unchanged — it never sees the difference. | |
| * If OPENROUTE_KEY is missing but GROQ_API_KEY is present, the agent runs | |
| directly on Groq (no OpenRouter dependency). | |
| Original v9.0 features: document_tools (PDF, DOCX, Excel, CSV), | |
| media_tools (charts, QR, image ops, zip). | |
| """ | |
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import logging | |
| import re | |
| from typing import Optional, Dict, Any | |
| import requests | |
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from .prompts.system_prompt import get_system_prompt | |
| from .memory.memory import get_memory_manager | |
| from .tools.registry import register_tool, get_tool, get_tool_descriptions, list_tools | |
| from .tools import web_tools, code_tools, file_tools, vision_tools, document_tools, media_tools | |
| from .tools.github_tools import ( | |
| github_list_repos, github_create_repo, github_get_repo, | |
| github_list_issues, github_create_issue, github_read_file, | |
| github_write_file, github_create_branch, github_create_pull_request, | |
| github_search_code, github_get_user_profile | |
| ) | |
| from .schemas.tool_schemas import ( | |
| # Existing | |
| RunShellInput, RunPythonInput, CreateFileInput, | |
| WebSearchInput, FetchUrlInput, AnalyzeImageInput, | |
| # Document tools | |
| CreatePDFInput, CreateDOCXInput, CreateExcelInput, | |
| CreateCSVInput, CreateTextFileInput, | |
| # Media tools | |
| CreateBarChartInput, CreateLineChartInput, CreatePieChartInput, | |
| CreateQRCodeInput, ResizeImageInput, CreateZipInput, | |
| # GitHub | |
| GitHubListReposInput, GitHubCreateRepoInput, GitHubGetRepoInput, | |
| GitHubListIssuesInput, GitHubCreateIssueInput, GitHubReadFileInput, | |
| GitHubWriteFileInput, GitHubCreateBranchInput, GitHubCreatePullRequestInput, | |
| GitHubSearchCodeInput, GitHubGetUserProfileInput | |
| ) | |
| log = logging.getLogger("agent.core.v9.0") | |
| # ==================== CONFIG ==================== | |
| OPENROUTER_KEY = os.getenv("OPENROUTE_KEY", "") | |
| MAX_STEPS = 5 | |
| MAX_TOKENS = 8000 | |
| PRIMARY_MODEL = "poolside/laguna-m.1:free" | |
| TEMPERATURE = 0.7 | |
| # ---- Groq fallback config (NEW in v9.0.1) ---- | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "") | |
| GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions" | |
| GROQ_FALLBACK_MODEL = os.getenv("AGENT_GROQ_MODEL", "groq/compound") | |
| GROQ_FALLBACK_TOKENS = 4096 | |
| # HTTP status codes / body substrings that trigger Groq fallback. | |
| GROQ_FALLBACK_STATUS_CODES = {429, 449, 503, 502, 504} | |
| GROQ_FALLBACK_SUBSTRINGS = ( | |
| "rate limit", "rate_limit", "rate limited", | |
| "llm unavailable", "model unavailable", | |
| "service unavailable", "temporarily unavailable", | |
| "overloaded", "capacity", | |
| ) | |
| # ==================== TOOL SCHEMAS ==================== | |
| TOOL_SCHEMAS: Dict[str, type[BaseModel]] = { | |
| # Existing | |
| "run_shell": RunShellInput, | |
| "run_python": RunPythonInput, | |
| "create_file": CreateFileInput, | |
| "web_search": WebSearchInput, | |
| "fetch_url": FetchUrlInput, | |
| "analyze_image": AnalyzeImageInput, | |
| # Document tools | |
| "create_pdf": CreatePDFInput, | |
| "create_docx": CreateDOCXInput, | |
| "create_excel": CreateExcelInput, | |
| "create_csv": CreateCSVInput, | |
| "create_text_file": CreateTextFileInput, | |
| # Media tools | |
| "create_bar_chart": CreateBarChartInput, | |
| "create_line_chart": CreateLineChartInput, | |
| "create_pie_chart": CreatePieChartInput, | |
| "create_qr_code": CreateQRCodeInput, | |
| "resize_image": ResizeImageInput, | |
| "create_zip": CreateZipInput, | |
| # GitHub | |
| "github_list_repos": GitHubListReposInput, | |
| "github_create_repo": GitHubCreateRepoInput, | |
| "github_get_repo": GitHubGetRepoInput, | |
| "github_list_issues": GitHubListIssuesInput, | |
| "github_create_issue": GitHubCreateIssueInput, | |
| "github_read_file": GitHubReadFileInput, | |
| "github_write_file": GitHubWriteFileInput, | |
| "github_create_branch": GitHubCreateBranchInput, | |
| "github_create_pull_request":GitHubCreatePullRequestInput, | |
| "github_search_code": GitHubSearchCodeInput, | |
| "github_get_user_profile": GitHubGetUserProfileInput, | |
| } | |
| # ==================== REGISTER TOOLS ==================== | |
| # Existing | |
| register_tool("web_search", web_tools.web_search, "Search the web for current information") | |
| register_tool("fetch_url", web_tools.fetch_url, "Fetch and extract content from a URL") | |
| register_tool("run_python", code_tools.run_python, "Execute Python code — output files are saved and download URLs returned") | |
| register_tool("run_shell", code_tools.run_shell, "Execute shell commands") | |
| register_tool("create_file", file_tools.create_file, "Create text/code files (supports batch)") | |
| register_tool("read_file", file_tools.read_file, "Read file content") | |
| register_tool("write_file", file_tools.write_file, "Write to file") | |
| register_tool("list_dir", file_tools.list_directory, "List directory contents") | |
| register_tool("analyze_image", vision_tools.analyze_image, "Analyze images with vision AI") | |
| # Document tools | |
| register_tool("create_pdf", document_tools.create_pdf, "Create a PDF document with title and content — returns download URL") | |
| register_tool("create_docx", document_tools.create_docx, "Create a Word DOCX document — returns download URL") | |
| register_tool("create_excel", document_tools.create_excel, "Create an Excel spreadsheet with headers and rows — returns download URL") | |
| register_tool("create_csv", document_tools.create_csv, "Create a CSV file — returns download URL") | |
| register_tool("create_text_file", document_tools.create_text_file, "Create a plain text or markdown file — returns download URL") | |
| # Media tools | |
| register_tool("create_bar_chart", media_tools.create_bar_chart, "Create a bar chart PNG — returns download URL") | |
| register_tool("create_line_chart", media_tools.create_line_chart, "Create a line chart PNG — returns download URL") | |
| register_tool("create_pie_chart", media_tools.create_pie_chart, "Create a pie chart PNG — returns download URL") | |
| register_tool("create_qr_code", media_tools.create_qr_code, "Generate a QR code PNG for any text or URL — returns download URL") | |
| register_tool("resize_image", media_tools.resize_image, "Resize an image from base64 — returns download URL") | |
| register_tool("create_zip", media_tools.create_zip, "Create a ZIP archive from files in agent_outputs — returns download URL") | |
| # GitHub | |
| register_tool("github_list_repos", github_list_repos, "List user's GitHub repositories") | |
| register_tool("github_create_repo", github_create_repo, "Create a new GitHub repository") | |
| register_tool("github_get_repo", github_get_repo, "Get repository details") | |
| register_tool("github_list_issues", github_list_issues, "List repository issues") | |
| register_tool("github_create_issue", github_create_issue, "Create a new issue") | |
| register_tool("github_read_file", github_read_file, "Read file from repository") | |
| register_tool("github_write_file", github_write_file, "Create or update file in repository") | |
| register_tool("github_create_branch", github_create_branch, "Create a new branch") | |
| register_tool("github_create_pull_request", github_create_pull_request, "Create a pull request") | |
| register_tool("github_search_code", github_search_code, "Search code across GitHub") | |
| register_tool("github_get_user_profile", github_get_user_profile, "Get user's GitHub profile") | |
| # ==================== MODELS ==================== | |
| class AgentRequest(BaseModel): | |
| user_id: str | |
| goal: str | |
| max_steps: int = MAX_STEPS | |
| context: Optional[str] = None | |
| memory_context: Optional[str] = None | |
| # ==================== HELPERS ==================== | |
| def _build_system_prompt(memory_context: str = "") -> str: | |
| return get_system_prompt( | |
| memory_context=memory_context, | |
| tool_descriptions=get_tool_descriptions() | |
| ) | |
| def _extract_json(content: str) -> Optional[Dict[str, Any]]: | |
| content = content.strip() | |
| if content.startswith("{"): | |
| try: | |
| return json.loads(content) | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r'(\{.*\})', content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def _is_groq_fallback_trigger(status_code: Optional[int], body: str = "") -> bool: | |
| """Decide whether an OpenRouter failure should trigger Groq fallback.""" | |
| if status_code is not None and status_code in GROQ_FALLBACK_STATUS_CODES: | |
| return True | |
| body_lower = (body or "").lower() | |
| return any(s in body_lower for s in GROQ_FALLBACK_SUBSTRINGS) | |
| def _parse_llm_message(message: dict) -> Dict[str, Any]: | |
| """Parse a chat-completion message dict into the agent's JSON protocol. | |
| Shared by OpenRouter and Groq paths so behavior is identical.""" | |
| content = (message.get("content") or "").strip() | |
| if not content: | |
| # Model may have used tool_calls or reasoning blocks | |
| reasoning = message.get("reasoning") or message.get("reasoning_content") or "" | |
| if reasoning: | |
| content = reasoning.strip() | |
| else: | |
| return { | |
| "thought": "Processing...", | |
| "action": "finish", | |
| "input": {"final_answer": "I was unable to generate a response. Please try again."} | |
| } | |
| parsed = _extract_json(content) | |
| if parsed: | |
| return parsed | |
| clean = content.strip().strip("```").strip() | |
| return { | |
| "thought": "Your message has been responded directly.", | |
| "action": "finish", | |
| "input": {"final_answer": clean} | |
| } | |
| # ==================== GROQ FALLBACK (NEW in v9.0.1) ==================== | |
| async def _call_groq_fallback(messages: list) -> Dict[str, Any]: | |
| """Groq fallback for _call_llm. Same JSON protocol as OpenRouter path. | |
| Retries 429 with exponential backoff. Raises HTTPException(503) only if | |
| Groq is unavailable or unconfigured — never returns None.""" | |
| if not GROQ_API_KEY: | |
| raise HTTPException(503, "LLM service unavailable — Groq fallback not configured.") | |
| headers = {"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"} | |
| payload = { | |
| "model": GROQ_FALLBACK_MODEL, | |
| "messages": messages, | |
| "max_tokens": GROQ_FALLBACK_TOKENS, | |
| "temperature": TEMPERATURE, | |
| } | |
| last_err: Optional[Exception] = None | |
| for attempt in range(3): | |
| try: | |
| r = requests.post( | |
| GROQ_API_URL, | |
| headers=headers, json=payload, timeout=45, | |
| ) | |
| # Groq 429 → back off and retry | |
| if r.status_code == 429: | |
| wait = min(int(r.headers.get("retry-after", 5 * (attempt + 1))), 15) | |
| log.warning(f"[Agent/Groq] 429 — waiting {wait}s (attempt {attempt + 1}/3)") | |
| await asyncio.sleep(wait) | |
| last_err = RuntimeError(f"Groq 429 (attempt {attempt + 1})") | |
| continue | |
| # Non-fallback HTTP error → log and stop retrying | |
| if r.status_code >= 400: | |
| log.error(f"[Agent/Groq] HTTP {r.status_code}: {r.text[:200]}") | |
| last_err = RuntimeError(f"Groq HTTP {r.status_code}") | |
| break | |
| r.raise_for_status() | |
| message = r.json()["choices"][0]["message"] | |
| return _parse_llm_message(message) | |
| except requests.HTTPError as e: | |
| last_err = e | |
| if e.response is not None and e.response.status_code == 429: | |
| await asyncio.sleep(5 * (attempt + 1)) | |
| continue | |
| log.error(f"[Agent/Groq] HTTPError: {e}") | |
| break | |
| except Exception as e: | |
| last_err = e | |
| log.error(f"[Agent/Groq] Unexpected error: {e}") | |
| break | |
| raise HTTPException(503, f"LLM service unavailable — Groq fallback failed: {last_err}") | |
| # ==================== LLM CALL (with Groq fallback — NEW in v9.0.1) ==================== | |
| async def _call_llm(messages: list) -> Dict[str, Any]: | |
| """Call OpenRouter; on 429/449/503/502/504 or 'LLM unavailable' body, | |
| fall back to Groq using the SAME JSON protocol. Never returns None — | |
| raises HTTPException(503) only if both backends fail.""" | |
| # If OpenRouter not configured, go straight to Groq. | |
| if not OPENROUTER_KEY: | |
| if GROQ_API_KEY: | |
| log.warning("[Agent] OPENROUTE_KEY missing — using Groq directly.") | |
| return await _call_groq_fallback(messages) | |
| raise HTTPException(503, "OPENROUTE_KEY not configured and GROQ_API_KEY not configured.") | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": PRIMARY_MODEL, | |
| "messages": messages, | |
| "max_tokens": MAX_TOKENS, | |
| "temperature": TEMPERATURE, | |
| } | |
| try: | |
| r = requests.post( | |
| "https://openrouter.ai/api/v1/chat/completions", | |
| headers=headers, json=payload, timeout=45 | |
| ) | |
| # ---- Trigger Groq fallback BEFORE raise_for_status ---- | |
| if r.status_code in GROQ_FALLBACK_STATUS_CODES: | |
| log.warning( | |
| f"[Agent] OpenRouter HTTP {r.status_code} — falling back to Groq." | |
| ) | |
| return await _call_groq_fallback(messages) | |
| r.raise_for_status() | |
| message = r.json()["choices"][0]["message"] | |
| # Detect "LLM unavailable" embedded in a 200 response body. | |
| body_text = message.get("content") or "" | |
| if _is_groq_fallback_trigger(None, body_text): | |
| log.warning( | |
| "[Agent] OpenRouter body indicates LLM unavailable — falling back to Groq." | |
| ) | |
| return await _call_groq_fallback(messages) | |
| return _parse_llm_message(message) | |
| except requests.HTTPError as e: | |
| # Check status code on the original error | |
| resp = e.response | |
| if resp is not None: | |
| if resp.status_code in GROQ_FALLBACK_STATUS_CODES: | |
| log.warning( | |
| f"[Agent] OpenRouter HTTPError {resp.status_code} — falling back to Groq." | |
| ) | |
| return await _call_groq_fallback(messages) | |
| # Also check body for fallback substrings | |
| body = (resp.text or "") | |
| if _is_groq_fallback_trigger(None, body): | |
| log.warning( | |
| "[Agent] OpenRouter error body indicates LLM unavailable — falling back to Groq." | |
| ) | |
| return await _call_groq_fallback(messages) | |
| log.error(f"[Agent] LLM HTTPError: {e}") | |
| # Last-resort fallback if Groq is configured | |
| if GROQ_API_KEY: | |
| log.warning("[Agent] Falling back to Groq after non-fallback HTTPError.") | |
| return await _call_groq_fallback(messages) | |
| raise HTTPException(503, "LLM service unavailable") | |
| except Exception as e: | |
| log.error(f"[Agent] LLM unexpected error: {e}") | |
| if GROQ_API_KEY: | |
| log.warning("[Agent] Falling back to Groq after unexpected error.") | |
| return await _call_groq_fallback(messages) | |
| raise HTTPException(503, "LLM service unavailable") | |
| async def _dispatch_tool(name: str, raw_input: dict) -> str: | |
| tool_info = get_tool(name) | |
| if not tool_info: | |
| return f"Unknown tool: {name}" | |
| schema = TOOL_SCHEMAS.get(name) | |
| if schema: | |
| try: | |
| validated = schema(**raw_input) | |
| raw_input = validated.model_dump() | |
| except Exception as e: | |
| return f"Invalid parameters for {name}: {str(e)}" | |
| try: | |
| fn = tool_info["fn"] | |
| if asyncio.iscoroutinefunction(fn): | |
| return await fn(**raw_input) | |
| else: | |
| return await asyncio.to_thread(fn, **raw_input) | |
| except Exception as e: | |
| log.error(f"Tool error ({name}): {e}") | |
| return f"Tool error ({name}): {str(e)}" | |
| def _make_sse(event_type: str, step: int, content: str, tool: str = "", elapsed: float = 0.0, reasoning: str = "") -> str: | |
| data = { | |
| "type": event_type, | |
| "step": step, | |
| "content": content, | |
| "tool": tool, | |
| "elapsed": round(elapsed, 2), | |
| "thought": reasoning | |
| } | |
| return f"data: {json.dumps(data)}\n\n" | |
| async def _execute_tool_and_update_history( | |
| action: str, tool_input: dict, messages: list, thought: str | |
| ) -> tuple[str, float]: | |
| tool_start = time.time() | |
| result = await _dispatch_tool(action, tool_input) | |
| tool_elapsed = time.time() - tool_start | |
| messages.append({ | |
| "role": "assistant", | |
| "content": json.dumps({"thought": thought, "action": action, "input": tool_input}) | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": f"[Tool Result from {action}]\n{result}\n\nContinue with next JSON action or finish." | |
| }) | |
| return result, tool_elapsed | |
| # ==================== MAIN STREAMING AGENT ==================== | |
| async def stream_agent( | |
| goal: str, | |
| user_id: str, | |
| context: str = "", | |
| memory_context: str = "", | |
| max_steps: int = MAX_STEPS | |
| ): | |
| memory = get_memory_manager(user_id) | |
| full_memory = memory.get_context_for_agent() + (f"\n\n{memory_context}" if memory_context else "") | |
| system_prompt = _build_system_prompt(full_memory) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"user_id: {user_id}\nGoal: {goal}" + (f"\nContext: {context}" if context else "")} | |
| ] | |
| total_start = time.time() | |
| for step in range(1, max_steps + 1): | |
| step_start = time.time() | |
| # _call_llm now handles Groq fallback internally; no changes needed here. | |
| parsed = await _call_llm(messages) | |
| thought = parsed.get("thought", "Thinking...") | |
| action = parsed.get("action", "finish") | |
| tool_input = parsed.get("input", {}) | |
| if not isinstance(tool_input, dict): | |
| tool_input = {} | |
| if action.startswith("github_"): | |
| tool_input.setdefault("user_id", user_id) | |
| yield _make_sse("reasoning", step, thought, action, time.time() - step_start, thought) | |
| if action == "finish": | |
| if isinstance(tool_input, dict): | |
| final = tool_input.get("final_answer", "") or thought | |
| else: | |
| final = str(tool_input) if tool_input else thought | |
| # Strip raw JSON, list dumps, thought leakage | |
| if not final \ | |
| or final.strip().startswith("{") \ | |
| or final.strip().startswith("[{") \ | |
| or ('"action"' in final and '"thought"' in final): | |
| # Extract any download URLs from thought as a clean message | |
| import re as _re | |
| urls = _re.findall(r'https?://\S+', thought) | |
| if urls: | |
| final = "Done! Here are your download links:\n" + "\n".join(urls) | |
| else: | |
| final = thought.split('"final_answer"')[-1].strip(' :"{}') if '"final_answer"' in thought else thought | |
| yield _make_sse("done", step, final, "finish", time.time() - total_start, thought) | |
| return | |
| yield _make_sse("tool_start", step, f"Running {action}...", action, time.time() - step_start) | |
| result, tool_elapsed = await _execute_tool_and_update_history(action, tool_input, messages, thought) | |
| display = result[:1800] + "..." if len(result) > 1800 else result | |
| yield _make_sse("result", step, display, action, tool_elapsed) | |
| await asyncio.sleep(0.12) | |
| yield _make_sse( | |
| "done", max_steps + 1, | |
| "I reached the maximum number of steps. Please try breaking the goal into smaller parts.", | |
| "finish", time.time() - total_start | |
| ) | |
| # ==================== FASTAPI ROUTER ==================== | |
| agent_router = APIRouter(prefix="/agent") | |
| async def agent_stream(req: AgentRequest): | |
| # Allow startup with only Groq configured — check both keys. | |
| if not OPENROUTER_KEY and not GROQ_API_KEY: | |
| raise HTTPException(503, "No LLM backend configured (need OPENROUTE_KEY or GROQ_API_KEY).") | |
| return StreamingResponse( | |
| stream_agent(req.goal, req.user_id, req.context or "", req.memory_context or "", req.max_steps), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" | |
| } | |
| ) | |
| async def health_check(): | |
| return { | |
| "status": "ok", | |
| "version": "9.0.1", | |
| "tools": list_tools(), | |
| "groq_fallback_enabled": bool(GROQ_API_KEY), | |
| } | |
| async def download_file(filename: str): | |
| """Serve files generated by the agent from /tmp/agent_outputs/""" | |
| from fastapi.responses import FileResponse | |
| from pathlib import Path | |
| # Security: no path traversal | |
| if ".." in filename or "/" in filename: | |
| raise HTTPException(400, "Invalid filename") | |
| file_path = Path("/tmp/agent_outputs") / filename | |
| if not file_path.exists(): | |
| raise HTTPException(404, "File not found or expired") | |
| return FileResponse( | |
| path=str(file_path), | |
| filename=filename, | |
| media_type="application/octet-stream" | |
| ) | |