Spaces:
Sleeping
Sleeping
| import os | |
| from ast import List | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import io | |
| import traceback | |
| import pandas as pd | |
| import logging | |
| import base64 | |
| import json | |
| import re | |
| import asyncio | |
| import functools | |
| from typing import Any, Optional | |
| from datetime import datetime | |
| import uvicorn | |
| import google.generativeai as genai | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from google.generativeai import generative_models | |
| from pydantic import BaseModel | |
| from past_reports import router as reports_router, db_fetch_reports, db_insert_report, db_get_report | |
| GEMINI_API_KEY="AIzaSyAK0HJWN-WLuG5BxkHawu6_qFpcXU71cT0" | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class Config: | |
| DEBUG = True | |
| app = FastAPI() | |
| api = APIRouter(prefix="/api") | |
| app.include_router(api) | |
| EXTRACTED_TEXT_CACHE = "" | |
| app.mount("/app", StaticFiles(directory="web", html=True), name="web") | |
| app.include_router(reports_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return RedirectResponse(url="/app/") | |
| class AnalyzeRequest(BaseModel): | |
| image_base64: str | |
| prompt: Optional[str] = None | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| generation_config = { | |
| "temperature": 0.1, | |
| "top_p": 0.8, | |
| "top_k": 20, | |
| "max_output_tokens": 4096, | |
| } | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| ] | |
| class ChatRequest(BaseModel): | |
| user_id: Optional[str] = "anonymous" | |
| question: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| class TextRequest(BaseModel): | |
| text: str | |
| system_prompt = """You are a highly skilled medical practitioner specializing in medical image and document analysis. You will be given either a medical image or a PDF. | |
| Your responsibilities are: | |
| 1. **Extract Text**: If the input is a PDF or image, first extract all the text content (lab values, notes, measurements, etc.). Do not summarize — keep the extracted text verbatim. | |
| 2. **Detailed Analysis**: Use both the extracted text and the visual features of the image to identify any anomalies, diseases, or health issues. | |
| 3. **Output Format**: You MUST return ONLY a valid JSON object with this EXACT structure (no additional text, no markdown, no code blocks): | |
| { | |
| "ocr_text": "<<<FULL VERBATIM TEXT FROM THE PDF/IMAGE>>>", | |
| "measurements": [ | |
| { | |
| "type": "HbA1c", | |
| "value": 8.5, | |
| "unit": "%", | |
| "min": "4.0", | |
| "max": "5.6", | |
| "status": "HIGH", | |
| "severity": "SEVERE" | |
| }, | |
| { | |
| "type": "Total Cholesterol", | |
| "value": 280, | |
| "unit": "mg/dL", | |
| "min": "0", | |
| "max": "200", | |
| "status": "HIGH", | |
| "severity": "SEVERE" | |
| } | |
| ], | |
| "analysis": [ | |
| { | |
| "findings": "DIABETES. Elevated HbA1c indicates poor glucose control over past 2-3 months.", | |
| "severity": "SEVERE", | |
| "recommendations": ["Consult endocrinologist immediately", "Review medication regimen"], | |
| "treatment_suggestions": ["Adjust insulin dosage", "Consider metformin"], | |
| "home_care_guidance": ["Monitor blood sugar 4x daily", "Follow diabetic diet"] | |
| } | |
| ] | |
| } | |
| 4. **Measurement Extraction Rules**: | |
| - Extract EVERY numerical health measurement found in the document | |
| - Include lab values, vital signs, body measurements, test results | |
| - For each measurement provide: type, value, unit, min, max, status, severity | |
| - To provide the min and max, first check the document for a provided min or max, if not just use your AI knowledge to provide the min and max for that specific measurement type | |
| - Status should be LOW, NORMAL, BORDER-LINE HIGH, and HIGH based on min and max. | |
| 5. **Finding Analysis**: | |
| - Document all observed anomalies or diseases in the analysis section | |
| - UPPERCASE the main concern in each finding | |
| - Link findings to relevant measurements when applicable | |
| - If a disease is family history or previously recovered, mark severity as: "severity of anomaly (Past Anomaly but Still Under Risk)" | |
| - Provide actionable recommendations and treatment suggestions | |
| CRITICAL: Return ONLY the JSON object. No explanatory text, no markdown formatting, no code blocks. Also make sure to check all your information twice before sending. | |
| """ | |
| system_prompt_chat = """ | |
| *** Role: Medical Chat Assistant *** | |
| You are a concise and empathetic medical chatbot. Your job is to give clear, short answers (max 3-4 sentences) based only on the provided medical report text. | |
| Rules: | |
| - Avoid repeating the entire report; focus only on what is directly relevant to the user’s question. | |
| - Give top 2 actionable steps if needed. | |
| - If condition is serious, suggest consulting a doctor immediately. | |
| - Always end with: "Check with your physician before acting." | |
| Input: | |
| Report Text: {document_text} | |
| User Question: {user_question} | |
| Response: | |
| """ | |
| model = genai.GenerativeModel(model_name="gemini-2.5-flash-lite") | |
| async def _call_model_blocking(request_inputs, generation_cfg, safety_cfg): | |
| fn = functools.partial( | |
| model.generate_content, | |
| request_inputs, | |
| generation_config=generation_cfg, | |
| safety_settings=safety_cfg, | |
| ) | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, fn) | |
| def extract_measurements_from_gemini_structured(measurements_data): | |
| measurements = [] | |
| if not measurements_data: | |
| logger.warning("No measurements data provided") | |
| return measurements | |
| for measurement in measurements_data: | |
| try: | |
| measurement_type = measurement.get("type") or measurement.get("measurement_type", "Unknown") | |
| value = measurement.get("value", 0) | |
| unit = measurement.get("unit", "") | |
| ref_range = "" | |
| if measurement.get("reference_range"): | |
| ref_range = measurement.get("reference_range") | |
| elif measurement.get("min") and measurement.get("max"): | |
| ref_range = f"{measurement.get('min')}-{measurement.get('max')}" | |
| elif measurement.get("min"): | |
| ref_range = f">{measurement.get('min')}" | |
| elif measurement.get("max"): | |
| ref_range = f"<{measurement.get('max')}" | |
| measurements.append({ | |
| "measurement_type": measurement_type, | |
| "value": float(value) if value else 0.0, | |
| "unit": unit, | |
| "min": measurement.get('min'), | |
| "max": measurement.get('max'), | |
| "status": measurement.get("status", "UNKNOWN"), | |
| "severity": measurement.get("severity", "UNKNOWN"), | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing measurement: {measurement}, error: {e}") | |
| continue | |
| return measurements | |
| async def analyze_image(image_bytes: bytes, mime_type: str, prompt: Optional[str] = None) -> tuple: | |
| base64_img = base64.b64encode(image_bytes).decode("utf-8") | |
| text_prompt = (prompt or system_prompt).strip() | |
| request_inputs = [ | |
| {"inline_data": {"mime_type": mime_type, "data": base64_img}}, | |
| {"text": text_prompt}, | |
| ] | |
| try: | |
| response = await _call_model_blocking(request_inputs, generation_config, safety_settings) | |
| except Exception as e: | |
| logger.error(f"Model call failed: {e}") | |
| raise RuntimeError(f"Model call failed: {e}") | |
| text = getattr(response, "text", None) | |
| if not text and isinstance(response, dict): | |
| candidates = response.get("candidates") or [] | |
| if candidates: | |
| text = candidates[0].get("content") or candidates[0].get("text") | |
| if not text: | |
| text = str(response) | |
| logger.info(f"Raw Gemini response: {text[:500]}...") | |
| clean = re.sub(r'```(?:json)?\s*', '', text).strip() | |
| clean = re.sub(r'```\s*$', '', clean).strip() | |
| logger.info(f"Cleaned response: {clean[:500]}...") | |
| try: | |
| parsed = json.loads(clean) | |
| if "ocr_text" in parsed and "measurements" in parsed and "analysis" in parsed: | |
| ocr_text = parsed.get("ocr_text", "") | |
| measurements = parsed.get("measurements", []) | |
| analysis = parsed.get("analysis", []) | |
| logger.info(f"Successfully parsed structured response with {len(measurements)} measurements and {len(analysis)} analyses") | |
| return analysis, ocr_text, measurements | |
| logger.warning("Response not in expected format, attempting to extract...") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Initial JSON decode error: {e}") | |
| json_match = re.search(r'\{[\s\S]*"ocr_text"[\s\S]*"measurements"[\s\S]*"analysis"[\s\S]*\}', clean) | |
| if json_match: | |
| try: | |
| logger.info("Found structured JSON in response, attempting to parse...") | |
| parsed = json.loads(json_match.group(0)) | |
| ocr_text = parsed.get("ocr_text", "") | |
| measurements = parsed.get("measurements", []) | |
| analysis = parsed.get("analysis", []) | |
| logger.info(f"Successfully extracted structured data with {len(measurements)} measurements and {len(analysis)} analyses") | |
| return analysis, ocr_text, measurements | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse extracted JSON: {e}") | |
| if "raw_found_json" in clean: | |
| try: | |
| temp_parsed = json.loads(clean) | |
| if "raw_found_json" in temp_parsed: | |
| inner_json = temp_parsed["raw_found_json"] | |
| if isinstance(inner_json, str): | |
| inner_parsed = json.loads(inner_json) | |
| else: | |
| inner_parsed = inner_json | |
| ocr_text = inner_parsed.get("ocr_text", "") | |
| measurements = inner_parsed.get("measurements", []) | |
| analysis = inner_parsed.get("analysis", []) | |
| logger.info(f"Successfully unwrapped raw_found_json with {len(measurements)} measurements") | |
| return analysis, ocr_text, measurements | |
| except (json.JSONDecodeError, KeyError) as e: | |
| logger.error(f"Failed to unwrap raw_found_json: {e}") | |
| logger.warning("Using fallback parsing - structured data extraction failed") | |
| return [{"findings": "Failed to parse structured response", "raw_output": clean[:1000]}], "", [] | |
| def save_analysis_with_measurements(user_id, ocr_text, analysis_data, measurements_data, report_date=None): | |
| measurements = extract_measurements_from_gemini_structured(measurements_data) | |
| report_data = { | |
| "user_id": user_id, | |
| "report_date": report_date or datetime.now().strftime("%Y-%m-%d"), | |
| "ocr_text": ocr_text, | |
| "anomalies": json.dumps(analysis_data) if analysis_data else None, | |
| "measurements": json.dumps(measurements) | |
| } | |
| try: | |
| logger.info(f"Saving report for user {user_id} with {len(measurements)} measurements") | |
| report_id = db_insert_report(report_data) | |
| logger.info(f"Report saved with ID: {report_id}") | |
| for measurement in measurements: | |
| status_indicator = "WARNING" if measurement['status'] in ['HIGH', 'LOW', 'CRITICAL'] else "OK" | |
| logger.info(f" {status_indicator} {measurement['measurement_type']}: {measurement['value']} {measurement['unit']} ({measurement['status']})") | |
| return report_id, measurements | |
| except Exception as e: | |
| logger.error(f"Failed to save report: {e}") | |
| logger.error(f"Report data: {report_data}") | |
| return None, measurements | |
| def get_past_reports_from_sqllite(user_id: str): | |
| try: | |
| reports = db_fetch_reports(user_id=user_id, limit=10, offset=0) | |
| history_text = "" | |
| for report in reports: | |
| history_text += f"Report from {report.get('report_date', 'N/A')}:\n{report.get('ocr_text', 'No OCR text found')}\n\n" | |
| logger.info(f"Retrieved {len(reports)} past reports for user {user_id}") | |
| return history_text | |
| except Exception as e: | |
| logger.error(f"Error fetching past reports: {e}") | |
| return "No past reports found for this user." | |
| async def chat_endpoint(request: ChatRequest): | |
| global result | |
| logger.info(f"Received chat request for user: {request.user_id}") | |
| full_document_text = get_past_reports_from_sqllite(request.user_id.strip()) | |
| full_document_text = EXTRACTED_TEXT_CACHE + "\n\n" + "PAST REPORTS:\n" + full_document_text | |
| logger.info(f"Full document text length: {len(full_document_text)}") | |
| if not full_document_text.strip(): | |
| raise HTTPException(status_code=400, detail="No past reports or current data exists for this user") | |
| try: | |
| full_prompt = system_prompt_chat.format( | |
| document_text=full_document_text, | |
| user_question=request.question | |
| ) | |
| logger.info(f"Generated chat prompt length: {len(full_prompt)}") | |
| response = model.generate_content(full_prompt) | |
| return ChatResponse(answer=response.text) | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Chat error: {e}") | |
| async def analyze_endpoint( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(None), | |
| user_id: str = Form("anonymous") | |
| ): | |
| global result, EXTRACTED_TEXT_CACHE | |
| filename = file.filename.lower() | |
| logger.info(f"Received analyze request for file {filename} from user {user_id}") | |
| contents = await file.read() | |
| mime = file.content_type or "image/png" | |
| try: | |
| analysis_result, ocr_text, measurements_data = await analyze_image(contents, mime, prompt) | |
| EXTRACTED_TEXT_CACHE = ocr_text | |
| result = analysis_result | |
| report_id, measurements = save_analysis_with_measurements( | |
| user_id=user_id, | |
| ocr_text=ocr_text, | |
| analysis_data=analysis_result, | |
| measurements_data=measurements_data | |
| ) | |
| response_data = { | |
| "report_id": report_id, | |
| "ocr_text": ocr_text, | |
| "Detected_Anomolies": analysis_result, | |
| "measurements": measurements, | |
| "measurement_count": len(measurements) | |
| } | |
| logger.info(f"Analysis complete. Report ID: {report_id}, Measurements: {len(measurements)}") | |
| return JSONResponse(content=response_data) | |
| except Exception as e: | |
| logger.error(f"Analysis error: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_json(req: AnalyzeRequest): | |
| import base64 | |
| image_bytes = base64.b64decode(req.image_base64) | |
| result, ocr_text, measurements = await analyze_image(image_bytes, "image/png", req.prompt) | |
| return { | |
| "result": result, | |
| "ocr_text": ocr_text, | |
| "measurements": measurements | |
| } | |
| async def get_report_measurements(report_id: int): | |
| try: | |
| report = db_get_report(report_id) | |
| if not report: | |
| raise HTTPException(status_code=404, detail="Report not found") | |
| measurements_json = report.get('measurements', '[]') | |
| if isinstance(measurements_json, str): | |
| measurements = json.loads(measurements_json) | |
| else: | |
| measurements = measurements_json or [] | |
| logger.info(f"Retrieved {len(measurements)} measurements for report {report_id}") | |
| return JSONResponse(content={ | |
| "report_id": report_id, | |
| "measurements": measurements, | |
| "measurement_count": len(measurements) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error getting measurements for report {report_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_user_measurements(user_id: str): | |
| try: | |
| reports = db_fetch_reports(user_id=user_id, limit=100, offset=0) | |
| all_measurements = [] | |
| for report in reports: | |
| measurements_json = report.get('measurements', '[]') | |
| if isinstance(measurements_json, str): | |
| measurements = json.loads(measurements_json) | |
| else: | |
| measurements = measurements_json or [] | |
| if measurements: | |
| for measurement in measurements: | |
| measurement['report_id'] = report['id'] | |
| measurement['report_date'] = report['report_date'] | |
| measurement['created_at'] = report['created_at'] | |
| all_measurements.append(measurement) | |
| all_measurements.sort(key=lambda x: x.get('created_at', ''), reverse=True) | |
| logger.info(f"Retrieved {len(all_measurements)} total measurements for user {user_id}") | |
| return JSONResponse(content={ | |
| "user_id": user_id, | |
| "total_measurements": len(all_measurements), | |
| "measurements": all_measurements | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error getting user measurements for {user_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_measurement_trends(user_id: str, measurement_type: str = None): | |
| try: | |
| reports = db_fetch_reports(user_id=user_id, limit=100, offset=0) | |
| trends = {} | |
| for report in reports: | |
| measurements_json = report.get('measurements', '[]') | |
| if isinstance(measurements_json, str): | |
| measurements = json.loads(measurements_json) | |
| else: | |
| measurements = measurements_json or [] | |
| if measurements: | |
| for measurement in measurements: | |
| m_type = measurement['measurement_type'] | |
| if measurement_type and m_type.lower() != measurement_type.lower(): | |
| continue | |
| if m_type not in trends: | |
| trends[m_type] = [] | |
| trends[m_type].append({ | |
| "date": report['report_date'] or report['created_at'], | |
| "value": measurement['value'], | |
| "unit": measurement['unit'], | |
| "status": measurement['status'], | |
| "severity": measurement['severity'], | |
| "report_id": report['id'] | |
| }) | |
| for m_type in trends: | |
| trends[m_type].sort(key=lambda x: x['date']) | |
| logger.info(f"Retrieved trends for {len(trends)} measurement types for user {user_id}") | |
| return JSONResponse(content={ | |
| "user_id": user_id, | |
| "measurement_type_filter": measurement_type, | |
| "trends": trends | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error getting measurement trends for {user_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def test_database(): | |
| try: | |
| test_reports = db_fetch_reports(user_id="test_user", limit=5, offset=0) | |
| test_data = { | |
| "user_id": "test_user", | |
| "report_date": datetime.now().strftime("%Y-%m-%d"), | |
| "ocr_text": "Test OCR text", | |
| "anomalies": json.dumps([{"test": "data"}]), | |
| "measurements": json.dumps([{"measurement_type": "Test", "value": 100, "unit": "mg/dL", "status": "NORMAL"}]) | |
| } | |
| test_report_id = db_insert_report(test_data) | |
| return JSONResponse(content={ | |
| "database_status": "connected", | |
| "existing_reports": len(test_reports), | |
| "test_report_id": test_report_id, | |
| "test_successful": True | |
| }) | |
| except Exception as e: | |
| logger.error(f"Database test failed: {e}") | |
| return JSONResponse(content={ | |
| "database_status": "error", | |
| "error": str(e), | |
| "test_successful": False | |
| }, status_code=500) | |
| def health(): | |
| return {"response": "ok"} | |
| def _log_routes(): | |
| from fastapi.routing import APIRoute | |
| print("Mounted routes:") | |
| for r in app.routes: | |
| if isinstance(r, APIRoute): | |
| print(" ", r.path, r.methods) | |
| def main(): | |
| try: | |
| logger.info(f"Starting server on 8000") | |
| logger.info(f"Debug mode: true") | |
| if Config.DEBUG: | |
| uvicorn.run( | |
| "main:app", | |
| host="localhost", | |
| port=8000, | |
| reload=True, | |
| log_level="debug" | |
| ) | |
| else: | |
| uvicorn.run( | |
| app, | |
| host="localhost", | |
| port=8000, | |
| reload=False, | |
| log_level="info" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start server: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |