{ "cells": [ { "cell_type": "code", "execution_count": 23, "id": "1d30c68a-460d-4038-9f14-43279bdaf8bf", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchaudio\n", "from datasets import load_dataset\n", "from IPython.display import Audio\n", "\n", "from sam_audio import SAMAudio, SAMAudioProcessor" ] }, { "cell_type": "code", "execution_count": 11, "id": "b4db04c0-bd64-47da-9e67-cfa625930a55", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get an example audio from AudioCaps\n", "dset = load_dataset(\n", " \"parquet\",\n", " data_files=\"hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet\",\n", ")\n", "samples = dset[\"train\"][8][\"audio\"].get_all_samples()\n", "Audio(samples.data, rate=samples.sample_rate)" ] }, { "cell_type": "code", "execution_count": 29, "id": "3bffc401-30ce-4ec2-8e82-beb1dfdc1ac1", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = SAMAudio.from_pretrained(\"facebook/sam-audio-large\").to(device).eval()\n", "processor = SAMAudioProcessor.from_pretrained(\"facebook/sam-audio-large\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "3b5ead6b-9ab5-4e14-9f00-d055b9092f13", "metadata": {}, "outputs": [], "source": [ "wav = torchaudio.functional.resample(\n", " samples.data, samples.sample_rate, processor.audio_sampling_rate\n", ")\n", "wav = wav.mean(0, keepdim=True)\n", "inputs = processor(\n", " audios=[wav], descriptions=[\"A horn honking\"], anchors=[[[\"+\", 6.3, 7.0]]]\n", ").to(device)\n", "with torch.inference_mode():\n", " result = model.separate(inputs)" ] }, { "cell_type": "code", "execution_count": 28, "id": "a7977d1e-5779-475c-a2dd-ea5e07335860", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(result.target[0].cpu(), rate=processor.audio_sampling_rate)" ] } ], "metadata": { "kernelspec": { "display_name": "sam-audio", "language": "python", "name": "sam-audio" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }