NEMOtools / app /model.py
AndrewKof's picture
πŸš€ Update UI with LFS for images and models
5b11294
import os
import io
import sys
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as pth_transforms
from safetensors.torch import load_file
import base64
from io import BytesIO
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm # Needed for attention head colormaps
from huggingface_hub import hf_hub_download
# --- Make sure Python can find dinov2 properly ---
sys.path.append(os.path.dirname(__file__))
from dinov2.models.vision_transformer import vit_large
# --- Constants ---
CKPT_PATH = hf_hub_download(
repo_id="Arew99/dinov2-costum",
filename="model.safetensors"
)
PATCH_SIZE = 14
# We use a fixed height for consistency
INFERENCE_HEIGHT = 616
# -------------------------------------------------------
# Load model
# -------------------------------------------------------
def load_model():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Loading DINOv2 (PCA + Heads) on device: {device}")
model = vit_large(
patch_size=PATCH_SIZE,
img_size=518,
init_values=1.0,
ffn_layer="swiglufused",
block_chunks=0
)
# Load weights
state_dict = load_file(CKPT_PATH)
keys_list = list(state_dict.keys())
if keys_list and "model." in keys_list[0]:
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
return model, device
# -------------------------------------------------------
# HYBRID PREDICTION (PCA + HEADS)
# -------------------------------------------------------
def predict_from_bytes(model_device_tuple, image_bytes):
model, device = model_device_tuple
# 1. Load & Preprocess
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = original_image.size
# Resize keeping aspect ratio, but ensuring dimensions are divisible by PATCH_SIZE
scale = INFERENCE_HEIGHT / orig_h
new_w = int(orig_w * scale)
new_w = new_w - (new_w % PATCH_SIZE)
new_h = INFERENCE_HEIGHT - (INFERENCE_HEIGHT % PATCH_SIZE)
transform = pth_transforms.Compose([
pth_transforms.Resize((new_h, new_w)),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img_tensor = transform(original_image).unsqueeze(0).to(device)
# Grid dimensions
w_featmap = new_w // PATCH_SIZE
h_featmap = new_h // PATCH_SIZE
# ─────────────────────────────────────────────────────────────
# PART A: PCA VISUALIZATION (Main Result)
# ─────────────────────────────────────────────────────────────
with torch.no_grad():
out = model.forward_features(img_tensor)
features = out["x_norm_patchtokens"][0].cpu().numpy()
# Apply PCA (3 Components -> RGB)
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
pca = PCA(n_components=3)
pca_features = pca.fit_transform(features_scaled)
# Reshape and Normalize
pca_img = pca_features.reshape(h_featmap, w_featmap, 3)
pca_min, pca_max = pca_img.min(axis=(0, 1)), pca_img.max(axis=(0, 1))
pca_img = (pca_img - pca_min) / (pca_max - pca_min + 1e-8)
pca_uint8 = (pca_img * 255).astype(np.uint8)
# Resize to original
pca_pil = Image.fromarray(pca_uint8).resize((orig_w, orig_h), resample=Image.NEAREST)
buf = BytesIO()
pca_pil.save(buf, format="PNG")
buf.seek(0)
pca_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# ─────────────────────────────────────────────────────────────
# PART B: ATTENTION HEADS (The 16 small plots)
# ─────────────────────────────────────────────────────────────
# Get raw attention weights from the last layer
with torch.no_grad():
attentions = model.get_last_self_attention(img_tensor)
nh = attentions.shape[1] # number of heads (usually 16 for ViT-Large)
# Process attention maps
# Shape: [1, heads, tokens, tokens] -> Extract CLS token attention [1, heads, 0, 1:]
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, h_featmap, w_featmap)
# Upsample to match patch size visually
attentions = nn.functional.interpolate(
attentions.unsqueeze(0),
scale_factor=PATCH_SIZE,
mode="nearest"
)[0].cpu().numpy()
all_heads_base64 = []
for i in range(nh):
head_attn = attentions[i]
# Normalize per head for better contrast
head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8)
# Apply colormap (Viridis)
heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8)
heatmap_img = Image.fromarray(heatmap).resize((orig_w, orig_h), Image.BILINEAR)
buf = BytesIO()
heatmap_img.save(buf, format="PNG")
buf.seek(0)
head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
all_heads_base64.append(head_b64)
# ─────────────────────────────────────────────────────────────
# RETURN BOTH
# ─────────────────────────────────────────────────────────────
return {
"pca_image": pca_b64, # The Main "Rainbow" Image
"head_attention_maps": all_heads_base64 # The 16 Head Plots
}