import os import io import sys import torch import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as pth_transforms from safetensors.torch import load_file import base64 from io import BytesIO from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler import matplotlib.cm as cm # Needed for attention head colormaps from huggingface_hub import hf_hub_download # --- Make sure Python can find dinov2 properly --- sys.path.append(os.path.dirname(__file__)) from dinov2.models.vision_transformer import vit_large # --- Constants --- CKPT_PATH = hf_hub_download( repo_id="Arew99/dinov2-costum", filename="model.safetensors" ) PATCH_SIZE = 14 # We use a fixed height for consistency INFERENCE_HEIGHT = 616 # ------------------------------------------------------- # Load model # ------------------------------------------------------- def load_model(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print(f"Loading DINOv2 (PCA + Heads) on device: {device}") model = vit_large( patch_size=PATCH_SIZE, img_size=518, init_values=1.0, ffn_layer="swiglufused", block_chunks=0 ) # Load weights state_dict = load_file(CKPT_PATH) keys_list = list(state_dict.keys()) if keys_list and "model." in keys_list[0]: state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=False) for p in model.parameters(): p.requires_grad = False model.eval() model.to(device) return model, device # ------------------------------------------------------- # HYBRID PREDICTION (PCA + HEADS) # ------------------------------------------------------- def predict_from_bytes(model_device_tuple, image_bytes): model, device = model_device_tuple # 1. Load & Preprocess original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") orig_w, orig_h = original_image.size # Resize keeping aspect ratio, but ensuring dimensions are divisible by PATCH_SIZE scale = INFERENCE_HEIGHT / orig_h new_w = int(orig_w * scale) new_w = new_w - (new_w % PATCH_SIZE) new_h = INFERENCE_HEIGHT - (INFERENCE_HEIGHT % PATCH_SIZE) transform = pth_transforms.Compose([ pth_transforms.Resize((new_h, new_w)), pth_transforms.ToTensor(), pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) img_tensor = transform(original_image).unsqueeze(0).to(device) # Grid dimensions w_featmap = new_w // PATCH_SIZE h_featmap = new_h // PATCH_SIZE # ───────────────────────────────────────────────────────────── # PART A: PCA VISUALIZATION (Main Result) # ───────────────────────────────────────────────────────────── with torch.no_grad(): out = model.forward_features(img_tensor) features = out["x_norm_patchtokens"][0].cpu().numpy() # Apply PCA (3 Components -> RGB) scaler = StandardScaler() features_scaled = scaler.fit_transform(features) pca = PCA(n_components=3) pca_features = pca.fit_transform(features_scaled) # Reshape and Normalize pca_img = pca_features.reshape(h_featmap, w_featmap, 3) pca_min, pca_max = pca_img.min(axis=(0, 1)), pca_img.max(axis=(0, 1)) pca_img = (pca_img - pca_min) / (pca_max - pca_min + 1e-8) pca_uint8 = (pca_img * 255).astype(np.uint8) # Resize to original pca_pil = Image.fromarray(pca_uint8).resize((orig_w, orig_h), resample=Image.NEAREST) buf = BytesIO() pca_pil.save(buf, format="PNG") buf.seek(0) pca_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") # ───────────────────────────────────────────────────────────── # PART B: ATTENTION HEADS (The 16 small plots) # ───────────────────────────────────────────────────────────── # Get raw attention weights from the last layer with torch.no_grad(): attentions = model.get_last_self_attention(img_tensor) nh = attentions.shape[1] # number of heads (usually 16 for ViT-Large) # Process attention maps # Shape: [1, heads, tokens, tokens] -> Extract CLS token attention [1, heads, 0, 1:] attentions = attentions[0, :, 0, 1:].reshape(nh, -1) attentions = attentions.reshape(nh, h_featmap, w_featmap) # Upsample to match patch size visually attentions = nn.functional.interpolate( attentions.unsqueeze(0), scale_factor=PATCH_SIZE, mode="nearest" )[0].cpu().numpy() all_heads_base64 = [] for i in range(nh): head_attn = attentions[i] # Normalize per head for better contrast head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8) # Apply colormap (Viridis) heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8) heatmap_img = Image.fromarray(heatmap).resize((orig_w, orig_h), Image.BILINEAR) buf = BytesIO() heatmap_img.save(buf, format="PNG") buf.seek(0) head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") all_heads_base64.append(head_b64) # ───────────────────────────────────────────────────────────── # RETURN BOTH # ───────────────────────────────────────────────────────────── return { "pca_image": pca_b64, # The Main "Rainbow" Image "head_attention_maps": all_heads_base64 # The 16 Head Plots }