Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import json | |
| import traceback | |
| # Load model metadata | |
| CLASS_NAMES = [ | |
| "Crossbite", | |
| "Crowding", | |
| "Deepbite", | |
| "No Treatment Needed", | |
| "Open Bite", | |
| "Overbite", | |
| "Spacing", | |
| "Underbite" | |
| ] | |
| # Condition descriptions | |
| DESCRIPTIONS = { | |
| "Crossbite": "A misalignment where upper teeth bite inside lower teeth.", | |
| "Crowding": "Insufficient space causing teeth to overlap or twist.", | |
| "Deepbite": "Upper front teeth excessively overlap lower front teeth.", | |
| "No Treatment Needed": "Teeth appear to be properly aligned.", | |
| "Open Bite": "Upper and lower teeth don't touch when mouth is closed.", | |
| "Overbite": "Upper front teeth protrude significantly over lower teeth.", | |
| "Spacing": "Gaps or spaces between teeth.", | |
| "Underbite": "Lower teeth protrude beyond upper teeth." | |
| } | |
| # Load model | |
| print("Loading model...") | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, 8) | |
| state_dict = torch.load("pytorch_model.pth", map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print("Model loaded!") | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| """Predict orthodontic condition from image""" | |
| try: | |
| if image is None: | |
| return {"error": "No image provided"} | |
| # Convert to PIL if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| # Preprocess | |
| img_tensor = transform(image).unsqueeze(0) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = F.softmax(outputs, dim=1)[0] | |
| confidence, predicted_idx = torch.max(probabilities, 0) | |
| predicted_condition = CLASS_NAMES[predicted_idx.item()] | |
| confidence_pct = confidence.item() * 100 | |
| # Get all probabilities | |
| all_probs = {CLASS_NAMES[i]: float(probabilities[i].item() * 100) | |
| for i in range(len(CLASS_NAMES))} | |
| # Determine recommendation | |
| if predicted_condition == "No Treatment Needed": | |
| recommendation = "not_candidate" | |
| recommendation_text = "Based on the AI analysis, you may not need orthodontic treatment at this time." | |
| elif confidence_pct >= 70: | |
| recommendation = "candidate" | |
| recommendation_text = f"You appear to be a good candidate for orthodontic treatment to address {predicted_condition.lower()}." | |
| else: | |
| recommendation = "requires_evaluation" | |
| recommendation_text = "We recommend scheduling a consultation with an orthodontist for a thorough evaluation." | |
| return { | |
| "predicted_condition": predicted_condition, | |
| "confidence": round(confidence_pct, 2), | |
| "all_probabilities": {k: round(v, 2) for k, v in all_probs.items()}, | |
| "recommendation": recommendation, | |
| "recommendation_text": recommendation_text, | |
| "condition_description": DESCRIPTIONS.get(predicted_condition, ""), | |
| "model_version": "ResNet18_512x512", | |
| "training_accuracy": 72.73 | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Dental Photo"), | |
| outputs=gr.JSON(label="Prediction Results"), | |
| title="Orthodontic Condition Classifier", | |
| description="Upload a photo of teeth to get an AI-powered orthodontic screening. This is for informational purposes only and should not replace professional evaluation.", | |
| examples=None, | |
| api_name="predict" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |