Spaces:
Sleeping
Sleeping
| import os | |
| from ast import List | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import io | |
| #import fitz | |
| import traceback | |
| import pandas as pd | |
| import base64 | |
| import json | |
| import re | |
| import asyncio | |
| import functools | |
| from typing import Any, Optional | |
| import google.generativeai as genai | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from google.generativeai import generative_models | |
| from pydantic import BaseModel | |
| from past_reports import router as reports_router, db_fetch_reports | |
| from api_key import GEMINI_API_KEY | |
| app = FastAPI() | |
| api = APIRouter(prefix="/api") | |
| app.include_router(api) | |
| EXTRACTED_TEXT_CACHE = "" | |
| app.mount("/app", StaticFiles(directory="web", html=True), name="web") | |
| app.include_router(reports_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return RedirectResponse(url="/app/") | |
| class AnalyzeRequest(BaseModel): | |
| image_base64: str | |
| prompt: Optional[str] = None | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", GEMINI_API_KEY) | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError( | |
| "No Gemini API key found. Put it in api_key.py as `GEMINI_API_KEY = '...'` or set env var GEMINI_API_KEY." | |
| ) | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| generation_config = { | |
| "temperature": 0.2, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_output_tokens": 2048, | |
| } | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| ] | |
| # --- Pydantic Models for API Endpoints --- | |
| class ChatRequest(BaseModel): | |
| user_id: Optional[str] = "anonymous" | |
| question: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| class TextRequest(BaseModel): | |
| text: str | |
| system_prompt = """ You are a highly skilled medical practitioner specializing in medical image and document analysis. You will be given either a medical image or a PDF. | |
| Your responsibilities are: | |
| 1. **Extract Text**: If the input is a PDF or image, first extract all the text content (lab values, notes, measurements, etc.). Do not summarize — keep the extracted text verbatim. | |
| 2. **Detailed Analysis**: Use both the extracted text and the visual features of the image to identify any anomalies, diseases, or health issues. | |
| 3. **Finding Report**: Document all observed anomalies or signs of disease. | |
| - Include any measurements (e.g., triglycerides, HBa1c, HDL) in the format: | |
| `{"findings": "Condition only if risky: measurement type -- value with unit(current range)"}` | |
| - Simplify the finding in **3 words** at the beginning when helpful. | |
| 4. **Checking for Past**: If a disease is family history or previously recovered, mark severity as: | |
| `"severity": "severity of anomaly (Past Anomaly but Still Under Risk)"` | |
| 5. **Recommendations and Next Steps**: Provide detailed recommendations (tests, follow-ups, consultations). | |
| 6. **Treatment Suggestions**: Offer preliminary treatments or interventions. | |
| 7. **Output Format**: Always return a JSON object containing both the raw extracted text and the structured analysis, like this: | |
| ```json | |
| { | |
| "ocr_text": "<<<FULL VERBATIM TEXT FROM THE PDF/IMAGE>>>", | |
| "analysis": [ | |
| { | |
| "findings": "UPPERCASE MAIN CONCERN. Description of the first disease or condition.", | |
| "severity": "MILD/SEVERE/CRITICAL", | |
| "recommendations": ["Follow-up test 1", "Follow-up test 2"], | |
| "treatment_suggestions": ["Treatment 1", "Treatment 2"], | |
| "home_care_guidance": ["Care tip 1", "Care tip 2"] | |
| }, | |
| { | |
| "findings": "UPPERCASE MAIN CONCERN. Description of the second disease or condition.", | |
| "severity": "MILD/SEVERE/CRITICAL", | |
| "recommendations": ["Follow-up test A", "Follow-up test B"], | |
| "treatment_suggestions": ["Treatment A", "Treatment B"], | |
| "home_care_guidance": ["Care tip A", "Care tip B"] | |
| } | |
| ] | |
| } | |
| Important Notes: | |
| 1. Scope of Response: Only respond if the image pertains to a human health issue. | |
| 2. Clarity of Image: Ensure the image is clear and suitable for accurate analysis. | |
| 3. As an AI, don't produce any output than doesn't follow this format and also remember that we need your input on the medical report for a team of renowned doctors to recheck your readings and determine AI's role in the medical industry. | |
| 4. Your Insights are Invaluable: Your insights play a crucial role in guiding clinical decisions. Please proceed with your analysis, adhering to the structured approach outlined above. | |
| 5. Completely UPPERCASE the main concern in the finding """ | |
| system_prompt_chat = """ | |
| *** Role: Medical Guidance Facilitator | |
| *** Objective: | |
| Analyze medical data, provide concise, evidence-based insights, and recommend actionable next steps for patient care. This includes suggesting local physicians or specialists within a user-specified mile radius, prioritizing in-network options when insurance information is available, and maintaining strict safety compliance with appropriate disclaimers. | |
| *** Capabilities: | |
| 1. Report Analysis – Review and interpret findings in uploaded medical reports. | |
| 2. Historical Context – Compare current findings with any available previous reports. | |
| 3. Medical Q&A – Answer specific questions about the report using trusted medical sources. | |
| 4. Specialist Matching – Recommend relevant physician specialties for identified conditions. | |
| 5. Local Physician Recommendations – List at least two real physician or clinic options within the user-specified mile radius (include name, specialty, address, distance from user, and contact info) based on the patient’s location and clinical need. | |
| 6. Insurance Guidance – If insurance/network information is provided, prioritize in-network physicians. | |
| 7. Safety Protocols – Include a brief disclaimer encouraging users to verify information, confirm insurance coverage, and consult providers directly. | |
| *** Response Structure: | |
| Start with a direct answer to the user’s primary question (maximum 4 concise sentences, each on a new line). | |
| If a physician/specialist is needed, recommend at least two local providers within the requested radius (include name, specialty, address, distance, and contact info). | |
| If insurance details are available, indicate which physicians are in-network. | |
| End with a short safety disclaimer. | |
| ***Input Fields: | |
| Provided Document Text: {document_text} | |
| User Question: {user_question} | |
| Assistant Answer: | |
| """ | |
| # Initialize model | |
| model = genai.GenerativeModel(model_name="gemini-2.5-flash-lite") | |
| async def _call_model_blocking(request_inputs, generation_cfg, safety_cfg): | |
| """Run blocking model call in threadpool (so uvicorn's event loop isn't blocked).""" | |
| fn = functools.partial( | |
| model.generate_content, | |
| request_inputs, | |
| generation_config=generation_cfg, | |
| safety_settings=safety_cfg, | |
| ) | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, fn) | |
| async def analyze_image(image_bytes: bytes, mime_type: str, prompt: Optional[str] = None) -> Any: | |
| base64_img = base64.b64encode(image_bytes).decode("utf-8") | |
| text_prompt = (prompt or system_prompt).strip() | |
| request_inputs = [ | |
| {"inline_data": {"mime_type": mime_type, "data": base64_img}}, | |
| {"text": text_prompt}, | |
| ] | |
| try: | |
| response = await _call_model_blocking(request_inputs, generation_config, safety_settings) | |
| except Exception as e: | |
| raise RuntimeError(f"Model call failed: {e}") | |
| text = getattr(response, "text", None) | |
| if not text and isinstance(response, dict): | |
| candidates = response.get("candidates") or [] | |
| if candidates: | |
| text = candidates[0].get("content") or candidates[0].get("text") | |
| if not text: | |
| text = str(response) | |
| clean = re.sub(r"```(?:json)?", "", text).strip() | |
| print(f"Cleaned text: {clean}") | |
| try: | |
| parsed = json.loads(clean) | |
| ocr_text = parsed["ocr_text"] | |
| analysis = parsed["analysis"] | |
| print(f"Parsed JSON: {parsed}") | |
| return analysis,ocr_text | |
| except json.JSONDecodeError: | |
| match = re.search(r"(\[.*\]|\{.*\})", clean, re.DOTALL) | |
| if match: | |
| try: | |
| parsed = json.loads(match.group(1)), None | |
| ocr_text = parsed["ocr_text"] | |
| analysis = parsed["analysis"] | |
| return analysis, ocr_text | |
| except json.JSONDecodeError: | |
| return {"raw_found_json": match.group(1)}, None | |
| return {"raw_output": clean}, None | |
| def get_past_reports_from_sqllite(user_id: str): | |
| try: | |
| reports = db_fetch_reports(user_id=user_id, limit=10, offset=0) | |
| history_text = "" | |
| for report in reports: | |
| history_text += f"Report from {report.get('report_date', 'N/A')}:\n{report.get('ocr_text', 'No OCR text found')}\n\n" | |
| except Exception as e: | |
| history_text = "No past reports found for this user." | |
| return history_text | |
| async def chat_endpoint(request: ChatRequest): | |
| global result | |
| print(f"Received chat request for user: {request.user_id}") | |
| """ | |
| Chatbot endpoint that answers questions based on the last analyzed document and user history. | |
| """ | |
| #history_text = get_past_reports_from_firestore(request.user_id) | |
| full_document_text = get_past_reports_from_sqllite(request.user_id.strip()) | |
| full_document_text = EXTRACTED_TEXT_CACHE+"\n\n" + "PAST REPORTS:\n" + full_document_text | |
| print(f"Full document text: {full_document_text}") | |
| if not full_document_text: | |
| raise HTTPException(status_code=400, detail="No past reports or current data exists for this user") | |
| try: | |
| full_prompt = system_prompt_chat.format( | |
| document_text=full_document_text, | |
| user_question=request.question | |
| ) | |
| print(f"Full prompt: {full_prompt}") | |
| response = model.generate_content(full_prompt) | |
| return ChatResponse(answer=response.text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Chat error: {e}") | |
| async def analyze_endpoint(file: UploadFile = File(...), prompt: str = Form(None)): | |
| """ | |
| Upload an image file (field name `file`) and optional text `prompt`. | |
| Returns parsed JSON (or raw model output if JSON couldn't be parsed). | |
| """ | |
| global result,EXTRACTED_TEXT_CACHE | |
| filename = file.filename.lower() | |
| print(f"Received analyze request for file {filename}") | |
| contents = await file.read() # <-- this gets the uploaded file bytes | |
| mime = file.content_type or "image/png" | |
| #result = await analyze_image(contents, mime, prompt) | |
| try: | |
| result, ocr_text = await analyze_image(contents, mime, prompt) | |
| EXTRACTED_TEXT_CACHE = ocr_text | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return JSONResponse(content={ | |
| "ocr_text": ocr_text, | |
| "Detected_Anomolies": result | |
| }) | |
| async def analyze_json(req: AnalyzeRequest): | |
| import base64 | |
| image_bytes = base64.b64decode(req.image_base64) | |
| result = await analyze_image(image_bytes, "image/png", req.prompt) | |
| return {"result": result} | |
| def health(): | |
| return {"response": "ok"} | |
| def _log_routes(): | |
| from fastapi.routing import APIRoute | |
| print("Mounted routes:") | |
| for r in app.routes: | |
| if isinstance(r, APIRoute): | |
| print(" ", r.path, r.methods) | |
| def get_user_results(user_id: str): | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT ocr_text, anomalies, created_at FROM reports WHERE user_id = ? ORDER BY created_at ASC", | |
| (user_id,) | |
| ) | |
| rows = cursor.fetchall() | |
| # On-the-fly "latest anomaly per measurement" | |
| latest_anomalies = {} # measurement -> anomaly dict | |
| combined_reports = [] # optional: keep full report data if needed | |
| for ocr_text, anomalies_json, created_at in rows: | |
| anomalies = json.loads(anomalies_json or "[]") | |
| for a in anomalies: | |
| meas = a.get("measurement") | |
| # Always keep the latest anomaly (rows are ordered oldest → newest) | |
| latest_anomalies[meas] = a | |
| combined_reports.append({ | |
| "ocr_text": ocr_text, | |
| "created_at": created_at, | |
| "anomalies": anomalies | |
| }) | |
| return { | |
| "user_id": user_id, | |
| "latest_anomalies": list(latest_anomalies.values()), | |
| "all_reports": combined_reports | |
| } | |
| def main(): | |
| """Run the application.""" | |
| try: | |
| logger.info(f"Starting server on 8000") | |
| logger.info(f"Debug mode: true") | |
| if Config.DEBUG: | |
| # Use import string for reload mode | |
| uvicorn.run( | |
| "main:app", | |
| host="localhost", | |
| port="8000", | |
| reload=True, | |
| log_level="debug" | |
| ) | |
| else: | |
| # Use app instance for production | |
| uvicorn.run( | |
| app, | |
| host="localhost", | |
| port="8000", | |
| reload=False, | |
| log_level="info" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start server: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |