In [1]:
import tempfile
from io import BytesIO

import cv2
import numpy as np
import sam3.visualization_utils as utils
import torch
import torchvision
from IPython.display import Audio, Video

# NOTE: requires installing sam3: `pip install git+https://github.com/facebookresearch/sam3.git`
from sam3.model_builder import build_sam3_video_predictor
from torchcodec.decoders import VideoDecoder
from tqdm import trange

from sam_audio import SAMAudio, SAMAudioProcessor

In [19]:
video_predictor = build_sam3_video_predictor()

In [3]:
video_file = "assets/office.mp4"
Video(video_file, embed=True, width=640, height=360)

In [21]:
decoder = VideoDecoder(video_file)
height, width = decoder.metadata.height, decoder.metadata.width

response = video_predictor.handle_request(
 request={
 "type": "start_session",
 "resource_path": video_file,
 }
)
session_id = response["session_id"]
outputs = []
for frame_index in trange(len(decoder)):
 response = video_predictor.handle_request(
 request={
 "type": "add_prompt",
 "session_id": session_id,
 "frame_index": frame_index,
 "text": "The person on the left",
 }
 )
 output = response["outputs"]
 mask = output["out_binary_masks"]
 if mask.shape[0] == 0:
 if frame_index > 0:
 mask = outputs[-1]
 else:
 mask = np.zeros((1, height, width), dtype=bool)
 outputs.append(mask)

frame loading (OpenCV) [rank=0]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:00<00:00, 678.23it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:44<00:00, 8.02it/s]


In [5]:
# Show the video with mask overlaid


def draw_masks_to_frame(
 frame: np.ndarray, masks: np.ndarray, colors: np.ndarray
) -> np.ndarray:
 masked_frame = frame
 for mask, color in zip(masks, colors, strict=False):
 curr_masked_frame = np.where(mask[..., None], color, masked_frame)
 masked_frame = cv2.addWeighted(masked_frame, 0.75, curr_masked_frame, 0.25, 0)
 contours, _ = cv2.findContours(
 np.array(mask, dtype=np.uint8).copy(),
 cv2.RETR_TREE,
 cv2.CHAIN_APPROX_NONE,
 )
 cv2.drawContours(masked_frame, contours, -1, (255, 255, 255), 1)
 cv2.drawContours(masked_frame, contours, -1, (0, 0, 0), 1)
 cv2.drawContours(masked_frame, contours, -1, color.tolist(), 1)
 return masked_frame


frames = decoder[:]
mask = torch.from_numpy(np.concatenate(outputs)).unsqueeze(1)
masked_frames = frames.clone()
COLORS = utils.pascal_color_map()[1:]
for i, frame in enumerate(frames):
 masked_frames[i] = torch.from_numpy(
 draw_masks_to_frame(frame.permute(1, 2, 0).numpy(), mask[i], COLORS[[0]])
 ).permute(2, 0, 1)

with tempfile.NamedTemporaryFile(suffix=".mp4") as tfile:
 bio = BytesIO()
 torchvision.io.write_video(
 tfile.name,
 masked_frames.permute(0, 2, 3, 1),
 fps=decoder.metadata.average_fps_from_header,
 video_codec="h264",
 )
 display(
 Video(
 tfile.name,
 embed=True,
 height=decoder.metadata.height,
 width=decoder.metadata.width,
 )
 )



In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAMAudio.from_pretrained("facebook/sam-audio-large").to(device).eval()
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")

In [16]:
inputs = processor(
 audios=[video_file],
 descriptions=[""],
 masked_videos=processor.mask_videos([frames], [mask]),
).to(device)
with torch.inference_mode():
 result = model.separate(inputs)

In [17]:
Audio(result.target[0].cpu().float(), rate=processor.audio_sampling_rate)