File size: 3,808 Bytes
0514a9d
57b42a5
0514a9d
1d3a8a8
57b42a5
 
0514a9d
57b42a5
1d3a8a8
 
 
57b42a5
0514a9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57b42a5
0514a9d
 
 
 
 
 
 
 
 
 
 
 
 
57b42a5
0514a9d
 
57b42a5
1d3a8a8
 
 
 
 
 
 
 
 
8b6ff1c
1d3a8a8
 
 
 
 
8b6ff1c
1d3a8a8
51d1e00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import ViTImageProcessor
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download

# Repository configuration
REPO_ID = "IFMedTech/Dental_Q"
MODEL_FILENAME = "quantized_model.ptl"

def download_model_from_hub():
    """Download model from private Hugging Face repository"""
    token = os.environ.get("HUGGINGFACE_TOKEN")
    
    if not token:
        raise ValueError(
            "HUGGINGFACE_TOKEN environment variable is required for private repo access. "
            "Please set it in your Space settings under 'Repository secrets'."
        )
    
    try:
        model_path = hf_hub_download(
            repo_id=REPO_ID,
            filename=MODEL_FILENAME,
            token=token
        )
        return model_path
    except Exception as e:
        raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}")

def load_model_and_processor():
    """Load the model and processor"""
    token = os.environ.get("HUGGINGFACE_TOKEN")
    
    # Download and load model
    model_path = download_model_from_hub()
    quantized_model = torch.jit.load(model_path, map_location="cpu")
    quantized_model.eval()
    
    # Load processor from private repo
    processor = ViTImageProcessor.from_pretrained(REPO_ID, token=token)
    
    return quantized_model, processor

# Initialize model and processor
quantized_model, processor = load_model_and_processor()

# Define Inference Preprocessing
size = processor.size['height']
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
inference_transform = Compose([
    Resize(size),
    CenterCrop(size),
    ToTensor(),
    normalize
])

# Multi-label class names
try:
    label_names = [quantized_model.config.id2label[i] for i in range(len(quantized_model.config.id2label))]
except AttributeError:
    label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]

def preprocess_image(image):
    """Load and preprocess a PIL image."""
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    return inference_transform(image).unsqueeze(0)

def predict_image(image):
    """Run inference on image and return multi-label predictions."""
    pixel_values = preprocess_image(image)
    
    with torch.no_grad():
        logits = quantized_model(pixel_values)
    
    probs = torch.sigmoid(logits).squeeze(0)
    preds = (probs > 0.5).int().tolist()
    
    detected_conditions = []
    for i, (label, pred) in enumerate(zip(label_names, preds)):
        if pred == 1:
            confidence = probs[i].item()
            detected_conditions.append(f"{label} (confidence: {confidence:.2%})")
    
    # Check for potential Caries
    try:
        caries_index = label_names.index("Caries")
        caries_prob = probs[caries_index].item()
        if 0.3 <= caries_prob < 0.5:
            detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})")
    except ValueError:
        pass
    
    if detected_conditions:
        result = "Detected: " + ", ".join(detected_conditions)
    else:
        result = "No dental issues detected"
    
    return result

# Example images
examples = [
    ["example_image1.jfif"],
    ["example_image2.jfif"],
    ["example_image3.jfif"]
]

# Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Dental Image Multi-Label Classification",
    description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.",
    examples=examples
)

if __name__ == "__main__":
    iface.launch()