PRMSChallengeOct / backend.py
vineelagampa's picture
Update backend.py (#30)
cf070a2 verified
raw
history blame
6.52 kB
import os
import base64
import json
import re
import asyncio
import functools
from typing import Any, Optional
import google.generativeai as genai
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
class AnalyzeRequest(BaseModel):
image_base64: str
prompt: str | None = None
API_KEY = None
try:
from api_key import GEMINI_API_KEY as API_KEY # <-- match the name in api_key.py
except ImportError:
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
raise RuntimeError(
"No Google API key found. Put it in api_key.py as `GEMINI_API_KEY = '...'` or set env var GEMINI_API_KEY."
)
genai.configure(api_key=API_KEY)
generation_config = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 2048,
"response_mime_type": "application/json",
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
system_prompt = """ As a highly skilled medical practitioner specializing in image analysis, you are tasked with examining medical images for a renowned hospital. Your expertise is crucial in identifying any anomalies, diseases, or health issues that may be present in the images. Your responsibilities include:
1. Detailed Analysis: Thoroughly analyze each image, focusing on identifying any abnormal findings that may indicate underlying medical conditions.
2. Finding Report: Document all observed anomalies or signs of disease. Clearly articulate these findings in a structured report format, ensuring accuracy and clarity.
3. Recommendations and Next Steps: Provide detailed recommendations based on your findings. Outline the necessary follow-up actions or additional tests required to confirm diagnoses or assess treatment options.
4. Treatment Suggestions: Offer preliminary treatment suggestions or interventions based on the identified conditions, collaborating with the healthcare team to develop comprehensive patient care plans.
5. Output Format: Your output should be a JSON array (list) of objects, each describing one disease or medical finding using the structure below:
[{"findings": "Description of the first disease or condition.", "severity": "MILD/SEVERE/CRITICAL", "recommendations": ["Follow-up test 1", "Follow-up test 2"], "treatment_suggestions": ["Treatment 1", "Treatment 2"], "home_care_guidance": ["Care tip 1", "Care tip 2"] }, { "findings": "Description of the second disease or condition.", "severity": "MILD/SEVERE/CRITICAL", "recommendations": ["Follow-up test A", "Follow-up test B"], "treatment_suggestions": ["Treatment A", "Treatment B"], "home_care_guidance": ["Care tip A", "Care tip B"] } ]
Important Notes: 1. Scope of Response: Only respond if the image pertains to a human health issue. 2. Clarity of Image: Ensure the image is clear and suitable for accurate analysis. 3. Disclaimer: Accompany your analysis with the disclaimer: “Consult with a doctor before making any decisions.” 4. Your Insights are Invaluable: Your insights play a crucial role in guiding clinical decisions. Please proceed with your analysis, adhering to the structured approach outlined above. """
# Initialize model
model = genai.GenerativeModel(model_name="gemini-2.5-flash-lite")
app = FastAPI()
async def _call_model_blocking(request_inputs, generation_cfg, safety_cfg):
"""Run blocking model call in threadpool (so uvicorn's event loop isn't blocked)."""
fn = functools.partial(
model.generate_content,
request_inputs,
generation_config=generation_cfg,
safety_settings=safety_cfg,
)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, fn)
async def analyze_image(image_bytes: bytes, mime_type: str, prompt: Optional[str] = None) -> Any:
base64_img = base64.b64encode(image_bytes).decode("utf-8")
text_prompt = (prompt or system_prompt).strip()
# prepare request — two messages: image inline + text prompt
request_inputs = [
{"inline_data": {"mime_type": mime_type, "data": base64_img}},
{"text": text_prompt},
]
try:
response = await _call_model_blocking(request_inputs, generation_config, safety_settings)
except Exception as e:
raise RuntimeError(f"Model call failed: {e}")
# Try to extract textual content robustly
text = getattr(response, "text", None)
if not text and isinstance(response, dict):
# older or alternative shapes
candidates = response.get("candidates") or []
if candidates:
text = candidates[0].get("content") or candidates[0].get("text")
if not text:
text = str(response)
# remove triple-backtick fences and stray code hints
clean = re.sub(r"```(?:json)?", "", text).strip()
# Try to parse JSON. If strict parse fails, try to extract first JSON-like block.
try:
parsed = json.loads(clean)
return parsed
except json.JSONDecodeError:
match = re.search(r"(\[.*\]|\{.*\})", clean, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
return {"raw_found_json": match.group(1)}
return {"raw_output": clean}
@app.post("/analyze")
async def analyze_endpoint(file: UploadFile = File(...), prompt: str = Form(None)):
"""
Upload an image file (field name `file`) and optional text `prompt`.
Returns parsed JSON (or raw model output if JSON couldn't be parsed).
"""
contents = await file.read() # <-- this gets the uploaded file bytes
mime = file.content_type or "image/png"
result = await analyze_image(contents, mime, prompt)
try:
result = await analyze_image(contents, mime, prompt)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content={"Detected_Anomolies": result})
@app.post("/analyze_json")
async def analyze_json(req: AnalyzeRequest):
import base64
image_bytes = base64.b64decode(req.image_base64)
result = await analyze_image(image_bytes, "image/png", req.prompt)
return {"result": result}