sammoftah's picture
Deploy HF Agentic Search for Build Small
96cf94c verified
Raw
History Blame Contribute Delete
3.62 kB
"""HF Agentic Search - Gradio Space entrypoint."""
from __future__ import annotations
import json
import os
import sys
import threading
import uuid
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
_ROOT = os.path.dirname(os.path.abspath(__file__))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
import gradio as gr
from fastapi import Request
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from backend.agent import MAX_TASK_LENGTH, weave, weave_events
app = gr.Server()
_DIST = os.path.join(_ROOT, "frontend", "dist")
_sessions: dict[str, dict] = {}
_lock = threading.Lock()
def _session_id(request: Request) -> str:
supplied = (request.headers.get("X-Session-Id") or "").strip()
return supplied[:80] if supplied else f"dw-{uuid.uuid4().hex[:16]}"
async def _task_from_request(request: Request) -> tuple[str | None, JSONResponse | None]:
try:
body = await request.json()
except Exception:
return None, JSONResponse({"error": "Body is not valid JSON."}, status_code=400)
task = str(body.get("task") or "").strip()
if not task:
return None, JSONResponse({"error": "Task description is required."}, status_code=400)
if len(task) > MAX_TASK_LENGTH:
return None, JSONResponse(
{"error": f"Task description must be {MAX_TASK_LENGTH} characters or fewer."},
status_code=422,
)
return task, None
@app.post("/weave")
async def api_weave(request: Request):
task, error = await _task_from_request(request)
if error:
return error
sid = _session_id(request)
try:
result = weave(task or "")
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=502)
with _lock:
_sessions[sid] = result
return JSONResponse(result, headers={"X-Session-Id": sid})
@app.post("/weave/stream")
async def api_weave_stream(request: Request):
task, error = await _task_from_request(request)
if error:
return error
sid = _session_id(request)
def stream():
try:
for event in weave_events(task or ""):
if event["type"] == "complete":
with _lock:
_sessions[sid] = event["result"]
yield json.dumps(event, ensure_ascii=True) + "\n"
except Exception as exc:
yield json.dumps({"type": "error", "message": str(exc)}) + "\n"
return StreamingResponse(
stream(),
media_type="application/x-ndjson",
headers={"X-Session-Id": sid, "Cache-Control": "no-store"},
)
@app.get("/state")
async def get_state(request: Request):
sid = _session_id(request)
with _lock:
result = _sessions.get(sid)
if result is None:
result = {
"datasets": [], "nodes": [], "threads": [], "task": "",
"top_pick": None, "fallback_used": False,
}
return JSONResponse(result, headers={"X-Session-Id": sid})
_DIST_ASSETS = os.path.join(_DIST, "assets")
if os.path.isdir(_DIST_ASSETS):
app.mount("/assets", StaticFiles(directory=_DIST_ASSETS), name="assets")
@app.get("/")
def index():
path = os.path.join(_DIST, "index.html")
if os.path.isfile(path):
return FileResponse(path)
return JSONResponse({"error": "Frontend not built."}, status_code=503)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860"))),
show_error=True,
)