Brain / app.py
Percy3822's picture
Create app.py
e54d33d verified
raw
history blame
9.7 kB
import os
import asyncio
import time
from typing import List, Dict, Any, Optional, Literal
import httpx
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import BaseModel, Field, ValidationError
# =========================
# Config (from env vars)
# =========================
PYTHON_AI_URL = os.getenv("PYTHON_AI_URL", "") # e.g. https://<you>-python_ai_space.hf.space
TTS_URL = os.getenv("TTS_URL", "") # e.g. https://<you>-tts_space.hf.space
STT_URL = os.getenv("STT_URL", "") # (optional) future WS/HTTP
VISION_URL = os.getenv("VISION_URL", "") # (optional) OCR/thumbnail summaries
MEMORY_URL = os.getenv("MEMORY_URL", "") # (optional) external memory/RAG svc
PROMPT_BUDGET_BYTES = int(os.getenv("PROMPT_BUDGET_BYTES", "12000"))
MEMORY_BUDGET_BYTES = int(os.getenv("MEMORY_BUDGET_BYTES", "6000"))
VIEWPORT_MAX_LINES = int(os.getenv("VIEWPORT_MAX_LINES", "60")) # visible ±30 default
REQUEST_TIMEOUT_S = float(os.getenv("REQUEST_TIMEOUT_S", "60"))
CONNECT_TIMEOUT_S = float(os.getenv("CONNECT_TIMEOUT_S", "10"))
# =========================
# Models / Schemas
# =========================
class Cursor(BaseModel):
l: int = Field(..., description="line")
c: int = Field(..., description="col")
class Viewport(BaseModel):
start: int
end: int
text: str
class Diagnostic(BaseModel):
l: int
sev: str
msg: str
class Memory(BaseModel):
short: List[str] = []
sess: List[str] = []
proj: List[str] = []
class Telemetry(BaseModel):
file: str
lang: str
cursor: Cursor
viewport: Viewport
diag: List[Diagnostic] = []
term: str = ""
class CodeHelpIn(BaseModel):
utterance: str
telemetry: Telemetry
memory: Memory = Memory()
response_mode: Literal["patch","full"] = "patch"
class Need(BaseModel):
function: bool = False
xrefs: List[str] = []
page_ids: List[str] = []
class PythonAIOutput(BaseModel):
mode: Literal["patch","full","ask"]
patch: str = ""
full_text: str = ""
explanation: str = ""
confidence: float = 0.5
need: Need = Need()
class CodeHelpOut(BaseModel):
ai: PythonAIOutput
tts_audio_url: Optional[str] = None
used_memory_bytes: int
used_prompt_bytes: int
notes: Dict[str, Any] = {}
# =========================
# App & HTTP client
# =========================
app = FastAPI(title="Brain (Router)", version="1.0")
client = httpx.AsyncClient(
timeout=httpx.Timeout(REQUEST_TIMEOUT_S, connect=CONNECT_TIMEOUT_S),
headers={"User-Agent": "BrainRouter/1.0"}
)
# =========================
# Small Utilities
# =========================
def _truncate_bytes(s: str, budget: int) -> str:
"""Truncate a string to a byte budget (UTF-8 safe)."""
b = s.encode("utf-8")
if len(b) <= budget:
return s
return b[:budget].decode("utf-8", errors="ignore")
def _shrink_lines_to_max(window: Viewport, max_lines: int) -> Viewport:
lines = window.text.splitlines()
if len(lines) <= max_lines:
return window
keep = max_lines
start_line = max(window.start, window.end - keep + 1)
slice_start = max(0, len(lines) - keep)
new_text = "\n".join(lines[slice_start:])
return Viewport(start=start_line, end=window.end, text=new_text)
async def _safe_post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
try:
r = await client.post(url, json=payload)
r.raise_for_status()
return r.json()
except Exception as e:
raise HTTPException(status_code=502, detail=f"POST {url} failed: {e}")
# =========================
# Priority Queue (P0/P1/P2)
# =========================
# P0: speech/telemetry (not used yet in this minimal Brain, reserved)
# P1: code model + TTS (interactive)
# P2: thumbnails / heavy analysis (future)
TASK_Q: "asyncio.PriorityQueue[tuple[int,float,dict]]" = asyncio.PriorityQueue()
async def worker_loop():
while True:
priority, ts, task = await TASK_Q.get()
try:
handler = task.get("handler")
if handler:
await handler(**task.get("args", {}))
except Exception:
# We keep the worker resilient; detailed logs would go here.
pass
finally:
TASK_Q.task_done()
@app.on_event("startup")
async def _startup():
# Start a couple of workers
asyncio.create_task(worker_loop())
asyncio.create_task(worker_loop())
@app.on_event("shutdown")
async def _shutdown():
try:
await client.aclose()
except Exception:
pass
# =========================
# Health & Warmup
# =========================
@app.get("/health")
async def health():
deps = {
"python_ai": bool(PYTHON_AI_URL),
"tts": bool(TTS_URL),
"stt": bool(STT_URL),
"vision": bool(VISION_URL),
"memory": bool(MEMORY_URL),
}
return {"ok": True, "deps": deps, "version": "1.0"}
@app.post("/warmup")
async def warmup():
"""Optionally ping downstream services to avoid cold starts."""
notes = {}
if PYTHON_AI_URL:
try:
# If your Python AI exposes /health, use it. Otherwise skip.
res = await _safe_post_json(f"{PYTHON_AI_URL}/code_help", {
"intent":"ping","file":"_warmup_.py","lang":"python",
"cursor":{"l":1,"c":1},
"viewport":{"start":1,"end":1,"text":"print('warmup')\n"},
"diag": [], "term":"", "mem":{"short":[],"sess":[],"proj":[]}
})
notes["python_ai"] = "ok" if res else "no-response"
except Exception as e:
notes["python_ai"] = f"err: {e}"
if TTS_URL:
try:
res = await _safe_post_json(f"{TTS_URL}/speak", {"text":"warming up"})
notes["tts"] = "ok" if "audio_path" in res else "no-audio"
except Exception as e:
notes["tts"] = f"err: {e}"
return {"ok": True, "notes": notes}
# =========================
# Core: Code Help endpoint
# =========================
def _enforce_budgets(t: Telemetry, m: Memory) -> tuple[Telemetry, Memory, int, int]:
# shrink viewport to max lines
t2 = Telemetry(
file=t.file, lang=t.lang, cursor=t.cursor,
viewport=_shrink_lines_to_max(t.viewport, VIEWPORT_MAX_LINES),
diag=t.diag[:5], # cap diagnostics
term=t.term
)
# compress memory bullets and apply byte budget
mem_text = " | ".join(m.short + m.sess + m.proj)
mem_text = _truncate_bytes(mem_text, MEMORY_BUDGET_BYTES)
# reconstruct memory by naive split (keeps one string bucketed in 'sess')
m2 = Memory(short=[], sess=[mem_text] if mem_text else [], proj=[])
# count budgets (approx: sum key strings + text fields)
used_mem = len(mem_text.encode("utf-8"))
prompt_bytes = (
len(t2.file) + len(t2.lang) +
len(t2.viewport.text) + sum(len(d.msg) for d in t2.diag) +
len(t2.term) + used_mem
)
if prompt_bytes > PROMPT_BUDGET_BYTES:
# tighten viewport again by half if still too large
t2 = Telemetry(
file=t2.file, lang=t2.lang, cursor=t2.cursor,
viewport=_shrink_lines_to_max(t2.viewport, max(20, VIEWPORT_MAX_LINES//2)),
diag=t2.diag, term=_truncate_bytes(t2.term, 1024)
)
prompt_bytes = (
len(t2.file) + len(t2.lang) +
len(t2.viewport.text) + sum(len(d.msg) for d in t2.diag) +
len(t2.term) + used_mem
)
return t2, m2, used_mem, prompt_bytes
async def _route_python_ai(payload: Dict[str, Any]) -> PythonAIOutput:
if not PYTHON_AI_URL:
raise HTTPException(status_code=500, detail="PYTHON_AI_URL not configured")
res = await _safe_post_json(f"{PYTHON_AI_URL}/code_help", payload)
try:
return PythonAIOutput(**res)
except ValidationError as ve:
# If downstream returns bad JSON, surface as 502 with details.
raise HTTPException(status_code=502, detail=f"Bad AI JSON schema: {ve}")
async def _send_tts(text: str) -> Optional[str]:
if not TTS_URL or not text:
return None
try:
res = await _safe_post_json(f"{TTS_URL}/speak", {"text": text})
# HF Space fastapi static path helper
audio_path = res.get("audio_path")
if not audio_path:
return None
# Convert to absolute URL the browser can open
base = TTS_URL.rstrip("/")
name = audio_path.split("/")[-1]
return f"{base}/file/{name}"
except Exception:
return None
@app.post("/code_help", response_model=CodeHelpOut)
async def code_help(x: CodeHelpIn):
# 1) enforce budgets / shrink context
t2, m2, used_mem, used_prompt = _enforce_budgets(x.telemetry, x.memory)
# 2) build compact contract for Python AI (as agreed)
py_in = {
"intent": x.utterance,
"file": t2.file,
"lang": t2.lang,
"cursor": {"l": t2.cursor.l, "c": t2.cursor.c},
"viewport": {"start": t2.viewport.start, "end": t2.viewport.end, "text": t2.viewport.text},
"diag": [{"l": d.l, "sev": d.sev, "msg": d.msg} for d in t2.diag],
"term": t2.term,
"mem": {"short": m2.short, "sess": m2.sess, "proj": m2.proj}
}
# 3) call Python AI (async)
ai_out = await _route_python_ai(py_in)
# 4) send short voice summary in parallel (explanation only)
tts_url = await _send_tts(ai_out.explanation)
# 5) respond
return CodeHelpOut(
ai=ai_out,
tts_audio_url=tts_url,
used_memory_bytes=used_mem,
used_prompt_bytes=used_prompt,
notes={"response_mode": x.response_mode}
)