scheduler / core /data_manager.py
umangchaudhry's picture
Upload 31 files
0d04b76 verified
Raw
History Blame Contribute Delete
11.9 kB
"""
Data Manager — CRUD operations backed by Google Sheets.
Refactored for FastAPI: module-level caching instead of streamlit.session_state.
"""
import json
import base64
import os
import time
import threading
from datetime import datetime
import gspread
from google.oauth2 import service_account
def _get_secret(key: str) -> str:
return os.environ.get(key)
# --- Module-level caching ---
# RLock (reentrant): some helpers acquire the lock while calling another helper
# that also acquires it (e.g. _get_worksheet -> _get_spreadsheet). A plain Lock
# would deadlock.
_lock = threading.RLock()
_state: dict = {
"spreadsheet": None,
"worksheets": {},
"headers": {},
"read_cache": {}, # name -> (timestamp, list[dict])
}
_READ_TTL_SECONDS = 60
def _get_spreadsheet() -> gspread.Spreadsheet:
with _lock:
if _state["spreadsheet"] is None:
creds_b64 = _get_secret("GOOGLE_CALENDAR_CREDENTIALS")
creds_json = json.loads(base64.b64decode(creds_b64).decode("utf-8"))
credentials = service_account.Credentials.from_service_account_info(
creds_json,
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
client = gspread.authorize(credentials)
sheet_id = _get_secret("GOOGLE_SHEET_ID")
_state["spreadsheet"] = client.open_by_key(sheet_id)
return _state["spreadsheet"]
def _get_worksheet(name: str) -> gspread.Worksheet:
with _lock:
if name not in _state["worksheets"]:
_state["worksheets"][name] = _get_spreadsheet().worksheet(name)
return _state["worksheets"][name]
_JSON_FIELDS = {
"work_hours", "lunch_break", "days_off", "recurring_days_off", "service_levels",
"on_call",
}
def _serialize_row(data: dict) -> dict:
row = {}
for k, v in data.items():
if k in _JSON_FIELDS:
row[k] = json.dumps(v)
elif v is None:
row[k] = ""
elif isinstance(v, bool):
row[k] = str(v)
else:
row[k] = str(v)
return row
def _deserialize_row(row: dict) -> dict:
data = {}
for k, v in row.items():
if k in _JSON_FIELDS:
try:
data[k] = json.loads(v) if v else []
except (json.JSONDecodeError, TypeError):
data[k] = []
elif k in ("lat", "lng"):
if v == "" or v is None:
data[k] = None
else:
try:
data[k] = float(v)
except (ValueError, TypeError):
data[k] = None
elif k == "active":
data[k] = v.lower() == "true" if isinstance(v, str) else bool(v)
elif k == "visit_level":
try:
data[k] = int(v)
except (ValueError, TypeError):
data[k] = v
else:
data[k] = v
return data
def _read_all(worksheet_name: str) -> list[dict]:
with _lock:
cached = _state["read_cache"].get(worksheet_name)
if cached and (time.time() - cached[0]) < _READ_TTL_SECONDS:
return list(cached[1])
ws = _get_worksheet(worksheet_name)
records = ws.get_all_records()
data = [_deserialize_row(r) for r in records]
with _lock:
_state["read_cache"][worksheet_name] = (time.time(), data)
return list(data)
def _invalidate_cache(worksheet_name: str) -> None:
with _lock:
_state["read_cache"].pop(worksheet_name, None)
def _get_headers(worksheet_name: str) -> list[str]:
with _lock:
if worksheet_name in _state["headers"]:
return list(_state["headers"][worksheet_name])
ws = _get_worksheet(worksheet_name)
headers = ws.row_values(1)
with _lock:
_state["headers"][worksheet_name] = headers
return list(headers)
def _col_letter(n: int) -> str:
result = ""
while n > 0:
n, remainder = divmod(n - 1, 26)
result = chr(65 + remainder) + result
return result
def _ensure_columns(worksheet_name: str, required: list[str]) -> list[str]:
headers = _get_headers(worksheet_name)
missing = [c for c in required if c not in headers]
if not missing:
return headers
ws = _get_worksheet(worksheet_name)
new_headers = headers + missing
end_col = _col_letter(len(new_headers))
ws.update(values=[new_headers], range_name=f"A1:{end_col}1", value_input_option="RAW")
with _lock:
_state["headers"][worksheet_name] = new_headers
_invalidate_cache(worksheet_name)
return new_headers
def _find_row_index(worksheet_name: str, id_column: str, id_value: str) -> int | None:
records = _read_all(worksheet_name)
for i, record in enumerate(records):
if record.get(id_column) == id_value:
return i + 2
return None
def _append_row(worksheet_name: str, data: dict) -> None:
ws = _get_worksheet(worksheet_name)
headers = _get_headers(worksheet_name)
row_values = [_serialize_row(data).get(h, "") for h in headers]
ws.append_row(row_values, value_input_option="RAW")
_invalidate_cache(worksheet_name)
def _update_row(worksheet_name: str, row_index: int, data: dict) -> None:
ws = _get_worksheet(worksheet_name)
headers = _get_headers(worksheet_name)
serialized = _serialize_row(data)
row_values = [serialized.get(h, "") for h in headers]
end_col = _col_letter(len(headers))
ws.update(values=[row_values], range_name=f"A{row_index}:{end_col}{row_index}", value_input_option="RAW")
_invalidate_cache(worksheet_name)
def _generate_id(prefix: str, existing_ids: list[str]) -> str:
if not existing_ids:
return f"{prefix}_001"
nums = []
for eid in existing_ids:
parts = eid.split("_")
if len(parts) >= 2:
try:
nums.append(int(parts[-1]))
except ValueError:
pass
next_num = max(nums, default=0) + 1
return f"{prefix}_{next_num:03d}"
# --- Providers ---
def get_all_providers(active_only: bool = True) -> list[dict]:
providers = _read_all("Providers")
if active_only:
return [p for p in providers if p.get("active", True)]
return providers
def get_provider_by_id(provider_id: str) -> dict | None:
for p in _read_all("Providers"):
if p.get("id") == provider_id:
return p
return None
def get_provider_by_name(name: str) -> dict | None:
for p in _read_all("Providers"):
if p.get("name", "").lower() == name.lower():
return p
return None
def add_provider(provider_data: dict) -> dict:
_ensure_columns("Providers", ["on_call"])
providers = _read_all("Providers")
existing_ids = [p["id"] for p in providers if "id" in p]
provider_data["id"] = _generate_id("prov", existing_ids)
provider_data.setdefault("active", True)
provider_data.setdefault("days_off", [])
provider_data.setdefault("recurring_days_off", [])
provider_data.setdefault("on_call", [])
provider_data.setdefault("lat", None)
provider_data.setdefault("lng", None)
_append_row("Providers", provider_data)
return provider_data
def update_provider(provider_id: str, updates: dict) -> dict | None:
if "on_call" in updates:
_ensure_columns("Providers", ["on_call"])
row_idx = _find_row_index("Providers", "id", provider_id)
if row_idx is None:
return None
providers = _read_all("Providers")
for p in providers:
if p.get("id") == provider_id:
p.update(updates)
_update_row("Providers", row_idx, p)
return p
return None
def delete_provider(provider_id: str) -> bool:
row_idx = _find_row_index("Providers", "id", provider_id)
if row_idx is None:
return False
ws = _get_worksheet("Providers")
ws.delete_rows(row_idx)
_invalidate_cache("Providers")
return True
# --- Patients ---
def get_all_patients() -> list[dict]:
return _read_all("Patients")
def get_patient_by_id(patient_id: str) -> dict | None:
for p in _read_all("Patients"):
if p.get("id") == patient_id:
return p
return None
def add_patient(patient_data: dict) -> dict:
patients = _read_all("Patients")
existing_ids = [p["id"] for p in patients if "id" in p]
patient_data["id"] = _generate_id("pat", existing_ids)
patient_data.setdefault("lat", None)
patient_data.setdefault("lng", None)
patient_data.setdefault("notes", "")
_append_row("Patients", patient_data)
return patient_data
def update_patient(patient_id: str, updates: dict) -> dict | None:
row_idx = _find_row_index("Patients", "id", patient_id)
if row_idx is None:
return None
patients = _read_all("Patients")
for p in patients:
if p.get("id") == patient_id:
p.update(updates)
_update_row("Patients", row_idx, p)
return p
return None
def delete_patient(patient_id: str) -> bool:
row_idx = _find_row_index("Patients", "id", patient_id)
if row_idx is None:
return False
ws = _get_worksheet("Patients")
ws.delete_rows(row_idx)
_invalidate_cache("Patients")
return True
# --- Appointments ---
def get_all_appointments(status: str | None = None) -> list[dict]:
appointments = _read_all("Appointments")
if status:
return [a for a in appointments if a.get("status") == status]
return appointments
def get_appointments_for_provider(provider_id: str, date: str | None = None) -> list[dict]:
appointments = _read_all("Appointments")
result = [a for a in appointments if a.get("provider_id") == provider_id and a.get("status") == "scheduled"]
if date:
result = [a for a in result if a.get("date") == date]
return result
def get_appointments_for_patient(patient_id: str) -> list[dict]:
appointments = _read_all("Appointments")
return [a for a in appointments if a.get("patient_id") == patient_id and a.get("status") == "scheduled"]
def add_appointment(appointment_data: dict) -> dict:
appointments = _read_all("Appointments")
existing_ids = [a["id"] for a in appointments if "id" in a]
appointment_data["id"] = _generate_id("appt", existing_ids)
appointment_data["created_at"] = datetime.now().isoformat()
appointment_data.setdefault("status", "scheduled")
_append_row("Appointments", appointment_data)
return appointment_data
def update_appointment(appointment_id: str, updates: dict) -> dict | None:
row_idx = _find_row_index("Appointments", "id", appointment_id)
if row_idx is None:
return None
appointments = _read_all("Appointments")
for a in appointments:
if a.get("id") == appointment_id:
a.update(updates)
_update_row("Appointments", row_idx, a)
return a
return None
def cancel_appointment(appointment_id: str) -> dict | None:
return update_appointment(appointment_id, {"status": "cancelled", "cancelled_at": datetime.now().isoformat()})
def complete_appointment(appointment_id: str) -> dict | None:
return update_appointment(appointment_id, {"status": "completed", "completed_at": datetime.now().isoformat()})
# --- Audit Log ---
def log_action(user: str, action: str, details: str, appointment_id: str | None = None) -> None:
entry = {
"timestamp": datetime.now().isoformat(),
"user": user,
"action": action,
"details": details,
"appointment_id": appointment_id or "",
}
_append_row("AuditLog", entry)
def get_audit_log(limit: int = 100) -> list[dict]:
log = _read_all("AuditLog")
log.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
return log[:limit]