Spaces:
Running
Running
| """ | |
| SNF OT Daily Progress Note Generator β FastAPI Backend | |
| Handles Groq AI integration for converting structured selections into | |
| Med A-compliant clinical paragraphs. | |
| """ | |
| import os | |
| import re | |
| import json | |
| import logging | |
| import hashlib | |
| import hmac | |
| import time | |
| import secrets | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException, Request, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field, validator | |
| from dotenv import load_dotenv | |
| from groq import Groq | |
| import httpx | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from fastapi.responses import JSONResponse | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Groq API keys β multiple for fallback on rate limits | |
| _raw_keys = os.getenv("GROQ_API_KEYS", os.getenv("GROQ_API_KEY", "")) | |
| GROQ_API_KEYS = [k.strip() for k in _raw_keys.split(",") if k.strip()] | |
| ALLOWED_ORIGINS = os.getenv( | |
| "ALLOWED_ORIGINS", | |
| "http://localhost:5173,http://localhost:3000,https://jashdoshi77-ot.hf.space" | |
| ).split(",") | |
| GROQ_MODEL = "llama-3.3-70b-versatile" | |
| # JWT secret | |
| JWT_SECRET = os.getenv("JWT_SECRET", secrets.token_hex(32)) | |
| JWT_EXPIRY_HOURS = 24 | |
| logger.info(f"Loaded {len(GROQ_API_KEYS)} Groq API key(s)") | |
| # --------------------------------------------------------------------------- | |
| # Authorized users β loaded from env (format: user1:pass1,user2:pass2) | |
| # --------------------------------------------------------------------------- | |
| def _hash_pw(pw: str) -> str: | |
| return hashlib.sha256(pw.encode()).hexdigest() | |
| def _load_users() -> dict: | |
| raw = os.getenv("AUTH_USERS", "") | |
| users = {} | |
| for pair in raw.split(","): | |
| pair = pair.strip() | |
| if ":" in pair: | |
| uname, pw = pair.split(":", 1) | |
| users[uname.strip().lower()] = _hash_pw(pw.strip()) | |
| return users | |
| AUTHORIZED_USERS = _load_users() | |
| logger.info(f"Loaded {len(AUTHORIZED_USERS)} authorized user(s)") | |
| # --------------------------------------------------------------------------- | |
| # JWT helpers (lightweight, no external dependency) | |
| # --------------------------------------------------------------------------- | |
| import base64, json as _json | |
| def _b64url_encode(data: bytes) -> str: | |
| return base64.urlsafe_b64encode(data).rstrip(b"=").decode() | |
| def _b64url_decode(s: str) -> bytes: | |
| s += "=" * (4 - len(s) % 4) | |
| return base64.urlsafe_b64decode(s) | |
| def create_jwt(username: str) -> str: | |
| header = _b64url_encode(_json.dumps({"alg": "HS256", "typ": "JWT"}).encode()) | |
| payload = _b64url_encode(_json.dumps({ | |
| "sub": username, | |
| "iat": int(time.time()), | |
| "exp": int(time.time()) + JWT_EXPIRY_HOURS * 3600, | |
| }).encode()) | |
| signature = hmac.new(JWT_SECRET.encode(), f"{header}.{payload}".encode(), hashlib.sha256).digest() | |
| sig_b64 = _b64url_encode(signature) | |
| return f"{header}.{payload}.{sig_b64}" | |
| def verify_jwt(token: str) -> dict: | |
| try: | |
| parts = token.split(".") | |
| if len(parts) != 3: | |
| raise ValueError("Invalid token") | |
| header_b64, payload_b64, sig_b64 = parts | |
| expected_sig = hmac.new(JWT_SECRET.encode(), f"{header_b64}.{payload_b64}".encode(), hashlib.sha256).digest() | |
| actual_sig = _b64url_decode(sig_b64) | |
| if not hmac.compare_digest(expected_sig, actual_sig): | |
| raise ValueError("Invalid signature") | |
| payload = _json.loads(_b64url_decode(payload_b64)) | |
| if payload.get("exp", 0) < time.time(): | |
| raise ValueError("Token expired") | |
| return payload | |
| except Exception as e: | |
| raise ValueError(f"Token verification failed: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Auth dependency | |
| # --------------------------------------------------------------------------- | |
| bearer_scheme = HTTPBearer() | |
| async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)): | |
| try: | |
| payload = verify_jwt(credentials.credentials) | |
| username = payload.get("sub") | |
| if username not in AUTHORIZED_USERS: | |
| raise HTTPException(status_code=401, detail="User not authorized") | |
| return username | |
| except ValueError as e: | |
| raise HTTPException(status_code=401, detail=str(e)) | |
| # --------------------------------------------------------------------------- | |
| # Rate limiting | |
| # --------------------------------------------------------------------------- | |
| limiter = Limiter(key_func=get_remote_address) | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="OT Note Generator API", | |
| version="1.0.0", | |
| docs_url="/api/docs", | |
| redoc_url=None, | |
| ) | |
| app.state.limiter = limiter | |
| async def rate_limit_handler(request: Request, exc: RateLimitExceeded): | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Rate limit exceeded. Please wait before trying again."}, | |
| ) | |
| # Security middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOWED_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["Content-Type", "Authorization"], | |
| max_age=600, | |
| ) | |
| app.add_middleware( | |
| TrustedHostMiddleware, | |
| allowed_hosts=["localhost", "127.0.0.1", "*.localhost", "*.hf.space", "*.huggingface.co"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Pydantic models | |
| # --------------------------------------------------------------------------- | |
| MAX_SECTIONS = 20 | |
| MAX_FIELDS_PER_SECTION = 200 | |
| MAX_TEXT_LENGTH = 2000 | |
| class FieldSelection(BaseModel): | |
| fieldId: str = Field(..., max_length=200) | |
| fieldLabel: str = Field(..., max_length=500) | |
| selectedOptions: List[str] = Field(default_factory=list) | |
| customText: Optional[str] = Field(None, max_length=MAX_TEXT_LENGTH) | |
| numericValue: Optional[str] = Field(None, max_length=50) | |
| def validate_option_length(cls, v): | |
| if len(v) > 500: | |
| raise ValueError("Option text too long") | |
| return v | |
| class SectionData(BaseModel): | |
| sectionId: str = Field(..., max_length=200) | |
| sectionTitle: str = Field(..., max_length=300) | |
| enabled: bool = True | |
| fields: List[FieldSelection] = Field(default_factory=list) | |
| def validate_fields_count(cls, v): | |
| if len(v) > MAX_FIELDS_PER_SECTION: | |
| raise ValueError(f"Too many fields (max {MAX_FIELDS_PER_SECTION})") | |
| return v | |
| class GenerateRequest(BaseModel): | |
| patientInfo: Dict[str, str] = Field(default_factory=dict) | |
| sections: List[SectionData] = Field(default_factory=list) | |
| customSections: List[SectionData] = Field(default_factory=list) | |
| def validate_sections_count(cls, v): | |
| if len(v) > MAX_SECTIONS: | |
| raise ValueError(f"Too many sections (max {MAX_SECTIONS})") | |
| return v | |
| def sanitize_patient_info(cls, v): | |
| sanitized = {} | |
| for key, value in v.items(): | |
| clean_key = re.sub(r"[^\w\s\-/]", "", key)[:100] | |
| clean_val = re.sub(r"[<>]", "", str(value))[:500] | |
| sanitized[clean_key] = clean_val | |
| return sanitized | |
| # --------------------------------------------------------------------------- | |
| # System prompt for Groq | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert Occupational Therapy clinical documentation specialist. Your job is to convert structured selections from an OT daily progress note template into polished, professional, Med A-compliant clinical narrative paragraphs. | |
| RULES: | |
| 1. Write in the EXACT style of skilled nursing facility (SNF) occupational therapy daily treatment encounter notes. | |
| 2. Use proper OT abbreviations: pt (patient), BUE/BLE (bilateral upper/lower extremity), AROM (active ROM), ADL, IADL, LBD/UBD (lower/upper body dressing), FM (fine motor), GM (gross motor), vc (verbal cue), tc (tactile cue), hoh (hand-over-hand), SBA (standby assist), CGA (contact guard assist), Min/Mod/Max A (assist levels), ec (energy conservation), JP (joint protection), STS (sit to stand), w/c (wheelchair), EOB (edge of bed), WFL (within functional limits), 2 degrees (secondary to), d/c (discharge), HEP (home exercise program), etc. | |
| 3. IMPORTANT: You MUST write a SEPARATE section with its own header for EVERY section provided in the input. This includes non-CPT sections like Patient Presentation, Education Provided, Response to Treatment, Assessment, Therapist Signature, and any other section. Do NOT merge or absorb non-CPT sections into CPT sections. | |
| 4. DO NOT use asterisks (*), stars, bold markers, or any markdown formatting whatsoever. Output must be clean plain text only. | |
| 5. DO NOT use em dashes, en dashes, or any special unicode characters. Use regular hyphens (-) or commas instead. | |
| 6. Write flowing narrative paragraphs, NOT bullet points. | |
| 7. Include clinical justification language that supports medical necessity. | |
| 8. Maintain a professional, skilled, clinical tone throughout. | |
| 9. If exercise details (sets/reps) are provided, weave them naturally into sentences. | |
| 10. Connect deficits to functional impacts and goals. | |
| 11. Include cueing types and assist levels naturally in the narrative. | |
| 12. Do NOT invent information not provided in the selections. | |
| 13. If a section has no selections, skip it entirely. | |
| 14. Format each section header as plain text followed by a colon. For CPT sections include the code, e.g. "Therapeutic Exercise (97110):". For non-CPT sections just use the name, e.g. "Patient Presentation:" or "Assessment:". | |
| 15. Use regular quotation marks and standard ASCII punctuation only. | |
| 16. The order of sections in the output should match the order they appear in the input. | |
| EXAMPLE OUTPUT STYLE: | |
| Patient Presentation: | |
| Pt was seen seated in w/c upon arrival, presenting with an initially reluctant demeanor; however, with skilled encouragement and verbal cueing, pt became engaged and receptive to treatment. Pt c/o pain at 6/10 localized to the right shoulder, which was managed with positioning and activity modification. Pt was alert and oriented x3 (person, place, time). | |
| Therapeutic Exercise (97110): | |
| Pt participated in skilled therapeutic exercise targeting BUE strength and AROM in seated position with back support to improve functional performance with ADLs, transfers, and mobility. Exercises were graded by therapist to address grip strength and core stability that directly impact pt's safe and independent functional task performance. Exercises performed today included shoulder flexion (3 x 10), shoulder abduction (3 x 10), and grip strengthening via putty squeeze. Pt required vc/tc for proper joint alignment and controlled movement speed throughout exercise program. Pt demonstrated fair tolerance to therapeutic exercise, completing 20 min of activity with intermittent rest breaks (1-2) secondary to fatigue and decreased endurance. | |
| Therapeutic Activities (97530): | |
| Pt engaged in skilled therapeutic activities addressing static and dynamic standing balance in alignment with established OT goals and POC. Pt participated in dynamic standing ball toss to targets and ring toss on cone (3 x 10 each) to improve dynamic balance and postural reactions. Pt required CGA for balance and safety, as well as vc for safety awareness and fall prevention. Pt demonstrated fair tolerance to therapeutic activities, with performance impacted by decreased endurance and impaired postural control. | |
| Self-Care / ADL Training (97535): | |
| Pt engaged in skilled ADL retraining focused on UBD and LBD and toileting to support goal of increased functional independence. Pt performed LBD donning with Min A, doffing with SBA. AE utilized: reacher and sock aid. Pt was able to thread bilateral LE through garment independently and pull clothing to knee level. | |
| Education Provided: | |
| Pt and/or caregiver educated on ec strategies for ADL participation, safe transfer techniques, and HEP compliance. Pt demonstrated fair understanding via return demonstration with minimal cueing required. | |
| Response to Treatment: | |
| Pt demonstrated fair participation with fatigue and decreased endurance, requiring frequent rest breaks and cueing for task completion, but was able to complete all graded activities with assistance. Pt demonstrated emerging carryover of compensatory strategies with minimal verbal cueing. | |
| Assessment: | |
| Pt presents with deficits in bed mobility, strength, activity tolerance, and ADL sequencing, impacting independence in self-care tasks. Skilled OT remains necessary for functional mobility training, ADL retraining, balance, and fine motor coordination to improve safety and independence. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Helper: build user message from selections | |
| # --------------------------------------------------------------------------- | |
| def build_user_message(data: GenerateRequest) -> str: | |
| parts: List[str] = [] | |
| # Patient info | |
| if data.patientInfo: | |
| info_lines = [f"- {k}: {v}" for k, v in data.patientInfo.items() if v] | |
| if info_lines: | |
| parts.append("PATIENT INFO:\n" + "\n".join(info_lines)) | |
| # Sections | |
| all_sections = [s for s in data.sections if s.enabled] + [ | |
| s for s in data.customSections if s.enabled | |
| ] | |
| for section in all_sections: | |
| has_data = any( | |
| f.selectedOptions or f.customText or f.numericValue for f in section.fields | |
| ) | |
| if not has_data: | |
| continue | |
| section_lines = [f"\n--- {section.sectionTitle} ---"] | |
| for field in section.fields: | |
| if not field.selectedOptions and not field.customText and not field.numericValue: | |
| continue | |
| line = f" {field.fieldLabel}: " | |
| values = [] | |
| if field.selectedOptions: | |
| values.extend(field.selectedOptions) | |
| if field.customText: | |
| values.append(field.customText) | |
| if field.numericValue: | |
| values.append(field.numericValue) | |
| line += " / ".join(values) | |
| section_lines.append(line) | |
| parts.append("\n".join(section_lines)) | |
| return "\n".join(parts) | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # Auth models & login route | |
| # --------------------------------------------------------------------------- | |
| class LoginRequest(BaseModel): | |
| username: str = Field(..., max_length=100) | |
| password: str = Field(..., max_length=200) | |
| async def login(request: Request, data: LoginRequest): | |
| username = data.username.strip().lower() | |
| password_hash = _hash_pw(data.password) | |
| stored_hash = AUTHORIZED_USERS.get(username) | |
| if not stored_hash or not hmac.compare_digest(stored_hash, password_hash): | |
| logger.warning(f"Failed login attempt for user: {username}") | |
| raise HTTPException(status_code=401, detail="Invalid username or password") | |
| token = create_jwt(username) | |
| logger.info(f"User '{username}' logged in successfully") | |
| return {"token": token, "username": username} | |
| async def verify_token(user: str = Depends(get_current_user)): | |
| return {"valid": True, "username": user} | |
| # Load template data β encrypted at rest, decrypted at startup using TEMPLATE_KEY | |
| _TEMPLATE_DATA = None | |
| _base_dir = os.path.dirname(__file__) | |
| _enc_path = os.path.join(_base_dir, "template_data.enc") | |
| _json_path = os.path.join(_base_dir, "template_data.json") | |
| try: | |
| if os.path.exists(_enc_path): | |
| # Production: decrypt the encrypted template | |
| from cryptography.fernet import Fernet | |
| _template_key = os.getenv("TEMPLATE_KEY") | |
| if not _template_key: | |
| raise RuntimeError("TEMPLATE_KEY env variable is required to decrypt template data") | |
| _fernet = Fernet(_template_key.encode()) | |
| with open(_enc_path, "rb") as f: | |
| _encrypted = f.read() | |
| _decrypted = _fernet.decrypt(_encrypted) | |
| _TEMPLATE_DATA = json.loads(_decrypted.decode("utf-8")) | |
| logger.info(f"Loaded encrypted template data ({len(_decrypted) // 1024} KB)") | |
| elif os.path.exists(_json_path): | |
| # Local dev fallback: load plain JSON | |
| with open(_json_path, "r", encoding="utf-8") as f: | |
| _TEMPLATE_DATA = json.load(f) | |
| logger.info(f"Loaded template data from JSON ({os.path.getsize(_json_path) // 1024} KB)") | |
| else: | |
| logger.warning("No template data file found β /api/template will be unavailable") | |
| except Exception as e: | |
| logger.error(f"Failed to load template data: {e}") | |
| _TEMPLATE_DATA = None | |
| async def get_template(user: str = Depends(get_current_user)): | |
| """Serves the template structure. Requires authentication.""" | |
| if _TEMPLATE_DATA is None: | |
| raise HTTPException(status_code=500, detail="Template data not available") | |
| return _TEMPLATE_DATA | |
| async def health_check(): | |
| return {"status": "ok", "model": GROQ_MODEL} | |
| async def generate_note(request: Request, data: GenerateRequest, user: str = Depends(get_current_user)): | |
| if not GROQ_API_KEYS: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="No Groq API keys configured. Please set GROQ_API_KEYS in .env file.", | |
| ) | |
| user_message = build_user_message(data) | |
| if not user_message.strip(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No selections provided. Please fill in at least one section.", | |
| ) | |
| logger.info(f"Generating note for '{user}' - payload size: {len(user_message)} chars") | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": f"Convert the following OT daily note selections into a polished, professional Med A-compliant clinical narrative:\n\n{user_message}", | |
| }, | |
| ] | |
| last_error = None | |
| # Try each API key in order β rotate on rate limit or auth errors | |
| for i, api_key in enumerate(GROQ_API_KEYS): | |
| try: | |
| key_label = f"key #{i + 1}/{len(GROQ_API_KEYS)}" | |
| logger.info(f"Attempting generation with {key_label}") | |
| http_client = httpx.Client() | |
| client = Groq(api_key=api_key, http_client=http_client) | |
| chat_completion = client.chat.completions.create( | |
| messages=messages, | |
| model=GROQ_MODEL, | |
| temperature=0.7, | |
| max_tokens=4096, | |
| top_p=0.9, | |
| ) | |
| generated_text = chat_completion.choices[0].message.content | |
| logger.info(f"Generation succeeded with {key_label}") | |
| return { | |
| "success": True, | |
| "note": generated_text, | |
| "model": GROQ_MODEL, | |
| "usage": { | |
| "prompt_tokens": chat_completion.usage.prompt_tokens, | |
| "completion_tokens": chat_completion.usage.completion_tokens, | |
| }, | |
| } | |
| except Exception as e: | |
| last_error = str(e) | |
| error_lower = last_error.lower() | |
| # If it's a rate limit or auth error, try the next key | |
| if "rate_limit" in error_lower or "429" in error_lower or "quota" in error_lower or "limit" in error_lower: | |
| logger.warning(f"{key_label} hit rate limit, trying next key...") | |
| continue | |
| elif "401" in error_lower or "invalid" in error_lower or "auth" in error_lower: | |
| logger.warning(f"{key_label} auth error, trying next key...") | |
| continue | |
| else: | |
| # Non-recoverable error β don't try other keys | |
| logger.error(f"Groq API error (non-recoverable): {last_error}") | |
| raise HTTPException(status_code=502, detail=f"AI generation failed: {last_error}") | |
| # All keys exhausted | |
| logger.error(f"All {len(GROQ_API_KEYS)} API keys exhausted. Last error: {last_error}") | |
| raise HTTPException( | |
| status_code=502, | |
| detail="All API keys have been rate-limited. Please wait a moment and try again.", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Static file serving (production β serves built React app) | |
| # --------------------------------------------------------------------------- | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| _static_dir = os.path.join(os.path.dirname(__file__), "static") | |
| if os.path.isdir(_static_dir): | |
| # Serve static assets (JS, CSS, images) | |
| app.mount("/assets", StaticFiles(directory=os.path.join(_static_dir, "assets")), name="assets") | |
| # SPA fallback β serve index.html for all non-API routes | |
| async def serve_spa(path: str): | |
| # If the file exists in static dir, serve it | |
| file_path = os.path.join(_static_dir, path) | |
| if os.path.isfile(file_path): | |
| return FileResponse(file_path) | |
| # Otherwise serve index.html (SPA routing) | |
| return FileResponse(os.path.join(_static_dir, "index.html")) | |
| logger.info(f"Serving frontend from {_static_dir}") | |
| else: | |
| logger.info("No static directory found β running in API-only mode") | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) | |