| from __future__ import annotations |
|
|
| from datetime import datetime, timezone |
| from typing import Any |
|
|
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
|
|
| import os |
| import httpx |
|
|
|
|
| class Settings: |
| def __init__(self) -> None: |
| self.supabase_url = os.environ.get("SUPABASE_URL") |
| self.supabase_key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") or os.environ.get("SUPABASE_ANON_KEY") |
|
|
|
|
| class SupabaseConfig(BaseModel): |
| url: str |
| key: str |
|
|
| @property |
| def rest_url(self) -> str: |
| return self.url.rstrip("/") + "/rest/v1" |
|
|
|
|
| def sb_headers(cfg: SupabaseConfig, extra: dict[str, str] | None = None) -> dict[str, str]: |
| headers = { |
| "apikey": cfg.key, |
| "authorization": f"Bearer {cfg.key}", |
| "content-type": "application/json", |
| } |
| if extra: |
| headers.update(extra) |
| return headers |
|
|
|
|
| async def sb_get(cfg: SupabaseConfig, path: str, params: dict[str, Any]) -> Any: |
| async with httpx.AsyncClient(timeout=30) as client: |
| r = await client.get(cfg.rest_url + path, headers=sb_headers(cfg), params=params) |
| r.raise_for_status() |
| return r.json() |
|
|
|
|
| async def sb_upsert(cfg: SupabaseConfig, path: str, json_payload: Any, on_conflict: str) -> None: |
| async with httpx.AsyncClient(timeout=30) as client: |
| r = await client.post( |
| cfg.rest_url + path, |
| headers=sb_headers(cfg, {"prefer": "resolution=merge-duplicates,return=minimal"}), |
| params={"on_conflict": on_conflict}, |
| json=json_payload, |
| ) |
| r.raise_for_status() |
|
|
|
|
| class BrandPairsResponse(BaseModel): |
| version: str |
| pairs: list[dict[str, Any]] |
|
|
|
|
| def utcnow_iso() -> str: |
| return datetime.now(timezone.utc).replace(microsecond=0).isoformat() |
|
|
|
|
| app = FastAPI(title="GEO DB Service (Supabase)", version="0.1.0") |
| settings = Settings() |
|
|
|
|
| def get_supabase_cfg() -> SupabaseConfig: |
| if not settings.supabase_url or not settings.supabase_key: |
| raise HTTPException(status_code=500, detail="supabase_not_configured") |
| return SupabaseConfig(url=settings.supabase_url, key=settings.supabase_key) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"ok": True} |
|
|
|
|
| @app.get("/brand-pairs/latest", response_model=BrandPairsResponse) |
| async def brand_pairs_latest(): |
| cfg = get_supabase_cfg() |
| latest = await sb_get( |
| cfg, |
| "/dim_brand_pairs", |
| params={"select": "version,effective_at", "order": "effective_at.desc", "limit": "1"}, |
| ) |
| if not latest: |
| return BrandPairsResponse(version="0", pairs=[]) |
| version = latest[0]["version"] |
| pairs = await sb_get( |
| cfg, |
| "/dim_brand_pairs", |
| params={ |
| "select": "main_brand,competitor_brand,industry,priority", |
| "version": f"eq.{version}", |
| "order": "priority.desc", |
| }, |
| ) |
| return BrandPairsResponse(version=version, pairs=pairs) |
|
|
|
|
| class RunCreateRequest(BaseModel): |
| run_id: str |
| trigger: str |
| schedule_id: str | None = None |
| config_version: str |
| config_snapshot_url: str | None = None |
| params: dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
| @app.post("/runs") |
| async def create_run(req: RunCreateRequest): |
| cfg = get_supabase_cfg() |
| payload = { |
| "run_id": req.run_id, |
| "trigger": req.trigger, |
| "schedule_id": req.schedule_id, |
| "config_version": req.config_version, |
| "config_snapshot_url": req.config_snapshot_url, |
| "params_json": req.params, |
| "started_at": utcnow_iso(), |
| "finished_at": None, |
| } |
| await sb_upsert(cfg, "/runs", json_payload=payload, on_conflict="run_id") |
| return {"ok": True} |
|
|
|
|
| class TaskUpsertRequest(BaseModel): |
| run_id: str |
| task: dict[str, Any] |
|
|
|
|
| @app.post("/tasks") |
| async def upsert_task(req: TaskUpsertRequest): |
| cfg = get_supabase_cfg() |
| task = req.task or {} |
| task_id = task.get("task_id") |
| if not task_id: |
| raise HTTPException(status_code=400, detail="missing_task_id") |
|
|
| status = task.get("status") or "running" |
| payload_task = { |
| "task_id": task_id, |
| "run_id": req.run_id, |
| "question_id": task.get("question_id") or "", |
| "site": task.get("site") or "", |
| "profile_id": task.get("profile_id") or "", |
| "status": status, |
| "reason": task.get("reason"), |
| "handoff_url": task.get("handoff_url"), |
| "started_at": utcnow_iso(), |
| "finished_at": utcnow_iso() if status in ("success", "failed", "needs_human") else None, |
| } |
| await sb_upsert(cfg, "/tasks", json_payload=payload_task, on_conflict="task_id") |
|
|
| result = task.get("result") if isinstance(task.get("result"), dict) else None |
| if result is not None: |
| payload_art = { |
| "task_id": task_id, |
| "answer_text": result.get("answer_text") or "", |
| "citations_json": result.get("citations") or [], |
| "screenshots_json": result.get("screenshots") or [], |
| "raw_html_url": result.get("raw_html_url"), |
| "log_url": result.get("log_url"), |
| } |
| await sb_upsert(cfg, "/task_artifacts", json_payload=payload_art, on_conflict="task_id") |
| return {"ok": True} |
|
|
|
|
| class MentionsBatchRequest(BaseModel): |
| run_id: str |
| mentions: list[dict[str, Any]] |
|
|
|
|
| @app.post("/mentions/batch") |
| async def upsert_mentions(req: MentionsBatchRequest): |
| cfg = get_supabase_cfg() |
| payload = [] |
| for m in req.mentions: |
| task_id = m.get("task_id") |
| entity_name = m.get("entity_name") |
| if not task_id or not entity_name: |
| continue |
| payload.append( |
| { |
| "task_id": task_id, |
| "entity_name": entity_name, |
| "entity_type": m.get("entity_type") or "", |
| "recommended": bool(m.get("recommended")), |
| "rank": m.get("rank"), |
| "first_index": m.get("first_index"), |
| "citations_json": m.get("citations_json") or [], |
| } |
| ) |
| if payload: |
| await sb_upsert(cfg, "/mentions", json_payload=payload, on_conflict="task_id,entity_name") |
| return {"ok": True, "count": len(payload)} |
|
|