File size: 2,824 Bytes
f1106d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os.path

from models.segment_anything import build_sam_vit_h
from models.segment_anything.utils.transforms import ResizeLongestSide
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from configs import args
from save_audio_feats import data_dir


def preprocess(x: torch.Tensor, device='cuda') -> torch.Tensor:
    """Normalize pixel values and pad to a square input."""
    # 确保输入张量在正确的设备上
    x = x.to(device)

    # Normalize colors
    pixel_mean = torch.Tensor([113.263, 99.370, 92.492]).view(-1, 1, 1).to(device)
    pixel_std = torch.Tensor([64.274, 61.068, 58.626]).view(-1, 1, 1).to(device)
    img_size = 1024

    x = (x - pixel_mean) / pixel_std
    # Pad
    h, w = x.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    x = F.pad(x, (0, padw, 0, padh))
    return x



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_dir = args.data_dir
metapath = os.path.join(data_dir, 'metadata.csv')
metadata = pd.read_csv(metapath, header=0)
metadata = metadata[metadata['split'].isin(['train', 'val', 'test_s', 'test_u', 'test_n'])]
# metadata = metadata[metadata['split'].isin(['test_s'])]

vids = metadata['uid'].apply(lambda x: x.rsplit('_', 2)[0]).unique()

sam_model = build_sam_vit_h(args.vision_pretrained)
sam_model.to(device)

for param in sam_model.parameters():
    param.requires_grad = False

save_dir = os.path.join(data_dir, 'image_embed')
os.makedirs(save_dir, exist_ok=True)


torch.cuda.empty_cache()

for vid in vids:
    image_embeds = []

    for _idx in range(10):
        path_frame = f'{data_dir}/media/{vid}/frames/{_idx}.jpg'
        frame = cv2.imread(path_frame)
        if frame is None:
            print(f"Warning: Could not read image {path_frame}")
            continue

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = ResizeLongestSide(1024).apply_image(frame)
        frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).contiguous()  # [3, H, W]

        frame_processed = preprocess(frame_tensor, device)  # [3, 1024, 1024]

        single_image = frame_processed.unsqueeze(0)  # [1, 3, 1024, 1024]

        with torch.no_grad():
            image_embed = sam_model.image_encoder(single_image)  # [1, 256, 64, 64]
            image_embed = image_embed.squeeze(0).cpu()

        image_embeds.append(image_embed)


        torch.cuda.empty_cache()

    if not image_embeds:
        print(f"Error: No images loaded for video {vid}")
        continue


    image_embeds_stacked = torch.stack(image_embeds, dim=0)  # [T, 256, 64, 64]


    torch.save(image_embeds_stacked, f'{save_dir}/{vid}.pt')

    print(f"Processed video {vid}, features shape: {image_embeds_stacked.shape}")

print("Processing completed!")