Bhess7's picture
Upload app.py with huggingface_hub
14c852d verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import json
import traceback
# Load model metadata
CLASS_NAMES = [
"Crossbite",
"Crowding",
"Deepbite",
"No Treatment Needed",
"Open Bite",
"Overbite",
"Spacing",
"Underbite"
]
# Condition descriptions
DESCRIPTIONS = {
"Crossbite": "A misalignment where upper teeth bite inside lower teeth.",
"Crowding": "Insufficient space causing teeth to overlap or twist.",
"Deepbite": "Upper front teeth excessively overlap lower front teeth.",
"No Treatment Needed": "Teeth appear to be properly aligned.",
"Open Bite": "Upper and lower teeth don't touch when mouth is closed.",
"Overbite": "Upper front teeth protrude significantly over lower teeth.",
"Spacing": "Gaps or spaces between teeth.",
"Underbite": "Lower teeth protrude beyond upper teeth."
}
# Load model
print("Loading model...")
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 8)
state_dict = torch.load("pytorch_model.pth", map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded!")
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image):
"""Predict orthodontic condition from image"""
try:
if image is None:
return {"error": "No image provided"}
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
# Preprocess
img_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
confidence, predicted_idx = torch.max(probabilities, 0)
predicted_condition = CLASS_NAMES[predicted_idx.item()]
confidence_pct = confidence.item() * 100
# Get all probabilities
all_probs = {CLASS_NAMES[i]: float(probabilities[i].item() * 100)
for i in range(len(CLASS_NAMES))}
# Determine recommendation
if predicted_condition == "No Treatment Needed":
recommendation = "not_candidate"
recommendation_text = "Based on the AI analysis, you may not need orthodontic treatment at this time."
elif confidence_pct >= 70:
recommendation = "candidate"
recommendation_text = f"You appear to be a good candidate for orthodontic treatment to address {predicted_condition.lower()}."
else:
recommendation = "requires_evaluation"
recommendation_text = "We recommend scheduling a consultation with an orthodontist for a thorough evaluation."
return {
"predicted_condition": predicted_condition,
"confidence": round(confidence_pct, 2),
"all_probabilities": {k: round(v, 2) for k, v in all_probs.items()},
"recommendation": recommendation,
"recommendation_text": recommendation_text,
"condition_description": DESCRIPTIONS.get(predicted_condition, ""),
"model_version": "ResNet18_512x512",
"training_accuracy": 72.73
}
except Exception as e:
traceback.print_exc()
return {"error": str(e), "traceback": traceback.format_exc()}
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Dental Photo"),
outputs=gr.JSON(label="Prediction Results"),
title="Orthodontic Condition Classifier",
description="Upload a photo of teeth to get an AI-powered orthodontic screening. This is for informational purposes only and should not replace professional evaluation.",
examples=None,
api_name="predict"
)
if __name__ == "__main__":
demo.launch(show_error=True)