""" Assnani Dental Chatbot — FastAPI Backend Serves the frontend, proxies YOLO API calls, runs the symptom analysis + correlation engine, and generates AI-powered reports. """ import os import io import asyncio import base64 import httpx import fitz # PyMuPDF from PIL import Image from fastapi import FastAPI, UploadFile, File, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, field_validator from typing import Optional, List from symptom_analyzer import SymptomAnalyzer from correlation_engine import CorrelationEngine from dental_expert_system import DentalTreatmentRecommender # --- Configuration --- YOLO_API_URL = os.environ.get( "YOLO_API_URL", "https://0xker-dental-x-ray-detection.hf.space/predict" ) TREAT_API_URL = os.environ.get( "TREAT_API_URL", "https://0xker-treat-recommend.hf.space/api/analyze" ) MAX_RETRIES = 3 RETRY_BACKOFF_BASE = 2 # seconds # --- App Setup --- app = FastAPI( title="Assnani Dental AI Chatbot", description="Symptom-to-X-ray Correlation Chatbot", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") # --- Initialize Engines --- symptom_analyzer = SymptomAnalyzer() correlation_engine = CorrelationEngine() expert_system = DentalTreatmentRecommender(large_cavity_threshold=5000) # --- Retry Helper --- async def _request_with_retry(client, url, max_retries=MAX_RETRIES, **kwargs): """ Send a POST request with exponential backoff retry logic. Retries on timeout, connection, and read errors. """ last_error = None for attempt in range(max_retries): try: response = await client.post(url, **kwargs) return response except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e: last_error = e if attempt < max_retries - 1: wait = RETRY_BACKOFF_BASE ** attempt # 1s, 2s, 4s print(f"[Retry {attempt + 1}/{max_retries}] {url} — {e}. Waiting {wait}s...") await asyncio.sleep(wait) raise last_error # --- Helpers --- def extract_images_from_pdf(pdf_bytes: bytes) -> list: """ Extract all images from a PDF file. Returns a list of tuples: (filename, image_bytes, content_type) """ images = [] try: doc = fitz.open(stream=pdf_bytes, filetype="pdf") for page_num in range(len(doc)): page = doc[page_num] image_list = page.get_images(full=True) for img_idx, img_info in enumerate(image_list): xref = img_info[0] base_image = doc.extract_image(xref) if base_image: img_bytes = base_image["image"] ext = base_image.get("ext", "png") content_type = f"image/{ext}" if ext != "jpg" else "image/jpeg" filename = f"pdf_page{page_num + 1}_img{img_idx + 1}.{ext}" images.append((filename, img_bytes, content_type)) # If no embedded images found, render the page as an image if not image_list: pix = page.get_pixmap(dpi=200) img_bytes = pix.tobytes("png") filename = f"pdf_page{page_num + 1}.png" images.append((filename, img_bytes, "image/png")) doc.close() except Exception as e: print(f"PDF extraction error: {e}") return images async def process_uploaded_files(files: List[UploadFile]) -> list: """ Process uploaded files — extract images from PDFs and pass through image files. Returns list of tuples: (filename, image_bytes, content_type) """ all_images = [] for f in files: file_bytes = await f.read() content_type = f.content_type or "" if content_type == "application/pdf" or (f.filename and f.filename.lower().endswith(".pdf")): # Extract images from PDF pdf_images = extract_images_from_pdf(file_bytes) if pdf_images: all_images.extend(pdf_images) else: print(f"No images found in PDF: {f.filename}") elif content_type.startswith("image/") or (f.filename and f.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))): ct = content_type if content_type.startswith("image/") else "image/jpeg" all_images.append((f.filename or "image.jpg", file_bytes, ct)) else: print(f"Skipping unsupported file type: {f.filename} ({content_type})") return all_images # --- Pydantic Models --- class SymptomData(BaseModel): has_pain: bool = False pain_location: str = "" pain_type: str = "" pain_intensity: int = 0 pain_duration: str = "" pain_triggers: list = [] has_swelling: bool = False swelling_severity: str = "" has_fever: bool = False difficulty_opening: bool = False has_trauma: bool = False has_broken_tooth: bool = False previous_root_canal: bool = False last_visit: str = "" recent_extraction: bool = False @field_validator('pain_intensity') @classmethod def clamp_intensity(cls, v): """Ensure pain intensity is within 0-10 range.""" return max(0, min(10, v)) @field_validator('pain_location', 'pain_type', 'pain_duration', 'swelling_severity', 'last_visit') @classmethod def sanitize_string(cls, v): """Strip whitespace and limit string length.""" if isinstance(v, str): return v.strip()[:200] return v class CorrelationRequest(BaseModel): symptoms: dict detections: list image_width: int = 0 image_height: int = 0 class TreatmentRequest(BaseModel): api_response: Optional[dict] = None detections: Optional[list] = None # --- Routes --- @app.get("/", response_class=HTMLResponse) async def serve_index(): """Serve the main chatbot page.""" return FileResponse("static/index.html") @app.get("/health") async def health_check(): return {"status": "healthy", "service": "Assnani Dental Chatbot"} @app.post("/api/analyze-symptoms") async def analyze_symptoms(data: SymptomData): """Analyze patient symptoms and return risk assessment.""" symptoms = data.model_dump() result = symptom_analyzer.analyze(symptoms) return JSONResponse(content=result) @app.post("/api/detect-xray") async def detect_xray(images: List[UploadFile] = File(...)): """ Receive X-ray image(s) or PDF uploads, forward each image to the YOLO API, and return combined detection results. """ try: all_images = await process_uploaded_files(images) if not all_images: return JSONResponse( status_code=400, content={"error": "No valid images found in the uploaded files."} ) all_results = [] annotated_images_b64 = [] async with httpx.AsyncClient(timeout=60.0) as client: for filename, img_bytes, content_type in all_images: # Detection JSON — with retry files = {"image": (filename, img_bytes, content_type)} response = await _request_with_retry(client, YOLO_API_URL, files=files) if response.status_code != 200: all_results.append({ "filename": filename, "detections": [], "total_detections": 0, "error": f"YOLO API returned {response.status_code}" }) continue det_data = response.json() # The annotated image (result_image_b64) is already included # in the YOLO API JSON response — no second request needed. if det_data.get("results"): result_entry = det_data["results"][0] # Normalize key so frontend can always find it if result_entry.get("result_image_b64") and not result_entry.get("annotated_image_b64"): result_entry["annotated_image_b64"] = result_entry["result_image_b64"] all_results.append(result_entry) annotated_images_b64.append(result_entry.get("annotated_image_b64")) # Build combined response total_detections = sum(r.get("total_detections", len(r.get("detections", []))) for r in all_results) combined = { "results": all_results, "success": True, "total_images": len(all_results), "total_detections": total_detections, } return JSONResponse(content=combined) except httpx.TimeoutException: return JSONResponse(status_code=504, content={"error": "YOLO API timed out after retries. Please try again."}) except Exception as e: return JSONResponse(status_code=500, content={"error": f"Error processing X-ray: {str(e)}"}) @app.post("/api/correlate") async def correlate_findings(data: CorrelationRequest): """Correlate patient symptoms with YOLO detection results.""" result = correlation_engine.correlate( symptoms=data.symptoms, detections=data.detections, image_width=data.image_width, image_height=data.image_height, ) return JSONResponse(content=result) @app.post("/api/treatment-plan") async def get_treatment_plan(data: TreatmentRequest): """Generate treatment recommendations from YOLO detections.""" if data.api_response: result = expert_system.analyze_api_response(data.api_response) elif data.detections: result = expert_system.analyze_detections(data.detections) else: return JSONResponse(status_code=400, content={"error": "Provide 'api_response' or 'detections'."}) return JSONResponse(content=result) @app.post("/api/ai-report") async def get_ai_report(images: List[UploadFile] = File(...)): """ Send X-ray image(s) / PDFs to the Gemini-powered treatment recommendation API. Returns an AI-generated clinical report. """ try: all_images = await process_uploaded_files(images) if not all_images: return JSONResponse(status_code=400, content={"error": "No valid images found."}) async with httpx.AsyncClient(timeout=120.0) as client: # Send all images in one request — external API expects key "image" files_list = [("image", (fn, img_bytes, ct)) for fn, img_bytes, ct in all_images] data = {"model": "gemini-2.5-flash"} response = await _request_with_retry(client, TREAT_API_URL, files=files_list, data=data) if response.status_code != 200: return JSONResponse( status_code=502, content={"error": f"Treatment API returned {response.status_code}", "detail": response.text[:500]} ) return JSONResponse(content=response.json()) except httpx.TimeoutException: return JSONResponse(status_code=504, content={"error": "AI report timed out. Gemini may be loading — try again."}) except Exception as e: return JSONResponse(status_code=500, content={"error": f"Error generating AI report: {str(e)}"}) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)