Teeth-Numbering / app.py
Noursine's picture
Update app.py
3c1a8b8 verified
import os
import gradio as gr
from ultralytics import YOLO
from PIL import Image
# Adjust these paths based on where you uploaded the files in your HF repo
MODEL_PATH = 'best.pt' # Put your best.pt at repo root or adjust if in a folder
TEST_IMAGES_FOLDER = 'test_images' # Folder inside your repo with test images
# Load YOLOv8 model
model = YOLO(MODEL_PATH)
# List all test images available (make sure folder exists and has images)
test_images = sorted(os.listdir(TEST_IMAGES_FOLDER)) if os.path.exists(TEST_IMAGES_FOLDER) else []
# List of all tooth classes (your dataset classes)
tooth_classes = ['11', '12', '13', '14', '15', '16', '17', '18',
'21', '22', '23', '24', '25', '26', '27', '28',
'31', '32', '33', '34', '35', '36', '37', '38',
'41', '42', '43', '44', '45', '46', '47', '48']
def predict_image(image_path):
results = model(image_path)
img_array = results[0].plot(conf=False, labels=True, boxes=True)
# Extract predicted classes (indices)
pred_classes = results[0].boxes.cls.cpu().numpy().astype(int) if len(results[0].boxes) > 0 else []
detected_classes = sorted(set([tooth_classes[i] for i in pred_classes])) if len(pred_classes) > 0 else []
missing_classes = sorted(set(tooth_classes) - set(detected_classes))
detected_str = ", ".join(detected_classes) if detected_classes else "None"
missing_str = ", ".join(missing_classes) if missing_classes else "None"
info_text = f"Detected tooth classes:\n{detected_str}\n\nMissing tooth classes:\n{missing_str}"
return Image.fromarray(img_array), info_text
def run_prediction(uploaded_image, selected_image):
if uploaded_image is not None:
return predict_image(uploaded_image)
elif selected_image is not None:
image_path = os.path.join(TEST_IMAGES_FOLDER, selected_image)
return predict_image(image_path)
else:
return None, ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🦷 Dental Segmentation with YOLOv8")
gr.Markdown("Upload your own image or choose a test image from the list below.")
with gr.Column():
uploaded_image = gr.Image(label="Upload your image (optional)", type="filepath")
selected_image = gr.Dropdown(choices=test_images, label="...or select a test image")
gr.Markdown("### Prediction Result")
output_image = gr.Image(label="Predicted Image")
output_text = gr.Textbox(label="Detected & Missing Tooth Classes", lines=5)
gr.Button("Run prediction").click(
fn=run_prediction,
inputs=[uploaded_image, selected_image],
outputs=[output_image, output_text]
)
demo.launch()