File size: 2,173 Bytes
0df0d16
cdc5131
 
4cc223c
 
 
 
 
 
 
 
 
 
3bc7e2d
 
 
4cc223c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from huggingface_hub import login

import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
import matplotlib.pyplot as plt

from weights import W, b

# Load DINOv3
token = os.environ.get("DINO_TOKEN")
if token:
    login(token)
MODEL_ID = "facebook/dinov3-vitl16-pretrain-lvd1689m"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID)
model.eval()

# Check for register tokens
num_reg = getattr(model.config, "num_register_tokens", 0)

def analyze(image: Image.Image):
    # Resize to 224x224
    image = image.convert("RGB").resize((224, 224), Image.BICUBIC)
    
    # Extract DINO features
    inputs = processor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get patch features (skip CLS and register tokens)
    features = outputs.last_hidden_state[0, 1+num_reg:, :].numpy()  # (196, 1024)
    features = features.reshape(14, 14, -1)                         # (14, 14, 1024)
    
    # EFIQA adapter: 1x1 conv
    logits = features @ W + b                        # (14, 14)
    probs = 1 / (1 + np.exp(-logits))                # sigmoid
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    axes[0].imshow(image)
    axes[0].set_title("Input Fundus Image")
    axes[0].axis("off")
    
    im = axes[1].imshow(probs, cmap="magma", vmin=0, vmax=1)
    axes[1].set_title("Quality Map (bright = degraded)")
    axes[1].axis("off")
    plt.colorbar(im, ax=axes[1], fraction=0.046)
    
    plt.tight_layout()
    return fig

demo = gr.Interface(
    fn=analyze,
    inputs=gr.Image(type="pil", label="Upload Fundus Image"),
    outputs=gr.Plot(label="Quality Assessment"),
    title="EFIQA: Explainable Fundus Image Quality Assessment",
    description="Upload a color fundus photograph to get a spatial quality map. Bright regions indicate areas with degraded quality (missing anatomical structures).",
    article="Paper: *EFIQA: Explainable Fundus Image Quality Assessment via Anatomical Priors* (MIDL 2026)",
)

if __name__ == "__main__":
    demo.launch()