File size: 4,138 Bytes
b293971
 
 
 
 
 
14c852d
b293971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c852d
b293971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b293971
 
 
 
 
 
 
 
 
 
 
 
 
14c852d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import json
import traceback

# Load model metadata
CLASS_NAMES = [
    "Crossbite",
    "Crowding", 
    "Deepbite",
    "No Treatment Needed",
    "Open Bite",
    "Overbite",
    "Spacing",
    "Underbite"
]

# Condition descriptions
DESCRIPTIONS = {
    "Crossbite": "A misalignment where upper teeth bite inside lower teeth.",
    "Crowding": "Insufficient space causing teeth to overlap or twist.",
    "Deepbite": "Upper front teeth excessively overlap lower front teeth.",
    "No Treatment Needed": "Teeth appear to be properly aligned.",
    "Open Bite": "Upper and lower teeth don't touch when mouth is closed.",
    "Overbite": "Upper front teeth protrude significantly over lower teeth.",
    "Spacing": "Gaps or spaces between teeth.",
    "Underbite": "Lower teeth protrude beyond upper teeth."
}

# Load model
print("Loading model...")
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 8)
state_dict = torch.load("pytorch_model.pth", map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded!")

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(image):
    """Predict orthodontic condition from image"""
    try:
        if image is None:
            return {"error": "No image provided"}
        
        # Convert to PIL if needed
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        image = image.convert("RGB")
        
        # Preprocess
        img_tensor = transform(image).unsqueeze(0)
        
        # Predict
        with torch.no_grad():
            outputs = model(img_tensor)
            probabilities = F.softmax(outputs, dim=1)[0]
            confidence, predicted_idx = torch.max(probabilities, 0)
        
        predicted_condition = CLASS_NAMES[predicted_idx.item()]
        confidence_pct = confidence.item() * 100
        
        # Get all probabilities
        all_probs = {CLASS_NAMES[i]: float(probabilities[i].item() * 100) 
                     for i in range(len(CLASS_NAMES))}
        
        # Determine recommendation
        if predicted_condition == "No Treatment Needed":
            recommendation = "not_candidate"
            recommendation_text = "Based on the AI analysis, you may not need orthodontic treatment at this time."
        elif confidence_pct >= 70:
            recommendation = "candidate"
            recommendation_text = f"You appear to be a good candidate for orthodontic treatment to address {predicted_condition.lower()}."
        else:
            recommendation = "requires_evaluation"
            recommendation_text = "We recommend scheduling a consultation with an orthodontist for a thorough evaluation."
        
        return {
            "predicted_condition": predicted_condition,
            "confidence": round(confidence_pct, 2),
            "all_probabilities": {k: round(v, 2) for k, v in all_probs.items()},
            "recommendation": recommendation,
            "recommendation_text": recommendation_text,
            "condition_description": DESCRIPTIONS.get(predicted_condition, ""),
            "model_version": "ResNet18_512x512",
            "training_accuracy": 72.73
        }
    except Exception as e:
        traceback.print_exc()
        return {"error": str(e), "traceback": traceback.format_exc()}

# Create Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Dental Photo"),
    outputs=gr.JSON(label="Prediction Results"),
    title="Orthodontic Condition Classifier",
    description="Upload a photo of teeth to get an AI-powered orthodontic screening. This is for informational purposes only and should not replace professional evaluation.",
    examples=None,
    api_name="predict"
)

if __name__ == "__main__":
    demo.launch(show_error=True)