{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"id": "1d30c68a-460d-4038-9f14-43279bdaf8bf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"from datasets import load_dataset\n",
"from IPython.display import Audio\n",
"\n",
"from sam_audio import SAMAudio, SAMAudioProcessor"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b4db04c0-bd64-47da-9e67-cfa625930a55",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get an example audio from AudioCaps\n",
"dset = load_dataset(\n",
" \"parquet\",\n",
" data_files=\"hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet\",\n",
")\n",
"samples = dset[\"train\"][8][\"audio\"].get_all_samples()\n",
"Audio(samples.data, rate=samples.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "3bffc401-30ce-4ec2-8e82-beb1dfdc1ac1",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = SAMAudio.from_pretrained(\"facebook/sam-audio-large\").to(device).eval()\n",
"processor = SAMAudioProcessor.from_pretrained(\"facebook/sam-audio-large\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "3b5ead6b-9ab5-4e14-9f00-d055b9092f13",
"metadata": {},
"outputs": [],
"source": [
"wav = torchaudio.functional.resample(\n",
" samples.data, samples.sample_rate, processor.audio_sampling_rate\n",
")\n",
"wav = wav.mean(0, keepdim=True)\n",
"inputs = processor(\n",
" audios=[wav], descriptions=[\"A horn honking\"], anchors=[[[\"+\", 6.3, 7.0]]]\n",
").to(device)\n",
"with torch.inference_mode():\n",
" result = model.separate(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "a7977d1e-5779-475c-a2dd-ea5e07335860",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Audio(result.target[0].cpu(), rate=processor.audio_sampling_rate)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sam-audio",
"language": "python",
"name": "sam-audio"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}