File size: 4,844 Bytes
68709d3
 
bcbbf99
8b2b7ca
ddfb6ec
164d858
ddfb6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164d858
8b2b7ca
367cd24
 
 
bcbbf99
367cd24
8b2b7ca
 
 
bcbbf99
8b2b7ca
367cd24
8b2b7ca
bcbbf99
 
 
 
8b2b7ca
367cd24
bcbbf99
8b2b7ca
 
 
dea4987
 
367cd24
dea4987
 
 
 
 
 
663f17c
367cd24
663f17c
 
 
367cd24
 
 
7d1fe20
a13af90
 
 
 
ddfb6ec
 
 
367cd24
a13af90
 
ddfb6ec
 
 
 
a13af90
367cd24
a13af90
 
 
 
 
 
 
 
367cd24
a13af90
 
 
 
 
f96dd8f
 
 
 
 
a13af90
 
 
 
 
 
 
 
68709d3
367cd24
 
 
68709d3
 
 
 
ae28530
68709d3
 
 
663f17c
 
 
bcbbf99
68709d3
bcbbf99
68709d3
 
367cd24
ddfb6ec
 
 
 
 
 
 
367cd24
 
 
 
f7dd166
367cd24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
from ultralytics import YOLO
from PIL import Image, ImageOps, ImageEnhance
import numpy as np
import io, base64

# =========================================================
# Lazy-loaded global models (LOAD ONLY ON FIRST REQUEST)
# =========================================================
model_swelling = None
model_redness = None
model_bleeding = None

def get_models():
    """Load YOLO models only once (lazy loading)."""
    global model_swelling, model_redness, model_bleeding

    if model_swelling is None:
        model_swelling = YOLO("models/swelling/swelling_final.pt")

    if model_redness is None:
        model_redness = YOLO("models/redness/redness_final.pt")

    if model_bleeding is None:
        model_bleeding = YOLO("models/bleeding/bleeding_final.pt")

    return model_swelling, model_redness, model_bleeding


# =========================================================
# Helper functions
# =========================================================
def preprocess(image):
    """Resize, fix orientation, improve contrast."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    image = ImageOps.exif_transpose(image).convert("RGB")

    # Resize if too large
    w, h = image.size
    max_dim = max(w, h)
    if max_dim > 1024:
        scale = 1024 / max_dim
        image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

    # Slight contrast enhancement
    image = ImageEnhance.Contrast(image).enhance(1.05)

    return image


def np_to_base64(img_np, format="JPEG"):
    """Convert numpy RGB image to Base64."""
    pil_img = Image.fromarray(img_np)
    buffer = io.BytesIO()
    pil_img.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def base64_to_pil(b64_str):
    """Convert Base64 string to PIL image."""
    return Image.open(io.BytesIO(base64.b64decode(b64_str)))


# =========================================================
# Main detection function
# =========================================================
def detect_gingivitis(image, conf=0.25, iou=0.5):
    try:
        if image is None:
            return [None, None, None, "❌ No image uploaded"]

        # Load models (only once)
        sw_model, rd_model, bl_model = get_models()

        # Preprocess
        image = preprocess(image)

        # Run detections
        sw_res = sw_model.predict(image, conf=conf, iou=iou)
        rd_res = rd_model.predict(image, conf=conf, iou=iou)
        bl_res = bl_model.predict(image, conf=conf, iou=iou)

        # Convert YOLO output β†’ numpy β†’ PIL
        img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
        img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
        img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]

        sw_pil = base64_to_pil(np_to_base64(img_sw))
        rd_pil = base64_to_pil(np_to_base64(img_rd))
        bl_pil = base64_to_pil(np_to_base64(img_bl))

        # Diagnosis logic
        has_sw = len(sw_res[0].boxes) > 0
        has_rd = len(rd_res[0].boxes) > 0
        has_bl = len(bl_res[0].boxes) > 0

        if has_sw and has_rd and has_bl:
            diagnosis = (
                "🦷 You have gingivitis.\n\n"
                "For accurate assessment and guidance, we recommend visiting your dentist.\n\n"
                "If you have a periapical X-ray, you may try the Detect Periodontitis tool."
            )
        else:
            diagnosis = "🟒 You don't have gingivitis."

        return [sw_pil, rd_pil, bl_pil, diagnosis]

    except Exception as e:
        return [None, None, None, f"❌ Error during processing: {str(e)}"]


# =========================================================
# Gradio Interface
# =========================================================
interface = gr.Interface(
    fn=detect_gingivitis,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(0, 1, value=0.5, step=0.05, label="Confidence Threshold"),
        gr.Slider(0, 1, value=0.5, step=0.05, label="NMS IoU Threshold"),
    ],
    outputs=[
        gr.Image(label="Swelling Detection", type="pil"),
        gr.Image(label="Redness Detection", type="pil"),
        gr.Image(label="Bleeding Detection", type="pil"),
        gr.Textbox(label="Diagnosis")
    ],
    title="Gingivitis Detection"
)


# =========================================================
# Warm-start: preload models on startup
# =========================================================
print("πŸ”₯ Preloading models to reduce Render cold start...")
get_models()
print("βœ… Gingivitis models ready")


# =========================================================
# Start server
# =========================================================
if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)