|
|
import os |
|
|
import io |
|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torchvision.transforms as pth_transforms |
|
|
from safetensors.torch import load_file |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from sklearn.decomposition import PCA |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
import matplotlib.cm as cm |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(__file__)) |
|
|
from dinov2.models.vision_transformer import vit_large |
|
|
|
|
|
|
|
|
CKPT_PATH = hf_hub_download( |
|
|
repo_id="Arew99/dinov2-costum", |
|
|
filename="model.safetensors" |
|
|
) |
|
|
|
|
|
PATCH_SIZE = 14 |
|
|
|
|
|
INFERENCE_HEIGHT = 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
print(f"Loading DINOv2 (PCA + Heads) on device: {device}") |
|
|
|
|
|
model = vit_large( |
|
|
patch_size=PATCH_SIZE, |
|
|
img_size=518, |
|
|
init_values=1.0, |
|
|
ffn_layer="swiglufused", |
|
|
block_chunks=0 |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = load_file(CKPT_PATH) |
|
|
keys_list = list(state_dict.keys()) |
|
|
if keys_list and "model." in keys_list[0]: |
|
|
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
for p in model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
model.eval() |
|
|
model.to(device) |
|
|
return model, device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_from_bytes(model_device_tuple, image_bytes): |
|
|
model, device = model_device_tuple |
|
|
|
|
|
|
|
|
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
orig_w, orig_h = original_image.size |
|
|
|
|
|
|
|
|
scale = INFERENCE_HEIGHT / orig_h |
|
|
new_w = int(orig_w * scale) |
|
|
new_w = new_w - (new_w % PATCH_SIZE) |
|
|
new_h = INFERENCE_HEIGHT - (INFERENCE_HEIGHT % PATCH_SIZE) |
|
|
|
|
|
transform = pth_transforms.Compose([ |
|
|
pth_transforms.Resize((new_h, new_w)), |
|
|
pth_transforms.ToTensor(), |
|
|
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
|
]) |
|
|
|
|
|
img_tensor = transform(original_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
w_featmap = new_w // PATCH_SIZE |
|
|
h_featmap = new_h // PATCH_SIZE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.forward_features(img_tensor) |
|
|
features = out["x_norm_patchtokens"][0].cpu().numpy() |
|
|
|
|
|
|
|
|
scaler = StandardScaler() |
|
|
features_scaled = scaler.fit_transform(features) |
|
|
pca = PCA(n_components=3) |
|
|
pca_features = pca.fit_transform(features_scaled) |
|
|
|
|
|
|
|
|
pca_img = pca_features.reshape(h_featmap, w_featmap, 3) |
|
|
pca_min, pca_max = pca_img.min(axis=(0, 1)), pca_img.max(axis=(0, 1)) |
|
|
pca_img = (pca_img - pca_min) / (pca_max - pca_min + 1e-8) |
|
|
pca_uint8 = (pca_img * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
pca_pil = Image.fromarray(pca_uint8).resize((orig_w, orig_h), resample=Image.NEAREST) |
|
|
|
|
|
buf = BytesIO() |
|
|
pca_pil.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
pca_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
attentions = model.get_last_self_attention(img_tensor) |
|
|
|
|
|
nh = attentions.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
attentions = attentions[0, :, 0, 1:].reshape(nh, -1) |
|
|
attentions = attentions.reshape(nh, h_featmap, w_featmap) |
|
|
|
|
|
|
|
|
attentions = nn.functional.interpolate( |
|
|
attentions.unsqueeze(0), |
|
|
scale_factor=PATCH_SIZE, |
|
|
mode="nearest" |
|
|
)[0].cpu().numpy() |
|
|
|
|
|
all_heads_base64 = [] |
|
|
|
|
|
for i in range(nh): |
|
|
head_attn = attentions[i] |
|
|
|
|
|
head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8) |
|
|
|
|
|
|
|
|
heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8) |
|
|
heatmap_img = Image.fromarray(heatmap).resize((orig_w, orig_h), Image.BILINEAR) |
|
|
|
|
|
buf = BytesIO() |
|
|
heatmap_img.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
all_heads_base64.append(head_b64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"pca_image": pca_b64, |
|
|
"head_attention_maps": all_heads_base64 |
|
|
} |