Nilay1400's picture
Upload 4 files
3353da4 verified
import streamlit as st
import cv2
import numpy as np
from ultralytics import YOLO
from PIL import Image
st.set_page_config(layout="wide")
st.title("OPG Segmentation + Midline + Sinus Detection")
# -----------------------------
# Load model (cached)
# -----------------------------
@st.cache_resource
def load_model():
return YOLO("best.pt")
model = load_model()
# -----------------------------
# Upload image
# -----------------------------
uploaded_file = st.file_uploader("Upload OPG Image", type=["jpg", "png", "jpeg"])
# -----------------------------
# Preprocessing Function
# -----------------------------
def preprocess_image(image):
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = clahe.apply(gray)
# Resize so max dimension = 2048
h, w = gray.shape
scale = 2048 / max(h, w)
gray = cv2.resize(gray, (int(w * scale), int(h * scale)))
# Convert back to 3-channel (YOLO expects 3 channels)
processed = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
return processed
if uploaded_file is not None:
# Convert to OpenCV format
image = Image.open(uploaded_file).convert("RGB")
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = preprocess_image(image)
h, w, _ = image.shape
# -----------------------------
# Run prediction
# -----------------------------
results = model(image, conf=0.25)
result = results[0]
tooth_centers_x = []
# -----------------------------
# Collect tooth centers
# -----------------------------
for box, cls in zip(result.boxes.xywh, result.boxes.cls):
cls = int(cls)
if cls == 0: # tooth class
x_center = box[0].item()
tooth_centers_x.append(x_center)
if len(tooth_centers_x) == 0:
st.warning("No teeth detected!")
st.stop()
# -----------------------------
# Compute midline
# -----------------------------
midline_x = int((min(tooth_centers_x) + max(tooth_centers_x)) / 2)
# -----------------------------
# Draw tooth masks
# -----------------------------
if result.masks is not None:
for mask, cls in zip(result.masks.xy, result.boxes.cls):
cls = int(cls)
if cls == 0:
polygon = np.array(mask, dtype=np.int32)
cv2.polylines(image, [polygon], True, (0, 255, 0), 2)
# -----------------------------
# Process sinus
# -----------------------------
for box, cls in zip(result.boxes.xywh, result.boxes.cls):
cls = int(cls)
if cls == 1: # sinus class
x, y, bw, bh = box
x_center = x.item()
if x_center < midline_x:
label = "Right Sinus"
color = (255, 0, 0)
else:
label = "Left Sinus"
color = (0, 0, 255)
x1 = int(x - bw / 2)
y1 = int(y - bh / 2)
x2 = int(x + bw / 2)
y2 = int(y + bh / 2)
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
cv2.putText(image, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
# -----------------------------
# Draw midline
# -----------------------------
cv2.line(image, (midline_x, 0), (midline_x, h),
(0, 255, 255), 2)
# Convert back to RGB for display
display_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
col1, col2 = st.columns(2)
with col1:
st.image(uploaded_file, caption="Original Image")
with col2:
st.image(display_image, caption="Segmented Output")
# -----------------------------
# Download button
# -----------------------------
result_pil = Image.fromarray(display_image)
st.download_button(
label="Download Result Image",
data=result_pil.tobytes(),
file_name="output_with_midline.jpg",
mime="image/jpeg"
)