File size: 6,629 Bytes
12a3802
7159bc5
 
 
 
 
 
 
 
 
 
 
 
12a3802
 
7159bc5
 
 
12a3802
7159bc5
bd1218a
7159bc5
 
 
 
12a3802
7159bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
12a3802
7159bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1218a
12a3802
 
 
7159bc5
 
 
 
 
12a3802
 
 
7159bc5
 
12a3802
7159bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12a3802
7159bc5
 
 
 
 
 
 
 
 
12a3802
7159bc5
12a3802
 
 
7159bc5
 
 
 
 
 
 
 
 
 
12a3802
 
 
 
 
 
 
 
 
 
7159bc5
 
 
 
 
 
 
 
12a3802
 
7159bc5
 
 
 
12a3802
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# dino_processor.py (OPTIMIZED VERSION)

import os
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms as pth_transforms
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import shutil
from datetime import datetime

import vision_transformer as vits

# --- Helper functions (with your new parameters) ---

def extract_frames(video_path, output_dir, fps=4): # OPTIMIZATION: Reduced FPS
    frames_dir = os.path.join(output_dir, "frames")
    os.makedirs(frames_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(video_fps / fps) if video_fps > 0 else 1
    frame_paths = []
    frame_count = 0
    extracted_count = 0
    while True:
        ret, frame = cap.read()
        if not ret: break
        if frame_count % frame_interval == 0:
            frame_filename = f"frame_{extracted_count:06d}.jpg"
            frame_path = os.path.join(frames_dir, frame_filename)
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            extracted_count += 1
        frame_count += 1
    cap.release()
    print(f"Extracted {len(frame_paths)} frames at {fps} FPS.")
    return frame_paths

def compute_embeddings(frame_paths, model, device, batch_size=32):
    transform = pth_transforms.Compose([
        pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    embeddings, frame_names = [], []
    for i in range(0, len(frame_paths), batch_size):
        batch_paths = frame_paths[i:i + batch_size]
        batch_images = []
        for frame_path in batch_paths:
            img = Image.open(frame_path).convert('RGB')
            batch_images.append(transform(img))
            frame_names.append(os.path.basename(frame_path))
        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            batch_embeddings = model(batch_tensor)
        embeddings.append(batch_embeddings.cpu().numpy())
    return np.concatenate(embeddings, axis=0), frame_names

def select_representative_frames(embeddings, frame_names, n_clusters=3, pca_dim=12): # OPTIMIZATION: Reduced clusters
    n_clusters = min(n_clusters, len(frame_names))
    if n_clusters == 0: return []
    
    pca = PCA(n_components=pca_dim, svd_solver='full', random_state=404543)
    pca_results = pca.fit_transform(embeddings)
    kmeans = KMeans(n_clusters=n_clusters, random_state=404543, n_init=10)
    kmeans.fit(pca_results)
    distances = cdist(kmeans.cluster_centers_, pca_results, 'euclidean')
    selected_frames_indices = np.argmin(distances, axis=1)
    selected_frames = [frame_names[i] for i in selected_frames_indices]
    print(f"Selected {len(selected_frames)} representative frames.")
    return selected_frames

def generate_attention_overlay(frame_path, model, device, output_dir, frame_name): # OPTIMIZATION: Renamed function
    img = Image.open(frame_path).convert('RGB')
    original_img = np.array(img)
    original_height, original_width = img.height, img.width
    transform = pth_transforms.Compose([
        pth_transforms.Resize((224, 224)), pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    img_tensor = transform(img).unsqueeze(0)
    patch_size = model.patch_embed.patch_size
    w_featmap = img_tensor.shape[-2] // patch_size
    h_featmap = img_tensor.shape[-1] // patch_size
    with torch.no_grad():
        attentions = model.get_last_selfattention(img_tensor.to(device))
    nh = attentions.shape[1]
    attention = attentions[0, :, 0, 1:].reshape(nh, -1)
    attention = attention.reshape(nh, w_featmap, h_featmap)
    attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
    
    # Create and save ONLY the overlay
    overlay_path = os.path.join(output_dir, f"{frame_name}_overlay.png")
    attention_map = np.sum(attention, axis=0)
    attention_map = (attention_map - np.min(attention_map)) / (np.max(attention_map) - np.min(attention_map))
    attention_colored = np.uint8(255 * attention_map)
    attention_colored = cv2.applyColorMap(attention_colored, cv2.COLORMAP_JET)
    attention_colored = cv2.cvtColor(attention_colored, cv2.COLOR_BGR2RGB)
    overlay = cv2.addWeighted(original_img, 0.5, cv2.resize(attention_colored, (original_width, original_height)), 0.5, 0)
    Image.fromarray(overlay).save(overlay_path)
    
    return overlay_path # OPTIMIZATION: Return only the overlay path

# --- Function to load the model (no changes) ---
def load_dino_model():
    print("--- Loading DINO model into memory (this happens only once) ---")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    patch_size = 8
    model = vits.vit_small(patch_size=patch_size, num_classes=0)
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    model.to(device)
    url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
    state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
    model.load_state_dict(state_dict, strict=True)
    print("--- DINO model loaded successfully ---")
    return model, device

# --- Main function (modified for simplified output) ---
def process_video_with_dino(video_path, model, device):
    archive_dir = "dino_archive"
    os.makedirs(archive_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = os.path.join(archive_dir, timestamp)
    os.makedirs(output_dir, exist_ok=True)

    frame_paths = extract_frames(video_path, output_dir)
    if not frame_paths:
        raise ValueError("No frames were extracted from the video.")

    embeddings, frame_names = compute_embeddings(frame_paths, model, device)
    selected_frames = select_representative_frames(embeddings, frame_names)

    # OPTIMIZATION: Results is now a simple list of overlay paths
    overlay_paths = []
    frames_dir = os.path.join(output_dir, "frames")
    for frame_name in selected_frames:
        frame_path = os.path.join(frames_dir, frame_name)
        frame_name_no_ext = os.path.splitext(frame_name)[0]
        overlay_path = generate_attention_overlay(frame_path, model, device, output_dir, frame_name_no_ext)
        overlay_paths.append(overlay_path)
    
    shutil.rmtree(frames_dir)
    return overlay_paths