Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,28 +27,31 @@ def get_models():
|
|
| 27 |
return model_swelling, model_redness, model_bleeding
|
| 28 |
|
| 29 |
|
| 30 |
-
#
|
|
|
|
|
|
|
| 31 |
def preprocess(image):
|
|
|
|
| 32 |
if isinstance(image, np.ndarray):
|
| 33 |
image = Image.fromarray(image)
|
| 34 |
|
| 35 |
image = ImageOps.exif_transpose(image).convert("RGB")
|
| 36 |
|
| 37 |
-
# Resize if
|
| 38 |
w, h = image.size
|
| 39 |
max_dim = max(w, h)
|
| 40 |
if max_dim > 1024:
|
| 41 |
scale = 1024 / max_dim
|
| 42 |
image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
|
| 43 |
|
| 44 |
-
#
|
| 45 |
image = ImageEnhance.Contrast(image).enhance(1.05)
|
| 46 |
|
| 47 |
return image
|
| 48 |
|
| 49 |
|
| 50 |
def np_to_base64(img_np, format="JPEG"):
|
| 51 |
-
"""Convert numpy RGB image to Base64
|
| 52 |
pil_img = Image.fromarray(img_np)
|
| 53 |
buffer = io.BytesIO()
|
| 54 |
pil_img.save(buffer, format=format)
|
|
@@ -56,11 +59,13 @@ def np_to_base64(img_np, format="JPEG"):
|
|
| 56 |
|
| 57 |
|
| 58 |
def base64_to_pil(b64_str):
|
| 59 |
-
"""Convert Base64 string
|
| 60 |
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
|
| 61 |
|
| 62 |
|
| 63 |
-
#
|
|
|
|
|
|
|
| 64 |
def detect_gingivitis(image, conf=0.25, iou=0.5):
|
| 65 |
try:
|
| 66 |
if image is None:
|
|
@@ -69,7 +74,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
|
|
| 69 |
# Load models (only once)
|
| 70 |
sw_model, rd_model, bl_model = get_models()
|
| 71 |
|
| 72 |
-
# Preprocess
|
| 73 |
image = preprocess(image)
|
| 74 |
|
| 75 |
# Run detections
|
|
@@ -77,7 +82,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
|
|
| 77 |
rd_res = rd_model.predict(image, conf=conf, iou=iou)
|
| 78 |
bl_res = bl_model.predict(image, conf=conf, iou=iou)
|
| 79 |
|
| 80 |
-
# Convert
|
| 81 |
img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
|
| 82 |
img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
|
| 83 |
img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]
|
|
@@ -86,7 +91,7 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
|
|
| 86 |
rd_pil = base64_to_pil(np_to_base64(img_rd))
|
| 87 |
bl_pil = base64_to_pil(np_to_base64(img_bl))
|
| 88 |
|
| 89 |
-
#
|
| 90 |
has_sw = len(sw_res[0].boxes) > 0
|
| 91 |
has_rd = len(rd_res[0].boxes) > 0
|
| 92 |
has_bl = len(bl_res[0].boxes) > 0
|
|
@@ -103,12 +108,12 @@ def detect_gingivitis(image, conf=0.25, iou=0.5):
|
|
| 103 |
return [sw_pil, rd_pil, bl_pil, diagnosis]
|
| 104 |
|
| 105 |
except Exception as e:
|
| 106 |
-
# Catch all errors and return a friendly message
|
| 107 |
return [None, None, None, f"❌ Error during processing: {str(e)}"]
|
| 108 |
|
| 109 |
|
| 110 |
-
|
| 111 |
-
#
|
|
|
|
| 112 |
interface = gr.Interface(
|
| 113 |
fn=detect_gingivitis,
|
| 114 |
inputs=[
|
|
@@ -125,6 +130,7 @@ interface = gr.Interface(
|
|
| 125 |
title="Gingivitis Detection"
|
| 126 |
)
|
| 127 |
|
|
|
|
| 128 |
# =========================================================
|
| 129 |
# Warm-start: preload models on startup
|
| 130 |
# =========================================================
|
|
@@ -132,6 +138,9 @@ print("🔥 Preloading models to reduce Render cold start...")
|
|
| 132 |
get_models()
|
| 133 |
print("✅ Gingivitis models ready")
|
| 134 |
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
| 136 |
if __name__ == "__main__":
|
| 137 |
-
interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
|
|
|
| 27 |
return model_swelling, model_redness, model_bleeding
|
| 28 |
|
| 29 |
|
| 30 |
+
# =========================================================
|
| 31 |
+
# Helper functions
|
| 32 |
+
# =========================================================
|
| 33 |
def preprocess(image):
|
| 34 |
+
"""Resize, fix orientation, improve contrast."""
|
| 35 |
if isinstance(image, np.ndarray):
|
| 36 |
image = Image.fromarray(image)
|
| 37 |
|
| 38 |
image = ImageOps.exif_transpose(image).convert("RGB")
|
| 39 |
|
| 40 |
+
# Resize if too large
|
| 41 |
w, h = image.size
|
| 42 |
max_dim = max(w, h)
|
| 43 |
if max_dim > 1024:
|
| 44 |
scale = 1024 / max_dim
|
| 45 |
image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
|
| 46 |
|
| 47 |
+
# Slight contrast enhancement
|
| 48 |
image = ImageEnhance.Contrast(image).enhance(1.05)
|
| 49 |
|
| 50 |
return image
|
| 51 |
|
| 52 |
|
| 53 |
def np_to_base64(img_np, format="JPEG"):
|
| 54 |
+
"""Convert numpy RGB image to Base64."""
|
| 55 |
pil_img = Image.fromarray(img_np)
|
| 56 |
buffer = io.BytesIO()
|
| 57 |
pil_img.save(buffer, format=format)
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def base64_to_pil(b64_str):
|
| 62 |
+
"""Convert Base64 string to PIL image."""
|
| 63 |
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
|
| 64 |
|
| 65 |
|
| 66 |
+
# =========================================================
|
| 67 |
+
# Main detection function
|
| 68 |
+
# =========================================================
|
| 69 |
def detect_gingivitis(image, conf=0.25, iou=0.5):
|
| 70 |
try:
|
| 71 |
if image is None:
|
|
|
|
| 74 |
# Load models (only once)
|
| 75 |
sw_model, rd_model, bl_model = get_models()
|
| 76 |
|
| 77 |
+
# Preprocess
|
| 78 |
image = preprocess(image)
|
| 79 |
|
| 80 |
# Run detections
|
|
|
|
| 82 |
rd_res = rd_model.predict(image, conf=conf, iou=iou)
|
| 83 |
bl_res = bl_model.predict(image, conf=conf, iou=iou)
|
| 84 |
|
| 85 |
+
# Convert YOLO output → numpy → PIL
|
| 86 |
img_sw = sw_res[0].plot(labels=False)[:, :, ::-1]
|
| 87 |
img_rd = rd_res[0].plot(labels=False)[:, :, ::-1]
|
| 88 |
img_bl = bl_res[0].plot(labels=False)[:, :, ::-1]
|
|
|
|
| 91 |
rd_pil = base64_to_pil(np_to_base64(img_rd))
|
| 92 |
bl_pil = base64_to_pil(np_to_base64(img_bl))
|
| 93 |
|
| 94 |
+
# Diagnosis logic
|
| 95 |
has_sw = len(sw_res[0].boxes) > 0
|
| 96 |
has_rd = len(rd_res[0].boxes) > 0
|
| 97 |
has_bl = len(bl_res[0].boxes) > 0
|
|
|
|
| 108 |
return [sw_pil, rd_pil, bl_pil, diagnosis]
|
| 109 |
|
| 110 |
except Exception as e:
|
|
|
|
| 111 |
return [None, None, None, f"❌ Error during processing: {str(e)}"]
|
| 112 |
|
| 113 |
|
| 114 |
+
# =========================================================
|
| 115 |
+
# Gradio Interface
|
| 116 |
+
# =========================================================
|
| 117 |
interface = gr.Interface(
|
| 118 |
fn=detect_gingivitis,
|
| 119 |
inputs=[
|
|
|
|
| 130 |
title="Gingivitis Detection"
|
| 131 |
)
|
| 132 |
|
| 133 |
+
|
| 134 |
# =========================================================
|
| 135 |
# Warm-start: preload models on startup
|
| 136 |
# =========================================================
|
|
|
|
| 138 |
get_models()
|
| 139 |
print("✅ Gingivitis models ready")
|
| 140 |
|
| 141 |
+
|
| 142 |
+
# =========================================================
|
| 143 |
+
# Start server
|
| 144 |
+
# =========================================================
|
| 145 |
if __name__ == "__main__":
|
| 146 |
+
interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|