scraper / supabase_client.py
greeta's picture
Upload supabase_client.py
5ed4248 verified
"""
REST client for Supabase/PostgREST.
Uses the service key directly and avoids the Python SDK dependency tree.
"""
from __future__ import annotations
from typing import Dict, List, Optional
import logging
import httpx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SupabaseClient:
"""Minimal async client for the `fipi_tasks` table."""
ALLOWED_TASK_FIELDS = {
"title",
"content",
"source_url",
"task_type",
"images",
"variants",
"task_number",
"source_kind",
"task_guid",
"rubert_analysis",
"scraped_at",
}
def __init__(self, url: str, key: str):
self.base_url = f"{url.rstrip('/')}/rest/v1"
self.table_name = "fipi_tasks"
self.headers = {
"apikey": key,
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
async def _request(
self,
method: str,
path: str,
*,
params: Optional[Dict[str, str]] = None,
json: Optional[Dict | List[Dict]] = None,
prefer: Optional[str] = None,
) -> List[Dict]:
headers = dict(self.headers)
if prefer:
headers["Prefer"] = prefer
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method,
f"{self.base_url}/{path}",
params=params,
json=json,
headers=headers,
)
response.raise_for_status()
if not response.content:
return []
data = response.json()
return data if isinstance(data, list) else [data]
def _prepare_task_payload(self, task: Dict) -> Dict:
payload = {
key: value
for key, value in task.items()
if key in self.ALLOWED_TASK_FIELDS
}
payload.setdefault("task_type", "other")
payload.setdefault("images", [])
payload.setdefault("variants", [])
return payload
def _has_enrichment_changes(self, existing: Dict, incoming: Dict) -> bool:
for field in ("task_number", "source_kind", "task_guid"):
if incoming.get(field) and not existing.get(field):
return True
incoming_variants = incoming.get("variants") or []
existing_variants = existing.get("variants") or []
if incoming_variants and not existing_variants:
return True
incoming_images = incoming.get("images") or []
existing_images = existing.get("images") or []
if incoming_images and not existing_images:
return True
return False
async def is_available(self) -> bool:
try:
await self._request(
"GET",
self.table_name,
params={"select": "id", "limit": "1"},
)
return True
except Exception as e:
logger.error("Supabase availability check failed: %s", e)
return False
async def insert_task(self, task: Dict) -> Optional[Dict]:
try:
existing = await self.get_task_by_url(task.get("source_url", ""))
if existing:
if self._has_enrichment_changes(existing, task):
logger.info("Updating existing task metadata: %s", task.get("source_url"))
return await self.update_task(existing["id"], task)
logger.info("Task already exists: %s", task.get("source_url"))
return None
result = await self._request(
"POST",
self.table_name,
json=self._prepare_task_payload(task),
prefer="return=representation",
)
return result[0] if result else None
except httpx.HTTPStatusError as e:
detail = e.response.text if e.response is not None else str(e)
logger.error("Error inserting task: %s", detail)
return None
except Exception as e:
logger.error("Error inserting task: %s", e)
return None
async def insert_tasks_batch(self, tasks: List[Dict]) -> List[Dict]:
saved = []
for task in tasks:
result = await self.insert_task(task)
if result:
saved.append(result)
logger.info("Saved %s of %s tasks", len(saved), len(tasks))
return saved
async def get_task_by_id(self, task_id: int) -> Optional[Dict]:
try:
result = await self._request(
"GET",
self.table_name,
params={"select": "*", "id": f"eq.{task_id}"},
)
return result[0] if result else None
except Exception as e:
logger.error("Error getting task by id: %s", e)
return None
async def get_task_by_url(self, url: str) -> Optional[Dict]:
if not url:
return None
try:
result = await self._request(
"GET",
self.table_name,
params={"select": "*", "source_url": f"eq.{url}"},
)
return result[0] if result else None
except Exception as e:
logger.error("Error getting task by url: %s", e)
return None
async def get_latest_tasks(self, limit: int = 10) -> List[Dict]:
try:
return await self._request(
"GET",
self.table_name,
params={
"select": "*",
"order": "scraped_at.desc",
"limit": str(limit),
},
)
except Exception as e:
logger.error("Error getting latest tasks: %s", e)
return []
async def get_all_tasks(self) -> List[Dict]:
try:
return await self._request(
"GET",
self.table_name,
params={"select": "*", "order": "scraped_at.desc"},
)
except Exception as e:
logger.error("Error getting all tasks: %s", e)
return []
async def search_tasks(self, query: str) -> List[Dict]:
try:
escaped_query = query.replace(",", " ").replace("(", " ").replace(")", " ")
pattern = f"*{escaped_query}*"
or_filter = f"(title.ilike.{pattern},content.ilike.{pattern})"
return await self._request(
"GET",
self.table_name,
params={
"select": "*",
"or": or_filter,
},
)
except Exception as e:
logger.error("Error searching tasks: %s", e)
return []
async def get_tasks_by_type(self, task_type: str) -> List[Dict]:
try:
return await self._request(
"GET",
self.table_name,
params={"select": "*", "task_type": f"eq.{task_type}"},
)
except Exception as e:
logger.error("Error getting tasks by type: %s", e)
return []
async def update_task(self, task_id: int, updates: Dict) -> Optional[Dict]:
try:
result = await self._request(
"PATCH",
self.table_name,
params={"id": f"eq.{task_id}"},
json=self._prepare_task_payload(updates),
prefer="return=representation",
)
return result[0] if result else None
except Exception as e:
logger.error("Error updating task: %s", e)
return None
async def delete_task(self, task_id: int) -> bool:
try:
await self._request(
"DELETE",
self.table_name,
params={"id": f"eq.{task_id}"},
prefer="return=representation",
)
return True
except Exception as e:
logger.error("Error deleting task: %s", e)
return False
async def get_stats(self) -> Dict:
try:
all_tasks = await self.get_all_tasks()
stats = {"total": len(all_tasks), "by_type": {}}
for task in all_tasks:
task_type = task.get("task_type", "unknown")
stats["by_type"][task_type] = stats["by_type"].get(task_type, 0) + 1
return stats
except Exception as e:
logger.error("Error getting stats: %s", e)
return {"total": 0, "by_type": {}}