lea-GEO / db_service /app.py
hsmm's picture
Initial commit for HF Space
35bdde1
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import os
import httpx
class Settings:
def __init__(self) -> None:
self.supabase_url = os.environ.get("SUPABASE_URL")
self.supabase_key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") or os.environ.get("SUPABASE_ANON_KEY")
class SupabaseConfig(BaseModel):
url: str
key: str
@property
def rest_url(self) -> str:
return self.url.rstrip("/") + "/rest/v1"
def sb_headers(cfg: SupabaseConfig, extra: dict[str, str] | None = None) -> dict[str, str]:
headers = {
"apikey": cfg.key,
"authorization": f"Bearer {cfg.key}",
"content-type": "application/json",
}
if extra:
headers.update(extra)
return headers
async def sb_get(cfg: SupabaseConfig, path: str, params: dict[str, Any]) -> Any:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.get(cfg.rest_url + path, headers=sb_headers(cfg), params=params)
r.raise_for_status()
return r.json()
async def sb_upsert(cfg: SupabaseConfig, path: str, json_payload: Any, on_conflict: str) -> None:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.post(
cfg.rest_url + path,
headers=sb_headers(cfg, {"prefer": "resolution=merge-duplicates,return=minimal"}),
params={"on_conflict": on_conflict},
json=json_payload,
)
r.raise_for_status()
class BrandPairsResponse(BaseModel):
version: str
pairs: list[dict[str, Any]]
def utcnow_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
app = FastAPI(title="GEO DB Service (Supabase)", version="0.1.0")
settings = Settings()
def get_supabase_cfg() -> SupabaseConfig:
if not settings.supabase_url or not settings.supabase_key:
raise HTTPException(status_code=500, detail="supabase_not_configured")
return SupabaseConfig(url=settings.supabase_url, key=settings.supabase_key)
@app.get("/health")
async def health():
return {"ok": True}
@app.get("/brand-pairs/latest", response_model=BrandPairsResponse)
async def brand_pairs_latest():
cfg = get_supabase_cfg()
latest = await sb_get(
cfg,
"/dim_brand_pairs",
params={"select": "version,effective_at", "order": "effective_at.desc", "limit": "1"},
)
if not latest:
return BrandPairsResponse(version="0", pairs=[])
version = latest[0]["version"]
pairs = await sb_get(
cfg,
"/dim_brand_pairs",
params={
"select": "main_brand,competitor_brand,industry,priority",
"version": f"eq.{version}",
"order": "priority.desc",
},
)
return BrandPairsResponse(version=version, pairs=pairs)
class RunCreateRequest(BaseModel):
run_id: str
trigger: str
schedule_id: str | None = None
config_version: str
config_snapshot_url: str | None = None
params: dict[str, Any] = Field(default_factory=dict)
@app.post("/runs")
async def create_run(req: RunCreateRequest):
cfg = get_supabase_cfg()
payload = {
"run_id": req.run_id,
"trigger": req.trigger,
"schedule_id": req.schedule_id,
"config_version": req.config_version,
"config_snapshot_url": req.config_snapshot_url,
"params_json": req.params,
"started_at": utcnow_iso(),
"finished_at": None,
}
await sb_upsert(cfg, "/runs", json_payload=payload, on_conflict="run_id")
return {"ok": True}
class TaskUpsertRequest(BaseModel):
run_id: str
task: dict[str, Any]
@app.post("/tasks")
async def upsert_task(req: TaskUpsertRequest):
cfg = get_supabase_cfg()
task = req.task or {}
task_id = task.get("task_id")
if not task_id:
raise HTTPException(status_code=400, detail="missing_task_id")
status = task.get("status") or "running"
payload_task = {
"task_id": task_id,
"run_id": req.run_id,
"question_id": task.get("question_id") or "",
"site": task.get("site") or "",
"profile_id": task.get("profile_id") or "",
"status": status,
"reason": task.get("reason"),
"handoff_url": task.get("handoff_url"),
"started_at": utcnow_iso(),
"finished_at": utcnow_iso() if status in ("success", "failed", "needs_human") else None,
}
await sb_upsert(cfg, "/tasks", json_payload=payload_task, on_conflict="task_id")
result = task.get("result") if isinstance(task.get("result"), dict) else None
if result is not None:
payload_art = {
"task_id": task_id,
"answer_text": result.get("answer_text") or "",
"citations_json": result.get("citations") or [],
"screenshots_json": result.get("screenshots") or [],
"raw_html_url": result.get("raw_html_url"),
"log_url": result.get("log_url"),
}
await sb_upsert(cfg, "/task_artifacts", json_payload=payload_art, on_conflict="task_id")
return {"ok": True}
class MentionsBatchRequest(BaseModel):
run_id: str
mentions: list[dict[str, Any]]
@app.post("/mentions/batch")
async def upsert_mentions(req: MentionsBatchRequest):
cfg = get_supabase_cfg()
payload = []
for m in req.mentions:
task_id = m.get("task_id")
entity_name = m.get("entity_name")
if not task_id or not entity_name:
continue
payload.append(
{
"task_id": task_id,
"entity_name": entity_name,
"entity_type": m.get("entity_type") or "",
"recommended": bool(m.get("recommended")),
"rank": m.get("rank"),
"first_index": m.get("first_index"),
"citations_json": m.get("citations_json") or [],
}
)
if payload:
await sb_upsert(cfg, "/mentions", json_payload=payload, on_conflict="task_id,entity_name")
return {"ok": True, "count": len(payload)}