Spaces:
Runtime error
Runtime error
File size: 4,217 Bytes
3353da4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import streamlit as st
import cv2
import numpy as np
from ultralytics import YOLO
from PIL import Image
st.set_page_config(layout="wide")
st.title("OPG Segmentation + Midline + Sinus Detection")
# -----------------------------
# Load model (cached)
# -----------------------------
@st.cache_resource
def load_model():
return YOLO("best.pt")
model = load_model()
# -----------------------------
# Upload image
# -----------------------------
uploaded_file = st.file_uploader("Upload OPG Image", type=["jpg", "png", "jpeg"])
# -----------------------------
# Preprocessing Function
# -----------------------------
def preprocess_image(image):
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = clahe.apply(gray)
# Resize so max dimension = 2048
h, w = gray.shape
scale = 2048 / max(h, w)
gray = cv2.resize(gray, (int(w * scale), int(h * scale)))
# Convert back to 3-channel (YOLO expects 3 channels)
processed = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
return processed
if uploaded_file is not None:
# Convert to OpenCV format
image = Image.open(uploaded_file).convert("RGB")
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = preprocess_image(image)
h, w, _ = image.shape
# -----------------------------
# Run prediction
# -----------------------------
results = model(image, conf=0.25)
result = results[0]
tooth_centers_x = []
# -----------------------------
# Collect tooth centers
# -----------------------------
for box, cls in zip(result.boxes.xywh, result.boxes.cls):
cls = int(cls)
if cls == 0: # tooth class
x_center = box[0].item()
tooth_centers_x.append(x_center)
if len(tooth_centers_x) == 0:
st.warning("No teeth detected!")
st.stop()
# -----------------------------
# Compute midline
# -----------------------------
midline_x = int((min(tooth_centers_x) + max(tooth_centers_x)) / 2)
# -----------------------------
# Draw tooth masks
# -----------------------------
if result.masks is not None:
for mask, cls in zip(result.masks.xy, result.boxes.cls):
cls = int(cls)
if cls == 0:
polygon = np.array(mask, dtype=np.int32)
cv2.polylines(image, [polygon], True, (0, 255, 0), 2)
# -----------------------------
# Process sinus
# -----------------------------
for box, cls in zip(result.boxes.xywh, result.boxes.cls):
cls = int(cls)
if cls == 1: # sinus class
x, y, bw, bh = box
x_center = x.item()
if x_center < midline_x:
label = "Right Sinus"
color = (255, 0, 0)
else:
label = "Left Sinus"
color = (0, 0, 255)
x1 = int(x - bw / 2)
y1 = int(y - bh / 2)
x2 = int(x + bw / 2)
y2 = int(y + bh / 2)
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
cv2.putText(image, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
# -----------------------------
# Draw midline
# -----------------------------
cv2.line(image, (midline_x, 0), (midline_x, h),
(0, 255, 255), 2)
# Convert back to RGB for display
display_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
col1, col2 = st.columns(2)
with col1:
st.image(uploaded_file, caption="Original Image")
with col2:
st.image(display_image, caption="Segmented Output")
# -----------------------------
# Download button
# -----------------------------
result_pil = Image.fromarray(display_image)
st.download_button(
label="Download Result Image",
data=result_pil.tobytes(),
file_name="output_with_midline.jpg",
mime="image/jpeg"
)
|