File size: 3,036 Bytes
647cdfa
6a2e2f2
647cdfa
 
 
 
 
 
f8de414
647cdfa
 
 
 
 
 
 
 
 
 
d1a31d9
 
 
 
 
 
 
647cdfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44cf9b0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st 
import numpy as np
from PIL import Image
import cv2
from ultralytics import YOLO

@st.cache_resource(show_spinner=False)
def load_model():
    return YOLO('/app/src/best.pt') 

def predict(image, model):
    img_np = np.array(image)
    if img_np.shape[2] == 4:
        img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    results = model(img_bgr)
    return results[0]

def draw_results(image, results):
    try:
        img_annotated = results.plot()
        img_annotated_rgb = cv2.cvtColor(img_annotated, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img_annotated_rgb)
    except Exception as e:
        st.error(f"Error during annotation: {e}")
        return image

def crop_box(image, box):
    img_np = np.array(image)
    x1, y1, x2, y2 = map(int, box.xyxy[0])
    h, w = img_np.shape[:2]
    pad_x, pad_y = int((x2 - x1) * 0.1), int((y2 - y1) * 0.1)
    x1, y1 = max(0, x1 - pad_x), max(0, y1 - pad_y)
    x2, y2 = min(w, x2 + pad_x), min(h, y2 + pad_y)
    cropped = img_np[y1:y2, x1:x2]
    return Image.fromarray(cropped)

def main():
    st.title("🦷 Interactive Teeth Segmentation & Annotation")

    st.markdown("""
    Upload a panoramic dental X-ray image.  
    After detection, you can edit the tooth number labels for each detected tooth.
    """)

    model = load_model()
    uploaded_file = st.file_uploader("Upload panoramic dental X-ray", type=["png", "jpg", "jpeg"])

    if uploaded_file:
        image = Image.open(uploaded_file).convert("RGB")
        with st.spinner("Detecting teeth..."):
            results = predict(image, model)

        annotated_img = draw_results(image, results)

        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Original Image")
            st.image(image, use_column_width=True)
        with col2:
            st.subheader("Detection & Segmentation")
            st.image(annotated_img, use_column_width=True)

        if len(results.boxes) > 0:
            st.markdown("### Edit Detected Tooth Numbers")

            edited_labels = {}

            for i, box in enumerate(results.boxes):
                class_id = int(box.cls[0])
                default_label = results.names[class_id]
                confidence = box.conf[0].item()
                cropped = crop_box(image, box)

                st.markdown(f"**Tooth {i+1}** (Confidence: {confidence:.2%})")
                st.image(cropped, width=150)

                new_label = st.text_input(f"Change tooth number (default: {default_label})",
                                          value=default_label,
                                          key=f"label_{i}")

                edited_labels[i] = new_label

            st.markdown("---")
            st.subheader("Final Tooth Labels")
            for i, label in edited_labels.items():
                st.write(f"Tooth {i+1}: {label}")

        else:
            st.info("No teeth detected in this image.")

if __name__ == "__main__":
    main()