|
|
import gradio as gr |
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
|
|
|
model_id = "replyquickflorida/tooth-agenesis-model" |
|
|
processor = AutoImageProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForImageClassification.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
id2label = { |
|
|
0: "Calculus", |
|
|
1: "Caries", |
|
|
2: "Gingivitis", |
|
|
3: "Mouth Ulcer", |
|
|
4: "Tooth Discoloration", |
|
|
5: "Hypodontia", |
|
|
} |
|
|
|
|
|
def predict(image): |
|
|
"""Run inference on uploaded image""" |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1)[0] |
|
|
|
|
|
|
|
|
results = {} |
|
|
for idx, label in id2label.items(): |
|
|
results[label] = float(probs[idx]) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Upload Tooth X-ray"), |
|
|
outputs=gr.Label(num_top_classes=6, label="Diagnosis"), |
|
|
title="Tooth Agenesis Diagnosis", |
|
|
description="Upload a dental X-ray image to get diagnosis predictions", |
|
|
examples=None |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |