Spaces:
Sleeping
Sleeping
| """ | |
| Assnani Dental Chatbot — FastAPI Backend | |
| Serves the frontend, proxies YOLO API calls, runs the symptom | |
| analysis + correlation engine, and generates AI-powered reports. | |
| """ | |
| import os | |
| import io | |
| import asyncio | |
| import base64 | |
| import httpx | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, field_validator | |
| from typing import Optional, List | |
| from symptom_analyzer import SymptomAnalyzer | |
| from correlation_engine import CorrelationEngine | |
| from dental_expert_system import DentalTreatmentRecommender | |
| # --- Configuration --- | |
| YOLO_API_URL = os.environ.get( | |
| "YOLO_API_URL", | |
| "https://0xker-dental-x-ray-detection.hf.space/predict" | |
| ) | |
| TREAT_API_URL = os.environ.get( | |
| "TREAT_API_URL", | |
| "https://0xker-treat-recommend.hf.space/api/analyze" | |
| ) | |
| MAX_RETRIES = 3 | |
| RETRY_BACKOFF_BASE = 2 # seconds | |
| # --- App Setup --- | |
| app = FastAPI( | |
| title="Assnani Dental AI Chatbot", | |
| description="Symptom-to-X-ray Correlation Chatbot", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # --- Initialize Engines --- | |
| symptom_analyzer = SymptomAnalyzer() | |
| correlation_engine = CorrelationEngine() | |
| expert_system = DentalTreatmentRecommender(large_cavity_threshold=5000) | |
| # --- Retry Helper --- | |
| async def _request_with_retry(client, url, max_retries=MAX_RETRIES, **kwargs): | |
| """ | |
| Send a POST request with exponential backoff retry logic. | |
| Retries on timeout, connection, and read errors. | |
| """ | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| response = await client.post(url, **kwargs) | |
| return response | |
| except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e: | |
| last_error = e | |
| if attempt < max_retries - 1: | |
| wait = RETRY_BACKOFF_BASE ** attempt # 1s, 2s, 4s | |
| print(f"[Retry {attempt + 1}/{max_retries}] {url} — {e}. Waiting {wait}s...") | |
| await asyncio.sleep(wait) | |
| raise last_error | |
| # --- Helpers --- | |
| def extract_images_from_pdf(pdf_bytes: bytes) -> list: | |
| """ | |
| Extract all images from a PDF file. | |
| Returns a list of tuples: (filename, image_bytes, content_type) | |
| """ | |
| images = [] | |
| try: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| image_list = page.get_images(full=True) | |
| for img_idx, img_info in enumerate(image_list): | |
| xref = img_info[0] | |
| base_image = doc.extract_image(xref) | |
| if base_image: | |
| img_bytes = base_image["image"] | |
| ext = base_image.get("ext", "png") | |
| content_type = f"image/{ext}" if ext != "jpg" else "image/jpeg" | |
| filename = f"pdf_page{page_num + 1}_img{img_idx + 1}.{ext}" | |
| images.append((filename, img_bytes, content_type)) | |
| # If no embedded images found, render the page as an image | |
| if not image_list: | |
| pix = page.get_pixmap(dpi=200) | |
| img_bytes = pix.tobytes("png") | |
| filename = f"pdf_page{page_num + 1}.png" | |
| images.append((filename, img_bytes, "image/png")) | |
| doc.close() | |
| except Exception as e: | |
| print(f"PDF extraction error: {e}") | |
| return images | |
| async def process_uploaded_files(files: List[UploadFile]) -> list: | |
| """ | |
| Process uploaded files — extract images from PDFs and pass through image files. | |
| Returns list of tuples: (filename, image_bytes, content_type) | |
| """ | |
| all_images = [] | |
| for f in files: | |
| file_bytes = await f.read() | |
| content_type = f.content_type or "" | |
| if content_type == "application/pdf" or (f.filename and f.filename.lower().endswith(".pdf")): | |
| # Extract images from PDF | |
| pdf_images = extract_images_from_pdf(file_bytes) | |
| if pdf_images: | |
| all_images.extend(pdf_images) | |
| else: | |
| print(f"No images found in PDF: {f.filename}") | |
| elif content_type.startswith("image/") or (f.filename and f.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))): | |
| ct = content_type if content_type.startswith("image/") else "image/jpeg" | |
| all_images.append((f.filename or "image.jpg", file_bytes, ct)) | |
| else: | |
| print(f"Skipping unsupported file type: {f.filename} ({content_type})") | |
| return all_images | |
| # --- Pydantic Models --- | |
| class SymptomData(BaseModel): | |
| has_pain: bool = False | |
| pain_location: str = "" | |
| pain_type: str = "" | |
| pain_intensity: int = 0 | |
| pain_duration: str = "" | |
| pain_triggers: list = [] | |
| has_swelling: bool = False | |
| swelling_severity: str = "" | |
| has_fever: bool = False | |
| difficulty_opening: bool = False | |
| has_trauma: bool = False | |
| has_broken_tooth: bool = False | |
| previous_root_canal: bool = False | |
| last_visit: str = "" | |
| recent_extraction: bool = False | |
| def clamp_intensity(cls, v): | |
| """Ensure pain intensity is within 0-10 range.""" | |
| return max(0, min(10, v)) | |
| def sanitize_string(cls, v): | |
| """Strip whitespace and limit string length.""" | |
| if isinstance(v, str): | |
| return v.strip()[:200] | |
| return v | |
| class CorrelationRequest(BaseModel): | |
| symptoms: dict | |
| detections: list | |
| image_width: int = 0 | |
| image_height: int = 0 | |
| class TreatmentRequest(BaseModel): | |
| api_response: Optional[dict] = None | |
| detections: Optional[list] = None | |
| # --- Routes --- | |
| async def serve_index(): | |
| """Serve the main chatbot page.""" | |
| return FileResponse("static/index.html") | |
| async def health_check(): | |
| return {"status": "healthy", "service": "Assnani Dental Chatbot"} | |
| async def analyze_symptoms(data: SymptomData): | |
| """Analyze patient symptoms and return risk assessment.""" | |
| symptoms = data.model_dump() | |
| result = symptom_analyzer.analyze(symptoms) | |
| return JSONResponse(content=result) | |
| async def detect_xray(images: List[UploadFile] = File(...)): | |
| """ | |
| Receive X-ray image(s) or PDF uploads, forward each image to the YOLO API, | |
| and return combined detection results. | |
| """ | |
| try: | |
| all_images = await process_uploaded_files(images) | |
| if not all_images: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "No valid images found in the uploaded files."} | |
| ) | |
| all_results = [] | |
| annotated_images_b64 = [] | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| for filename, img_bytes, content_type in all_images: | |
| # Detection JSON — with retry | |
| files = {"image": (filename, img_bytes, content_type)} | |
| response = await _request_with_retry(client, YOLO_API_URL, files=files) | |
| if response.status_code != 200: | |
| all_results.append({ | |
| "filename": filename, | |
| "detections": [], | |
| "total_detections": 0, | |
| "error": f"YOLO API returned {response.status_code}" | |
| }) | |
| continue | |
| det_data = response.json() | |
| # The annotated image (result_image_b64) is already included | |
| # in the YOLO API JSON response — no second request needed. | |
| if det_data.get("results"): | |
| result_entry = det_data["results"][0] | |
| # Normalize key so frontend can always find it | |
| if result_entry.get("result_image_b64") and not result_entry.get("annotated_image_b64"): | |
| result_entry["annotated_image_b64"] = result_entry["result_image_b64"] | |
| all_results.append(result_entry) | |
| annotated_images_b64.append(result_entry.get("annotated_image_b64")) | |
| # Build combined response | |
| total_detections = sum(r.get("total_detections", len(r.get("detections", []))) for r in all_results) | |
| combined = { | |
| "results": all_results, | |
| "success": True, | |
| "total_images": len(all_results), | |
| "total_detections": total_detections, | |
| } | |
| return JSONResponse(content=combined) | |
| except httpx.TimeoutException: | |
| return JSONResponse(status_code=504, content={"error": "YOLO API timed out after retries. Please try again."}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": f"Error processing X-ray: {str(e)}"}) | |
| async def correlate_findings(data: CorrelationRequest): | |
| """Correlate patient symptoms with YOLO detection results.""" | |
| result = correlation_engine.correlate( | |
| symptoms=data.symptoms, | |
| detections=data.detections, | |
| image_width=data.image_width, | |
| image_height=data.image_height, | |
| ) | |
| return JSONResponse(content=result) | |
| async def get_treatment_plan(data: TreatmentRequest): | |
| """Generate treatment recommendations from YOLO detections.""" | |
| if data.api_response: | |
| result = expert_system.analyze_api_response(data.api_response) | |
| elif data.detections: | |
| result = expert_system.analyze_detections(data.detections) | |
| else: | |
| return JSONResponse(status_code=400, content={"error": "Provide 'api_response' or 'detections'."}) | |
| return JSONResponse(content=result) | |
| async def get_ai_report(images: List[UploadFile] = File(...)): | |
| """ | |
| Send X-ray image(s) / PDFs to the Gemini-powered treatment recommendation API. | |
| Returns an AI-generated clinical report. | |
| """ | |
| try: | |
| all_images = await process_uploaded_files(images) | |
| if not all_images: | |
| return JSONResponse(status_code=400, content={"error": "No valid images found."}) | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| # Send all images in one request — external API expects key "image" | |
| files_list = [("image", (fn, img_bytes, ct)) for fn, img_bytes, ct in all_images] | |
| data = {"model": "gemini-2.5-flash"} | |
| response = await _request_with_retry(client, TREAT_API_URL, files=files_list, data=data) | |
| if response.status_code != 200: | |
| return JSONResponse( | |
| status_code=502, | |
| content={"error": f"Treatment API returned {response.status_code}", "detail": response.text[:500]} | |
| ) | |
| return JSONResponse(content=response.json()) | |
| except httpx.TimeoutException: | |
| return JSONResponse(status_code=504, content={"error": "AI report timed out. Gemini may be loading — try again."}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": f"Error generating AI report: {str(e)}"}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |