ethnmcl commited on
Commit
6636e90
·
verified ·
1 Parent(s): 3d803ee

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +461 -0
main.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import List, Optional, Dict, Any, Literal
4
+ from datetime import datetime, timedelta
5
+
6
+ import httpx
7
+ import pytz
8
+ import dateparser
9
+ from dateparser.search import search_dates
10
+ from fastapi import FastAPI, Header, HTTPException, Depends, Query
11
+ from pydantic import BaseModel, Field
12
+ from sentence_transformers import SentenceTransformer
13
+ from dateutil.relativedelta import relativedelta
14
+
15
+ # === Environment ===
16
+ API_KEY = os.getenv("API_KEY") # shared secret for this API (set as a Space secret/variable)
17
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
18
+ SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
19
+ MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-small-en-v1.5")
20
+ LOCAL_TZ = pytz.timezone(os.getenv("TZ", "America/New_York"))
21
+ DEFAULT_WEEK_START = (os.getenv("WEEK_START", "monday") or "monday").strip().lower() # 'monday' or 'sunday'
22
+
23
+ if not (SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY):
24
+ raise RuntimeError("Missing SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY")
25
+
26
+ # Monday=0 ... Sunday=6
27
+ WEEKDAYS = {
28
+ "monday": 0, "tuesday": 1, "wednesday": 2,
29
+ "thursday": 3, "friday": 4, "saturday": 5, "sunday": 6
30
+ }
31
+ MONTHS = [
32
+ "january","february","march","april","may","june",
33
+ "july","august","september","october","november","december"
34
+ ]
35
+ TIME_PATTERNS = [
36
+ r"\blast\s+(monday|tuesday|wednesday|thursday|friday|saturday|sunday)\b",
37
+ r"\b(this|last)\s+(week|month)\b",
38
+ r"\b(past|last)\s+\d+\s+(?:day|days|week|weeks|month|months)\b",
39
+ r"\bq[1-4](?:\s+\d{4})?\b",
40
+ r"\b(today|yesterday)\b",
41
+ r"\b(january|february|march|april|may|june|july|august|september|october|november|december)(?:\s+\d{4})?\b",
42
+ ]
43
+
44
+ app = FastAPI(title="CIC Check-ins API", version="1.3.0")
45
+
46
+ # === Auth guard ===
47
+ def require_key(authorization: Optional[str] = Header(None)):
48
+ """If API_KEY is set, require 'Authorization: Bearer <API_KEY>'."""
49
+ if not API_KEY:
50
+ return
51
+ if not authorization or not authorization.startswith("Bearer "):
52
+ raise HTTPException(status_code=401, detail="Missing bearer token")
53
+ if authorization.split(" ", 1)[1].strip() != API_KEY:
54
+ raise HTTPException(status_code=403, detail="Invalid token")
55
+
56
+ # === Startup / Shutdown ===
57
+ @app.on_event("startup")
58
+ async def on_startup():
59
+ # Load embedding model once
60
+ app.state.model = SentenceTransformer(MODEL_NAME)
61
+ # Supabase REST client (uses service role for RPCs)
62
+ app.state.http = httpx.AsyncClient(
63
+ base_url=f"{SUPABASE_URL}/rest/v1",
64
+ headers={
65
+ "apikey": SUPABASE_SERVICE_ROLE_KEY,
66
+ "Authorization": f"Bearer {SUPABASE_SERVICE_ROLE_KEY}",
67
+ "Content-Type": "application/json",
68
+ "Accept": "application/json",
69
+ },
70
+ timeout=20.0,
71
+ )
72
+
73
+ @app.on_event("shutdown")
74
+ async def on_shutdown():
75
+ try:
76
+ await app.state.http.aclose()
77
+ except Exception:
78
+ pass
79
+
80
+ # === Helpers ===
81
+ def embed_text(texts: List[str]) -> List[List[float]]:
82
+ vecs = app.state.model.encode(texts, normalize_embeddings=True)
83
+ return [v.tolist() for v in vecs]
84
+
85
+ def _day_start(dt: datetime) -> datetime:
86
+ return dt.replace(hour=0, minute=0, second=0, microsecond=0)
87
+
88
+ def _week_start(dt: datetime, week_start: str) -> datetime:
89
+ idx = 0 if week_start == "monday" else 6 # monday=0, sunday=6 baseline
90
+ delta_days = (dt.weekday() - idx) % 7
91
+ return _day_start(dt - timedelta(days=delta_days))
92
+
93
+ def _localize(tz: pytz.BaseTzInfo, naive_dt: datetime) -> datetime:
94
+ return tz.localize(naive_dt)
95
+
96
+ def to_utc_iso(local_iso: str) -> str:
97
+ return datetime.fromisoformat(local_iso).astimezone(pytz.UTC).isoformat()
98
+
99
+ def extract_time_subphrase(text: str, tz: pytz.BaseTzInfo) -> Optional[str]:
100
+ s = (text or "").lower()
101
+ # 1) Regex heuristics
102
+ for pat in TIME_PATTERNS:
103
+ m = re.search(pat, s)
104
+ if m:
105
+ return m.group(0)
106
+ # 2) Fallback: search any date in text
107
+ settings = {
108
+ "TIMEZONE": str(tz),
109
+ "RETURN_AS_TIMEZONE_AWARE": True,
110
+ "PREFER_DATES_FROM": "past",
111
+ "RELATIVE_BASE": datetime.now(tz)
112
+ }
113
+ found = search_dates(s, settings=settings, languages=["en"])
114
+ if found:
115
+ return found[0][0]
116
+ return None
117
+
118
+ def parse_phrase_to_range(
119
+ phrase: str,
120
+ *,
121
+ tz: Optional[pytz.BaseTzInfo] = None,
122
+ week_start: Optional[str] = None
123
+ ) -> Dict[str, str]:
124
+ """Parse human phrase into [start, end) in tz. Returns {start, end, source}."""
125
+ tz = tz or LOCAL_TZ
126
+ week_start = (week_start or DEFAULT_WEEK_START).strip().lower()
127
+ s_in = (phrase or "").strip()
128
+ s = s_in.lower()
129
+ if not s:
130
+ raise HTTPException(400, detail="Empty phrase")
131
+
132
+ now = datetime.now(tz)
133
+
134
+ # last <weekday>
135
+ m = re.fullmatch(r"last\s+(monday|tuesday|wednesday|thursday|friday|saturday|sunday)", s)
136
+ if m:
137
+ target = WEEKDAYS[m.group(1)]
138
+ delta = (now.weekday() - target) % 7
139
+ delta = 7 if delta == 0 else delta
140
+ day = _day_start(now - timedelta(days=delta))
141
+ return {"start": day.isoformat(), "end": (day + timedelta(days=1)).isoformat(), "source": "weekday"}
142
+
143
+ # today / yesterday
144
+ if s == "today":
145
+ start = _day_start(now)
146
+ return {"start": start.isoformat(), "end": (start + timedelta(days=1)).isoformat(), "source": "day"}
147
+ if s == "yesterday":
148
+ end = _day_start(now)
149
+ start = end - timedelta(days=1)
150
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "day"}
151
+
152
+ # this/last week
153
+ if s == "this week":
154
+ start = _week_start(now, week_start)
155
+ return {"start": start.isoformat(), "end": (start + timedelta(days=7)).isoformat(), "source": "week"}
156
+ if s == "last week":
157
+ this_start = _week_start(now, week_start)
158
+ start = this_start - timedelta(days=7)
159
+ return {"start": start.isoformat(), "end": (start + timedelta(days=7)).isoformat(), "source": "week"}
160
+
161
+ # this/last month
162
+ if s == "this month":
163
+ start = _localize(tz, datetime(now.year, now.month, 1))
164
+ end = _localize(tz, datetime(now.year + (1 if now.month == 12 else 0),
165
+ 1 if now.month == 12 else now.month + 1, 1))
166
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"}
167
+ if s == "last month":
168
+ first_this = _localize(tz, datetime(now.year, now.month, 1))
169
+ start = _day_start(first_this - timedelta(days=1)).replace(day=1)
170
+ end = first_this
171
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"}
172
+
173
+ # <month> [year]?
174
+ m = re.fullmatch(rf"({'|'.join(MONTHS)})(?:\s+(\d{{4}}))?", s)
175
+ if m:
176
+ month_name, year_str = m.group(1), m.group(2)
177
+ month_idx = MONTHS.index(month_name) + 1
178
+ year = int(year_str) if year_str else now.year
179
+ start = _localize(tz, datetime(year, month_idx, 1))
180
+ end = _localize(tz, datetime(year + 1, 1, 1)) if month_idx == 12 else _localize(tz, datetime(year, month_idx + 1, 1))
181
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "month"}
182
+
183
+ # (past|last) <N> (days|weeks|months)
184
+ m = re.fullmatch(r"(past|last)\s+(\d+)\s*(day|days|week|weeks|month|months)", s)
185
+ if m:
186
+ n = int(m.group(2))
187
+ unit = m.group(3)
188
+ end = _day_start(now) + timedelta(days=1) # through end of today
189
+ if unit.startswith("day"):
190
+ start = end - timedelta(days=n)
191
+ elif unit.startswith("week"):
192
+ start = end - timedelta(weeks=n)
193
+ else:
194
+ start = end - relativedelta(months=n)
195
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "relative"}
196
+
197
+ # quarters: Q1..Q4 [year]?
198
+ m = re.fullmatch(r"q([1-4])(?:\s+(\d{4}))?", s)
199
+ if m:
200
+ q = int(m.group(1))
201
+ year = int(m.group(2)) if m.group(2) else now.year
202
+ start_month = (q - 1) * 3 + 1
203
+ start = _localize(tz, datetime(year, start_month, 1))
204
+ end = start + relativedelta(months=3)
205
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "quarter"}
206
+
207
+ # fallback: dateparser -> day range
208
+ settings = {
209
+ "TIMEZONE": str(tz),
210
+ "RETURN_AS_TIMEZONE_AWARE": True,
211
+ "PREFER_DATES_FROM": "past",
212
+ "RELATIVE_BASE": now
213
+ }
214
+ dt = dateparser.parse(s, settings=settings, languages=["en"])
215
+ if not dt:
216
+ raise HTTPException(400, detail=f"Could not parse phrase: {phrase}\n")
217
+ start = _day_start(dt.astimezone(tz))
218
+ end = start + timedelta(days=1)
219
+ return {"start": start.isoformat(), "end": end.isoformat(), "source": "dateparser"}
220
+
221
+ # === Schemas ===
222
+ class IngestBody(BaseModel):
223
+ id: str
224
+ sender: Optional[str] = None
225
+ username: Optional[str] = None
226
+ slack_id: Optional[str] = None
227
+ msg: str
228
+ timestamp: Optional[str] = Field(None, description="ISO8601; if absent, now()")
229
+ tags: Optional[List[str]] = []
230
+ valid_checkin: Optional[bool] = True
231
+
232
+ class SearchFilters(BaseModel):
233
+ phrase: Optional[str] = None
234
+ start: Optional[str] = None
235
+ end: Optional[str] = None
236
+ sender: Optional[str] = None
237
+ valid_only: Optional[bool] = None
238
+
239
+ class SearchBody(BaseModel):
240
+ query: str
241
+ k: int = 20
242
+ filters: Optional[SearchFilters] = None
243
+ return_fields: List[str] = ["id","ts","sender","username","msg","score"]
244
+
245
+ # /interpret request schema
246
+ class InterpretDefaults(BaseModel):
247
+ timezone: Optional[str] = None
248
+ week_start: Optional[str] = None
249
+ fallback_range: Optional[str] = None
250
+
251
+ class InterpretOptions(BaseModel):
252
+ return_suggestions: bool = True
253
+ infer_sender: Optional[str] = None
254
+ k: int = 20
255
+ return_fields: List[str] = ["id","ts","sender","username","msg","score"]
256
+ run_search: bool = True
257
+
258
+ class InterpretBody(BaseModel):
259
+ text: str
260
+ defaults: Optional[InterpretDefaults] = None
261
+ options: Optional[InterpretOptions] = None
262
+
263
+ # === Routes ===
264
+ @app.get("/")
265
+ async def root():
266
+ return {
267
+ "ok": True,
268
+ "hint": "Use /healthz, /ingest, /search, /phrases/resolve, /interpret, /stats",
269
+ "week_start": DEFAULT_WEEK_START
270
+ }
271
+
272
+ @app.get("/healthz")
273
+ async def health():
274
+ return {"ok": True, "model": MODEL_NAME}
275
+
276
+ @app.get("/phrases/resolve")
277
+ async def resolve_phrase(phrase: str = Query(..., min_length=1), _: None = Depends(require_key)):
278
+ r = parse_phrase_to_range(phrase)
279
+ return {"phrase": phrase, "timezone": str(LOCAL_TZ), "range": r}
280
+
281
+ @app.post("/ingest")
282
+ async def ingest(body: IngestBody, _: None = Depends(require_key)):
283
+ ts_utc = (
284
+ datetime.fromisoformat(body.timestamp).astimezone(pytz.UTC).isoformat()
285
+ if body.timestamp else datetime.now(pytz.UTC).isoformat()
286
+ )
287
+ vec = embed_text([body.msg])[0]
288
+ payload = {
289
+ "_id": body.id,
290
+ "_sender": body.sender,
291
+ "_username": body.username,
292
+ "_slack_id": body.slack_id,
293
+ "_msg": body.msg,
294
+ "_ts": ts_utc,
295
+ "_tags": body.tags or [],
296
+ "_valid": True if body.valid_checkin is not False else False,
297
+ "_embedding": vec,
298
+ }
299
+ r = await app.state.http.post("/rpc/upsert_checkin", json=payload)
300
+ if r.status_code >= 300:
301
+ raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}")
302
+ return {"ok": True, "id": body.id}
303
+
304
+ @app.post("/search")
305
+ async def search(body: SearchBody, _: None = Depends(require_key)):
306
+ q_vec = embed_text([body.query])[0]
307
+ start_utc = end_utc = None
308
+ if body.filters:
309
+ if body.filters.phrase:
310
+ rng = parse_phrase_to_range(body.filters.phrase)
311
+ start_utc, end_utc = to_utc_iso(rng["start"]), to_utc_iso(rng["end"])
312
+ if body.filters.start:
313
+ start_utc = to_utc_iso(body.filters.start) if "T" in body.filters.start else to_utc_iso(LOCAL_TZ.localize(datetime.fromisoformat(body.filters.start)).isoformat())
314
+ if body.filters.end:
315
+ end_utc = to_utc_iso(body.filters.end) if "T" in body.filters.end else to_utc_iso(LOCAL_TZ.localize(datetime.fromisoformat(body.filters.end)).isoformat())
316
+ rpc_payload = {
317
+ "q_embedding": q_vec,
318
+ "k": max(1, min(body.k, 100)),
319
+ "start_ts": start_utc,
320
+ "end_ts": end_utc,
321
+ "sender_eq": body.filters.sender if body.filters and body.filters.sender else None,
322
+ "valid_only": body.filters.valid_only if body.filters else None
323
+ }
324
+ r = await app.state.http.post("/rpc/search_checkins", json=rpc_payload)
325
+ if r.status_code >= 300:
326
+ raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}")
327
+ rows = r.json()
328
+ out = []
329
+ for row in rows:
330
+ item = {f: row.get(f) for f in body.return_fields if f in row or f == "score"}
331
+ if "score" in item and item["score"] is not None:
332
+ item["score"] = float(item["score"])
333
+ out.append(item)
334
+ return {"results": out, "used": {"semantic": True}}
335
+
336
+ @app.get("/stats")
337
+ async def stats(
338
+ phrase: Optional[str] = None,
339
+ bucket: Literal["weekly","monthly"] = "weekly",
340
+ _: None = Depends(require_key)
341
+ ):
342
+ if phrase:
343
+ rng = parse_phrase_to_range(phrase)
344
+ start_utc, end_utc = to_utc_iso(rng["start"]), to_utc_iso(rng["end"])
345
+ else:
346
+ end = datetime.now(pytz.UTC)
347
+ start = end - timedelta(days=30)
348
+ start_utc, end_utc = start.isoformat(), end.isoformat()
349
+ payload = {"bucket": bucket, "start_ts": start_utc, "end_ts": end_utc}
350
+ r = await app.state.http.post("/rpc/stats_range", json=payload)
351
+ if r.status_code >= 300:
352
+ raise HTTPException(r.status_code, detail=f"Supabase RPC error: {r.text[:300]}")
353
+ return {"bucket": bucket, "range": {"start": start_utc, "end": end_utc}, **r.json()}
354
+
355
+ # === /interpret ===
356
+ class InterpretResponse(BaseModel):
357
+ ok: bool
358
+
359
+ @app.post("/interpret")
360
+ async def interpret(body: InterpretBody, _: None = Depends(require_key)):
361
+ """
362
+ Free-form input -> (query, time window) + (optionally) return matching rows.
363
+ """
364
+ text = (body.text or "").strip()
365
+ if not text:
366
+ raise HTTPException(400, detail="Missing 'text'")
367
+
368
+ tz = LOCAL_TZ
369
+ week_start = DEFAULT_WEEK_START
370
+ if body.defaults:
371
+ if body.defaults.timezone:
372
+ try: tz = pytz.timezone(body.defaults.timezone)
373
+ except Exception: pass
374
+ if body.defaults.week_start and body.defaults.week_start.lower() in ("monday","sunday"):
375
+ week_start = body.defaults.week_start.lower()
376
+
377
+ sub = extract_time_subphrase(text, tz)
378
+ rng = None
379
+ time_source = None
380
+ extracted = None
381
+ suggestions: List[str] = []
382
+
383
+ if sub:
384
+ extracted = sub
385
+ parsed = parse_phrase_to_range(sub, tz=tz, week_start=week_start)
386
+ rng = {"start": parsed["start"], "end": parsed["end"], "tz": str(tz)}
387
+ time_source = parsed.get("source", "detected")
388
+ m = re.fullmatch(rf"({'|'.join(MONTHS)})", sub.strip().lower())
389
+ if m and (not body.options or body.options.return_suggestions):
390
+ now = datetime.now(tz)
391
+ mon = m.group(1).capitalize()
392
+ suggestions = [f"{mon} {now.year}", f"{mon} {now.year-1}"]
393
+
394
+ query = text
395
+ if extracted:
396
+ pattern = re.compile(re.escape(extracted), re.IGNORECASE)
397
+ query = pattern.sub("", query, count=1).strip()
398
+ query = re.sub(r"\s+", " ", query).strip()
399
+
400
+ used_fallback = False
401
+ if rng is None:
402
+ if body.defaults and body.defaults.fallback_range:
403
+ parsed = parse_phrase_to_range(body.defaults.fallback_range, tz=tz, week_start=week_start)
404
+ rng = {"start": parsed["start"], "end": parsed["end"], "tz": str(tz), "confidence": 0.2}
405
+ time_source = "fallback"
406
+ used_fallback = True
407
+ else:
408
+ return {
409
+ "ok": False,
410
+ "error": { "code": "NO_TIME_FOUND", "message": "No time phrase detected and no fallback_range provided." },
411
+ "hints": ["Add 'last week', 'August', 'past 30 days'", "Or pass defaults.fallback_range"],
412
+ "query_guess": query or text
413
+ }
414
+
415
+ opt = body.options or InterpretOptions()
416
+ search_payload = {
417
+ "query": query or text,
418
+ "k": max(1, min(opt.k, 100)),
419
+ "filters": { "start": rng["start"], "end": rng["end"], "sender": opt.infer_sender, "valid_only": None },
420
+ "return_fields": opt.return_fields
421
+ }
422
+
423
+ results = None
424
+ if opt.run_search:
425
+ q_vec = embed_text([search_payload["query"]])[0]
426
+ start_utc = to_utc_iso(search_payload["filters"]["start"])
427
+ end_utc = to_utc_iso(search_payload["filters"]["end"])
428
+ rpc_payload = {
429
+ "q_embedding": q_vec,
430
+ "k": search_payload["k"],
431
+ "start_ts": start_utc,
432
+ "end_ts": end_utc,
433
+ "sender_eq": search_payload["filters"]["sender"],
434
+ "valid_only": search_payload["filters"].get("valid_only")
435
+ }
436
+ r2 = await app.state.http.post("/rpc/search_checkins", json=rpc_payload)
437
+ if r2.status_code >= 300:
438
+ raise HTTPException(r2.status_code, detail=f"Supabase RPC error: {r2.text[:300]}")
439
+ rows = r2.json()
440
+ results = []
441
+ for row in rows:
442
+ item = {f: row.get(f) for f in opt.return_fields if f in row or f == "score"}
443
+ if "score" in item and item["score"] is not None:
444
+ item["score"] = float(item["score"])
445
+ results.append(item)
446
+
447
+ resp: Dict[str, Any] = {
448
+ "ok": True,
449
+ "input": { "text": body.text, "timezone": str(tz), "week_start": week_start },
450
+ "query": query or text,
451
+ "time": {
452
+ "phrase_raw": body.text, "phrase_extracted": extracted, "source": time_source,
453
+ "start": rng["start"], "end": rng["end"], "tz": rng["tz"]
454
+ },
455
+ "search_payload": search_payload
456
+ }
457
+ if suggestions and (not used_fallback):
458
+ resp["suggestions"] = suggestions
459
+ if results is not None:
460
+ resp["results"] = results
461
+ return resp