vivek9chavan commited on
Commit
65f6d12
·
verified ·
1 Parent(s): dc0b0f1

Update dino_processor.py

Browse files
Files changed (1) hide show
  1. dino_processor.py +8 -22
dino_processor.py CHANGED
@@ -12,7 +12,6 @@ from sklearn.cluster import KMeans
12
  from scipy.spatial.distance import cdist
13
  import matplotlib.pyplot as plt
14
  import shutil # For cleaning up temporary directories
15
- from datetime import datetime
16
 
17
  # This will import the ViT model definitions from the other file
18
  import vision_transformer as vits
@@ -21,7 +20,7 @@ import vision_transformer as vits
21
  # (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
22
  # I will copy them here for completeness, but you can just leave them as they are.
23
 
24
- def extract_frames(video_path, output_dir, fps=10):
25
  frames_dir = os.path.join(output_dir, "frames")
26
  os.makedirs(frames_dir, exist_ok=True)
27
  cap = cv2.VideoCapture(video_path)
@@ -63,7 +62,7 @@ def compute_embeddings(frame_paths, model, device, batch_size=32):
63
  embeddings.append(batch_embeddings.cpu().numpy())
64
  return np.concatenate(embeddings, axis=0), frame_names
65
 
66
- def select_representative_frames(embeddings, frame_names, n_clusters=5, pca_dim=32):
67
  pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
68
  pca_results = pca.fit_transform(embeddings)
69
  kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
@@ -112,28 +111,21 @@ def generate_attention_maps(frame_path, model, device, output_dir, frame_name):
112
  return overlay_path, attn_path
113
 
114
  # --- Main orchestrator function ---
115
- def process_video_with_dino(video_path):
116
  """
117
  Main function to process a video and generate DINO attention maps.
118
- Saves all outputs to a permanent, timestamped folder.
119
 
120
  Args:
121
  video_path (str): Path to the input video.
 
122
 
123
  Returns:
124
  list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
125
  """
126
- # --- MODIFICATION START ---
127
- # 1. Define a permanent archive directory.
128
- archive_dir = "dino_archive"
129
- os.makedirs(archive_dir, exist_ok=True)
130
-
131
- # 2. Create a unique, timestamped directory for this specific run.
132
- timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
133
- output_dir = os.path.join(archive_dir, timestamp)
134
  os.makedirs(output_dir, exist_ok=True)
135
- print(f"Results for this run will be saved in: {output_dir}")
136
- # --- MODIFICATION END ---
137
 
138
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
139
 
@@ -151,7 +143,7 @@ def process_video_with_dino(video_path):
151
  model.load_state_dict(state_dict, strict=True)
152
  print("DINO weights loaded successfully from torch.hub.")
153
 
154
- # Step 1: Extract frames (saves into the new unique output_dir)
155
  frame_paths = extract_frames(video_path, output_dir)
156
  if not frame_paths:
157
  raise ValueError("No frames were extracted from the video.")
@@ -168,13 +160,7 @@ def process_video_with_dino(video_path):
168
  for frame_name in selected_frames:
169
  frame_path = os.path.join(frames_dir, frame_name)
170
  frame_name_no_ext = os.path.splitext(frame_name)[0]
171
- # The generated images will now be saved inside the unique timestamped folder
172
  overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
173
  results.append((overlay_path, attn_path))
174
 
175
- # We no longer need the temporary frames, so we can clean them up to save space.
176
- # The final images (overlays and heatmaps) will remain.
177
- shutil.rmtree(frames_dir)
178
- print(f"Cleaned up temporary frames directory: {frames_dir}")
179
-
180
  return results
 
12
  from scipy.spatial.distance import cdist
13
  import matplotlib.pyplot as plt
14
  import shutil # For cleaning up temporary directories
 
15
 
16
  # This will import the ViT model definitions from the other file
17
  import vision_transformer as vits
 
20
  # (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
21
  # I will copy them here for completeness, but you can just leave them as they are.
22
 
23
+ def extract_frames(video_path, output_dir, fps=5):
24
  frames_dir = os.path.join(output_dir, "frames")
25
  os.makedirs(frames_dir, exist_ok=True)
26
  cap = cv2.VideoCapture(video_path)
 
62
  embeddings.append(batch_embeddings.cpu().numpy())
63
  return np.concatenate(embeddings, axis=0), frame_names
64
 
65
+ def select_representative_frames(embeddings, frame_names, n_clusters=3, pca_dim=32):
66
  pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
67
  pca_results = pca.fit_transform(embeddings)
68
  kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
 
111
  return overlay_path, attn_path
112
 
113
  # --- Main orchestrator function ---
114
+ def process_video_with_dino(video_path, output_dir="dino_output"):
115
  """
116
  Main function to process a video and generate DINO attention maps.
 
117
 
118
  Args:
119
  video_path (str): Path to the input video.
120
+ output_dir (str): Directory to save all intermediate and final files.
121
 
122
  Returns:
123
  list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
124
  """
125
+ # Clean up previous runs and create output directory
126
+ if os.path.exists(output_dir):
127
+ shutil.rmtree(output_dir)
 
 
 
 
 
128
  os.makedirs(output_dir, exist_ok=True)
 
 
129
 
130
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
131
 
 
143
  model.load_state_dict(state_dict, strict=True)
144
  print("DINO weights loaded successfully from torch.hub.")
145
 
146
+ # Step 1: Extract frames
147
  frame_paths = extract_frames(video_path, output_dir)
148
  if not frame_paths:
149
  raise ValueError("No frames were extracted from the video.")
 
160
  for frame_name in selected_frames:
161
  frame_path = os.path.join(frames_dir, frame_name)
162
  frame_name_no_ext = os.path.splitext(frame_name)[0]
 
163
  overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
164
  results.append((overlay_path, attn_path))
165
 
 
 
 
 
 
166
  return results