Spaces:
Sleeping
Sleeping
| import os | |
| from ast import List | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import io | |
| #import fitz | |
| import traceback | |
| import pandas as pd | |
| import base64 | |
| import json | |
| import re | |
| import asyncio | |
| import functools | |
| from typing import Any, Optional | |
| import google.generativeai as genai | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from google.generativeai import generative_models | |
| from pydantic import BaseModel | |
| from past_reports import router as reports_router, db_fetch_reports | |
| from api_key import GEMINI_API_KEY | |
| app = FastAPI() | |
| api = APIRouter(prefix="/api") | |
| app.include_router(api) | |
| EXTRACTED_TEXT_CACHE = "" | |
| app.mount("/app", StaticFiles(directory="web", html=True), name="web") | |
| app.include_router(reports_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return RedirectResponse(url="/app/") | |
| class AnalyzeRequest(BaseModel): | |
| image_base64: str | |
| prompt: Optional[str] = None | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", GEMINI_API_KEY) | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError( | |
| "No Gemini API key found. Put it in api_key.py as `GEMINI_API_KEY = '...'` or set env var GEMINI_API_KEY." | |
| ) | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| generation_config = { | |
| "temperature": 0.2, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_output_tokens": 2048, | |
| } | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| ] | |
| class ChatRequest(BaseModel): | |
| user_id: Optional[str] = "anonymous" | |
| question: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| class TextRequest(BaseModel): | |
| text: str | |
| system_prompt = """ You are a highly skilled medical practitioner specializing in medical image and document analysis. You will be given either a medical image or a PDF. | |
| Your responsibilities are: | |
| 1. **Extract Text**: If the input is a PDF or image, first extract all the text content (lab values, notes, measurements, etc.). Do not summarize — keep the extracted text verbatim. | |
| 2. **Detailed Analysis**: Use both the extracted text and the visual features of the image to identify any anomalies, diseases, or health issues. | |
| 3. **Finding Report**: Document all observed anomalies or signs of disease. | |
| - Include any measurements (e.g., triglycerides, HBa1c, HDL) in the format: | |
| `{"findings": "Condition only if risky: measurement type -- value with unit(current range)"}` | |
| - Simplify the finding in **3 words** at the beginning when helpful. | |
| 4. **Checking for Past**: If a disease is family history or previously recovered, mark severity as: | |
| `"severity": "severity of anomaly (Past Anomaly but Still Under Risk)"` | |
| 5. **Recommendations and Next Steps**: Provide detailed recommendations (tests, follow-ups, consultations). | |
| 6. **Treatment Suggestions**: Offer preliminary treatments or interventions. | |
| 7. **Output Format**: Always return a JSON object containing both the raw extracted text and the structured analysis, like this: | |
| ```json | |
| { | |
| "ocr_text": "<<<FULL VERBATIM TEXT FROM THE PDF/IMAGE>>>", | |
| "analysis": [ | |
| { | |
| "findings": "UPPERCASE MAIN CONCERN. Description of the first disease or condition.", | |
| "severity": "MILD/SEVERE/CRITICAL", | |
| "recommendations": ["Follow-up test 1", "Follow-up test 2"], | |
| "treatment_suggestions": ["Treatment 1", "Treatment 2"], | |
| "home_care_guidance": ["Care tip 1", "Care tip 2"] | |
| }, | |
| { | |
| "findings": "UPPERCASE MAIN CONCERN. Description of the second disease or condition.", | |
| "severity": "MILD/SEVERE/CRITICAL", | |
| "recommendations": ["Follow-up test A", "Follow-up test B"], | |
| "treatment_suggestions": ["Treatment A", "Treatment B"], | |
| "home_care_guidance": ["Care tip A", "Care tip B"] | |
| } | |
| ] | |
| } | |
| Important Notes: | |
| 1. Scope of Response: Only respond if the image pertains to a human health issue. | |
| 2. Clarity of Image: Ensure the image is clear and suitable for accurate analysis. | |
| 3. As an AI, don't produce any output than doesn't follow this format and also remember that we need your input on the medical report for a team of renowned doctors to recheck your readings and determine AI's role in the medical industry. | |
| 4. Your Insights are Invaluable: Your insights play a crucial role in guiding clinical decisions. Please proceed with your analysis, adhering to the structured approach outlined above. | |
| 5. Completely UPPERCASE the main concern in the finding """ | |
| system_prompt_chat = """ | |
| *** Role: Medical Chat Assistant *** | |
| You are a concise and empathetic medical chatbot. Your job is to give clear, short answers (max 3-4 sentences) based only on the provided medical report text. | |
| Rules: | |
| - Avoid repeating the entire report; focus only on what is directly relevant to the user’s question. | |
| - Give top 2 actionable steps if needed. | |
| - If condition is serious, suggest consulting a doctor immediately. | |
| - Always end with: "Check with your physician before acting." | |
| Input: | |
| Report Text: {document_text} | |
| User Question: {user_question} | |
| Response: | |
| """ | |
| model = genai.GenerativeModel(model_name="gemini-2.5-flash-lite") | |
| async def _call_model_blocking(request_inputs, generation_cfg, safety_cfg): | |
| fn = functools.partial( | |
| model.generate_content, | |
| request_inputs, | |
| generation_config=generation_cfg, | |
| safety_settings=safety_cfg, | |
| ) | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, fn) | |
| async def analyze_image(image_bytes: bytes, mime_type: str, prompt: Optional[str] = None) -> Any: | |
| base64_img = base64.b64encode(image_bytes).decode("utf-8") | |
| text_prompt = (prompt or system_prompt).strip() | |
| request_inputs = [ | |
| {"inline_data": {"mime_type": mime_type, "data": base64_img}}, | |
| {"text": text_prompt}, | |
| ] | |
| try: | |
| response = await _call_model_blocking(request_inputs, generation_config, safety_settings) | |
| except Exception as e: | |
| raise RuntimeError(f"Model call failed: {e}") | |
| text = getattr(response, "text", None) | |
| if not text and isinstance(response, dict): | |
| candidates = response.get("candidates") or [] | |
| if candidates: | |
| text = candidates[0].get("content") or candidates[0].get("text") | |
| if not text: | |
| text = str(response) | |
| clean = re.sub(r"```(?:json)?", "", text).strip() | |
| print(f"Cleaned text: {clean}") | |
| try: | |
| parsed = json.loads(clean) | |
| ocr_text = parsed["ocr_text"] | |
| analysis = parsed["analysis"] | |
| print(f"Parsed JSON: {parsed}") | |
| return analysis,ocr_text | |
| except json.JSONDecodeError: | |
| match = re.search(r"(\[.*\]|\{.*\})", clean, re.DOTALL) | |
| if match: | |
| try: | |
| parsed = json.loads(match.group(1)), None | |
| ocr_text = parsed["ocr_text"] | |
| analysis = parsed["analysis"] | |
| return analysis, ocr_text | |
| except json.JSONDecodeError: | |
| return {"raw_found_json": match.group(1)}, None | |
| return {"raw_output": clean}, None | |
| def get_past_reports_from_sqllite(user_id: str): | |
| try: | |
| reports = db_fetch_reports(user_id=user_id, limit=10, offset=0) | |
| history_text = "" | |
| for report in reports: | |
| history_text += f"Report from {report.get('report_date', 'N/A')}:\n{report.get('ocr_text', 'No OCR text found')}\n\n" | |
| except Exception as e: | |
| history_text = "No past reports found for this user." | |
| return history_text | |
| async def chat_endpoint(request: ChatRequest): | |
| global result | |
| print(f"Received chat request for user: {request.user_id}") | |
| full_document_text = get_past_reports_from_sqllite(request.user_id.strip()) | |
| full_document_text = EXTRACTED_TEXT_CACHE+"\n\n" + "PAST REPORTS:\n" + full_document_text | |
| print(f"Full document text: {full_document_text}") | |
| if not full_document_text: | |
| raise HTTPException(status_code=400, detail="No past reports or current data exists for this user") | |
| try: | |
| full_prompt = system_prompt_chat.format( | |
| document_text=full_document_text, | |
| user_question=request.question | |
| ) | |
| print(f"Full prompt: {full_prompt}") | |
| response = model.generate_content(full_prompt) | |
| return ChatResponse(answer=response.text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Chat error: {e}") | |
| async def analyze_endpoint(file: UploadFile = File(...), prompt: str = Form(None)): | |
| global result,EXTRACTED_TEXT_CACHE | |
| filename = file.filename.lower() | |
| print(f"Received analyze request for file {filename}") | |
| contents = await file.read() | |
| mime = file.content_type or "image/png" | |
| try: | |
| result, ocr_text = await analyze_image(contents, mime, prompt) | |
| EXTRACTED_TEXT_CACHE = ocr_text | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return JSONResponse(content={ | |
| "ocr_text": ocr_text, | |
| "Detected_Anomolies": result | |
| }) | |
| async def analyze_json(req: AnalyzeRequest): | |
| import base64 | |
| image_bytes = base64.b64decode(req.image_base64) | |
| result = await analyze_image(image_bytes, "image/png", req.prompt) | |
| return {"result": result} | |
| def health(): | |
| return {"response": "ok"} | |
| def _log_routes(): | |
| from fastapi.routing import APIRoute | |
| print("Mounted routes:") | |
| for r in app.routes: | |
| if isinstance(r, APIRoute): | |
| print(" ", r.path, r.methods) | |
| def main(): | |
| try: | |
| logger.info(f"Starting server on 8000") | |
| logger.info(f"Debug mode: true") | |
| if Config.DEBUG: | |
| uvicorn.run( | |
| "main:app", | |
| host="localhost", | |
| port="8000", | |
| reload=True, | |
| log_level="debug" | |
| ) | |
| else: | |
| uvicorn.run( | |
| app, | |
| host="localhost", | |
| port="8000", | |
| reload=False, | |
| log_level="info" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start server: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() | |