import os import torch from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from transformers import ViTImageProcessor from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download # Repository configuration REPO_ID = "IFMedTech/Dental_Q" MODEL_FILENAME = "quantized_model.ptl" def download_model_from_hub(): """Download model from private Hugging Face repository""" token = os.environ.get("HUGGINGFACE_TOKEN") if not token: raise ValueError( "HUGGINGFACE_TOKEN environment variable is required for private repo access. " "Please set it in your Space settings under 'Repository secrets'." ) try: model_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILENAME, token=token ) return model_path except Exception as e: raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}") def load_model_and_processor(): """Load the model and processor""" token = os.environ.get("HUGGINGFACE_TOKEN") # Download and load model model_path = download_model_from_hub() quantized_model = torch.jit.load(model_path, map_location="cpu") quantized_model.eval() # Load processor from private repo processor = ViTImageProcessor.from_pretrained(REPO_ID, token=token) return quantized_model, processor # Initialize model and processor quantized_model, processor = load_model_and_processor() # Define Inference Preprocessing size = processor.size['height'] normalize = Normalize(mean=processor.image_mean, std=processor.image_std) inference_transform = Compose([ Resize(size), CenterCrop(size), ToTensor(), normalize ]) # Multi-label class names try: label_names = [quantized_model.config.id2label[i] for i in range(len(quantized_model.config.id2label))] except AttributeError: label_names = ["Background", "Caries", "Normal Teeth", "Plaque"] def preprocess_image(image): """Load and preprocess a PIL image.""" if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") return inference_transform(image).unsqueeze(0) def predict_image(image): """Run inference on image and return multi-label predictions.""" pixel_values = preprocess_image(image) with torch.no_grad(): logits = quantized_model(pixel_values) probs = torch.sigmoid(logits).squeeze(0) preds = (probs > 0.5).int().tolist() detected_conditions = [] for i, (label, pred) in enumerate(zip(label_names, preds)): if pred == 1: confidence = probs[i].item() detected_conditions.append(f"{label} (confidence: {confidence:.2%})") # Check for potential Caries try: caries_index = label_names.index("Caries") caries_prob = probs[caries_index].item() if 0.3 <= caries_prob < 0.5: detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})") except ValueError: pass if detected_conditions: result = "Detected: " + ", ".join(detected_conditions) else: result = "No dental issues detected" return result # Example images examples = [ ["example_image1.jfif"], ["example_image2.jfif"], ["example_image3.jfif"] ] # Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs="text", title="Dental Image Multi-Label Classification", description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.", examples=examples ) if __name__ == "__main__": iface.launch()