IFMedTechdemo's picture
Update app.py
51d1e00 verified
import os
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import ViTImageProcessor
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
# Repository configuration
REPO_ID = "IFMedTech/Dental_Q"
MODEL_FILENAME = "quantized_model.ptl"
def download_model_from_hub():
"""Download model from private Hugging Face repository"""
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable is required for private repo access. "
"Please set it in your Space settings under 'Repository secrets'."
)
try:
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
token=token
)
return model_path
except Exception as e:
raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}")
def load_model_and_processor():
"""Load the model and processor"""
token = os.environ.get("HUGGINGFACE_TOKEN")
# Download and load model
model_path = download_model_from_hub()
quantized_model = torch.jit.load(model_path, map_location="cpu")
quantized_model.eval()
# Load processor from private repo
processor = ViTImageProcessor.from_pretrained(REPO_ID, token=token)
return quantized_model, processor
# Initialize model and processor
quantized_model, processor = load_model_and_processor()
# Define Inference Preprocessing
size = processor.size['height']
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
inference_transform = Compose([
Resize(size),
CenterCrop(size),
ToTensor(),
normalize
])
# Multi-label class names
try:
label_names = [quantized_model.config.id2label[i] for i in range(len(quantized_model.config.id2label))]
except AttributeError:
label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]
def preprocess_image(image):
"""Load and preprocess a PIL image."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
return inference_transform(image).unsqueeze(0)
def predict_image(image):
"""Run inference on image and return multi-label predictions."""
pixel_values = preprocess_image(image)
with torch.no_grad():
logits = quantized_model(pixel_values)
probs = torch.sigmoid(logits).squeeze(0)
preds = (probs > 0.5).int().tolist()
detected_conditions = []
for i, (label, pred) in enumerate(zip(label_names, preds)):
if pred == 1:
confidence = probs[i].item()
detected_conditions.append(f"{label} (confidence: {confidence:.2%})")
# Check for potential Caries
try:
caries_index = label_names.index("Caries")
caries_prob = probs[caries_index].item()
if 0.3 <= caries_prob < 0.5:
detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})")
except ValueError:
pass
if detected_conditions:
result = "Detected: " + ", ".join(detected_conditions)
else:
result = "No dental issues detected"
return result
# Example images
examples = [
["example_image1.jfif"],
["example_image2.jfif"],
["example_image3.jfif"]
]
# Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Dental Image Multi-Label Classification",
description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.",
examples=examples
)
if __name__ == "__main__":
iface.launch()