File size: 6,284 Bytes
7179b2d
 
 
 
 
 
 
 
 
 
 
5b11294
 
 
 
7179b2d
 
5b11294
7179b2d
 
 
b90b725
5b11294
b90b725
 
c6b075f
7179b2d
5b11294
 
7179b2d
 
 
 
 
 
5b11294
7179b2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b11294
7179b2d
5b11294
 
 
 
 
 
 
 
 
 
 
 
 
7179b2d
5b11294
7179b2d
5b11294
7179b2d
5b11294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7179b2d
5b11294
 
 
7179b2d
 
 
 
 
 
c34dda4
7179b2d
c34dda4
 
5b11294
c34dda4
5b11294
 
c34dda4
5b11294
7179b2d
c34dda4
 
 
 
 
 
5b11294
 
 
c34dda4
5b11294
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import io
import sys
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as pth_transforms
from safetensors.torch import load_file
import base64
from io import BytesIO
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm # Needed for attention head colormaps
from huggingface_hub import hf_hub_download

# --- Make sure Python can find dinov2 properly ---
sys.path.append(os.path.dirname(__file__))  
from dinov2.models.vision_transformer import vit_large

# --- Constants ---
CKPT_PATH = hf_hub_download(
    repo_id="Arew99/dinov2-costum",
    filename="model.safetensors"
)

PATCH_SIZE = 14
# We use a fixed height for consistency
INFERENCE_HEIGHT = 616 

# -------------------------------------------------------
# Load model
# -------------------------------------------------------
def load_model():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(f"Loading DINOv2 (PCA + Heads) on device: {device}")

    model = vit_large(
        patch_size=PATCH_SIZE,
        img_size=518,
        init_values=1.0,
        ffn_layer="swiglufused",
        block_chunks=0
    )

    # Load weights
    state_dict = load_file(CKPT_PATH)
    keys_list = list(state_dict.keys())
    if keys_list and "model." in keys_list[0]:
        state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict, strict=False)
    for p in model.parameters():
        p.requires_grad = False

    model.eval()
    model.to(device)
    return model, device


# -------------------------------------------------------
# HYBRID PREDICTION (PCA + HEADS)
# -------------------------------------------------------
def predict_from_bytes(model_device_tuple, image_bytes):
    model, device = model_device_tuple
    
    # 1. Load & Preprocess
    original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    orig_w, orig_h = original_image.size

    # Resize keeping aspect ratio, but ensuring dimensions are divisible by PATCH_SIZE
    scale = INFERENCE_HEIGHT / orig_h
    new_w = int(orig_w * scale)
    new_w = new_w - (new_w % PATCH_SIZE)
    new_h = INFERENCE_HEIGHT - (INFERENCE_HEIGHT % PATCH_SIZE)
    
    transform = pth_transforms.Compose([
        pth_transforms.Resize((new_h, new_w)),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    img_tensor = transform(original_image).unsqueeze(0).to(device)
    
    # Grid dimensions
    w_featmap = new_w // PATCH_SIZE
    h_featmap = new_h // PATCH_SIZE

    # ─────────────────────────────────────────────────────────────
    # PART A: PCA VISUALIZATION (Main Result)
    # ─────────────────────────────────────────────────────────────
    with torch.no_grad():
        out = model.forward_features(img_tensor)
        features = out["x_norm_patchtokens"][0].cpu().numpy()

    # Apply PCA (3 Components -> RGB)
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)
    pca = PCA(n_components=3)
    pca_features = pca.fit_transform(features_scaled)
    
    # Reshape and Normalize
    pca_img = pca_features.reshape(h_featmap, w_featmap, 3)
    pca_min, pca_max = pca_img.min(axis=(0, 1)), pca_img.max(axis=(0, 1))
    pca_img = (pca_img - pca_min) / (pca_max - pca_min + 1e-8)
    pca_uint8 = (pca_img * 255).astype(np.uint8)

    # Resize to original
    pca_pil = Image.fromarray(pca_uint8).resize((orig_w, orig_h), resample=Image.NEAREST)
    
    buf = BytesIO()
    pca_pil.save(buf, format="PNG")
    buf.seek(0)
    pca_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

    # ─────────────────────────────────────────────────────────────
    # PART B: ATTENTION HEADS (The 16 small plots)
    # ─────────────────────────────────────────────────────────────
    # Get raw attention weights from the last layer
    with torch.no_grad():
        attentions = model.get_last_self_attention(img_tensor)
    
    nh = attentions.shape[1]  # number of heads (usually 16 for ViT-Large)

    # Process attention maps
    # Shape: [1, heads, tokens, tokens] -> Extract CLS token attention [1, heads, 0, 1:]
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    attentions = attentions.reshape(nh, h_featmap, w_featmap)
    
    # Upsample to match patch size visually
    attentions = nn.functional.interpolate(
        attentions.unsqueeze(0),
        scale_factor=PATCH_SIZE,
        mode="nearest"
    )[0].cpu().numpy()

    all_heads_base64 = []

    for i in range(nh):
        head_attn = attentions[i]
        # Normalize per head for better contrast
        head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8)
        
        # Apply colormap (Viridis)
        heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8)
        heatmap_img = Image.fromarray(heatmap).resize((orig_w, orig_h), Image.BILINEAR)

        buf = BytesIO()
        heatmap_img.save(buf, format="PNG")
        buf.seek(0)
        head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
        all_heads_base64.append(head_b64)

    # ─────────────────────────────────────────────────────────────
    # RETURN BOTH
    # ─────────────────────────────────────────────────────────────
    return {
        "pca_image": pca_b64,           # The Main "Rainbow" Image
        "head_attention_maps": all_heads_base64 # The 16 Head Plots
    }