|
|
import os |
|
|
from huggingface_hub import login |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModel, AutoImageProcessor |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from weights import W, b |
|
|
|
|
|
|
|
|
token = os.environ.get("DINO_TOKEN") |
|
|
if token: |
|
|
login(token) |
|
|
MODEL_ID = "facebook/dinov3-vitl16-pretrain-lvd1689m" |
|
|
processor = AutoImageProcessor.from_pretrained(MODEL_ID) |
|
|
model = AutoModel.from_pretrained(MODEL_ID) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
num_reg = getattr(model.config, "num_register_tokens", 0) |
|
|
|
|
|
def analyze(image: Image.Image): |
|
|
|
|
|
image = image.convert("RGB").resize((224, 224), Image.BICUBIC) |
|
|
|
|
|
|
|
|
inputs = processor(image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
features = outputs.last_hidden_state[0, 1+num_reg:, :].numpy() |
|
|
features = features.reshape(14, 14, -1) |
|
|
|
|
|
|
|
|
logits = features @ W + b |
|
|
probs = 1 / (1 + np.exp(-logits)) |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5)) |
|
|
|
|
|
axes[0].imshow(image) |
|
|
axes[0].set_title("Input Fundus Image") |
|
|
axes[0].axis("off") |
|
|
|
|
|
im = axes[1].imshow(probs, cmap="magma", vmin=0, vmax=1) |
|
|
axes[1].set_title("Quality Map (bright = degraded)") |
|
|
axes[1].axis("off") |
|
|
plt.colorbar(im, ax=axes[1], fraction=0.046) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=analyze, |
|
|
inputs=gr.Image(type="pil", label="Upload Fundus Image"), |
|
|
outputs=gr.Plot(label="Quality Assessment"), |
|
|
title="EFIQA: Explainable Fundus Image Quality Assessment", |
|
|
description="Upload a color fundus photograph to get a spatial quality map. Bright regions indicate areas with degraded quality (missing anatomical structures).", |
|
|
article="Paper: *EFIQA: Explainable Fundus Image Quality Assessment via Anatomical Priors* (MIDL 2026)", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|