Spaces:
Sleeping
Sleeping
File size: 4,844 Bytes
68709d3 bcbbf99 8b2b7ca ddfb6ec 164d858 ddfb6ec 164d858 8b2b7ca 367cd24 bcbbf99 367cd24 8b2b7ca bcbbf99 8b2b7ca 367cd24 8b2b7ca bcbbf99 8b2b7ca 367cd24 bcbbf99 8b2b7ca dea4987 367cd24 dea4987 663f17c 367cd24 663f17c 367cd24 7d1fe20 a13af90 ddfb6ec 367cd24 a13af90 ddfb6ec a13af90 367cd24 a13af90 367cd24 a13af90 f96dd8f a13af90 68709d3 367cd24 68709d3 ae28530 68709d3 663f17c bcbbf99 68709d3 bcbbf99 68709d3 367cd24 ddfb6ec 367cd24 f7dd166 367cd24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import gradio as gr
from ultralytics import YOLO
from PIL import Image, ImageOps, ImageEnhance
import numpy as np
import io, base64
# =========================================================
# Lazy-loaded global models (LOAD ONLY ON FIRST REQUEST)
# =========================================================
model_swelling = None
model_redness = None
model_bleeding = None
def get_models():
"""Load YOLO models only once (lazy loading)."""
global model_swelling, model_redness, model_bleeding
if model_swelling is None:
model_swelling = YOLO("models/swelling/swelling_final.pt")
if model_redness is None:
model_redness = YOLO("models/redness/redness_final.pt")
if model_bleeding is None:
model_bleeding = YOLO("models/bleeding/bleeding_final.pt")
return model_swelling, model_redness, model_bleeding
# =========================================================
# Helper functions
# =========================================================
def preprocess(image):
"""Resize, fix orientation, improve contrast."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = ImageOps.exif_transpose(image).convert("RGB")
# Resize if too large
w, h = image.size
max_dim = max(w, h)
if max_dim > 1024:
scale = 1024 / max_dim
image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
# Slight contrast enhancement
image = ImageEnhance.Contrast(image).enhance(1.05)
return image
def np_to_base64(img_np, format="JPEG"):
"""Convert numpy RGB image to Base64."""
pil_img = Image.fromarray(img_np)
buffer = io.BytesIO()
pil_img.save(buffer, format=format)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def base64_to_pil(b64_str):
"""Convert Base64 string to PIL image."""
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
# =========================================================
# Main detection function
# =========================================================
def detect_gingivitis(image, conf=0.25, iou=0.5):
try:
if image is None:
return [None, None, None, "β No image uploaded"]
# Load models (only once)
sw_model, rd_model, bl_model = get_models()
# Preprocess
image = preprocess(image)
# Run detections
sw_res = sw_model.predict(image, conf=conf, iou=iou)
rd_res = rd_model.predict(image, conf=conf, iou=iou)
bl_res = bl_model.predict(image, conf=conf, iou=iou)
# Convert YOLO output β numpy β PIL
img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]
sw_pil = base64_to_pil(np_to_base64(img_sw))
rd_pil = base64_to_pil(np_to_base64(img_rd))
bl_pil = base64_to_pil(np_to_base64(img_bl))
# Diagnosis logic
has_sw = len(sw_res[0].boxes) > 0
has_rd = len(rd_res[0].boxes) > 0
has_bl = len(bl_res[0].boxes) > 0
if has_sw and has_rd and has_bl:
diagnosis = (
"π¦· You have gingivitis.\n\n"
"For accurate assessment and guidance, we recommend visiting your dentist.\n\n"
"If you have a periapical X-ray, you may try the Detect Periodontitis tool."
)
else:
diagnosis = "π’ You don't have gingivitis."
return [sw_pil, rd_pil, bl_pil, diagnosis]
except Exception as e:
return [None, None, None, f"β Error during processing: {str(e)}"]
# =========================================================
# Gradio Interface
# =========================================================
interface = gr.Interface(
fn=detect_gingivitis,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(0, 1, value=0.5, step=0.05, label="Confidence Threshold"),
gr.Slider(0, 1, value=0.5, step=0.05, label="NMS IoU Threshold"),
],
outputs=[
gr.Image(label="Swelling Detection", type="pil"),
gr.Image(label="Redness Detection", type="pil"),
gr.Image(label="Bleeding Detection", type="pil"),
gr.Textbox(label="Diagnosis")
],
title="Gingivitis Detection"
)
# =========================================================
# Warm-start: preload models on startup
# =========================================================
print("π₯ Preloading models to reduce Render cold start...")
get_models()
print("β
Gingivitis models ready")
# =========================================================
# Start server
# =========================================================
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True) |