{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "13a3ca30-1380-4789-af4a-05b3015908f2",
"metadata": {},
"outputs": [],
"source": [
"import tempfile\n",
"from io import BytesIO\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import sam3.visualization_utils as utils\n",
"import torch\n",
"import torchvision\n",
"from IPython.display import Audio, Video\n",
"\n",
"# NOTE: requires installing sam3: `pip install git+https://github.com/facebookresearch/sam3.git`\n",
"from sam3.model_builder import build_sam3_video_predictor\n",
"from torchcodec.decoders import VideoDecoder\n",
"from tqdm import trange\n",
"\n",
"from sam_audio import SAMAudio, SAMAudioProcessor"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "edd9eba3-7d57-46a0-b4f9-2534d4068568",
"metadata": {},
"outputs": [],
"source": [
"video_predictor = build_sam3_video_predictor()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "09d51ff7-ff9f-4a4b-a93c-d9d3152f5a54",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video_file = \"assets/office.mp4\"\n",
"Video(video_file, embed=True, width=640, height=360)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "c4a12176-4cc8-402d-836c-c0da6f8dc91e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"frame loading (OpenCV) [rank=0]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:00<00:00, 678.23it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:44<00:00, 8.02it/s]\n"
]
}
],
"source": [
"decoder = VideoDecoder(video_file)\n",
"height, width = decoder.metadata.height, decoder.metadata.width\n",
"\n",
"response = video_predictor.handle_request(\n",
" request={\n",
" \"type\": \"start_session\",\n",
" \"resource_path\": video_file,\n",
" }\n",
")\n",
"session_id = response[\"session_id\"]\n",
"outputs = []\n",
"for frame_index in trange(len(decoder)):\n",
" response = video_predictor.handle_request(\n",
" request={\n",
" \"type\": \"add_prompt\",\n",
" \"session_id\": session_id,\n",
" \"frame_index\": frame_index,\n",
" \"text\": \"The person on the left\",\n",
" }\n",
" )\n",
" output = response[\"outputs\"]\n",
" mask = output[\"out_binary_masks\"]\n",
" if mask.shape[0] == 0:\n",
" if frame_index > 0:\n",
" mask = outputs[-1]\n",
" else:\n",
" mask = np.zeros((1, height, width), dtype=bool)\n",
" outputs.append(mask)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "613ae7e0-eef2-4799-9d31-16dcfacf23a2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mattle/anaconda3/sam-audio/lib/python3.11/site-packages/torchvision/io/_video_deprecation_warning.py:5: UserWarning: The video decoding and encoding capabilities of torchvision are deprecated from version 0.22 and will be removed in version 0.24. We recommend that you migrate to TorchCodec, where we'll consolidate the future decoding/encoding capabilities of PyTorch: https://github.com/pytorch/torchcodec\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Show the video with mask overlaid\n",
"\n",
"\n",
"def draw_masks_to_frame(\n",
" frame: np.ndarray, masks: np.ndarray, colors: np.ndarray\n",
") -> np.ndarray:\n",
" masked_frame = frame\n",
" for mask, color in zip(masks, colors, strict=False):\n",
" curr_masked_frame = np.where(mask[..., None], color, masked_frame)\n",
" masked_frame = cv2.addWeighted(masked_frame, 0.75, curr_masked_frame, 0.25, 0)\n",
" contours, _ = cv2.findContours(\n",
" np.array(mask, dtype=np.uint8).copy(),\n",
" cv2.RETR_TREE,\n",
" cv2.CHAIN_APPROX_NONE,\n",
" )\n",
" cv2.drawContours(masked_frame, contours, -1, (255, 255, 255), 1)\n",
" cv2.drawContours(masked_frame, contours, -1, (0, 0, 0), 1)\n",
" cv2.drawContours(masked_frame, contours, -1, color.tolist(), 1)\n",
" return masked_frame\n",
"\n",
"\n",
"frames = decoder[:]\n",
"mask = torch.from_numpy(np.concatenate(outputs)).unsqueeze(1)\n",
"masked_frames = frames.clone()\n",
"COLORS = utils.pascal_color_map()[1:]\n",
"for i, frame in enumerate(frames):\n",
" masked_frames[i] = torch.from_numpy(\n",
" draw_masks_to_frame(frame.permute(1, 2, 0).numpy(), mask[i], COLORS[[0]])\n",
" ).permute(2, 0, 1)\n",
"\n",
"with tempfile.NamedTemporaryFile(suffix=\".mp4\") as tfile:\n",
" bio = BytesIO()\n",
" torchvision.io.write_video(\n",
" tfile.name,\n",
" masked_frames.permute(0, 2, 3, 1),\n",
" fps=decoder.metadata.average_fps_from_header,\n",
" video_codec=\"h264\",\n",
" )\n",
" display(\n",
" Video(\n",
" tfile.name,\n",
" embed=True,\n",
" height=decoder.metadata.height,\n",
" width=decoder.metadata.width,\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "71111558-c5d3-4328-af35-7808f21fe9d5",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = SAMAudio.from_pretrained(\"facebook/sam-audio-large\").to(device).eval()\n",
"processor = SAMAudioProcessor.from_pretrained(\"facebook/sam-audio-large\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b8f98ffa-f37b-4314-96a2-bf701052a7a1",
"metadata": {},
"outputs": [],
"source": [
"inputs = processor(\n",
" audios=[video_file],\n",
" descriptions=[\"\"],\n",
" masked_videos=processor.mask_videos([frames], [mask]),\n",
").to(device)\n",
"with torch.inference_mode():\n",
" result = model.separate(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d9676cc7-d699-41ea-bd3a-b12d8882cc8d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Audio(result.target[0].cpu().float(), rate=processor.audio_sampling_rate)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sam-audio",
"language": "python",
"name": "sam-audio"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}