vivek9chavan commited on
Commit
7159bc5
·
verified ·
1 Parent(s): 7135735

Create dino_processor.py

Browse files
Files changed (1) hide show
  1. dino_processor.py +166 -0
dino_processor.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dino_processor.py
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms as pth_transforms
10
+ from sklearn.decomposition import PCA
11
+ from sklearn.cluster import KMeans
12
+ from scipy.spatial.distance import cdist
13
+ import matplotlib.pyplot as plt
14
+ import shutil # For cleaning up temporary directories
15
+
16
+ # This will import the ViT model definitions from the other file
17
+ import vision_transformer as vits
18
+
19
+ # --- Helper functions from your script (no changes needed) ---
20
+ # (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
21
+ # I will copy them here for completeness, but you can just leave them as they are.
22
+
23
+ def extract_frames(video_path, output_dir, fps=10):
24
+ frames_dir = os.path.join(output_dir, "frames")
25
+ os.makedirs(frames_dir, exist_ok=True)
26
+ cap = cv2.VideoCapture(video_path)
27
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
28
+ frame_interval = int(video_fps / fps)
29
+ frame_paths = []
30
+ frame_count = 0
31
+ extracted_count = 0
32
+ while True:
33
+ ret, frame = cap.read()
34
+ if not ret: break
35
+ if frame_count % frame_interval == 0:
36
+ frame_filename = f"frame_{extracted_count:06d}.jpg"
37
+ frame_path = os.path.join(frames_dir, frame_filename)
38
+ cv2.imwrite(frame_path, frame)
39
+ frame_paths.append(frame_path)
40
+ extracted_count += 1
41
+ frame_count += 1
42
+ cap.release()
43
+ print(f"Extracted {len(frame_paths)} frames.")
44
+ return frame_paths
45
+
46
+ def compute_embeddings(frame_paths, model, device, batch_size=32):
47
+ transform = pth_transforms.Compose([
48
+ pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
49
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
50
+ ])
51
+ embeddings, frame_names = [], []
52
+ for i in range(0, len(frame_paths), batch_size):
53
+ batch_paths = frame_paths[i:i + batch_size]
54
+ batch_images = []
55
+ for frame_path in batch_paths:
56
+ img = Image.open(frame_path).convert('RGB')
57
+ batch_images.append(transform(img))
58
+ frame_names.append(os.path.basename(frame_path))
59
+ batch_tensor = torch.stack(batch_images).to(device)
60
+ with torch.no_grad():
61
+ batch_embeddings = model(batch_tensor)
62
+ embeddings.append(batch_embeddings.cpu().numpy())
63
+ return np.concatenate(embeddings, axis=0), frame_names
64
+
65
+ def select_representative_frames(embeddings, frame_names, n_clusters=5, pca_dim=32):
66
+ pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
67
+ pca_results = pca.fit_transform(embeddings)
68
+ kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
69
+ kmeans.fit(pca_results)
70
+ distances = cdist(kmeans.cluster_centers_, pca_results, 'euclidean')
71
+ selected_frames = []
72
+ for i in range(n_clusters):
73
+ closest_point_idx = np.argmin(distances[i])
74
+ selected_frames.append(frame_names[closest_point_idx])
75
+ print(f"Selected frames: {selected_frames}")
76
+ return selected_frames
77
+
78
+ def generate_attention_maps(frame_path, model, device, output_dir, frame_name):
79
+ img = Image.open(frame_path).convert('RGB')
80
+ original_img = np.array(img)
81
+ original_height, original_width = img.height, img.width
82
+ transform = pth_transforms.Compose([
83
+ pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
84
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
85
+ ])
86
+ img_tensor = transform(img).unsqueeze(0)
87
+ patch_size = model.patch_embed.patch_size
88
+ w_featmap = img_tensor.shape[-2] // patch_size
89
+ h_featmap = img_tensor.shape[-1] // patch_size
90
+ with torch.no_grad():
91
+ attentions = model.get_last_selfattention(img_tensor.to(device))
92
+ nh = attentions.shape[1]
93
+ attention = attentions[0, :, 0, 1:].reshape(nh, -1)
94
+ attention = attention.reshape(nh, w_featmap, h_featmap)
95
+ attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
96
+
97
+ # Save attention map
98
+ attn_path = os.path.join(output_dir, f"{frame_name}_attn.png")
99
+ plt.imsave(attn_path, np.sum(attention, axis=0), cmap='inferno', format='png')
100
+
101
+ # Save overlay
102
+ overlay_path = os.path.join(output_dir, f"{frame_name}_overlay.png")
103
+ attention_map = np.sum(attention, axis=0)
104
+ attention_map = (attention_map - np.min(attention_map)) / (np.max(attention_map) - np.min(attention_map))
105
+ attention_colored = np.uint8(255 * attention_map)
106
+ attention_colored = cv2.applyColorMap(attention_colored, cv2.COLORMAP_JET)
107
+ attention_colored = cv2.cvtColor(attention_colored, cv2.COLOR_BGR2RGB)
108
+ overlay = cv2.addWeighted(original_img, 0.5, cv2.resize(attention_colored, (original_width, original_height)), 0.5, 0)
109
+ Image.fromarray(overlay).save(overlay_path)
110
+
111
+ return overlay_path, attn_path
112
+
113
+ # --- Main orchestrator function ---
114
+ def process_video_with_dino(video_path, output_dir="dino_output"):
115
+ """
116
+ Main function to process a video and generate DINO attention maps.
117
+
118
+ Args:
119
+ video_path (str): Path to the input video.
120
+ output_dir (str): Directory to save all intermediate and final files.
121
+
122
+ Returns:
123
+ list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
124
+ """
125
+ # Clean up previous runs and create output directory
126
+ if os.path.exists(output_dir):
127
+ shutil.rmtree(output_dir)
128
+ os.makedirs(output_dir, exist_ok=True)
129
+
130
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
131
+
132
+ # Build model (using vit_small with patch size 8 as a default)
133
+ patch_size = 8
134
+ model = vits.vit_small(patch_size=patch_size, num_classes=0)
135
+ for p in model.parameters():
136
+ p.requires_grad = False
137
+ model.eval()
138
+ model.to(device)
139
+
140
+ # Load pretrained weights from torch.hub
141
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
142
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
143
+ model.load_state_dict(state_dict, strict=True)
144
+ print("DINO weights loaded successfully from torch.hub.")
145
+
146
+ # Step 1: Extract frames
147
+ frame_paths = extract_frames(video_path, output_dir)
148
+ if not frame_paths:
149
+ raise ValueError("No frames were extracted from the video.")
150
+
151
+ # Step 2: Compute embeddings
152
+ embeddings, frame_names = compute_embeddings(frame_paths, model, device)
153
+
154
+ # Step 3: Select representative frames
155
+ selected_frames = select_representative_frames(embeddings, frame_names)
156
+
157
+ # Step 4: Generate attention maps for selected frames
158
+ results = []
159
+ frames_dir = os.path.join(output_dir, "frames")
160
+ for frame_name in selected_frames:
161
+ frame_path = os.path.join(frames_dir, frame_name)
162
+ frame_name_no_ext = os.path.splitext(frame_name)[0]
163
+ overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
164
+ results.append((overlay_path, attn_path))
165
+
166
+ return results