File size: 2,691 Bytes
07c2cff
 
 
 
 
3c1a8b8
 
 
 
1d10091
3c1a8b8
07c2cff
3c1a8b8
 
07c2cff
3c1a8b8
 
 
 
 
1d10091
07c2cff
 
 
3c1a8b8
 
 
 
 
1d10091
3c1a8b8
1d10091
 
3c1a8b8
1d10091
3c1a8b8
1d10091
07c2cff
 
 
 
 
3c1a8b8
07c2cff
 
1d10091
07c2cff
 
 
1d10091
07c2cff
 
 
1d10091
07c2cff
1d10091
 
 
07c2cff
 
 
 
1d10091
07c2cff
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import gradio as gr
from ultralytics import YOLO
from PIL import Image

# Adjust these paths based on where you uploaded the files in your HF repo
MODEL_PATH = 'best.pt'  # Put your best.pt at repo root or adjust if in a folder
TEST_IMAGES_FOLDER = 'test_images'  # Folder inside your repo with test images

# Load YOLOv8 model
model = YOLO(MODEL_PATH)

# List all test images available (make sure folder exists and has images)
test_images = sorted(os.listdir(TEST_IMAGES_FOLDER)) if os.path.exists(TEST_IMAGES_FOLDER) else []

# List of all tooth classes (your dataset classes)
tooth_classes = ['11', '12', '13', '14', '15', '16', '17', '18', 
                 '21', '22', '23', '24', '25', '26', '27', '28', 
                 '31', '32', '33', '34', '35', '36', '37', '38', 
                 '41', '42', '43', '44', '45', '46', '47', '48']

def predict_image(image_path):
    results = model(image_path)
    img_array = results[0].plot(conf=False, labels=True, boxes=True)

    # Extract predicted classes (indices)
    pred_classes = results[0].boxes.cls.cpu().numpy().astype(int) if len(results[0].boxes) > 0 else []

    detected_classes = sorted(set([tooth_classes[i] for i in pred_classes])) if len(pred_classes) > 0 else []
    missing_classes = sorted(set(tooth_classes) - set(detected_classes))

    detected_str = ", ".join(detected_classes) if detected_classes else "None"
    missing_str = ", ".join(missing_classes) if missing_classes else "None"

    info_text = f"Detected tooth classes:\n{detected_str}\n\nMissing tooth classes:\n{missing_str}"

    return Image.fromarray(img_array), info_text

def run_prediction(uploaded_image, selected_image):
    if uploaded_image is not None:
        return predict_image(uploaded_image)
    elif selected_image is not None:
        image_path = os.path.join(TEST_IMAGES_FOLDER, selected_image)
        return predict_image(image_path)
    else:
        return None, ""

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 🦷 Dental Segmentation with YOLOv8")
    gr.Markdown("Upload your own image or choose a test image from the list below.")

    with gr.Column():
        uploaded_image = gr.Image(label="Upload your image (optional)", type="filepath")
        selected_image = gr.Dropdown(choices=test_images, label="...or select a test image")

    gr.Markdown("### Prediction Result")
    output_image = gr.Image(label="Predicted Image")
    output_text = gr.Textbox(label="Detected & Missing Tooth Classes", lines=5)

    gr.Button("Run prediction").click(
        fn=run_prediction,
        inputs=[uploaded_image, selected_image],
        outputs=[output_image, output_text]
    )

demo.launch()