0xKer's picture
Upload 7 files
a1b2be7 verified
"""
Assnani Dental Chatbot — FastAPI Backend
Serves the frontend, proxies YOLO API calls, runs the symptom
analysis + correlation engine, and generates AI-powered reports.
"""
import os
import io
import asyncio
import base64
import httpx
import fitz # PyMuPDF
from PIL import Image
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, field_validator
from typing import Optional, List
from symptom_analyzer import SymptomAnalyzer
from correlation_engine import CorrelationEngine
from dental_expert_system import DentalTreatmentRecommender
# --- Configuration ---
YOLO_API_URL = os.environ.get(
"YOLO_API_URL",
"https://0xker-dental-x-ray-detection.hf.space/predict"
)
TREAT_API_URL = os.environ.get(
"TREAT_API_URL",
"https://0xker-treat-recommend.hf.space/api/analyze"
)
MAX_RETRIES = 3
RETRY_BACKOFF_BASE = 2 # seconds
# --- App Setup ---
app = FastAPI(
title="Assnani Dental AI Chatbot",
description="Symptom-to-X-ray Correlation Chatbot",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# --- Initialize Engines ---
symptom_analyzer = SymptomAnalyzer()
correlation_engine = CorrelationEngine()
expert_system = DentalTreatmentRecommender(large_cavity_threshold=5000)
# --- Retry Helper ---
async def _request_with_retry(client, url, max_retries=MAX_RETRIES, **kwargs):
"""
Send a POST request with exponential backoff retry logic.
Retries on timeout, connection, and read errors.
"""
last_error = None
for attempt in range(max_retries):
try:
response = await client.post(url, **kwargs)
return response
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e:
last_error = e
if attempt < max_retries - 1:
wait = RETRY_BACKOFF_BASE ** attempt # 1s, 2s, 4s
print(f"[Retry {attempt + 1}/{max_retries}] {url}{e}. Waiting {wait}s...")
await asyncio.sleep(wait)
raise last_error
# --- Helpers ---
def extract_images_from_pdf(pdf_bytes: bytes) -> list:
"""
Extract all images from a PDF file.
Returns a list of tuples: (filename, image_bytes, content_type)
"""
images = []
try:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page_num in range(len(doc)):
page = doc[page_num]
image_list = page.get_images(full=True)
for img_idx, img_info in enumerate(image_list):
xref = img_info[0]
base_image = doc.extract_image(xref)
if base_image:
img_bytes = base_image["image"]
ext = base_image.get("ext", "png")
content_type = f"image/{ext}" if ext != "jpg" else "image/jpeg"
filename = f"pdf_page{page_num + 1}_img{img_idx + 1}.{ext}"
images.append((filename, img_bytes, content_type))
# If no embedded images found, render the page as an image
if not image_list:
pix = page.get_pixmap(dpi=200)
img_bytes = pix.tobytes("png")
filename = f"pdf_page{page_num + 1}.png"
images.append((filename, img_bytes, "image/png"))
doc.close()
except Exception as e:
print(f"PDF extraction error: {e}")
return images
async def process_uploaded_files(files: List[UploadFile]) -> list:
"""
Process uploaded files — extract images from PDFs and pass through image files.
Returns list of tuples: (filename, image_bytes, content_type)
"""
all_images = []
for f in files:
file_bytes = await f.read()
content_type = f.content_type or ""
if content_type == "application/pdf" or (f.filename and f.filename.lower().endswith(".pdf")):
# Extract images from PDF
pdf_images = extract_images_from_pdf(file_bytes)
if pdf_images:
all_images.extend(pdf_images)
else:
print(f"No images found in PDF: {f.filename}")
elif content_type.startswith("image/") or (f.filename and f.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))):
ct = content_type if content_type.startswith("image/") else "image/jpeg"
all_images.append((f.filename or "image.jpg", file_bytes, ct))
else:
print(f"Skipping unsupported file type: {f.filename} ({content_type})")
return all_images
# --- Pydantic Models ---
class SymptomData(BaseModel):
has_pain: bool = False
pain_location: str = ""
pain_type: str = ""
pain_intensity: int = 0
pain_duration: str = ""
pain_triggers: list = []
has_swelling: bool = False
swelling_severity: str = ""
has_fever: bool = False
difficulty_opening: bool = False
has_trauma: bool = False
has_broken_tooth: bool = False
previous_root_canal: bool = False
last_visit: str = ""
recent_extraction: bool = False
@field_validator('pain_intensity')
@classmethod
def clamp_intensity(cls, v):
"""Ensure pain intensity is within 0-10 range."""
return max(0, min(10, v))
@field_validator('pain_location', 'pain_type', 'pain_duration', 'swelling_severity', 'last_visit')
@classmethod
def sanitize_string(cls, v):
"""Strip whitespace and limit string length."""
if isinstance(v, str):
return v.strip()[:200]
return v
class CorrelationRequest(BaseModel):
symptoms: dict
detections: list
image_width: int = 0
image_height: int = 0
class TreatmentRequest(BaseModel):
api_response: Optional[dict] = None
detections: Optional[list] = None
# --- Routes ---
@app.get("/", response_class=HTMLResponse)
async def serve_index():
"""Serve the main chatbot page."""
return FileResponse("static/index.html")
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "Assnani Dental Chatbot"}
@app.post("/api/analyze-symptoms")
async def analyze_symptoms(data: SymptomData):
"""Analyze patient symptoms and return risk assessment."""
symptoms = data.model_dump()
result = symptom_analyzer.analyze(symptoms)
return JSONResponse(content=result)
@app.post("/api/detect-xray")
async def detect_xray(images: List[UploadFile] = File(...)):
"""
Receive X-ray image(s) or PDF uploads, forward each image to the YOLO API,
and return combined detection results.
"""
try:
all_images = await process_uploaded_files(images)
if not all_images:
return JSONResponse(
status_code=400,
content={"error": "No valid images found in the uploaded files."}
)
all_results = []
annotated_images_b64 = []
async with httpx.AsyncClient(timeout=60.0) as client:
for filename, img_bytes, content_type in all_images:
# Detection JSON — with retry
files = {"image": (filename, img_bytes, content_type)}
response = await _request_with_retry(client, YOLO_API_URL, files=files)
if response.status_code != 200:
all_results.append({
"filename": filename,
"detections": [],
"total_detections": 0,
"error": f"YOLO API returned {response.status_code}"
})
continue
det_data = response.json()
# The annotated image (result_image_b64) is already included
# in the YOLO API JSON response — no second request needed.
if det_data.get("results"):
result_entry = det_data["results"][0]
# Normalize key so frontend can always find it
if result_entry.get("result_image_b64") and not result_entry.get("annotated_image_b64"):
result_entry["annotated_image_b64"] = result_entry["result_image_b64"]
all_results.append(result_entry)
annotated_images_b64.append(result_entry.get("annotated_image_b64"))
# Build combined response
total_detections = sum(r.get("total_detections", len(r.get("detections", []))) for r in all_results)
combined = {
"results": all_results,
"success": True,
"total_images": len(all_results),
"total_detections": total_detections,
}
return JSONResponse(content=combined)
except httpx.TimeoutException:
return JSONResponse(status_code=504, content={"error": "YOLO API timed out after retries. Please try again."})
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Error processing X-ray: {str(e)}"})
@app.post("/api/correlate")
async def correlate_findings(data: CorrelationRequest):
"""Correlate patient symptoms with YOLO detection results."""
result = correlation_engine.correlate(
symptoms=data.symptoms,
detections=data.detections,
image_width=data.image_width,
image_height=data.image_height,
)
return JSONResponse(content=result)
@app.post("/api/treatment-plan")
async def get_treatment_plan(data: TreatmentRequest):
"""Generate treatment recommendations from YOLO detections."""
if data.api_response:
result = expert_system.analyze_api_response(data.api_response)
elif data.detections:
result = expert_system.analyze_detections(data.detections)
else:
return JSONResponse(status_code=400, content={"error": "Provide 'api_response' or 'detections'."})
return JSONResponse(content=result)
@app.post("/api/ai-report")
async def get_ai_report(images: List[UploadFile] = File(...)):
"""
Send X-ray image(s) / PDFs to the Gemini-powered treatment recommendation API.
Returns an AI-generated clinical report.
"""
try:
all_images = await process_uploaded_files(images)
if not all_images:
return JSONResponse(status_code=400, content={"error": "No valid images found."})
async with httpx.AsyncClient(timeout=120.0) as client:
# Send all images in one request — external API expects key "image"
files_list = [("image", (fn, img_bytes, ct)) for fn, img_bytes, ct in all_images]
data = {"model": "gemini-2.5-flash"}
response = await _request_with_retry(client, TREAT_API_URL, files=files_list, data=data)
if response.status_code != 200:
return JSONResponse(
status_code=502,
content={"error": f"Treatment API returned {response.status_code}", "detail": response.text[:500]}
)
return JSONResponse(content=response.json())
except httpx.TimeoutException:
return JSONResponse(status_code=504, content={"error": "AI report timed out. Gemini may be loading — try again."})
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Error generating AI report: {str(e)}"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)