import os import re from typing import List, Optional, Dict, Any, Literal from datetime import datetime, timedelta import httpx import pytz import dateparser from dateparser.search import search_dates from fastapi import FastAPI, Header, HTTPException, Depends, Query from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer from dateutil.relativedelta import relativedelta # ==== Cache & Env setup (important on HF Spaces) ==== # Put all model caches under /data (persistent & writable in Spaces) CACHE_ROOT = os.getenv("MODEL_CACHE_DIR", "/data/.cache") HF_HOME = os.getenv("HF_HOME", os.path.join(CACHE_ROOT, "huggingface")) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", os.path.join(HF_HOME, "transformers")) ST_HOME = os.getenv("SENTENCE_TRANSFORMERS_HOME", os.path.join(CACHE_ROOT, "sentence-transformers")) os.makedirs(TRANSFORMERS_CACHE, exist_ok=True) os.makedirs(ST_HOME, exist_ok=True) os.environ["HF_HOME"] = HF_HOME os.environ["TRANSFORMERS_CACHE"] = TRANSFORMERS_CACHE os.environ["SENTENCE_TRANSFORMERS_HOME"] = ST_HOME # ==== App config ==== API_KEY = os.getenv("API_KEY") # shared secret for this API SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY") MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-small-en-v1.5") LOCAL_TZ = pytz.timezone(os.getenv("TZ", "America/New_York")) DEFAULT_WEEK_START = (os.getenv("WEEK_START", "monday") or "monday").strip().lower() # 'monday' or 'sunday' if not (SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY): raise RuntimeError("Missing SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY") # Monday=0 ... Sunday=6 WEEKDAYS = { "monday": 0, "tuesday": 1, "wednesday": 2, "thursday": 3, "friday": 4, "saturday": 5, "sunday": 6 } MONTHS = [ "january","february","march","april","may","june", "july","august","september","october","november","december" ] TIME_PATTERNS = [ r"\blast\s+(monday|tuesday|wednesday|thursday|friday|saturday|sunday)\b", r"\b(this|last)\s+(week|month)\b", r"\b(past|last)\s+\d+\s+(?:day|days|week|weeks|month|months)\b", r"\bq[1-4](?:\s+\d{4})?\b", r"\b(today|yesterday)\b", r"\b(january|february|march|april|may|june|july|august|september|october|november|december)(?:\s+\d{4})?\b", ] app = FastAPI(title="CIC Check-ins API", version="1.3.1") # === Auth guard === def require_key(authorization: Optional[str] = Header(None)): """If API_KEY is set, require 'Authorization: Bearer '.""" if not API_KEY: return if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing bearer token") if authorization.split(" ", 1)[1].strip() != API_KEY: raise HTTPException(status_code=403, detail="Invalid token") # === Startup / Shutdown === @app.on_event("startup") async def on_startup(): # Load embedding model with explicit cache folder (fixes /.cache permission issue) try: app.state.model = SentenceTransformer(MODEL_NAME, cache_folder=ST_HOME) except Exception as e: # Optional fallback to a tiny model if the specified one fails (keeps API available) fallback = "sentence-transformers/all-MiniLM-L6-v2" app.state.model = SentenceTransformer(fallback, cache_folder=ST_HOME) app.state.model_name_fallback = fallback # Supabase REST client app.state.http = httpx.AsyncClient( base_url=f"{SUPABASE_URL}/rest/v1", headers={ "apikey": SUPABASE_SERVICE_ROLE_KEY, "Authorization": f"Bearer {SUPABASE_SERVICE_ROLE_KEY}", "Content-Type": "application/json", "Accept": "application/json", }, timeout=20.0, ) @app.on_event("shutdown") async def on_shutdown(): try: await app.state.http.aclose() except Exception: pass # === Helpers === def embed_text(texts: List[str]) -> List[List[float]]: vecs = app.state.model.encode(texts, normalize_embeddings=True) return [v.tolist() for v in vecs] def _day_start(dt: datetime) -> datetime: return dt.replace(hour=0, minute=0, second=0, microsecond=0) def _week_start(dt: datetime, week_start: str) -> datetime: idx = 0 if week_start == "monday" else 6 # monday=0, sunday=6 baseline delta_days = (dt.weekday() - idx) % 7 return _day_start(dt - timedelta(days=delta_days)) def _localize(tz: pytz.BaseTzInfo, naive_dt: datetime) -> datetime: return tz.localize(naive_dt) def to_utc_iso(local_iso: str) -> str: return datetime.fromisoformat(local_iso).astimezone(pytz.UTC).isoformat() def extract_time_subphrase(text: str, tz: pytz.BaseTzInfo) -> Optional[str]: s = (text or "").lower() for pat in TIME_PATTERNS: m = re.search(pat, s) if m: return m.group(0) settings = { "TIMEZONE": str(tz), "RETURN_AS_TIMEZONE_AWARE": True, "PREFER_DATES_FROM": "past", "RELATIVE_BASE": datetime.now(tz) } found = search_dates(s, settings=settings, languages=["en"]) if found: return found[0][0] return None def parse_phrase_to_range( phrase: str, *, tz: Optional[pytz.BaseTzInfo] = None, week_start: Optional[str] = None ) -> Dict[str, str]: tz = tz or LOCAL_TZ week_start = (week_start or DEFAULT_WEEK_START).strip().lower() s_in = (phrase or "").strip() s = s_in.lower() if not s: raise HTTPException(400, detail="Empty phrase") now = datetime.now(tz) m = re.fullmatch(r"last\s+(monday|tuesday|wednesday|thursday|friday|saturday|sunday)", s) if m: target = WEEKDAYS[m.group(1)] delta = (now.weekday() - target) % 7 delta = 7 if delta == 0 else delta day = _day_start(now - timedelta(days=delta)) return {"start": day.isoformat(), "end": (day + timedelta(days=1)).isoformat(), "source": "weekday"} if s == "today": start = _day_start(now) return {"start": start.isoformat(), "end": (start + timedelta(days=1)).isoformat(), "source": "day"} if s == "yesterday": end = _day_start(now) start = end - timedelta(days=1) return {"start": start.isoformat(), "end": end.isoformat(), "source": "day"} if s == "this week": start = _week_start(now, week_start) return {"start": start.isoformat(), "end": (start + timedelta(days=7)).isoformat(), "source": "week"} if s == "last week": this_start = _week_start(now, week_start) start = this_start - timedelta(days=7) return {"start": start.isoformat(), "end": (start + timedelta(days=7)).isoformat(), "source": "week"} if s == "this month": start = _localize(tz, datetime(now.year, now.month, 1)) end = _localize(tz, datetime(now.year + (1 if now.month == 12 else 0), 1 if now.month == 12 else now.month + 1, 1)) return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"} if s == "last month": first_this = _localize(tz, datetime(now.year, now.month, 1)) start = _day_start(first_this - timedelta(days=1)).replace(day=1) end = first_this return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"} m = re.fullmatch(rf"({'|'.join(MONTHS)})(?:\s+(\d{{4}}))?", s) if m: month_name, year_str = m.group(1), m.group(2) month_idx = MONTHS.index(month_name) + 1 year = int(year_str) if year_str else now.year start = _localize(tz, datetime(year, month_idx, 1)) end = _localize(tz, datetime(year + 1, 1, 1)) if month_idx == 12 else _localize(tz, datetime(year, month_idx + 1, 1)) return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"} m = re.fullmatch(r"(past|last)\s+(\d+)\s*(day|days|week|weeks|month|months)", s) if m: n = int(m.group(2)) unit = m.group(3) end = _day_start(now) + timedelta(days=1) if unit.startswith("day"): start = end - timedelta(days=n) elif unit.startswith("week"): start = end - timedelta(weeks=n) else: start = end - relativedelta(months=n) return {"start": start.isoformat(), "end": end.isoformat(), "source": "relative"} m = re.fullmatch(r"q([1-4])(?:\s+(\d{4}))?", s) if m: q = int(m.group(1)) year = int(m.group(2)) if m.group(2) else now.year start_month = (q - 1) * 3 + 1 start = _localize(tz, datetime(year, start_month, 1)) end = start + relativedelta(months=3) return {"start": start.isoformat(), "end": end.isoformat(), "source": "quarter"} settings = {"TIMEZONE": str(tz), "RETURN_AS_TIMEZONE_AWARE": True, "PREFER_DATES_FROM": "past", "RELATIVE_BASE": now} dt = dateparser.parse(s, settings=settings, languages=["en"]) if not dt: raise HTTPException(400, detail=f"Could not parse phrase: {phrase}\n") start = _day_start(dt.astimezone(tz)) end = start + timedelta(days=1) return {"start": start.isoformat(), "end": end.isoformat(), "source": "dateparser"} # === Schemas === class IngestBody(BaseModel): id: str sender: Optional[str] = None username: Optional[str] = None slack_id: Optional[str] = None msg: str timestamp: Optional[str] = Field(None, description="ISO8601; if absent, now()") tags: Optional[List[str]] = [] valid_checkin: Optional[bool] = True class SearchFilters(BaseModel): phrase: Optional[str] = None start: Optional[str] = None end: Optional[str] = None sender: Optional[str] = None valid_only: Optional[bool] = None class SearchBody(BaseModel): query: str k: int = 20 filters: Optional[SearchFilters] = None return_fields: List[str] = ["id","ts","sender","username","msg","score"] class InterpretDefaults(BaseModel): timezone: Optional[str] = None week_start: Optional[str] = None fallback_range: Optional[str] = None class InterpretOptions(BaseModel): return_suggestions: bool = True infer_sender: Optional[str] = None k: int = 20 return_fields: List[str] = ["id","ts","sender","username","msg","score"] run_search: bool = True class InterpretBody(BaseModel): text: str defaults: Optional[InterpretDefaults] = None options: Optional[InterpretOptions] = None # === Routes === @app.get("/") async def root(): return {"ok": True, "hint": "Use /healthz, /ingest, /search, /phrases/resolve, /interpret, /stats", "week_start": DEFAULT_WEEK_START} @app.get("/healthz") async def health(): model_name = getattr(app.state, "model_name_fallback", MODEL_NAME) return {"ok": True, "model": model_name} @app.get("/phrases/resolve") async def resolve_phrase(phrase: str = Query(..., min_length=1), _: None = Depends(require_key)): r = parse_phrase_to_range(phrase) return {"phrase": phrase, "timezone": str(LOCAL_TZ), "range": r} @app.post("/ingest") async def ingest(body: IngestBody, _: None = Depends(require_key)): ts_utc = (datetime.fromisoformat(body.timestamp).astimezone(pytz.UTC).isoformat() if body.timestamp else datetime.now(pytz.UTC).isoformat()) vec = embed_text([body.msg])[0] payload = { "_id": body.id, "_sender": body.sender, "_username": body.username, "_slack_id": body.slack_id, "_msg": body.msg, "_ts": ts_utc, "_tags": body.tags or [], "_valid": True if body.valid_checkin is not False else False, "_embedding": vec, } r = await app.state.http.post("/rpc/upsert_checkin", json=payload) if r.status_code >= 300: raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}") return {"ok": True, "id": body.id} @app.post("/search") async def search(body: SearchBody, _: None = Depends(require_key)): q_vec = embed_text([body.query])[0] start_utc = end_utc = None if body.filters: if body.filters.phrase: rng = parse_phrase_to_range(body.filters.phrase) start_utc, end_utc = to_utc_iso(rng["start"]), to_utc_iso(rng["end"]) if body.filters.start: start_utc = to_utc_iso(body.filters.start) if "T" in body.filters.start else to_utc_iso(LOCAL_TZ.localize(datetime.fromisoformat(body.filters.start)).isoformat()) if body.filters.end: end_utc = to_utc_iso(body.filters.end) if "T" in body.filters.end else to_utc_iso(LOCAL_TZ.localize(datetime.fromisoformat(body.filters.end)).isoformat()) rpc_payload = { "q_embedding": q_vec, "k": max(1, min(body.k, 100)), "start_ts": start_utc, "end_ts": end_utc, "sender_eq": body.filters.sender if body.filters and body.filters.sender else None, "valid_only": body.filters.valid_only if body.filters else None } r = await app.state.http.post("/rpc/search_checkins", json=rpc_payload) if r.status_code >= 300: raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}") rows = r.json() out = [] for row in rows: item = {f: row.get(f) for f in body.return_fields if f in row or f == "score"} if "score" in item and item["score"] is not None: item["score"] = float(item["score"]) out.append(item) return {"results": out, "used": {"semantic": True}} @app.get("/stats") async def stats(phrase: Optional[str] = None, bucket: Literal["weekly","monthly"] = "weekly", _: None = Depends(require_key)): if phrase: rng = parse_phrase_to_range(phrase) start_utc, end_utc = to_utc_iso(rng["start"]), to_utc_iso(rng["end"]) else: end = datetime.now(pytz.UTC) start = end - timedelta(days=30) start_utc, end_utc = start.isoformat(), end.isoformat() payload = {"bucket": bucket, "start_ts": start_utc, "end_ts": end_utc} r = await app.state.http.post("/rpc/stats_range", json=payload) if r.status_code >= 300: raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}") return {"bucket": bucket, "range": {"start": start_utc, "end": end_utc}, **r.json()} @app.post("/interpret") async def interpret(body: InterpretBody, _: None = Depends(require_key)): text = (body.text or "").strip() if not text: raise HTTPException(400, detail="Missing 'text'") tz = LOCAL_TZ week_start = DEFAULT_WEEK_START if body.defaults: if body.defaults.timezone: try: tz = pytz.timezone(body.defaults.timezone) except Exception: pass if body.defaults.week_start and body.defaults.week_start.lower() in ("monday","sunday"): week_start = body.defaults.week_start.lower() sub = extract_time_subphrase(text, tz) rng = None time_source = None extracted = None suggestions: List[str] = [] if sub: extracted = sub parsed = parse_phrase_to_range(sub, tz=tz, week_start=week_start) rng = {"start": parsed["start"], "end": parsed["end"], "tz": str(tz)} time_source = parsed.get("source", "detected") m = re.fullmatch(rf"({'|'.join(MONTHS)})", sub.strip().lower()) if m and (not body.options or body.options.return_suggestions): now = datetime.now(tz) mon = m.group(1).capitalize() suggestions = [f"{mon} {now.year}", f"{mon} {now.year-1}"] query = text if extracted: pattern = re.compile(re.escape(extracted), re.IGNORECASE) query = pattern.sub("", query, count=1).strip() query = re.sub(r"\s+", " ", query).strip() used_fallback = False if rng is None: if body.defaults and body.defaults.fallback_range: parsed = parse_phrase_to_range(body.defaults.fallback_range, tz=tz, week_start=week_start) rng = {"start": parsed["start"], "end": parsed["end"], "tz": str(tz), "confidence": 0.2} time_source = "fallback" used_fallback = True else: return { "ok": False, "error": {"code": "NO_TIME_FOUND", "message": "No time phrase detected and no fallback_range provided."}, "hints": ["Add 'last week', 'August', 'past 30 days'", "Or pass defaults.fallback_range"], "query_guess": query or text } opt = body.options or InterpretOptions() search_payload = { "query": query or text, "k": max(1, min(opt.k, 100)), "filters": {"start": rng["start"], "end": rng["end"], "sender": opt.infer_sender, "valid_only": None}, "return_fields": opt.return_fields } results = None if opt.run_search: q_vec = embed_text([search_payload["query"]])[0] start_utc = to_utc_iso(search_payload["filters"]["start"]) end_utc = to_utc_iso(search_payload["filters"]["end"]) rpc_payload = { "q_embedding": q_vec, "k": search_payload["k"], "start_ts": start_utc, "end_ts": end_utc, "sender_eq": search_payload["filters"]["sender"], "valid_only": search_payload["filters"].get("valid_only") } r2 = await app.state.http.post("/rpc/search_checkins", json=rpc_payload) if r2.status_code >= 300: raise HTTPException(r2.status_code, detail=f"Supabase RPC error: {r2.text[:300]}") rows = r2.json() results = [] for row in rows: item = {f: row.get(f) for f in opt.return_fields if f in row or f == "score"} if "score" in item and item["score"] is not None: item["score"] = float(item["score"]) results.append(item) resp: Dict[str, Any] = { "ok": True, "input": {"text": body.text, "timezone": str(tz), "week_start": week_start}, "query": query or text, "time": {"phrase_raw": body.text, "phrase_extracted": extracted, "source": time_source, "start": rng["start"], "end": rng["end"], "tz": rng["tz"]}, "search_payload": search_payload } if suggestions and (not used_fallback): resp["suggestions"] = suggestions if results is not None: resp["results"] = results return resp