Spaces:
Sleeping
Sleeping
File size: 6,629 Bytes
12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 bd1218a 7159bc5 12a3802 7159bc5 12a3802 7159bc5 bd1218a 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 7159bc5 12a3802 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# dino_processor.py (OPTIMIZED VERSION)
import os
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms as pth_transforms
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import shutil
from datetime import datetime
import vision_transformer as vits
# --- Helper functions (with your new parameters) ---
def extract_frames(video_path, output_dir, fps=4): # OPTIMIZATION: Reduced FPS
frames_dir = os.path.join(output_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = int(video_fps / fps) if video_fps > 0 else 1
frame_paths = []
frame_count = 0
extracted_count = 0
while True:
ret, frame = cap.read()
if not ret: break
if frame_count % frame_interval == 0:
frame_filename = f"frame_{extracted_count:06d}.jpg"
frame_path = os.path.join(frames_dir, frame_filename)
cv2.imwrite(frame_path, frame)
frame_paths.append(frame_path)
extracted_count += 1
frame_count += 1
cap.release()
print(f"Extracted {len(frame_paths)} frames at {fps} FPS.")
return frame_paths
def compute_embeddings(frame_paths, model, device, batch_size=32):
transform = pth_transforms.Compose([
pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
embeddings, frame_names = [], []
for i in range(0, len(frame_paths), batch_size):
batch_paths = frame_paths[i:i + batch_size]
batch_images = []
for frame_path in batch_paths:
img = Image.open(frame_path).convert('RGB')
batch_images.append(transform(img))
frame_names.append(os.path.basename(frame_path))
batch_tensor = torch.stack(batch_images).to(device)
with torch.no_grad():
batch_embeddings = model(batch_tensor)
embeddings.append(batch_embeddings.cpu().numpy())
return np.concatenate(embeddings, axis=0), frame_names
def select_representative_frames(embeddings, frame_names, n_clusters=3, pca_dim=12): # OPTIMIZATION: Reduced clusters
n_clusters = min(n_clusters, len(frame_names))
if n_clusters == 0: return []
pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
pca_results = pca.fit_transform(embeddings)
kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
kmeans.fit(pca_results)
distances = cdist(kmeans.cluster_centers_, pca_results, 'euclidean')
selected_frames_indices = np.argmin(distances, axis=1)
selected_frames = [frame_names[i] for i in selected_frames_indices]
print(f"Selected {len(selected_frames)} representative frames.")
return selected_frames
def generate_attention_overlay(frame_path, model, device, output_dir, frame_name): # OPTIMIZATION: Renamed function
img = Image.open(frame_path).convert('RGB')
original_img = np.array(img)
original_height, original_width = img.height, img.width
transform = pth_transforms.Compose([
pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img_tensor = transform(img).unsqueeze(0)
patch_size = model.patch_embed.patch_size
w_featmap = img_tensor.shape[-2] // patch_size
h_featmap = img_tensor.shape[-1] // patch_size
with torch.no_grad():
attentions = model.get_last_selfattention(img_tensor.to(device))
nh = attentions.shape[1]
attention = attentions[0, :, 0, 1:].reshape(nh, -1)
attention = attention.reshape(nh, w_featmap, h_featmap)
attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
# Create and save ONLY the overlay
overlay_path = os.path.join(output_dir, f"{frame_name}_overlay.png")
attention_map = np.sum(attention, axis=0)
attention_map = (attention_map - np.min(attention_map)) / (np.max(attention_map) - np.min(attention_map))
attention_colored = np.uint8(255 * attention_map)
attention_colored = cv2.applyColorMap(attention_colored, cv2.COLORMAP_JET)
attention_colored = cv2.cvtColor(attention_colored, cv2.COLOR_BGR2RGB)
overlay = cv2.addWeighted(original_img, 0.5, cv2.resize(attention_colored, (original_width, original_height)), 0.5, 0)
Image.fromarray(overlay).save(overlay_path)
return overlay_path # OPTIMIZATION: Return only the overlay path
# --- Function to load the model (no changes) ---
def load_dino_model():
print("--- Loading DINO model into memory (this happens only once) ---")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
patch_size = 8
model = vits.vit_small(patch_size=patch_size, num_classes=0)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print("--- DINO model loaded successfully ---")
return model, device
# --- Main function (modified for simplified output) ---
def process_video_with_dino(video_path, model, device):
archive_dir = "dino_archive"
os.makedirs(archive_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = os.path.join(archive_dir, timestamp)
os.makedirs(output_dir, exist_ok=True)
frame_paths = extract_frames(video_path, output_dir)
if not frame_paths:
raise ValueError("No frames were extracted from the video.")
embeddings, frame_names = compute_embeddings(frame_paths, model, device)
selected_frames = select_representative_frames(embeddings, frame_names)
# OPTIMIZATION: Results is now a simple list of overlay paths
overlay_paths = []
frames_dir = os.path.join(output_dir, "frames")
for frame_name in selected_frames:
frame_path = os.path.join(frames_dir, frame_name)
frame_name_no_ext = os.path.splitext(frame_name)[0]
overlay_path = generate_attention_overlay(frame_path, model, device, output_dir, frame_name_no_ext)
overlay_paths.append(overlay_path)
shutil.rmtree(frames_dir)
return overlay_paths |