self-trained2 / agent /agent.py
DeepImagix's picture
Upload 2 files
6c46c0d verified
Raw
History Blame Contribute Delete
22.4 kB
"""
NeuraPrompt Agent v9.0.1 — Unlimited Power Update + Groq Fallback
-----------------------------------------------------------------
Updates over v9.0:
* Groq fallback for _call_llm — when OpenRouter returns 429 / 449 / 503 /
502 / 504, or any "LLM unavailable / rate limit / overloaded" body, the
agent automatically retries the SAME messages against Groq using the same
JSON protocol ({thought, action, input}). The rest of stream_agent is
unchanged — it never sees the difference.
* If OPENROUTE_KEY is missing but GROQ_API_KEY is present, the agent runs
directly on Groq (no OpenRouter dependency).
Original v9.0 features: document_tools (PDF, DOCX, Excel, CSV),
media_tools (charts, QR, image ops, zip).
"""
import os
import json
import time
import asyncio
import logging
import re
from typing import Optional, Dict, Any
import requests
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from .prompts.system_prompt import get_system_prompt
from .memory.memory import get_memory_manager
from .tools.registry import register_tool, get_tool, get_tool_descriptions, list_tools
from .tools import web_tools, code_tools, file_tools, vision_tools, document_tools, media_tools
from .tools.github_tools import (
github_list_repos, github_create_repo, github_get_repo,
github_list_issues, github_create_issue, github_read_file,
github_write_file, github_create_branch, github_create_pull_request,
github_search_code, github_get_user_profile
)
from .schemas.tool_schemas import (
# Existing
RunShellInput, RunPythonInput, CreateFileInput,
WebSearchInput, FetchUrlInput, AnalyzeImageInput,
# Document tools
CreatePDFInput, CreateDOCXInput, CreateExcelInput,
CreateCSVInput, CreateTextFileInput,
# Media tools
CreateBarChartInput, CreateLineChartInput, CreatePieChartInput,
CreateQRCodeInput, ResizeImageInput, CreateZipInput,
# GitHub
GitHubListReposInput, GitHubCreateRepoInput, GitHubGetRepoInput,
GitHubListIssuesInput, GitHubCreateIssueInput, GitHubReadFileInput,
GitHubWriteFileInput, GitHubCreateBranchInput, GitHubCreatePullRequestInput,
GitHubSearchCodeInput, GitHubGetUserProfileInput
)
log = logging.getLogger("agent.core.v9.0")
# ==================== CONFIG ====================
OPENROUTER_KEY = os.getenv("OPENROUTE_KEY", "")
MAX_STEPS = 5
MAX_TOKENS = 8000
PRIMARY_MODEL = "poolside/laguna-m.1:free"
TEMPERATURE = 0.7
# ---- Groq fallback config (NEW in v9.0.1) ----
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
GROQ_FALLBACK_MODEL = os.getenv("AGENT_GROQ_MODEL", "groq/compound")
GROQ_FALLBACK_TOKENS = 4096
# HTTP status codes / body substrings that trigger Groq fallback.
GROQ_FALLBACK_STATUS_CODES = {429, 449, 503, 502, 504}
GROQ_FALLBACK_SUBSTRINGS = (
"rate limit", "rate_limit", "rate limited",
"llm unavailable", "model unavailable",
"service unavailable", "temporarily unavailable",
"overloaded", "capacity",
)
# ==================== TOOL SCHEMAS ====================
TOOL_SCHEMAS: Dict[str, type[BaseModel]] = {
# Existing
"run_shell": RunShellInput,
"run_python": RunPythonInput,
"create_file": CreateFileInput,
"web_search": WebSearchInput,
"fetch_url": FetchUrlInput,
"analyze_image": AnalyzeImageInput,
# Document tools
"create_pdf": CreatePDFInput,
"create_docx": CreateDOCXInput,
"create_excel": CreateExcelInput,
"create_csv": CreateCSVInput,
"create_text_file": CreateTextFileInput,
# Media tools
"create_bar_chart": CreateBarChartInput,
"create_line_chart": CreateLineChartInput,
"create_pie_chart": CreatePieChartInput,
"create_qr_code": CreateQRCodeInput,
"resize_image": ResizeImageInput,
"create_zip": CreateZipInput,
# GitHub
"github_list_repos": GitHubListReposInput,
"github_create_repo": GitHubCreateRepoInput,
"github_get_repo": GitHubGetRepoInput,
"github_list_issues": GitHubListIssuesInput,
"github_create_issue": GitHubCreateIssueInput,
"github_read_file": GitHubReadFileInput,
"github_write_file": GitHubWriteFileInput,
"github_create_branch": GitHubCreateBranchInput,
"github_create_pull_request":GitHubCreatePullRequestInput,
"github_search_code": GitHubSearchCodeInput,
"github_get_user_profile": GitHubGetUserProfileInput,
}
# ==================== REGISTER TOOLS ====================
# Existing
register_tool("web_search", web_tools.web_search, "Search the web for current information")
register_tool("fetch_url", web_tools.fetch_url, "Fetch and extract content from a URL")
register_tool("run_python", code_tools.run_python, "Execute Python code — output files are saved and download URLs returned")
register_tool("run_shell", code_tools.run_shell, "Execute shell commands")
register_tool("create_file", file_tools.create_file, "Create text/code files (supports batch)")
register_tool("read_file", file_tools.read_file, "Read file content")
register_tool("write_file", file_tools.write_file, "Write to file")
register_tool("list_dir", file_tools.list_directory, "List directory contents")
register_tool("analyze_image", vision_tools.analyze_image, "Analyze images with vision AI")
# Document tools
register_tool("create_pdf", document_tools.create_pdf, "Create a PDF document with title and content — returns download URL")
register_tool("create_docx", document_tools.create_docx, "Create a Word DOCX document — returns download URL")
register_tool("create_excel", document_tools.create_excel, "Create an Excel spreadsheet with headers and rows — returns download URL")
register_tool("create_csv", document_tools.create_csv, "Create a CSV file — returns download URL")
register_tool("create_text_file", document_tools.create_text_file, "Create a plain text or markdown file — returns download URL")
# Media tools
register_tool("create_bar_chart", media_tools.create_bar_chart, "Create a bar chart PNG — returns download URL")
register_tool("create_line_chart", media_tools.create_line_chart, "Create a line chart PNG — returns download URL")
register_tool("create_pie_chart", media_tools.create_pie_chart, "Create a pie chart PNG — returns download URL")
register_tool("create_qr_code", media_tools.create_qr_code, "Generate a QR code PNG for any text or URL — returns download URL")
register_tool("resize_image", media_tools.resize_image, "Resize an image from base64 — returns download URL")
register_tool("create_zip", media_tools.create_zip, "Create a ZIP archive from files in agent_outputs — returns download URL")
# GitHub
register_tool("github_list_repos", github_list_repos, "List user's GitHub repositories")
register_tool("github_create_repo", github_create_repo, "Create a new GitHub repository")
register_tool("github_get_repo", github_get_repo, "Get repository details")
register_tool("github_list_issues", github_list_issues, "List repository issues")
register_tool("github_create_issue", github_create_issue, "Create a new issue")
register_tool("github_read_file", github_read_file, "Read file from repository")
register_tool("github_write_file", github_write_file, "Create or update file in repository")
register_tool("github_create_branch", github_create_branch, "Create a new branch")
register_tool("github_create_pull_request", github_create_pull_request, "Create a pull request")
register_tool("github_search_code", github_search_code, "Search code across GitHub")
register_tool("github_get_user_profile", github_get_user_profile, "Get user's GitHub profile")
# ==================== MODELS ====================
class AgentRequest(BaseModel):
user_id: str
goal: str
max_steps: int = MAX_STEPS
context: Optional[str] = None
memory_context: Optional[str] = None
# ==================== HELPERS ====================
def _build_system_prompt(memory_context: str = "") -> str:
return get_system_prompt(
memory_context=memory_context,
tool_descriptions=get_tool_descriptions()
)
def _extract_json(content: str) -> Optional[Dict[str, Any]]:
content = content.strip()
if content.startswith("{"):
try:
return json.loads(content)
except json.JSONDecodeError:
pass
match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', content, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
match = re.search(r'(\{.*\})', content, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return None
def _is_groq_fallback_trigger(status_code: Optional[int], body: str = "") -> bool:
"""Decide whether an OpenRouter failure should trigger Groq fallback."""
if status_code is not None and status_code in GROQ_FALLBACK_STATUS_CODES:
return True
body_lower = (body or "").lower()
return any(s in body_lower for s in GROQ_FALLBACK_SUBSTRINGS)
def _parse_llm_message(message: dict) -> Dict[str, Any]:
"""Parse a chat-completion message dict into the agent's JSON protocol.
Shared by OpenRouter and Groq paths so behavior is identical."""
content = (message.get("content") or "").strip()
if not content:
# Model may have used tool_calls or reasoning blocks
reasoning = message.get("reasoning") or message.get("reasoning_content") or ""
if reasoning:
content = reasoning.strip()
else:
return {
"thought": "Processing...",
"action": "finish",
"input": {"final_answer": "I was unable to generate a response. Please try again."}
}
parsed = _extract_json(content)
if parsed:
return parsed
clean = content.strip().strip("```").strip()
return {
"thought": "Your message has been responded directly.",
"action": "finish",
"input": {"final_answer": clean}
}
# ==================== GROQ FALLBACK (NEW in v9.0.1) ====================
async def _call_groq_fallback(messages: list) -> Dict[str, Any]:
"""Groq fallback for _call_llm. Same JSON protocol as OpenRouter path.
Retries 429 with exponential backoff. Raises HTTPException(503) only if
Groq is unavailable or unconfigured — never returns None."""
if not GROQ_API_KEY:
raise HTTPException(503, "LLM service unavailable — Groq fallback not configured.")
headers = {"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"}
payload = {
"model": GROQ_FALLBACK_MODEL,
"messages": messages,
"max_tokens": GROQ_FALLBACK_TOKENS,
"temperature": TEMPERATURE,
}
last_err: Optional[Exception] = None
for attempt in range(3):
try:
r = requests.post(
GROQ_API_URL,
headers=headers, json=payload, timeout=45,
)
# Groq 429 → back off and retry
if r.status_code == 429:
wait = min(int(r.headers.get("retry-after", 5 * (attempt + 1))), 15)
log.warning(f"[Agent/Groq] 429 — waiting {wait}s (attempt {attempt + 1}/3)")
await asyncio.sleep(wait)
last_err = RuntimeError(f"Groq 429 (attempt {attempt + 1})")
continue
# Non-fallback HTTP error → log and stop retrying
if r.status_code >= 400:
log.error(f"[Agent/Groq] HTTP {r.status_code}: {r.text[:200]}")
last_err = RuntimeError(f"Groq HTTP {r.status_code}")
break
r.raise_for_status()
message = r.json()["choices"][0]["message"]
return _parse_llm_message(message)
except requests.HTTPError as e:
last_err = e
if e.response is not None and e.response.status_code == 429:
await asyncio.sleep(5 * (attempt + 1))
continue
log.error(f"[Agent/Groq] HTTPError: {e}")
break
except Exception as e:
last_err = e
log.error(f"[Agent/Groq] Unexpected error: {e}")
break
raise HTTPException(503, f"LLM service unavailable — Groq fallback failed: {last_err}")
# ==================== LLM CALL (with Groq fallback — NEW in v9.0.1) ====================
async def _call_llm(messages: list) -> Dict[str, Any]:
"""Call OpenRouter; on 429/449/503/502/504 or 'LLM unavailable' body,
fall back to Groq using the SAME JSON protocol. Never returns None —
raises HTTPException(503) only if both backends fail."""
# If OpenRouter not configured, go straight to Groq.
if not OPENROUTER_KEY:
if GROQ_API_KEY:
log.warning("[Agent] OPENROUTE_KEY missing — using Groq directly.")
return await _call_groq_fallback(messages)
raise HTTPException(503, "OPENROUTE_KEY not configured and GROQ_API_KEY not configured.")
headers = {
"Authorization": f"Bearer {OPENROUTER_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": PRIMARY_MODEL,
"messages": messages,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
}
try:
r = requests.post(
"https://openrouter.ai/api/v1/chat/completions",
headers=headers, json=payload, timeout=45
)
# ---- Trigger Groq fallback BEFORE raise_for_status ----
if r.status_code in GROQ_FALLBACK_STATUS_CODES:
log.warning(
f"[Agent] OpenRouter HTTP {r.status_code} — falling back to Groq."
)
return await _call_groq_fallback(messages)
r.raise_for_status()
message = r.json()["choices"][0]["message"]
# Detect "LLM unavailable" embedded in a 200 response body.
body_text = message.get("content") or ""
if _is_groq_fallback_trigger(None, body_text):
log.warning(
"[Agent] OpenRouter body indicates LLM unavailable — falling back to Groq."
)
return await _call_groq_fallback(messages)
return _parse_llm_message(message)
except requests.HTTPError as e:
# Check status code on the original error
resp = e.response
if resp is not None:
if resp.status_code in GROQ_FALLBACK_STATUS_CODES:
log.warning(
f"[Agent] OpenRouter HTTPError {resp.status_code} — falling back to Groq."
)
return await _call_groq_fallback(messages)
# Also check body for fallback substrings
body = (resp.text or "")
if _is_groq_fallback_trigger(None, body):
log.warning(
"[Agent] OpenRouter error body indicates LLM unavailable — falling back to Groq."
)
return await _call_groq_fallback(messages)
log.error(f"[Agent] LLM HTTPError: {e}")
# Last-resort fallback if Groq is configured
if GROQ_API_KEY:
log.warning("[Agent] Falling back to Groq after non-fallback HTTPError.")
return await _call_groq_fallback(messages)
raise HTTPException(503, "LLM service unavailable")
except Exception as e:
log.error(f"[Agent] LLM unexpected error: {e}")
if GROQ_API_KEY:
log.warning("[Agent] Falling back to Groq after unexpected error.")
return await _call_groq_fallback(messages)
raise HTTPException(503, "LLM service unavailable")
async def _dispatch_tool(name: str, raw_input: dict) -> str:
tool_info = get_tool(name)
if not tool_info:
return f"Unknown tool: {name}"
schema = TOOL_SCHEMAS.get(name)
if schema:
try:
validated = schema(**raw_input)
raw_input = validated.model_dump()
except Exception as e:
return f"Invalid parameters for {name}: {str(e)}"
try:
fn = tool_info["fn"]
if asyncio.iscoroutinefunction(fn):
return await fn(**raw_input)
else:
return await asyncio.to_thread(fn, **raw_input)
except Exception as e:
log.error(f"Tool error ({name}): {e}")
return f"Tool error ({name}): {str(e)}"
def _make_sse(event_type: str, step: int, content: str, tool: str = "", elapsed: float = 0.0, reasoning: str = "") -> str:
data = {
"type": event_type,
"step": step,
"content": content,
"tool": tool,
"elapsed": round(elapsed, 2),
"thought": reasoning
}
return f"data: {json.dumps(data)}\n\n"
async def _execute_tool_and_update_history(
action: str, tool_input: dict, messages: list, thought: str
) -> tuple[str, float]:
tool_start = time.time()
result = await _dispatch_tool(action, tool_input)
tool_elapsed = time.time() - tool_start
messages.append({
"role": "assistant",
"content": json.dumps({"thought": thought, "action": action, "input": tool_input})
})
messages.append({
"role": "user",
"content": f"[Tool Result from {action}]\n{result}\n\nContinue with next JSON action or finish."
})
return result, tool_elapsed
# ==================== MAIN STREAMING AGENT ====================
async def stream_agent(
goal: str,
user_id: str,
context: str = "",
memory_context: str = "",
max_steps: int = MAX_STEPS
):
memory = get_memory_manager(user_id)
full_memory = memory.get_context_for_agent() + (f"\n\n{memory_context}" if memory_context else "")
system_prompt = _build_system_prompt(full_memory)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user_id: {user_id}\nGoal: {goal}" + (f"\nContext: {context}" if context else "")}
]
total_start = time.time()
for step in range(1, max_steps + 1):
step_start = time.time()
# _call_llm now handles Groq fallback internally; no changes needed here.
parsed = await _call_llm(messages)
thought = parsed.get("thought", "Thinking...")
action = parsed.get("action", "finish")
tool_input = parsed.get("input", {})
if not isinstance(tool_input, dict):
tool_input = {}
if action.startswith("github_"):
tool_input.setdefault("user_id", user_id)
yield _make_sse("reasoning", step, thought, action, time.time() - step_start, thought)
if action == "finish":
if isinstance(tool_input, dict):
final = tool_input.get("final_answer", "") or thought
else:
final = str(tool_input) if tool_input else thought
# Strip raw JSON, list dumps, thought leakage
if not final \
or final.strip().startswith("{") \
or final.strip().startswith("[{") \
or ('"action"' in final and '"thought"' in final):
# Extract any download URLs from thought as a clean message
import re as _re
urls = _re.findall(r'https?://\S+', thought)
if urls:
final = "Done! Here are your download links:\n" + "\n".join(urls)
else:
final = thought.split('"final_answer"')[-1].strip(' :"{}') if '"final_answer"' in thought else thought
yield _make_sse("done", step, final, "finish", time.time() - total_start, thought)
return
yield _make_sse("tool_start", step, f"Running {action}...", action, time.time() - step_start)
result, tool_elapsed = await _execute_tool_and_update_history(action, tool_input, messages, thought)
display = result[:1800] + "..." if len(result) > 1800 else result
yield _make_sse("result", step, display, action, tool_elapsed)
await asyncio.sleep(0.12)
yield _make_sse(
"done", max_steps + 1,
"I reached the maximum number of steps. Please try breaking the goal into smaller parts.",
"finish", time.time() - total_start
)
# ==================== FASTAPI ROUTER ====================
agent_router = APIRouter(prefix="/agent")
@agent_router.post("/stream")
async def agent_stream(req: AgentRequest):
# Allow startup with only Groq configured — check both keys.
if not OPENROUTER_KEY and not GROQ_API_KEY:
raise HTTPException(503, "No LLM backend configured (need OPENROUTE_KEY or GROQ_API_KEY).")
return StreamingResponse(
stream_agent(req.goal, req.user_id, req.context or "", req.memory_context or "", req.max_steps),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@agent_router.get("/health")
async def health_check():
return {
"status": "ok",
"version": "9.0.1",
"tools": list_tools(),
"groq_fallback_enabled": bool(GROQ_API_KEY),
}
@agent_router.get("/download/{filename}")
async def download_file(filename: str):
"""Serve files generated by the agent from /tmp/agent_outputs/"""
from fastapi.responses import FileResponse
from pathlib import Path
# Security: no path traversal
if ".." in filename or "/" in filename:
raise HTTPException(400, "Invalid filename")
file_path = Path("/tmp/agent_outputs") / filename
if not file_path.exists():
raise HTTPException(404, "File not found or expired")
return FileResponse(
path=str(file_path),
filename=filename,
media_type="application/octet-stream"
)