jayn95 commited on
Commit
071f625
·
verified ·
1 Parent(s): 8f7e84e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -61
app.py CHANGED
@@ -1,84 +1,68 @@
 
1
  import gradio as gr
2
- import numpy as np
3
  import cv2
4
- from ultralytics import YOLO
5
- from tensorflow.keras.models import load_model
6
- from PIL import Image, ImageDraw
7
 
8
  # ==========================
9
- # 1️⃣ Load models
10
  # ==========================
11
- yolo_model = YOLO("yolov8n-seg.pt") # tooth detection
12
- unet_model = load_model("unet.keras", compile=False) # CEJ/ABC segmentation
13
-
14
- # ==========================
15
- # 2️⃣ Helper functions
16
- # ==========================
17
- def preprocess_image(image, target_size=(256, 256)):
18
- """Resize and normalize for UNet."""
19
- img = np.array(image.resize(target_size)) / 255.0
20
- if img.ndim == 2:
21
- img = np.expand_dims(img, axis=-1)
22
- return np.expand_dims(img, axis=0)
23
-
24
- def postprocess_mask(mask, original_size):
25
- """Resize UNet output back to original cropped size."""
26
- mask = (mask[0] > 0.5).astype(np.uint8) * 255
27
- return cv2.resize(mask, original_size[::-1])
28
-
29
- def overlay_mask_on_image(image, mask):
30
- """Blend segmentation mask with tooth image."""
31
- color_mask = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
32
- blended = cv2.addWeighted(np.array(image), 0.7, color_mask, 0.3, 0)
33
- return Image.fromarray(blended)
34
 
35
  # ==========================
36
- # 3️⃣ Main pipeline
37
  # ==========================
38
- def detect_and_segment(xray_image):
39
- # Convert PIL → OpenCV
40
- img_cv = cv2.cvtColor(np.array(xray_image), cv2.COLOR_RGB2BGR)
41
-
42
- # Run YOLOv8 for tooth detection
43
- results = yolo_model.predict(img_cv)
44
- boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
45
-
46
- annotated_image = np.array(xray_image).copy()
47
 
48
- segmented_teeth = []
 
49
 
50
- for i, (x1, y1, x2, y2) in enumerate(boxes):
51
- # Crop individual tooth
52
- crop = xray_image.crop((x1, y1, x2, y2))
 
53
 
54
- # UNet segmentation
55
- preprocessed = preprocess_image(crop)
56
- mask = unet_model.predict(preprocessed)
57
- seg_mask = postprocess_mask(mask, crop.size)
 
 
 
 
 
 
58
 
59
- # Overlay segmentation mask
60
- overlay = overlay_mask_on_image(crop, seg_mask)
61
- segmented_teeth.append(overlay)
62
 
63
- # Draw bounding box on main image
64
- draw = ImageDraw.Draw(Image.fromarray(annotated_image))
65
- draw.rectangle([x1, y1, x2, y2], outline="lime", width=2)
66
- draw.text((x1, y1 - 10), f"Tooth {i+1}", fill="lime")
67
 
68
- return Image.fromarray(annotated_image), segmented_teeth
69
 
70
  # ==========================
71
- # 4️⃣ Gradio Interface
72
  # ==========================
73
  demo = gr.Interface(
74
- fn=detect_and_segment,
75
- inputs=gr.Image(type="pil", label="Upload Dental X-Ray"),
76
  outputs=[
77
- gr.Image(label="Detected Teeth (YOLOv8)"),
78
- gr.Gallery(label="Segmented Teeth (UNet)", show_label=True)
 
 
79
  ],
80
- title="Periodontitis Detection YOLOv8 + UNet",
81
- description="Stage 1: Tooth detection via YOLOv8. Stage 2: CEJ & ABC segmentation via UNet."
 
 
 
82
  )
83
 
84
  if __name__ == "__main__":
 
1
+ # app.py
2
  import gradio as gr
 
3
  import cv2
4
+ from periodontitis_detection import SimpleDentalSegmentationNoEnhance
 
 
5
 
6
  # ==========================
7
+ # 1️⃣ Load models once
8
  # ==========================
9
+ model = SimpleDentalSegmentationNoEnhance(
10
+ unet_model_path="unet(10_22_25).keras", # same filenames as your repo
11
+ yolo_model_path="yolov8n-seg.pt"
12
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # ==========================
15
+ # 2️⃣ Define wrapper for Gradio
16
  # ==========================
17
+ def detect_periodontitis(image_np):
18
+ """
19
+ Gradio sends image as a NumPy RGB array.
20
+ We temporarily save it to a file path since analyze_image() needs a path.
21
+ """
22
+ temp_path = "temp_input.jpg"
23
+ cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
 
 
24
 
25
+ # Run full pipeline
26
+ results = model.analyze_image(temp_path)
27
 
28
+ # Convert OpenCV BGR RGB for Gradio display
29
+ combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
30
+ cej_mask_rgb = cv2.cvtColor(results["cej_mask"] * 255, cv2.COLOR_GRAY2RGB)
31
+ abc_mask_rgb = cv2.cvtColor(results["abc_mask"] * 255, cv2.COLOR_GRAY2RGB)
32
 
33
+ # Optional: summarize measurements for text output
34
+ summaries = []
35
+ for tooth in results["distance_analyses"]:
36
+ tooth_id = tooth["tooth_id"]
37
+ analysis = tooth["analysis"]
38
+ if analysis:
39
+ mean_d = analysis["mean_distance"]
40
+ summaries.append(f"Tooth {tooth_id}: mean={mean_d:.2f}px")
41
+ else:
42
+ summaries.append(f"Tooth {tooth_id}: no valid CEJ–ABC measurement")
43
 
44
+ summary_text = "\n".join(summaries) if summaries else "No detections found."
 
 
45
 
46
+ return combined_rgb, cej_mask_rgb, abc_mask_rgb, summary_text
 
 
 
47
 
 
48
 
49
  # ==========================
50
+ # 3️⃣ Build Gradio Interface
51
  # ==========================
52
  demo = gr.Interface(
53
+ fn=detect_periodontitis,
54
+ inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
55
  outputs=[
56
+ gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
57
+ gr.Image(label="CEJ Segmentation Mask"),
58
+ gr.Image(label="ABC Segmentation Mask"),
59
+ gr.Textbox(label="Analysis Summary"),
60
  ],
61
+ title="🦷 Periodontitis Detection & Analysis",
62
+ description=(
63
+ "Automatically detects teeth (YOLOv8), segments CEJ/ABC (U-Net), "
64
+ "and measures CEJ–ABC distances per tooth to assess bone loss."
65
+ ),
66
  )
67
 
68
  if __name__ == "__main__":