Spaces:
Sleeping
Sleeping
File size: 3,808 Bytes
0514a9d 57b42a5 0514a9d 1d3a8a8 57b42a5 0514a9d 57b42a5 1d3a8a8 57b42a5 0514a9d 57b42a5 0514a9d 57b42a5 0514a9d 57b42a5 1d3a8a8 8b6ff1c 1d3a8a8 8b6ff1c 1d3a8a8 51d1e00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import ViTImageProcessor
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
# Repository configuration
REPO_ID = "IFMedTech/Dental_Q"
MODEL_FILENAME = "quantized_model.ptl"
def download_model_from_hub():
"""Download model from private Hugging Face repository"""
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable is required for private repo access. "
"Please set it in your Space settings under 'Repository secrets'."
)
try:
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
token=token
)
return model_path
except Exception as e:
raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}")
def load_model_and_processor():
"""Load the model and processor"""
token = os.environ.get("HUGGINGFACE_TOKEN")
# Download and load model
model_path = download_model_from_hub()
quantized_model = torch.jit.load(model_path, map_location="cpu")
quantized_model.eval()
# Load processor from private repo
processor = ViTImageProcessor.from_pretrained(REPO_ID, token=token)
return quantized_model, processor
# Initialize model and processor
quantized_model, processor = load_model_and_processor()
# Define Inference Preprocessing
size = processor.size['height']
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
inference_transform = Compose([
Resize(size),
CenterCrop(size),
ToTensor(),
normalize
])
# Multi-label class names
try:
label_names = [quantized_model.config.id2label[i] for i in range(len(quantized_model.config.id2label))]
except AttributeError:
label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]
def preprocess_image(image):
"""Load and preprocess a PIL image."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
return inference_transform(image).unsqueeze(0)
def predict_image(image):
"""Run inference on image and return multi-label predictions."""
pixel_values = preprocess_image(image)
with torch.no_grad():
logits = quantized_model(pixel_values)
probs = torch.sigmoid(logits).squeeze(0)
preds = (probs > 0.5).int().tolist()
detected_conditions = []
for i, (label, pred) in enumerate(zip(label_names, preds)):
if pred == 1:
confidence = probs[i].item()
detected_conditions.append(f"{label} (confidence: {confidence:.2%})")
# Check for potential Caries
try:
caries_index = label_names.index("Caries")
caries_prob = probs[caries_index].item()
if 0.3 <= caries_prob < 0.5:
detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})")
except ValueError:
pass
if detected_conditions:
result = "Detected: " + ", ".join(detected_conditions)
else:
result = "No dental issues detected"
return result
# Example images
examples = [
["example_image1.jfif"],
["example_image2.jfif"],
["example_image3.jfif"]
]
# Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Dental Image Multi-Label Classification",
description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.",
examples=examples
)
if __name__ == "__main__":
iface.launch()
|