Spaces:
Sleeping
Sleeping
| # dino_processor.py | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms as pth_transforms | |
| from sklearn.decomposition import PCA | |
| from sklearn.cluster import KMeans | |
| from scipy.spatial.distance import cdist | |
| import matplotlib.pyplot as plt | |
| import shutil # For cleaning up temporary directories | |
| # This will import the ViT model definitions from the other file | |
| import vision_transformer as vits | |
| # --- Helper functions from your script (no changes needed) --- | |
| # (extract_frames, compute_embeddings, select_representative_frames, generate_attention_maps) | |
| # I will copy them here for completeness, but you can just leave them as they are. | |
| def extract_frames(video_path, output_dir, fps=10): | |
| frames_dir = os.path.join(output_dir, "frames") | |
| os.makedirs(frames_dir, exist_ok=True) | |
| cap = cv2.VideoCapture(video_path) | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(video_fps / fps) | |
| frame_paths = [] | |
| frame_count = 0 | |
| extracted_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if frame_count % frame_interval == 0: | |
| frame_filename = f"frame_{extracted_count:06d}.jpg" | |
| frame_path = os.path.join(frames_dir, frame_filename) | |
| cv2.imwrite(frame_path, frame) | |
| frame_paths.append(frame_path) | |
| extracted_count += 1 | |
| frame_count += 1 | |
| cap.release() | |
| print(f"Extracted {len(frame_paths)} frames.") | |
| return frame_paths | |
| def compute_embeddings(frame_paths, model, device, batch_size=32): | |
| transform = pth_transforms.Compose([ | |
| pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(), | |
| pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| embeddings, frame_names = [], [] | |
| for i in range(0, len(frame_paths), batch_size): | |
| batch_paths = frame_paths[i:i + batch_size] | |
| batch_images = [] | |
| for frame_path in batch_paths: | |
| img = Image.open(frame_path).convert('RGB') | |
| batch_images.append(transform(img)) | |
| frame_names.append(os.path.basename(frame_path)) | |
| batch_tensor = torch.stack(batch_images).to(device) | |
| with torch.no_grad(): | |
| batch_embeddings = model(batch_tensor) | |
| embeddings.append(batch_embeddings.cpu().numpy()) | |
| return np.concatenate(embeddings, axis=0), frame_names | |
| def select_representative_frames(embeddings, frame_names, n_clusters=5, pca_dim=32): | |
| pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543) | |
| pca_results = pca.fit_transform(embeddings) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10) | |
| kmeans.fit(pca_results) | |
| distances = cdist(kmeans.cluster_centers_, pca_results, 'euclidean') | |
| selected_frames = [] | |
| for i in range(n_clusters): | |
| closest_point_idx = np.argmin(distances[i]) | |
| selected_frames.append(frame_names[closest_point_idx]) | |
| print(f"Selected frames: {selected_frames}") | |
| return selected_frames | |
| def generate_attention_maps(frame_path, model, device, output_dir, frame_name): | |
| img = Image.open(frame_path).convert('RGB') | |
| original_img = np.array(img) | |
| original_height, original_width = img.height, img.width | |
| transform = pth_transforms.Compose([ | |
| pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(), | |
| pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| img_tensor = transform(img).unsqueeze(0) | |
| patch_size = model.patch_embed.patch_size | |
| w_featmap = img_tensor.shape[-2] // patch_size | |
| h_featmap = img_tensor.shape[-1] // patch_size | |
| with torch.no_grad(): | |
| attentions = model.get_last_selfattention(img_tensor.to(device)) | |
| nh = attentions.shape[1] | |
| attention = attentions[0, :, 0, 1:].reshape(nh, -1) | |
| attention = attention.reshape(nh, w_featmap, h_featmap) | |
| attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy() | |
| # Save attention map | |
| attn_path = os.path.join(output_dir, f"{frame_name}_attn.png") | |
| plt.imsave(attn_path, np.sum(attention, axis=0), cmap='inferno', format='png') | |
| # Save overlay | |
| overlay_path = os.path.join(output_dir, f"{frame_name}_overlay.png") | |
| attention_map = np.sum(attention, axis=0) | |
| attention_map = (attention_map - np.min(attention_map)) / (np.max(attention_map) - np.min(attention_map)) | |
| attention_colored = np.uint8(255 * attention_map) | |
| attention_colored = cv2.applyColorMap(attention_colored, cv2.COLORMAP_JET) | |
| attention_colored = cv2.cvtColor(attention_colored, cv2.COLOR_BGR2RGB) | |
| overlay = cv2.addWeighted(original_img, 0.5, cv2.resize(attention_colored, (original_width, original_height)), 0.5, 0) | |
| Image.fromarray(overlay).save(overlay_path) | |
| return overlay_path, attn_path | |
| # --- Main orchestrator function --- | |
| def process_video_with_dino(video_path, output_dir="dino_output"): | |
| """ | |
| Main function to process a video and generate DINO attention maps. | |
| Args: | |
| video_path (str): Path to the input video. | |
| output_dir (str): Directory to save all intermediate and final files. | |
| Returns: | |
| list: A list of tuples, where each tuple contains (overlay_path, attention_map_path). | |
| """ | |
| # Clean up previous runs and create output directory | |
| if os.path.exists(output_dir): | |
| shutil.rmtree(output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # Build model (using vit_small with patch size 8 as a default) | |
| patch_size = 8 | |
| model = vits.vit_small(patch_size=patch_size, num_classes=0) | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| model.eval() | |
| model.to(device) | |
| # Load pretrained weights from torch.hub | |
| url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" | |
| state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
| model.load_state_dict(state_dict, strict=True) | |
| print("DINO weights loaded successfully from torch.hub.") | |
| # Step 1: Extract frames | |
| frame_paths = extract_frames(video_path, output_dir) | |
| if not frame_paths: | |
| raise ValueError("No frames were extracted from the video.") | |
| # Step 2: Compute embeddings | |
| embeddings, frame_names = compute_embeddings(frame_paths, model, device) | |
| # Step 3: Select representative frames | |
| selected_frames = select_representative_frames(embeddings, frame_names) | |
| # Step 4: Generate attention maps for selected frames | |
| results = [] | |
| frames_dir = os.path.join(output_dir, "frames") | |
| for frame_name in selected_frames: | |
| frame_path = os.path.join(frames_dir, frame_name) | |
| frame_name_no_ext = os.path.splitext(frame_name)[0] | |
| overlay_path, attn_path = generate_attention_maps(frame_path, model, device, output_dir, frame_name_no_ext) | |
| results.append((overlay_path, attn_path)) | |
| return results |