Spaces:
Sleeping
Sleeping
File size: 4,138 Bytes
b293971 14c852d b293971 14c852d b293971 14c852d b293971 14c852d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import json
import traceback
# Load model metadata
CLASS_NAMES = [
"Crossbite",
"Crowding",
"Deepbite",
"No Treatment Needed",
"Open Bite",
"Overbite",
"Spacing",
"Underbite"
]
# Condition descriptions
DESCRIPTIONS = {
"Crossbite": "A misalignment where upper teeth bite inside lower teeth.",
"Crowding": "Insufficient space causing teeth to overlap or twist.",
"Deepbite": "Upper front teeth excessively overlap lower front teeth.",
"No Treatment Needed": "Teeth appear to be properly aligned.",
"Open Bite": "Upper and lower teeth don't touch when mouth is closed.",
"Overbite": "Upper front teeth protrude significantly over lower teeth.",
"Spacing": "Gaps or spaces between teeth.",
"Underbite": "Lower teeth protrude beyond upper teeth."
}
# Load model
print("Loading model...")
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 8)
state_dict = torch.load("pytorch_model.pth", map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded!")
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image):
"""Predict orthodontic condition from image"""
try:
if image is None:
return {"error": "No image provided"}
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
# Preprocess
img_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
confidence, predicted_idx = torch.max(probabilities, 0)
predicted_condition = CLASS_NAMES[predicted_idx.item()]
confidence_pct = confidence.item() * 100
# Get all probabilities
all_probs = {CLASS_NAMES[i]: float(probabilities[i].item() * 100)
for i in range(len(CLASS_NAMES))}
# Determine recommendation
if predicted_condition == "No Treatment Needed":
recommendation = "not_candidate"
recommendation_text = "Based on the AI analysis, you may not need orthodontic treatment at this time."
elif confidence_pct >= 70:
recommendation = "candidate"
recommendation_text = f"You appear to be a good candidate for orthodontic treatment to address {predicted_condition.lower()}."
else:
recommendation = "requires_evaluation"
recommendation_text = "We recommend scheduling a consultation with an orthodontist for a thorough evaluation."
return {
"predicted_condition": predicted_condition,
"confidence": round(confidence_pct, 2),
"all_probabilities": {k: round(v, 2) for k, v in all_probs.items()},
"recommendation": recommendation,
"recommendation_text": recommendation_text,
"condition_description": DESCRIPTIONS.get(predicted_condition, ""),
"model_version": "ResNet18_512x512",
"training_accuracy": 72.73
}
except Exception as e:
traceback.print_exc()
return {"error": str(e), "traceback": traceback.format_exc()}
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Dental Photo"),
outputs=gr.JSON(label="Prediction Results"),
title="Orthodontic Condition Classifier",
description="Upload a photo of teeth to get an AI-powered orthodontic screening. This is for informational purposes only and should not replace professional evaluation.",
examples=None,
api_name="predict"
)
if __name__ == "__main__":
demo.launch(show_error=True)
|