Spaces:
Sleeping
Sleeping
Update dino_processor.py
Browse files- dino_processor.py +8 -22
dino_processor.py
CHANGED
|
@@ -12,7 +12,6 @@ from sklearn.cluster import KMeans
|
|
| 12 |
from scipy.spatial.distance import cdist
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import shutil # For cleaning up temporary directories
|
| 15 |
-
from datetime import datetime
|
| 16 |
|
| 17 |
# This will import the ViT model definitions from the other file
|
| 18 |
import vision_transformer as vits
|
|
@@ -21,7 +20,7 @@ import vision_transformer as vits
|
|
| 21 |
# (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
|
| 22 |
# I will copy them here for completeness, but you can just leave them as they are.
|
| 23 |
|
| 24 |
-
def extract_frames(video_path, output_dir, fps=
|
| 25 |
frames_dir = os.path.join(output_dir, "frames")
|
| 26 |
os.makedirs(frames_dir, exist_ok=True)
|
| 27 |
cap = cv2.VideoCapture(video_path)
|
|
@@ -63,7 +62,7 @@ def compute_embeddings(frame_paths, model, device, batch_size=32):
|
|
| 63 |
embeddings.append(batch_embeddings.cpu().numpy())
|
| 64 |
return np.concatenate(embeddings, axis=0), frame_names
|
| 65 |
|
| 66 |
-
def select_representative_frames(embeddings, frame_names, n_clusters=
|
| 67 |
pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
|
| 68 |
pca_results = pca.fit_transform(embeddings)
|
| 69 |
kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
|
|
@@ -112,28 +111,21 @@ def generate_attention_maps(frame_path, model, device, output_dir, frame_name):
|
|
| 112 |
return overlay_path, attn_path
|
| 113 |
|
| 114 |
# --- Main orchestrator function ---
|
| 115 |
-
def process_video_with_dino(video_path):
|
| 116 |
"""
|
| 117 |
Main function to process a video and generate DINO attention maps.
|
| 118 |
-
Saves all outputs to a permanent, timestamped folder.
|
| 119 |
|
| 120 |
Args:
|
| 121 |
video_path (str): Path to the input video.
|
|
|
|
| 122 |
|
| 123 |
Returns:
|
| 124 |
list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
|
| 125 |
"""
|
| 126 |
-
#
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
os.makedirs(archive_dir, exist_ok=True)
|
| 130 |
-
|
| 131 |
-
# 2. Create a unique, timestamped directory for this specific run.
|
| 132 |
-
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 133 |
-
output_dir = os.path.join(archive_dir, timestamp)
|
| 134 |
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
-
print(f"Results for this run will be saved in: {output_dir}")
|
| 136 |
-
# --- MODIFICATION END ---
|
| 137 |
|
| 138 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 139 |
|
|
@@ -151,7 +143,7 @@ def process_video_with_dino(video_path):
|
|
| 151 |
model.load_state_dict(state_dict, strict=True)
|
| 152 |
print("DINO weights loaded successfully from torch.hub.")
|
| 153 |
|
| 154 |
-
# Step 1: Extract frames
|
| 155 |
frame_paths = extract_frames(video_path, output_dir)
|
| 156 |
if not frame_paths:
|
| 157 |
raise ValueError("No frames were extracted from the video.")
|
|
@@ -168,13 +160,7 @@ def process_video_with_dino(video_path):
|
|
| 168 |
for frame_name in selected_frames:
|
| 169 |
frame_path = os.path.join(frames_dir, frame_name)
|
| 170 |
frame_name_no_ext = os.path.splitext(frame_name)[0]
|
| 171 |
-
# The generated images will now be saved inside the unique timestamped folder
|
| 172 |
overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
|
| 173 |
results.append((overlay_path, attn_path))
|
| 174 |
|
| 175 |
-
# We no longer need the temporary frames, so we can clean them up to save space.
|
| 176 |
-
# The final images (overlays and heatmaps) will remain.
|
| 177 |
-
shutil.rmtree(frames_dir)
|
| 178 |
-
print(f"Cleaned up temporary frames directory: {frames_dir}")
|
| 179 |
-
|
| 180 |
return results
|
|
|
|
| 12 |
from scipy.spatial.distance import cdist
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import shutil # For cleaning up temporary directories
|
|
|
|
| 15 |
|
| 16 |
# This will import the ViT model definitions from the other file
|
| 17 |
import vision_transformer as vits
|
|
|
|
| 20 |
# (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps)
|
| 21 |
# I will copy them here for completeness, but you can just leave them as they are.
|
| 22 |
|
| 23 |
+
def extract_frames(video_path, output_dir, fps=5):
|
| 24 |
frames_dir = os.path.join(output_dir, "frames")
|
| 25 |
os.makedirs(frames_dir, exist_ok=True)
|
| 26 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
| 62 |
embeddings.append(batch_embeddings.cpu().numpy())
|
| 63 |
return np.concatenate(embeddings, axis=0), frame_names
|
| 64 |
|
| 65 |
+
def select_representative_frames(embeddings, frame_names, n_clusters=3, pca_dim=32):
|
| 66 |
pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
|
| 67 |
pca_results = pca.fit_transform(embeddings)
|
| 68 |
kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
|
|
|
|
| 111 |
return overlay_path, attn_path
|
| 112 |
|
| 113 |
# --- Main orchestrator function ---
|
| 114 |
+
def process_video_with_dino(video_path, output_dir="dino_output"):
|
| 115 |
"""
|
| 116 |
Main function to process a video and generate DINO attention maps.
|
|
|
|
| 117 |
|
| 118 |
Args:
|
| 119 |
video_path (str): Path to the input video.
|
| 120 |
+
output_dir (str): Directory to save all intermediate and final files.
|
| 121 |
|
| 122 |
Returns:
|
| 123 |
list: A list of tuples, where each tuple contains (overlay_path, attention_map_path).
|
| 124 |
"""
|
| 125 |
+
# Clean up previous runs and create output directory
|
| 126 |
+
if os.path.exists(output_dir):
|
| 127 |
+
shutil.rmtree(output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
| 129 |
|
| 130 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 131 |
|
|
|
|
| 143 |
model.load_state_dict(state_dict, strict=True)
|
| 144 |
print("DINO weights loaded successfully from torch.hub.")
|
| 145 |
|
| 146 |
+
# Step 1: Extract frames
|
| 147 |
frame_paths = extract_frames(video_path, output_dir)
|
| 148 |
if not frame_paths:
|
| 149 |
raise ValueError("No frames were extracted from the video.")
|
|
|
|
| 160 |
for frame_name in selected_frames:
|
| 161 |
frame_path = os.path.join(frames_dir, frame_name)
|
| 162 |
frame_name_no_ext = os.path.splitext(frame_name)[0]
|
|
|
|
| 163 |
overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext)
|
| 164 |
results.append((overlay_path, attn_path))
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
return results
|