| import os.path |
|
|
| from models.segment_anything import build_sam_vit_h |
| from models.segment_anything.utils.transforms import ResizeLongestSide |
| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import pandas as pd |
| from configs import args |
| from save_audio_feats import data_dir |
|
|
|
|
| def preprocess(x: torch.Tensor, device='cuda') -> torch.Tensor: |
| """Normalize pixel values and pad to a square input.""" |
| |
| x = x.to(device) |
|
|
| |
| pixel_mean = torch.Tensor([113.263, 99.370, 92.492]).view(-1, 1, 1).to(device) |
| pixel_std = torch.Tensor([64.274, 61.068, 58.626]).view(-1, 1, 1).to(device) |
| img_size = 1024 |
|
|
| x = (x - pixel_mean) / pixel_std |
| |
| h, w = x.shape[-2:] |
| padh = img_size - h |
| padw = img_size - w |
| x = F.pad(x, (0, padw, 0, padh)) |
| return x |
|
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| data_dir = args.data_dir |
| metapath = os.path.join(data_dir, 'metadata.csv') |
| metadata = pd.read_csv(metapath, header=0) |
| metadata = metadata[metadata['split'].isin(['train', 'val', 'test_s', 'test_u', 'test_n'])] |
| |
|
|
| vids = metadata['uid'].apply(lambda x: x.rsplit('_', 2)[0]).unique() |
|
|
| sam_model = build_sam_vit_h(args.vision_pretrained) |
| sam_model.to(device) |
|
|
| for param in sam_model.parameters(): |
| param.requires_grad = False |
|
|
| save_dir = os.path.join(data_dir, 'image_embed') |
| os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
| torch.cuda.empty_cache() |
|
|
| for vid in vids: |
| image_embeds = [] |
|
|
| for _idx in range(10): |
| path_frame = f'{data_dir}/media/{vid}/frames/{_idx}.jpg' |
| frame = cv2.imread(path_frame) |
| if frame is None: |
| print(f"Warning: Could not read image {path_frame}") |
| continue |
|
|
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frame = ResizeLongestSide(1024).apply_image(frame) |
| frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).contiguous() |
|
|
| frame_processed = preprocess(frame_tensor, device) |
|
|
| single_image = frame_processed.unsqueeze(0) |
|
|
| with torch.no_grad(): |
| image_embed = sam_model.image_encoder(single_image) |
| image_embed = image_embed.squeeze(0).cpu() |
|
|
| image_embeds.append(image_embed) |
|
|
|
|
| torch.cuda.empty_cache() |
|
|
| if not image_embeds: |
| print(f"Error: No images loaded for video {vid}") |
| continue |
|
|
|
|
| image_embeds_stacked = torch.stack(image_embeds, dim=0) |
|
|
|
|
| torch.save(image_embeds_stacked, f'{save_dir}/{vid}.pt') |
|
|
| print(f"Processed video {vid}, features shape: {image_embeds_stacked.shape}") |
|
|
| print("Processing completed!") |