Spaces:
Sleeping
Sleeping
File size: 7,099 Bytes
7159bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# dino_processor.py
import os
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms as pth_transforms
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import shutil # For cleaning up temporary directories
# This will import the ViT model definitions from the other file
import vision_transformer as vits
# --- Helper functions from your script (no changes needed) ---
# (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
# I will copy them here for completeness, but you can just leave them as they are.
def extract_frames(video_path, output_dir, fps=10):
frames_dir = os.path.join(output_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = int(video_fps / fps)
frame_paths = []
frame_count = 0
extracted_count = 0
while True:
ret, frame = cap.read()
if not ret: break
if frame_count % frame_interval == 0:
frame_filename = f"frame_{extracted_count:06d}.jpg"
frame_path = os.path.join(frames_dir, frame_filename)
cv2.imwrite(frame_path, frame)
frame_paths.append(frame_path)
extracted_count += 1
frame_count += 1
cap.release()
print(f"Extracted {len(frame_paths)} frames.")
return frame_paths
def compute_embeddings(frame_paths, model, device, batch_size=32):
transform = pth_transforms.Compose([
pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
embeddings, frame_names = [], []
for i in range(0, len(frame_paths), batch_size):
batch_paths = frame_paths[i:i + batch_size]
batch_images = []
for frame_path in batch_paths:
img = Image.open(frame_path).convert('RGB')
batch_images.append(transform(img))
frame_names.append(os.path.basename(frame_path))
batch_tensor = torch.stack(batch_images).to(device)
with torch.no_grad():
batch_embeddings = model(batch_tensor)
embeddings.append(batch_embeddings.cpu().numpy())
return np.concatenate(embeddings, axis=0), frame_names
def select_representative_frames(embeddings, frame_names, n_clusters=5, pca_dim=32):
pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
pca_results = pca.fit_transform(embeddings)
kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
kmeans.fit(pca_results)
distances = cdist(kmeans.cluster_centers_, pca_results, 'euclidean')
selected_frames = []
for i in range(n_clusters):
closest_point_idx = np.argmin(distances[i])
selected_frames.append(frame_names[closest_point_idx])
print(f"Selected frames: {selected_frames}")
return selected_frames
def generate_attention_maps(frame_path, model, device, output_dir, frame_name):
img = Image.open(frame_path).convert('RGB')
original_img = np.array(img)
original_height, original_width = img.height, img.width
transform = pth_transforms.Compose([
pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img_tensor = transform(img).unsqueeze(0)
patch_size = model.patch_embed.patch_size
w_featmap = img_tensor.shape[-2] // patch_size
h_featmap = img_tensor.shape[-1] // patch_size
with torch.no_grad():
attentions = model.get_last_selfattention(img_tensor.to(device))
nh = attentions.shape[1]
attention = attentions[0, :, 0, 1:].reshape(nh, -1)
attention = attention.reshape(nh, w_featmap, h_featmap)
attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
# Save attention map
attn_path = os.path.join(output_dir, f"{frame_name}_attn.png")
plt.imsave(attn_path, np.sum(attention, axis=0), cmap='inferno', format='png')
# Save overlay
overlay_path = os.path.join(output_dir, f"{frame_name}_overlay.png")
attention_map = np.sum(attention, axis=0)
attention_map = (attention_map - np.min(attention_map)) / (np.max(attention_map) - np.min(attention_map))
attention_colored = np.uint8(255 * attention_map)
attention_colored = cv2.applyColorMap(attention_colored, cv2.COLORMAP_JET)
attention_colored = cv2.cvtColor(attention_colored, cv2.COLOR_BGR2RGB)
overlay = cv2.addWeighted(original_img, 0.5, cv2.resize(attention_colored, (original_width, original_height)), 0.5, 0)
Image.fromarray(overlay).save(overlay_path)
return overlay_path, attn_path
# --- Main orchestrator function ---
def process_video_with_dino(video_path, output_dir="dino_output"):
"""
Main function to process a video and generate DINO attention maps.
Args:
video_path (str): Path to the input video.
output_dir (str): Directory to save all intermediate and final files.
Returns:
list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
"""
# Clean up previous runs and create output directory
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Build model (using vit_small with patch size 8 as a default)
patch_size = 8
model = vits.vit_small(patch_size=patch_size, num_classes=0)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
# Load pretrained weights from torch.hub
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print("DINO weights loaded successfully from torch.hub.")
# Step 1: Extract frames
frame_paths = extract_frames(video_path, output_dir)
if not frame_paths:
raise ValueError("No frames were extracted from the video.")
# Step 2: Compute embeddings
embeddings, frame_names = compute_embeddings(frame_paths, model, device)
# Step 3: Select representative frames
selected_frames = select_representative_frames(embeddings, frame_names)
# Step 4: Generate attention maps for selected frames
results = []
frames_dir = os.path.join(output_dir, "frames")
for frame_name in selected_frames:
frame_path = os.path.join(frames_dir, frame_name)
frame_name_no_ext = os.path.splitext(frame_name)[0]
overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
results.append((overlay_path, attn_path))
return results |