File size: 6,284 Bytes
7179b2d 5b11294 7179b2d 5b11294 7179b2d b90b725 5b11294 b90b725 c6b075f 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d 5b11294 7179b2d c34dda4 7179b2d c34dda4 5b11294 c34dda4 5b11294 c34dda4 5b11294 7179b2d c34dda4 5b11294 c34dda4 5b11294 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import io
import sys
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as pth_transforms
from safetensors.torch import load_file
import base64
from io import BytesIO
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm # Needed for attention head colormaps
from huggingface_hub import hf_hub_download
# --- Make sure Python can find dinov2 properly ---
sys.path.append(os.path.dirname(__file__))
from dinov2.models.vision_transformer import vit_large
# --- Constants ---
CKPT_PATH = hf_hub_download(
repo_id="Arew99/dinov2-costum",
filename="model.safetensors"
)
PATCH_SIZE = 14
# We use a fixed height for consistency
INFERENCE_HEIGHT = 616
# -------------------------------------------------------
# Load model
# -------------------------------------------------------
def load_model():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Loading DINOv2 (PCA + Heads) on device: {device}")
model = vit_large(
patch_size=PATCH_SIZE,
img_size=518,
init_values=1.0,
ffn_layer="swiglufused",
block_chunks=0
)
# Load weights
state_dict = load_file(CKPT_PATH)
keys_list = list(state_dict.keys())
if keys_list and "model." in keys_list[0]:
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
return model, device
# -------------------------------------------------------
# HYBRID PREDICTION (PCA + HEADS)
# -------------------------------------------------------
def predict_from_bytes(model_device_tuple, image_bytes):
model, device = model_device_tuple
# 1. Load & Preprocess
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = original_image.size
# Resize keeping aspect ratio, but ensuring dimensions are divisible by PATCH_SIZE
scale = INFERENCE_HEIGHT / orig_h
new_w = int(orig_w * scale)
new_w = new_w - (new_w % PATCH_SIZE)
new_h = INFERENCE_HEIGHT - (INFERENCE_HEIGHT % PATCH_SIZE)
transform = pth_transforms.Compose([
pth_transforms.Resize((new_h, new_w)),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img_tensor = transform(original_image).unsqueeze(0).to(device)
# Grid dimensions
w_featmap = new_w // PATCH_SIZE
h_featmap = new_h // PATCH_SIZE
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PART A: PCA VISUALIZATION (Main Result)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with torch.no_grad():
out = model.forward_features(img_tensor)
features = out["x_norm_patchtokens"][0].cpu().numpy()
# Apply PCA (3 Components -> RGB)
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
pca = PCA(n_components=3)
pca_features = pca.fit_transform(features_scaled)
# Reshape and Normalize
pca_img = pca_features.reshape(h_featmap, w_featmap, 3)
pca_min, pca_max = pca_img.min(axis=(0, 1)), pca_img.max(axis=(0, 1))
pca_img = (pca_img - pca_min) / (pca_max - pca_min + 1e-8)
pca_uint8 = (pca_img * 255).astype(np.uint8)
# Resize to original
pca_pil = Image.fromarray(pca_uint8).resize((orig_w, orig_h), resample=Image.NEAREST)
buf = BytesIO()
pca_pil.save(buf, format="PNG")
buf.seek(0)
pca_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PART B: ATTENTION HEADS (The 16 small plots)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Get raw attention weights from the last layer
with torch.no_grad():
attentions = model.get_last_self_attention(img_tensor)
nh = attentions.shape[1] # number of heads (usually 16 for ViT-Large)
# Process attention maps
# Shape: [1, heads, tokens, tokens] -> Extract CLS token attention [1, heads, 0, 1:]
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, h_featmap, w_featmap)
# Upsample to match patch size visually
attentions = nn.functional.interpolate(
attentions.unsqueeze(0),
scale_factor=PATCH_SIZE,
mode="nearest"
)[0].cpu().numpy()
all_heads_base64 = []
for i in range(nh):
head_attn = attentions[i]
# Normalize per head for better contrast
head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8)
# Apply colormap (Viridis)
heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8)
heatmap_img = Image.fromarray(heatmap).resize((orig_w, orig_h), Image.BILINEAR)
buf = BytesIO()
heatmap_img.save(buf, format="PNG")
buf.seek(0)
head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
all_heads_base64.append(head_b64)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# RETURN BOTH
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
return {
"pca_image": pca_b64, # The Main "Rainbow" Image
"head_attention_maps": all_heads_base64 # The 16 Head Plots
} |