""" Data Manager — CRUD operations backed by Google Sheets. Refactored for FastAPI: module-level caching instead of streamlit.session_state. """ import json import base64 import os import time import threading from datetime import datetime import gspread from google.oauth2 import service_account def _get_secret(key: str) -> str: return os.environ.get(key) # --- Module-level caching --- # RLock (reentrant): some helpers acquire the lock while calling another helper # that also acquires it (e.g. _get_worksheet -> _get_spreadsheet). A plain Lock # would deadlock. _lock = threading.RLock() _state: dict = { "spreadsheet": None, "worksheets": {}, "headers": {}, "read_cache": {}, # name -> (timestamp, list[dict]) } _READ_TTL_SECONDS = 60 def _get_spreadsheet() -> gspread.Spreadsheet: with _lock: if _state["spreadsheet"] is None: creds_b64 = _get_secret("GOOGLE_CALENDAR_CREDENTIALS") creds_json = json.loads(base64.b64decode(creds_b64).decode("utf-8")) credentials = service_account.Credentials.from_service_account_info( creds_json, scopes=[ "https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive", ], ) client = gspread.authorize(credentials) sheet_id = _get_secret("GOOGLE_SHEET_ID") _state["spreadsheet"] = client.open_by_key(sheet_id) return _state["spreadsheet"] def _get_worksheet(name: str) -> gspread.Worksheet: with _lock: if name not in _state["worksheets"]: _state["worksheets"][name] = _get_spreadsheet().worksheet(name) return _state["worksheets"][name] _JSON_FIELDS = { "work_hours", "lunch_break", "days_off", "recurring_days_off", "service_levels", "on_call", } def _serialize_row(data: dict) -> dict: row = {} for k, v in data.items(): if k in _JSON_FIELDS: row[k] = json.dumps(v) elif v is None: row[k] = "" elif isinstance(v, bool): row[k] = str(v) else: row[k] = str(v) return row def _deserialize_row(row: dict) -> dict: data = {} for k, v in row.items(): if k in _JSON_FIELDS: try: data[k] = json.loads(v) if v else [] except (json.JSONDecodeError, TypeError): data[k] = [] elif k in ("lat", "lng"): if v == "" or v is None: data[k] = None else: try: data[k] = float(v) except (ValueError, TypeError): data[k] = None elif k == "active": data[k] = v.lower() == "true" if isinstance(v, str) else bool(v) elif k == "visit_level": try: data[k] = int(v) except (ValueError, TypeError): data[k] = v else: data[k] = v return data def _read_all(worksheet_name: str) -> list[dict]: with _lock: cached = _state["read_cache"].get(worksheet_name) if cached and (time.time() - cached[0]) < _READ_TTL_SECONDS: return list(cached[1]) ws = _get_worksheet(worksheet_name) records = ws.get_all_records() data = [_deserialize_row(r) for r in records] with _lock: _state["read_cache"][worksheet_name] = (time.time(), data) return list(data) def _invalidate_cache(worksheet_name: str) -> None: with _lock: _state["read_cache"].pop(worksheet_name, None) def _get_headers(worksheet_name: str) -> list[str]: with _lock: if worksheet_name in _state["headers"]: return list(_state["headers"][worksheet_name]) ws = _get_worksheet(worksheet_name) headers = ws.row_values(1) with _lock: _state["headers"][worksheet_name] = headers return list(headers) def _col_letter(n: int) -> str: result = "" while n > 0: n, remainder = divmod(n - 1, 26) result = chr(65 + remainder) + result return result def _ensure_columns(worksheet_name: str, required: list[str]) -> list[str]: headers = _get_headers(worksheet_name) missing = [c for c in required if c not in headers] if not missing: return headers ws = _get_worksheet(worksheet_name) new_headers = headers + missing end_col = _col_letter(len(new_headers)) ws.update(values=[new_headers], range_name=f"A1:{end_col}1", value_input_option="RAW") with _lock: _state["headers"][worksheet_name] = new_headers _invalidate_cache(worksheet_name) return new_headers def _find_row_index(worksheet_name: str, id_column: str, id_value: str) -> int | None: records = _read_all(worksheet_name) for i, record in enumerate(records): if record.get(id_column) == id_value: return i + 2 return None def _append_row(worksheet_name: str, data: dict) -> None: ws = _get_worksheet(worksheet_name) headers = _get_headers(worksheet_name) row_values = [_serialize_row(data).get(h, "") for h in headers] ws.append_row(row_values, value_input_option="RAW") _invalidate_cache(worksheet_name) def _update_row(worksheet_name: str, row_index: int, data: dict) -> None: ws = _get_worksheet(worksheet_name) headers = _get_headers(worksheet_name) serialized = _serialize_row(data) row_values = [serialized.get(h, "") for h in headers] end_col = _col_letter(len(headers)) ws.update(values=[row_values], range_name=f"A{row_index}:{end_col}{row_index}", value_input_option="RAW") _invalidate_cache(worksheet_name) def _generate_id(prefix: str, existing_ids: list[str]) -> str: if not existing_ids: return f"{prefix}_001" nums = [] for eid in existing_ids: parts = eid.split("_") if len(parts) >= 2: try: nums.append(int(parts[-1])) except ValueError: pass next_num = max(nums, default=0) + 1 return f"{prefix}_{next_num:03d}" # --- Providers --- def get_all_providers(active_only: bool = True) -> list[dict]: providers = _read_all("Providers") if active_only: return [p for p in providers if p.get("active", True)] return providers def get_provider_by_id(provider_id: str) -> dict | None: for p in _read_all("Providers"): if p.get("id") == provider_id: return p return None def get_provider_by_name(name: str) -> dict | None: for p in _read_all("Providers"): if p.get("name", "").lower() == name.lower(): return p return None def add_provider(provider_data: dict) -> dict: _ensure_columns("Providers", ["on_call"]) providers = _read_all("Providers") existing_ids = [p["id"] for p in providers if "id" in p] provider_data["id"] = _generate_id("prov", existing_ids) provider_data.setdefault("active", True) provider_data.setdefault("days_off", []) provider_data.setdefault("recurring_days_off", []) provider_data.setdefault("on_call", []) provider_data.setdefault("lat", None) provider_data.setdefault("lng", None) _append_row("Providers", provider_data) return provider_data def update_provider(provider_id: str, updates: dict) -> dict | None: if "on_call" in updates: _ensure_columns("Providers", ["on_call"]) row_idx = _find_row_index("Providers", "id", provider_id) if row_idx is None: return None providers = _read_all("Providers") for p in providers: if p.get("id") == provider_id: p.update(updates) _update_row("Providers", row_idx, p) return p return None def delete_provider(provider_id: str) -> bool: row_idx = _find_row_index("Providers", "id", provider_id) if row_idx is None: return False ws = _get_worksheet("Providers") ws.delete_rows(row_idx) _invalidate_cache("Providers") return True # --- Patients --- def get_all_patients() -> list[dict]: return _read_all("Patients") def get_patient_by_id(patient_id: str) -> dict | None: for p in _read_all("Patients"): if p.get("id") == patient_id: return p return None def add_patient(patient_data: dict) -> dict: patients = _read_all("Patients") existing_ids = [p["id"] for p in patients if "id" in p] patient_data["id"] = _generate_id("pat", existing_ids) patient_data.setdefault("lat", None) patient_data.setdefault("lng", None) patient_data.setdefault("notes", "") _append_row("Patients", patient_data) return patient_data def update_patient(patient_id: str, updates: dict) -> dict | None: row_idx = _find_row_index("Patients", "id", patient_id) if row_idx is None: return None patients = _read_all("Patients") for p in patients: if p.get("id") == patient_id: p.update(updates) _update_row("Patients", row_idx, p) return p return None def delete_patient(patient_id: str) -> bool: row_idx = _find_row_index("Patients", "id", patient_id) if row_idx is None: return False ws = _get_worksheet("Patients") ws.delete_rows(row_idx) _invalidate_cache("Patients") return True # --- Appointments --- def get_all_appointments(status: str | None = None) -> list[dict]: appointments = _read_all("Appointments") if status: return [a for a in appointments if a.get("status") == status] return appointments def get_appointments_for_provider(provider_id: str, date: str | None = None) -> list[dict]: appointments = _read_all("Appointments") result = [a for a in appointments if a.get("provider_id") == provider_id and a.get("status") == "scheduled"] if date: result = [a for a in result if a.get("date") == date] return result def get_appointments_for_patient(patient_id: str) -> list[dict]: appointments = _read_all("Appointments") return [a for a in appointments if a.get("patient_id") == patient_id and a.get("status") == "scheduled"] def add_appointment(appointment_data: dict) -> dict: appointments = _read_all("Appointments") existing_ids = [a["id"] for a in appointments if "id" in a] appointment_data["id"] = _generate_id("appt", existing_ids) appointment_data["created_at"] = datetime.now().isoformat() appointment_data.setdefault("status", "scheduled") _append_row("Appointments", appointment_data) return appointment_data def update_appointment(appointment_id: str, updates: dict) -> dict | None: row_idx = _find_row_index("Appointments", "id", appointment_id) if row_idx is None: return None appointments = _read_all("Appointments") for a in appointments: if a.get("id") == appointment_id: a.update(updates) _update_row("Appointments", row_idx, a) return a return None def cancel_appointment(appointment_id: str) -> dict | None: return update_appointment(appointment_id, {"status": "cancelled", "cancelled_at": datetime.now().isoformat()}) def complete_appointment(appointment_id: str) -> dict | None: return update_appointment(appointment_id, {"status": "completed", "completed_at": datetime.now().isoformat()}) # --- Audit Log --- def log_action(user: str, action: str, details: str, appointment_id: str | None = None) -> None: entry = { "timestamp": datetime.now().isoformat(), "user": user, "action": action, "details": details, "appointment_id": appointment_id or "", } _append_row("AuditLog", entry) def get_audit_log(limit: int = 100) -> list[dict]: log = _read_all("AuditLog") log.sort(key=lambda x: x.get("timestamp", ""), reverse=True) return log[:limit]