import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import json import traceback # Load model metadata CLASS_NAMES = [ "Crossbite", "Crowding", "Deepbite", "No Treatment Needed", "Open Bite", "Overbite", "Spacing", "Underbite" ] # Condition descriptions DESCRIPTIONS = { "Crossbite": "A misalignment where upper teeth bite inside lower teeth.", "Crowding": "Insufficient space causing teeth to overlap or twist.", "Deepbite": "Upper front teeth excessively overlap lower front teeth.", "No Treatment Needed": "Teeth appear to be properly aligned.", "Open Bite": "Upper and lower teeth don't touch when mouth is closed.", "Overbite": "Upper front teeth protrude significantly over lower teeth.", "Spacing": "Gaps or spaces between teeth.", "Underbite": "Lower teeth protrude beyond upper teeth." } # Load model print("Loading model...") model = models.resnet18(weights=None) model.fc = torch.nn.Linear(model.fc.in_features, 8) state_dict = torch.load("pytorch_model.pth", map_location="cpu", weights_only=True) model.load_state_dict(state_dict) model.eval() print("Model loaded!") # Image preprocessing transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(image): """Predict orthodontic condition from image""" try: if image is None: return {"error": "No image provided"} # Convert to PIL if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") # Preprocess img_tensor = transform(image).unsqueeze(0) # Predict with torch.no_grad(): outputs = model(img_tensor) probabilities = F.softmax(outputs, dim=1)[0] confidence, predicted_idx = torch.max(probabilities, 0) predicted_condition = CLASS_NAMES[predicted_idx.item()] confidence_pct = confidence.item() * 100 # Get all probabilities all_probs = {CLASS_NAMES[i]: float(probabilities[i].item() * 100) for i in range(len(CLASS_NAMES))} # Determine recommendation if predicted_condition == "No Treatment Needed": recommendation = "not_candidate" recommendation_text = "Based on the AI analysis, you may not need orthodontic treatment at this time." elif confidence_pct >= 70: recommendation = "candidate" recommendation_text = f"You appear to be a good candidate for orthodontic treatment to address {predicted_condition.lower()}." else: recommendation = "requires_evaluation" recommendation_text = "We recommend scheduling a consultation with an orthodontist for a thorough evaluation." return { "predicted_condition": predicted_condition, "confidence": round(confidence_pct, 2), "all_probabilities": {k: round(v, 2) for k, v in all_probs.items()}, "recommendation": recommendation, "recommendation_text": recommendation_text, "condition_description": DESCRIPTIONS.get(predicted_condition, ""), "model_version": "ResNet18_512x512", "training_accuracy": 72.73 } except Exception as e: traceback.print_exc() return {"error": str(e), "traceback": traceback.format_exc()} # Create Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Dental Photo"), outputs=gr.JSON(label="Prediction Results"), title="Orthodontic Condition Classifier", description="Upload a photo of teeth to get an AI-powered orthodontic screening. This is for informational purposes only and should not replace professional evaluation.", examples=None, api_name="predict" ) if __name__ == "__main__": demo.launch(show_error=True)