mahmoudelsheemy's picture
Add FastAPI teeth detection API with advanced recommendations
e544e72
from fastapi import FastAPI, File, UploadFile, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn
import numpy as np
from io import BytesIO
from PIL import Image
import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input # type: ignore
import pyheif
from transformers import pipeline
import torch
import uuid
import json
import os
from pathlib import Path
from datetime import datetime
import gdown
import requests
import time
# ============================================================
# CONFIGURATION
# ============================================================
IMAGE_SIZE = 224
BINARY_MODEL_PATH = "./model_2_fixed.h5"
DISEASE_MODEL_PATH = "./LAST_model_fixed_v2.keras"
HF_TEETH_HEALTH_MODEL = "steven123/Check_GoodBad_Teeth"
DEVICE = 0 if torch.cuda.is_available() else -1
BINARY_CLASSES = ["not_teath", "teath"]
DISEASE_CLASSES = ["Calculus", "Data caries", "Gingivitis", "Mouth Ulcer", "Tooth Discoloration", "hypodontia"]
BASE_DIR = Path(__file__).parent
KNOWLEDGE_BASE_PATH = BASE_DIR / "knowledge_base" / "clinical_rules.json"
# ============================================================
# LOAD KNOWLEDGE BASE
# ============================================================
def load_knowledge_base():
"""تحميل قاعدة المعرفة من ملف JSON"""
try:
with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
print(f"⚠️ Warning: Knowledge base file not found at {KNOWLEDGE_BASE_PATH}")
return {"diseases": {}, "general_rules": {}}
knowledge_base_data = load_knowledge_base()
diseases_db = knowledge_base_data.get("diseases", {})
general_rules = knowledge_base_data.get("general_rules", {})
# ============================================================
# FASTAPI INIT
# ============================================================
app = FastAPI(title="Teeth Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# ============================================================
# LOAD MODELS
# ============================================================
BINARY_MODEL = None
DISEASE_MODEL = None
TEETH_HEALTH_MODEL = None
def load_keras_model(path: str):
if os.path.exists(path):
try:
model = tf.keras.models.load_model(path, compile=False)
print(f"[SUCCESS] Model loaded from {path}")
return model
except Exception as e:
print(f"[ERROR] Failed to load model {path}: {e}")
else:
print(f"[ERROR] Model not found at {path}")
return None
print("\n[INFO] Loading models...")
BINARY_MODEL = load_keras_model(BINARY_MODEL_PATH)
DISEASE_MODEL = load_keras_model(DISEASE_MODEL_PATH)
try:
TEETH_HEALTH_MODEL = pipeline(
"image-classification",
model=HF_TEETH_HEALTH_MODEL,
device=DEVICE,
)
print("[SUCCESS] HuggingFace model loaded")
except Exception as e:
print(f"[ERROR] HuggingFace model failed: {e}")
# ============================================================
# IMAGE PROCESSING
# ============================================================
def load_image(image_bytes: bytes) -> Image.Image:
"""
Load any image and convert to RGB.
Supports HEIC/HEIF and standard formats (JPEG, PNG, etc.).
"""
try:
heif_file = pyheif.read_heif(image_bytes)
image = Image.frombytes(
heif_file.mode,
heif_file.size,
heif_file.data,
"raw",
heif_file.mode,
heif_file.stride
)
return image.convert("RGB")
except Exception:
try:
return Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=422, detail=f"Invalid or corrupted image: {str(e)}")
def preprocess_for_binary(image_bytes: bytes) -> np.ndarray:
image = load_image(image_bytes)
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
image = np.array(image).astype(np.float32)
return image
def preprocess_for_disease(image_bytes: bytes) -> np.ndarray:
image = load_image(image_bytes)
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
image = np.array(image).astype(np.float32)
image = preprocess_input(image)
return image
def assess_urgency(result):
return {
"level": result.get("urgency_level", "low"),
"message": result.get("urgency_message", "")
}
def combine_advice(result):
combined = []
seen = set()
for advice in result.get("personalized_home_care", []):
if advice not in seen:
combined.append(advice)
seen.add(advice)
return combined[:4]
#===========================================================
# RECOMMENDATION ENGINE
# ============================================================
def add_unique_advice(advice_list, target_list):
"""
Adds advice to target_list if not already present (prevents exact duplicates)
"""
if not advice_list:
return
for advice in advice_list:
if advice and advice not in target_list:
target_list.append(advice)
#============================================================'
# Advanced weighted recommendation
#============================================================
def get_weighted_recommendations(top_predictions, age: int, pain_level: int, bleeding: bool):
result = {
"timestamp": datetime.now().isoformat(),
"primary_condition": None,
"overall_risk_score": 0.0,
"risk_category": "Early Stage",
"clinical_overview": [],
"priority_treatment_plan": [],
"supportive_treatments": [],
"personalized_home_care": {"essential": [], "recommended": [], "avoid": []},
"follow_up_recommendation": [],
"requires_dentist": False,
"urgency_level": "low",
"urgency_message": ""
}
if not top_predictions:
return result
severity_scale = {"high": 3, "medium": 2, "moderate":2, "mild": 1, "low": 1,"structural": 2}
urgency_scale = {"high": 3, "medium": 2, "low": 1}
confidence_rules = general_rules.get("confidence_weighting", {})
filtered_predictions = [p for p in top_predictions if p.get("confidence", 0) > 0.05]
if not filtered_predictions:
return result
total_conf = sum(p["confidence"] for p in filtered_predictions)
total_risk_score = 0
detected_conditions = []
for pred in filtered_predictions:
disease = pred["class"]
confidence = pred["confidence"]
weight = confidence / total_conf if total_conf > 0 else 0
if disease not in diseases_db:
continue
detected_conditions.append(disease)
disease_info = diseases_db[disease]
base = disease_info.get("base_info", {})
treatments = disease_info.get("treatment_options", {}).get("primary", [])
home_advice = disease_info.get("home_advice", {})
# 🔹 Severity & urgency
severity = base.get("severity", "low")
urgency = base.get("urgency", "low")
severity_value = severity_scale.get(severity, 1)
urgency_value = urgency_scale.get(urgency, 1)
# --- Symptom-based adjustment (Improved Clinical Logic) ---
bleeding_factor = 1 if bleeding else 0
disease_category = base.get("category", "")
# Option 1: Direct multiplier (pain_level 0-10)
if disease_category in ["tooth_decay", "inflammatory"]:
severity_value += pain_level * 0.3
urgency_value += pain_level * 0.3
elif disease_category in ["soft_tissue", "mineral_deposit","developmental","aesthetic"]:
severity_value += pain_level * 0.1
urgency_value += pain_level * 0.1
# Bleeding impact
if disease_category in ["inflammatory", "tooth_decay"]:
urgency_value += bleeding_factor * 1.5
elif disease_category in ["soft_tissue", "mineral_deposit","developmental","aesthetic"]:
urgency_value += bleeding_factor * 0.4
# Age sensitivity
if age < 12 or age > 65:
urgency_value += 0.5
# حساب عامل الثقة بناءً على confidence
if confidence >= 0.8:
confidence_factor = 1.0
elif confidence >= 0.5:
confidence_factor = confidence_rules.get("medium", 0.5)
else:
confidence_factor = confidence_rules.get("low", 0.2)
disease_risk = ((severity_value * 0.6 + urgency_value * 0.4) * weight * confidence_factor)
total_risk_score += disease_risk
# Treatment level
if severity == "high":
treatment_level = "aggressive"
elif severity in ["medium", "structural"]:
treatment_level = "moderate"
else:
treatment_level = "conservative"
# Clinical overview
result["clinical_overview"].append({
"condition": disease,
"confidence_percent": round(confidence * 100, 2),
"impact_weight": round(weight, 3),
"severity": severity,
"urgency": urgency,
"treatment_level": treatment_level
})
# Treatment plans
if treatment_level == "aggressive":
for t in treatments:
add_unique_advice([t], result["priority_treatment_plan"])
elif treatment_level == "moderate":
for t in treatments[:1]:
add_unique_advice([t], result["supportive_treatments"])
# --- Home care advice (Compact & Professional) ---
essential_advice = home_advice.get("essential", [])[:2]
recommended_advice = home_advice.get("recommended", [])[:2]
avoid_advice = home_advice.get("avoid", [])[:2]
add_unique_advice(essential_advice, result["personalized_home_care"]["essential"])
add_unique_advice(recommended_advice, result["personalized_home_care"]["recommended"])
add_unique_advice(avoid_advice, result["personalized_home_care"]["avoid"])
# Build-up advice
build_up = disease_info.get("build_up_recommendation", {})
if build_up.get("applicable", False):
conditions = build_up.get("conditions", {})
for cond in conditions.values():
if confidence >= cond.get("confidence_threshold", 0.3):
materials = ", ".join(cond.get("materials", []))
reason = cond.get("reason", "")
advice_text = f"Consider build-up using {materials} ({reason})"
add_unique_advice([advice_text], result["personalized_home_care"]["recommended"])
# Dentist requirement
if base.get("requires_dentist", False):
result["requires_dentist"] = True
# Follow up
follow_up = disease_info.get("follow_up")
if follow_up:
add_unique_advice([follow_up], result["follow_up_recommendation"])
# Normalize risk
normalized_risk = min(total_risk_score, 5)
result["overall_risk_score"] = round(normalized_risk, 2)
# Risk Category
if normalized_risk >= 4:
result["risk_category"] = "Critical"
elif normalized_risk >= 3:
result["risk_category"] = "Advanced"
elif normalized_risk >= 2:
result["risk_category"] = "Progressive"
else:
result["risk_category"] = "Early Stage"
# Urgency
if normalized_risk >= 3.5:
result["urgency_level"] = "high"
result["urgency_message"] = "Immediate dental consultation required (within 24-48 hours)."
elif normalized_risk >= 2.0:
result["urgency_level"] = "medium"
result["urgency_message"] = "Dental appointment recommended within 1-4 weeks."
else:
result["urgency_level"] = "low"
result["urgency_message"] = "Maintain oral hygiene and monitor symptoms."
# Multi-disease interaction
if "Calculus" in detected_conditions and "Gingivitis" in detected_conditions:
result["clinical_overview"].append({
"condition": "Clinical Interaction",
"note": "Dental calculus may be contributing to gingival inflammation.",
"impact_weight": 0
})
# Sort overview
result["clinical_overview"] = sorted(result["clinical_overview"], key=lambda x: x.get("impact_weight", 0), reverse=True)
# Primary condition
if result["clinical_overview"]:
result["primary_condition"] = result["clinical_overview"][0]["condition"]
return result
# ============================================================
# PREDICTION FUNCTIONS
# ============================================================
def predict_teeth(image: np.ndarray, threshold: float = 0.5) -> dict:
image = np.expand_dims(image, axis=0)
score = BINARY_MODEL.predict(image, verbose=0)[0][0]
is_teeth = score >= threshold
confidence = score if is_teeth else 1 - score
return {
"is_teeth": bool(is_teeth),
"class": BINARY_CLASSES[1] if is_teeth else BINARY_CLASSES[0],
"confidence": float(confidence),
"raw_score": float(score),
"threshold": threshold
}
def predict_teeth_health(image_bytes):
if TEETH_HEALTH_MODEL is None:
raise HTTPException(status_code=503, detail="Health model not loaded")
img = load_image(image_bytes)
outputs = TEETH_HEALTH_MODEL(img)
top = outputs[0]
return {
"predicted_class": top["label"],
"confidence": float(top["score"]),
"all_predictions": outputs
}
def predict_disease(image: np.ndarray) -> dict:
image = np.expand_dims(image, axis=0)
predictions = DISEASE_MODEL.predict(image, verbose=0)[0]
top_index = np.argmax(predictions)
confidence = predictions[top_index]
top_predictions = sorted(
[
{
"class": DISEASE_CLASSES[i],
"confidence": float(predictions[i])
}
for i in range(len(DISEASE_CLASSES))
],
key=lambda x: x["confidence"],
reverse=True
)[:3]
return {
"predicted_class": DISEASE_CLASSES[top_index],
"confidence": float(confidence),
"top_predictions": top_predictions
}
# ============================================================
# MAIN PIPELINE
# ============================================================
def teeth_diagnosis_pipeline(image_bytes: bytes, threshold: float = 0.5) -> dict:
# 1️⃣ Binary detection
binary_image = preprocess_for_binary(image_bytes)
binary_result = predict_teeth(binary_image, threshold)
if not binary_result["is_teeth"]:
return {
"status": "rejected",
"binary_result": binary_result,
"message": "Image does not contain teeth"
}
# 2️⃣ Teeth Health
health_result = predict_teeth_health(image_bytes)
label = str(health_result.get("predicted_class", "")).lower()
confidence = health_result.get("confidence", 0)
if label == "good teeth" and confidence >= 0.84:
disease_result = {
"message": "Teeth are healthy and free of diseases",
"predicted_class": None,
"top_predictions": []
}
else:
disease_image = preprocess_for_disease(image_bytes)
disease_result = predict_disease(disease_image)
return {
"status": "success",
"binary_result": binary_result,
"teeth_health_result": health_result,
"disease_result": disease_result
}
# ============================================================
# ENDPOINTS
# ============================================================
@app.get("/")
def root():
return {
"system": "Integrated Teeth Detection & Diagnosis API",
"pipeline": [
"Teeth Detection",
"Teeth Health Classification",
"Disease Classification"
]
}
@app.post("/predict")
async def predict(file: UploadFile = File(...), threshold: float = 0.5):
image_bytes = await file.read()
result = teeth_diagnosis_pipeline(image_bytes, threshold)
result["filename"] = file.filename
return result
@app.post("/detect-teeth")
async def detect_teeth(
file: UploadFile = File(...),
):
"""
Detect whether the image contains teeth or not
"""
request_id = str(uuid.uuid4())
try:
image_bytes = await file.read()
image = preprocess_for_binary(image_bytes)
binary_result = predict_teeth(image)
return {
"status": "success",
"request_id": request_id,
"filename": file.filename,
"is_teeth": binary_result["is_teeth"],
"predicted_class": binary_result["class"],
"confidence": binary_result["confidence"],
"raw_score": binary_result["raw_score"],
}
except Exception:
raise HTTPException(
status_code=500,
detail=f"Internal server error | request_id: {request_id}"
)
@app.post("/check-teeth-health")
async def check_teeth_health(file: UploadFile = File(...)):
image_bytes = await file.read()
return predict_teeth_health(image_bytes)
@app.post("/advanced-recommendations")
async def advanced_recommendations(
file: UploadFile = File(...),
threshold: float = Query(0.5, ge=0.0, le=1.0),
age: int = Query(18, ge=0, le=120),
pain_level: int = Query(0, ge=0, le=10),
bleeding: bool = False,
):
request_id = str(uuid.uuid4())
try:
image_bytes = await file.read()
diagnosis = teeth_diagnosis_pipeline(image_bytes, threshold)
if diagnosis.get("status") != "success":
raise HTTPException(status_code=422, detail="Diagnosis failed.")
top_predictions = diagnosis["disease_result"].get("top_predictions", [])
if not top_predictions:
return JSONResponse(
status_code=200,
content={
"status": "no_disease_detected",
"request_id": request_id,
"summary": {
"primary_condition": None,
"confidence": 0,
"confidence_level": "none",
"overall_risk_score": 0,
"risk_category": "Low",
"urgency_level": "low",
"requires_dentist": False
},
"general_advice": [
"Continue regular dental checkups",
"Maintain good oral hygiene"
]
}
)
advanced_recs = get_weighted_recommendations(
top_predictions, age=age, pain_level=pain_level, bleeding=bleeding
)
primary_conf = diagnosis["disease_result"]["confidence"]
if primary_conf >= 0.9:
confidence_level = "very_high"
elif primary_conf >= 0.7:
confidence_level = "high"
elif primary_conf >= 0.5:
confidence_level = "medium"
else:
confidence_level = "low"
return {
"status": "success",
"request_id": request_id,
"filename": file.filename,
"summary": {
"primary_condition": diagnosis["disease_result"]["predicted_class"],
"confidence": primary_conf,
"confidence_level": confidence_level,
"overall_risk_score": advanced_recs.get("overall_risk_score"),
"risk_category": advanced_recs.get("risk_category"),
"urgency_level": advanced_recs.get("urgency_level"),
"requires_dentist": advanced_recs.get("requires_dentist"),
"show_emergency_banner": advanced_recs.get("urgency_level") == "high"
},
"diagnosis": {"top_predictions": top_predictions},
"recommendations": {
"clinical_overview": advanced_recs.get("clinical_overview"),
"priority_treatment": advanced_recs.get("priority_treatment_plan"),
"supportive_treatment": advanced_recs.get("supportive_treatments"),
"home_care": advanced_recs.get("personalized_home_care"),
"follow_up": advanced_recs.get("follow_up_recommendation"),
"urgency_message": advanced_recs.get("urgency_message")
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error | request_id: {request_id}")
# ============================================================
# SERVER START
# ============================================================
if __name__ == "__main__":
print("=" * 70)
print("🚀 Teeth Detection API")
print("📝 Server starting at http://localhost:7860")
print("📚 API Docs: http://localhost:7860/docs")
print("=" * 70)
uvicorn.run(app, host="0.0.0.0", port=7860)