File size: 4,217 Bytes
3353da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import cv2
import numpy as np
from ultralytics import YOLO
from PIL import Image

st.set_page_config(layout="wide")
st.title("OPG Segmentation + Midline + Sinus Detection")

# -----------------------------
# Load model (cached)
# -----------------------------
@st.cache_resource
def load_model():
    return YOLO("best.pt")

model = load_model()

# -----------------------------
# Upload image
# -----------------------------
uploaded_file = st.file_uploader("Upload OPG Image", type=["jpg", "png", "jpeg"])

# -----------------------------
# Preprocessing Function
# -----------------------------
def preprocess_image(image):

    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Apply CLAHE
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    gray = clahe.apply(gray)

    # Resize so max dimension = 2048
    h, w = gray.shape
    scale = 2048 / max(h, w)
    gray = cv2.resize(gray, (int(w * scale), int(h * scale)))

    # Convert back to 3-channel (YOLO expects 3 channels)
    processed = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)

    return processed



if uploaded_file is not None:

    # Convert to OpenCV format
    image = Image.open(uploaded_file).convert("RGB")
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    image = preprocess_image(image)

    h, w, _ = image.shape

    # -----------------------------
    # Run prediction
    # -----------------------------
    results = model(image, conf=0.25)
    result = results[0]

    tooth_centers_x = []

    # -----------------------------
    # Collect tooth centers
    # -----------------------------
    for box, cls in zip(result.boxes.xywh, result.boxes.cls):
        cls = int(cls)
        if cls == 0:  # tooth class
            x_center = box[0].item()
            tooth_centers_x.append(x_center)

    if len(tooth_centers_x) == 0:
        st.warning("No teeth detected!")
        st.stop()

    # -----------------------------
    # Compute midline
    # -----------------------------
    midline_x = int((min(tooth_centers_x) + max(tooth_centers_x)) / 2)

    # -----------------------------
    # Draw tooth masks
    # -----------------------------
    if result.masks is not None:
        for mask, cls in zip(result.masks.xy, result.boxes.cls):
            cls = int(cls)
            if cls == 0:
                polygon = np.array(mask, dtype=np.int32)
                cv2.polylines(image, [polygon], True, (0, 255, 0), 2)

    # -----------------------------
    # Process sinus
    # -----------------------------
    for box, cls in zip(result.boxes.xywh, result.boxes.cls):
        cls = int(cls)
        if cls == 1:  # sinus class
            x, y, bw, bh = box
            x_center = x.item()

            if x_center < midline_x:
                label = "Right Sinus"
                color = (255, 0, 0)
            else:
                label = "Left Sinus"
                color = (0, 0, 255)

            x1 = int(x - bw / 2)
            y1 = int(y - bh / 2)
            x2 = int(x + bw / 2)
            y2 = int(y + bh / 2)

            cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
            cv2.putText(image, label, (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

    # -----------------------------
    # Draw midline
    # -----------------------------
    cv2.line(image, (midline_x, 0), (midline_x, h),
             (0, 255, 255), 2)

    # Convert back to RGB for display
    display_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    col1, col2 = st.columns(2)

    with col1:
        st.image(uploaded_file, caption="Original Image")

    with col2:
        st.image(display_image, caption="Segmented Output")

    # -----------------------------
    # Download button
    # -----------------------------
    result_pil = Image.fromarray(display_image)
    st.download_button(
        label="Download Result Image",
        data=result_pil.tobytes(),
        file_name="output_with_midline.jpg",
        mime="image/jpeg"
    )