{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting torchaudio\n",
" Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl.metadata (6.5 kB)\n",
"Requirement already satisfied: torch==2.5.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torchaudio) (2.5.1)\n",
"Requirement already satisfied: filelock in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (4.11.0)\n",
"Requirement already satisfied: networkx in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.3)\n",
"Requirement already satisfied: jinja2 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (3.1.4)\n",
"Requirement already satisfied: fsspec in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (2024.6.1)\n",
"Requirement already satisfied: setuptools in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (75.1.0)\n",
"Requirement already satisfied: sympy==1.13.1 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from torch==2.5.1->torchaudio) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch==2.5.1->torchaudio) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\asus\\anaconda3\\lib\\site-packages (from jinja2->torch==2.5.1->torchaudio) (2.1.3)\n",
"Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl (2.4 MB)\n",
" ---------------------------------------- 0.0/2.4 MB ? eta -:--:--\n",
" ---------------------------------------- 2.4/2.4 MB 11.6 MB/s eta 0:00:00\n",
"Installing collected packages: torchaudio\n",
"Successfully installed torchaudio-2.5.1\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install torchaudio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'datasets'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[3], line 5\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorchaudio\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DatasetDict, load_dataset\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprepare_dataset\u001b[39m(directory):\n\u001b[0;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'datasets'"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick here for more info. \n",
"\u001b[1;31mView Jupyter log for further details."
]
}
],
"source": [
"## The data is not pushed to repo, only model and training logs etc are uploaded\n",
"\n",
"import os\n",
"import torchaudio\n",
"from datasets import DatasetDict, load_dataset\n",
"\n",
"def prepare_dataset(directory):\n",
" data = {\"path\": [], \"label\": []}\n",
" labels = {\"fake\": 0, \"real\": 1} # Map fake to 0 and real to 1\n",
"\n",
" for label, label_id in labels.items():\n",
" folder_path = os.path.join(directory, label)\n",
" for file in os.listdir(folder_path):\n",
" if file.endswith(\".wav\"):\n",
" data[\"path\"].append(os.path.join(folder_path, file))\n",
" data[\"label\"].append(label_id)\n",
" return data\n",
"\n",
"# Prepare train, validation, and test datasets\n",
"train_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
"val_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n",
"test_data = prepare_dataset(r\"dataset\\for-norm\\for-norm\\testing\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"\n",
"train_dataset = Dataset.from_dict(train_data)\n",
"val_dataset = Dataset.from_dict(val_data)\n",
"test_dataset = Dataset.from_dict(test_data)\n",
"\n",
"dataset = DatasetDict({\"train\": train_dataset, \"validation\": val_dataset, \"test\": test_dataset})\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\configuration_utils.py:302: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n",
" warnings.warn(\n",
"Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoProcessor\n",
"from transformers import AutoModelForAudioClassification\n",
"\n",
"# Initialize processor\n",
"model_name = \"facebook/wav2vec2-base\" # Replace with your model if different\n",
"model = AutoModelForAudioClassification.from_pretrained(model_name, num_labels=2) # Adjust `num_labels` based on your dataset\n",
"processor = AutoProcessor.from_pretrained(model_name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "af2c5f31f0db43ee9975023b28e2c57c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4634 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "051c3226b30145f3adbc708cf8afd6a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4634 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "86de10f37b7441e3a9c9f0ee88bf5149",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4634 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0) \n",
"torch.int64\n"
]
}
],
"source": [
"import torch\n",
"\n",
"\n",
"def preprocess_function(batch):\n",
" audio = torchaudio.load(batch[\"path\"])[0].squeeze().numpy()\n",
" inputs = processor(\n",
" audio,\n",
" sampling_rate=16000,\n",
" padding=True,\n",
" truncation=True,\n",
" max_length=32000, \n",
" return_tensors=\"pt\"\n",
" )\n",
" batch[\"input_values\"] = inputs.input_values[0]\n",
" # Ensure labels are converted to LongTensor\n",
" batch[\"label\"] = torch.tensor(batch[\"label\"], dtype=torch.long) # Convert label to LongTensor\n",
" return batch\n",
"\n",
"processed_dataset = dataset.map(preprocess_function, remove_columns=[\"path\"], batched=False)\n",
"# Set format to torch tensors for compatibility with PyTorch\n",
"processed_dataset.set_format(type=\"torch\", columns=[\"input_values\", \"label\"])\n",
"\n",
"# Double-check the label type again\n",
"print(processed_dataset[\"train\"][0][\"label\"], type(processed_dataset[\"train\"][0][\"label\"]))\n",
"print(processed_dataset[\"train\"][0][\"label\"].dtype) # Should print torch.int64\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Map Training Labels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Labels: {0: 'Fake', 1: 'Real'}\n",
"Labels: {0: 'Fake', 1: 'Real'}\n"
]
}
],
"source": [
"# Ensure labels are in numerical format (e.g., 0, 1)\n",
"id2label = {0: \"Fake\", 1: \"Real\"} # Define the mapping based on your dataset\n",
"label2id = {v: k for k, v in id2label.items()} # Reverse mapping\n",
"\n",
"\n",
"print(\"Labels:\", id2label)\n",
"\n",
"# Update the model's configuration with labels\n",
"model.config.id2label = id2label\n",
"model.config.label2id = label2id\n",
"\n",
"print(\"Labels:\", model.config.id2label) # Verify\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorWithPadding\n",
"\n",
"# Use the processor's tokenizer for padding\n",
"data_collator = DataCollatorWithPadding(tokenizer=processor, padding=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize Training Arguments"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TrainingArguments initialized successfully!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\60165\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
" warnings.warn(\n"
]
}
],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./results\",\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" learning_rate=5e-5,\n",
" per_device_train_batch_size=8,\n",
" per_device_eval_batch_size=8,\n",
" num_train_epochs=3,\n",
" weight_decay=0.01,\n",
" logging_dir=\"./logs\",\n",
" logging_steps=10,\n",
" save_total_limit=2,\n",
" fp16=True, \n",
" push_to_hub=False,\n",
")\n",
"print(\"TrainingArguments initialized successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=processed_dataset[\"train\"],\n",
" eval_dataset=processed_dataset[\"validation\"],\n",
" tokenizer=processor, # Required for the data collator\n",
" data_collator=data_collator,\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Training"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "53044eac53174227a4eb19becbeb0bbf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1740 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 0.6607, 'grad_norm': 2.5936832427978516, 'learning_rate': 4.971264367816092e-05, 'epoch': 0.02}\n",
"{'loss': 0.4545, 'grad_norm': 5.04218864440918, 'learning_rate': 4.9425287356321845e-05, 'epoch': 0.03}\n",
"{'loss': 0.1779, 'grad_norm': 0.8874927163124084, 'learning_rate': 4.913793103448276e-05, 'epoch': 0.05}\n",
"{'loss': 0.0833, 'grad_norm': 0.40262681245803833, 'learning_rate': 4.885057471264368e-05, 'epoch': 0.07}\n",
"{'loss': 0.0948, 'grad_norm': 0.5579108595848083, 'learning_rate': 4.85632183908046e-05, 'epoch': 0.09}\n",
"{'loss': 0.0128, 'grad_norm': 0.1585635393857956, 'learning_rate': 4.827586206896552e-05, 'epoch': 0.1}\n",
"{'loss': 0.0696, 'grad_norm': 0.12149885296821594, 'learning_rate': 4.798850574712644e-05, 'epoch': 0.12}\n",
"{'loss': 0.0065, 'grad_norm': 0.09655608981847763, 'learning_rate': 4.770114942528736e-05, 'epoch': 0.14}\n",
"{'loss': 0.0052, 'grad_norm': 0.08148041367530823, 'learning_rate': 4.741379310344828e-05, 'epoch': 0.16}\n",
"{'loss': 0.0041, 'grad_norm': 0.07030971348285675, 'learning_rate': 4.7126436781609195e-05, 'epoch': 0.17}\n",
"{'loss': 0.0034, 'grad_norm': 0.05784648284316063, 'learning_rate': 4.6839080459770116e-05, 'epoch': 0.19}\n",
"{'loss': 0.0029, 'grad_norm': 0.04742524400353432, 'learning_rate': 4.655172413793104e-05, 'epoch': 0.21}\n",
"{'loss': 0.0024, 'grad_norm': 0.04222293198108673, 'learning_rate': 4.626436781609196e-05, 'epoch': 0.22}\n",
"{'loss': 0.0021, 'grad_norm': 0.039586298167705536, 'learning_rate': 4.597701149425287e-05, 'epoch': 0.24}\n",
"{'loss': 0.0018, 'grad_norm': 0.034121569246053696, 'learning_rate': 4.5689655172413794e-05, 'epoch': 0.26}\n",
"{'loss': 0.0016, 'grad_norm': 0.031423550099134445, 'learning_rate': 4.5402298850574716e-05, 'epoch': 0.28}\n",
"{'loss': 0.0418, 'grad_norm': 0.02912888117134571, 'learning_rate': 4.511494252873563e-05, 'epoch': 0.29}\n",
"{'loss': 0.0831, 'grad_norm': 0.027829233556985855, 'learning_rate': 4.482758620689655e-05, 'epoch': 0.31}\n",
"{'loss': 0.2431, 'grad_norm': 0.10419639199972153, 'learning_rate': 4.454022988505747e-05, 'epoch': 0.33}\n",
"{'loss': 0.0024, 'grad_norm': 0.046179670840501785, 'learning_rate': 4.4252873563218394e-05, 'epoch': 0.34}\n",
"{'loss': 0.0729, 'grad_norm': 0.03943876922130585, 'learning_rate': 4.396551724137931e-05, 'epoch': 0.36}\n",
"{'loss': 0.0806, 'grad_norm': 0.05223412811756134, 'learning_rate': 4.367816091954024e-05, 'epoch': 0.38}\n",
"{'loss': 0.0023, 'grad_norm': 0.041366685181856155, 'learning_rate': 4.339080459770115e-05, 'epoch': 0.4}\n",
"{'loss': 0.0019, 'grad_norm': 0.03711611405014992, 'learning_rate': 4.3103448275862066e-05, 'epoch': 0.41}\n",
"{'loss': 0.0016, 'grad_norm': 0.030888166278600693, 'learning_rate': 4.2816091954022994e-05, 'epoch': 0.43}\n",
"{'loss': 0.0015, 'grad_norm': 0.027964089065790176, 'learning_rate': 4.252873563218391e-05, 'epoch': 0.45}\n",
"{'loss': 0.0013, 'grad_norm': 0.025312568992376328, 'learning_rate': 4.224137931034483e-05, 'epoch': 0.47}\n",
"{'loss': 0.0012, 'grad_norm': 0.02383635751903057, 'learning_rate': 4.195402298850575e-05, 'epoch': 0.48}\n",
"{'loss': 0.0011, 'grad_norm': 0.021753674373030663, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}\n",
"{'loss': 0.001, 'grad_norm': 0.019886162132024765, 'learning_rate': 4.1379310344827587e-05, 'epoch': 0.52}\n",
"{'loss': 0.0009, 'grad_norm': 0.018876325339078903, 'learning_rate': 4.109195402298851e-05, 'epoch': 0.53}\n",
"{'loss': 0.0008, 'grad_norm': 0.017580321058630943, 'learning_rate': 4.080459770114943e-05, 'epoch': 0.55}\n",
"{'loss': 0.0008, 'grad_norm': 0.015544956550002098, 'learning_rate': 4.0517241379310344e-05, 'epoch': 0.57}\n",
"{'loss': 0.0007, 'grad_norm': 0.015221121720969677, 'learning_rate': 4.0229885057471265e-05, 'epoch': 0.59}\n",
"{'loss': 0.0007, 'grad_norm': 0.014960471540689468, 'learning_rate': 3.9942528735632186e-05, 'epoch': 0.6}\n",
"{'loss': 0.0006, 'grad_norm': 0.013560828752815723, 'learning_rate': 3.965517241379311e-05, 'epoch': 0.62}\n",
"{'loss': 0.0006, 'grad_norm': 0.013700570911169052, 'learning_rate': 3.936781609195402e-05, 'epoch': 0.64}\n",
"{'loss': 0.0006, 'grad_norm': 0.011968374252319336, 'learning_rate': 3.908045977011495e-05, 'epoch': 0.66}\n",
"{'loss': 0.0005, 'grad_norm': 0.011644795536994934, 'learning_rate': 3.8793103448275865e-05, 'epoch': 0.67}\n",
"{'loss': 0.0005, 'grad_norm': 0.011345883831381798, 'learning_rate': 3.850574712643678e-05, 'epoch': 0.69}\n",
"{'loss': 0.0005, 'grad_norm': 0.010393058881163597, 'learning_rate': 3.82183908045977e-05, 'epoch': 0.71}\n",
"{'loss': 0.0005, 'grad_norm': 0.010386484675109386, 'learning_rate': 3.793103448275862e-05, 'epoch': 0.72}\n",
"{'loss': 0.0004, 'grad_norm': 0.009744665585458279, 'learning_rate': 3.764367816091954e-05, 'epoch': 0.74}\n",
"{'loss': 0.0004, 'grad_norm': 0.009590468369424343, 'learning_rate': 3.735632183908046e-05, 'epoch': 0.76}\n",
"{'loss': 0.0004, 'grad_norm': 0.009154150262475014, 'learning_rate': 3.7068965517241385e-05, 'epoch': 0.78}\n",
"{'loss': 0.0004, 'grad_norm': 0.008997919037938118, 'learning_rate': 3.67816091954023e-05, 'epoch': 0.79}\n",
"{'loss': 0.0004, 'grad_norm': 0.008509515784680843, 'learning_rate': 3.649425287356322e-05, 'epoch': 0.81}\n",
"{'loss': 0.0004, 'grad_norm': 0.008223678916692734, 'learning_rate': 3.620689655172414e-05, 'epoch': 0.83}\n",
"{'loss': 0.0003, 'grad_norm': 0.00758435670286417, 'learning_rate': 3.591954022988506e-05, 'epoch': 0.84}\n",
"{'loss': 0.0003, 'grad_norm': 0.0074744271114468575, 'learning_rate': 3.563218390804598e-05, 'epoch': 0.86}\n",
"{'loss': 0.0003, 'grad_norm': 0.007454875390976667, 'learning_rate': 3.53448275862069e-05, 'epoch': 0.88}\n",
"{'loss': 0.0003, 'grad_norm': 0.007157924585044384, 'learning_rate': 3.505747126436782e-05, 'epoch': 0.9}\n",
"{'loss': 0.0003, 'grad_norm': 0.006946589332073927, 'learning_rate': 3.4770114942528735e-05, 'epoch': 0.91}\n",
"{'loss': 0.0003, 'grad_norm': 0.0067284563556313515, 'learning_rate': 3.4482758620689657e-05, 'epoch': 0.93}\n",
"{'loss': 0.0003, 'grad_norm': 0.00652291439473629, 'learning_rate': 3.419540229885058e-05, 'epoch': 0.95}\n",
"{'loss': 0.0003, 'grad_norm': 0.006468599662184715, 'learning_rate': 3.390804597701149e-05, 'epoch': 0.97}\n",
"{'loss': 0.0003, 'grad_norm': 0.0061700050719082355, 'learning_rate': 3.3620689655172414e-05, 'epoch': 0.98}\n",
"{'loss': 0.0002, 'grad_norm': 0.005980886053293943, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9fe6f453d3dc4644b7e7adf483e8064b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/580 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.00017967642634175718, 'eval_runtime': 566.6055, 'eval_samples_per_second': 8.179, 'eval_steps_per_second': 1.024, 'epoch': 1.0}\n",
"{'loss': 0.0002, 'grad_norm': 0.005443067755550146, 'learning_rate': 3.3045977011494256e-05, 'epoch': 1.02}\n",
"{'loss': 0.0002, 'grad_norm': 0.005919346585869789, 'learning_rate': 3.275862068965517e-05, 'epoch': 1.03}\n",
"{'loss': 0.0002, 'grad_norm': 0.00538916140794754, 'learning_rate': 3.24712643678161e-05, 'epoch': 1.05}\n",
"{'loss': 0.0002, 'grad_norm': 0.00514333276078105, 'learning_rate': 3.218390804597701e-05, 'epoch': 1.07}\n",
"{'loss': 0.0002, 'grad_norm': 0.005011783912777901, 'learning_rate': 3.1896551724137935e-05, 'epoch': 1.09}\n",
"{'loss': 0.0002, 'grad_norm': 0.005112846381962299, 'learning_rate': 3.160919540229885e-05, 'epoch': 1.1}\n",
"{'loss': 0.0002, 'grad_norm': 0.004895139951258898, 'learning_rate': 3.132183908045977e-05, 'epoch': 1.12}\n",
"{'loss': 0.0002, 'grad_norm': 0.004565018694847822, 'learning_rate': 3.103448275862069e-05, 'epoch': 1.14}\n",
"{'loss': 0.0002, 'grad_norm': 0.00477330107241869, 'learning_rate': 3.0747126436781606e-05, 'epoch': 1.16}\n",
"{'loss': 0.0002, 'grad_norm': 0.004563583992421627, 'learning_rate': 3.045977011494253e-05, 'epoch': 1.17}\n",
"{'loss': 0.0002, 'grad_norm': 0.004568100906908512, 'learning_rate': 3.017241379310345e-05, 'epoch': 1.19}\n",
"{'loss': 0.0002, 'grad_norm': 0.0046887630596756935, 'learning_rate': 2.988505747126437e-05, 'epoch': 1.21}\n",
"{'loss': 0.0002, 'grad_norm': 0.004261981230229139, 'learning_rate': 2.9597701149425288e-05, 'epoch': 1.22}\n",
"{'loss': 0.0002, 'grad_norm': 0.004290647804737091, 'learning_rate': 2.9310344827586206e-05, 'epoch': 1.24}\n",
"{'loss': 0.0002, 'grad_norm': 0.004188802093267441, 'learning_rate': 2.9022988505747127e-05, 'epoch': 1.26}\n",
"{'loss': 0.0002, 'grad_norm': 0.003917949739843607, 'learning_rate': 2.8735632183908045e-05, 'epoch': 1.28}\n",
"{'loss': 0.0002, 'grad_norm': 0.003940439783036709, 'learning_rate': 2.844827586206897e-05, 'epoch': 1.29}\n",
"{'loss': 0.0002, 'grad_norm': 0.004128696396946907, 'learning_rate': 2.8160919540229884e-05, 'epoch': 1.31}\n",
"{'loss': 0.0002, 'grad_norm': 0.004070794675499201, 'learning_rate': 2.787356321839081e-05, 'epoch': 1.33}\n",
"{'loss': 0.0002, 'grad_norm': 0.0037206218112260103, 'learning_rate': 2.7586206896551727e-05, 'epoch': 1.34}\n",
"{'loss': 0.0001, 'grad_norm': 0.00400934973731637, 'learning_rate': 2.7298850574712648e-05, 'epoch': 1.36}\n",
"{'loss': 0.0001, 'grad_norm': 0.0037558332551270723, 'learning_rate': 2.7011494252873566e-05, 'epoch': 1.38}\n",
"{'loss': 0.0001, 'grad_norm': 0.0035916264168918133, 'learning_rate': 2.672413793103448e-05, 'epoch': 1.4}\n",
"{'loss': 0.0001, 'grad_norm': 0.003707454539835453, 'learning_rate': 2.6436781609195405e-05, 'epoch': 1.41}\n",
"{'loss': 0.0001, 'grad_norm': 0.0034801277797669172, 'learning_rate': 2.6149425287356323e-05, 'epoch': 1.43}\n",
"{'loss': 0.0001, 'grad_norm': 0.003501839004456997, 'learning_rate': 2.5862068965517244e-05, 'epoch': 1.45}\n",
"{'loss': 0.0001, 'grad_norm': 0.003458078484982252, 'learning_rate': 2.5574712643678162e-05, 'epoch': 1.47}\n",
"{'loss': 0.0001, 'grad_norm': 0.0031666585709899664, 'learning_rate': 2.5287356321839083e-05, 'epoch': 1.48}\n",
"{'loss': 0.0001, 'grad_norm': 0.0033736126497387886, 'learning_rate': 2.5e-05, 'epoch': 1.5}\n",
"{'loss': 0.0001, 'grad_norm': 0.003289664164185524, 'learning_rate': 2.4712643678160922e-05, 'epoch': 1.52}\n",
"{'loss': 0.0001, 'grad_norm': 0.0031710772309452295, 'learning_rate': 2.442528735632184e-05, 'epoch': 1.53}\n",
"{'loss': 0.0001, 'grad_norm': 0.0029777924064546824, 'learning_rate': 2.413793103448276e-05, 'epoch': 1.55}\n",
"{'loss': 0.0001, 'grad_norm': 0.003144340356811881, 'learning_rate': 2.385057471264368e-05, 'epoch': 1.57}\n",
"{'loss': 0.0001, 'grad_norm': 0.0029524925630539656, 'learning_rate': 2.3563218390804597e-05, 'epoch': 1.59}\n",
"{'loss': 0.0001, 'grad_norm': 0.002912016585469246, 'learning_rate': 2.327586206896552e-05, 'epoch': 1.6}\n",
"{'loss': 0.0001, 'grad_norm': 0.002869528019800782, 'learning_rate': 2.2988505747126437e-05, 'epoch': 1.62}\n",
"{'loss': 0.0001, 'grad_norm': 0.002966447500512004, 'learning_rate': 2.2701149425287358e-05, 'epoch': 1.64}\n",
"{'loss': 0.0001, 'grad_norm': 0.0027959959115833044, 'learning_rate': 2.2413793103448276e-05, 'epoch': 1.66}\n",
"{'loss': 0.0001, 'grad_norm': 0.0030403095297515392, 'learning_rate': 2.2126436781609197e-05, 'epoch': 1.67}\n",
"{'loss': 0.0001, 'grad_norm': 0.0026587999891489744, 'learning_rate': 2.183908045977012e-05, 'epoch': 1.69}\n",
"{'loss': 0.0001, 'grad_norm': 0.002689346671104431, 'learning_rate': 2.1551724137931033e-05, 'epoch': 1.71}\n",
"{'loss': 0.0001, 'grad_norm': 0.002710141707211733, 'learning_rate': 2.1264367816091954e-05, 'epoch': 1.72}\n",
"{'loss': 0.0001, 'grad_norm': 0.002674366347491741, 'learning_rate': 2.0977011494252875e-05, 'epoch': 1.74}\n",
"{'loss': 0.0001, 'grad_norm': 0.0026578502729535103, 'learning_rate': 2.0689655172413793e-05, 'epoch': 1.76}\n",
"{'loss': 0.0001, 'grad_norm': 0.00243232655338943, 'learning_rate': 2.0402298850574715e-05, 'epoch': 1.78}\n",
"{'loss': 0.0001, 'grad_norm': 0.0025773164816200733, 'learning_rate': 2.0114942528735632e-05, 'epoch': 1.79}\n",
"{'loss': 0.0001, 'grad_norm': 0.0024439615663141012, 'learning_rate': 1.9827586206896554e-05, 'epoch': 1.81}\n",
"{'loss': 0.0001, 'grad_norm': 0.0024733347818255424, 'learning_rate': 1.9540229885057475e-05, 'epoch': 1.83}\n",
"{'loss': 0.0001, 'grad_norm': 0.002439699834212661, 'learning_rate': 1.925287356321839e-05, 'epoch': 1.84}\n",
"{'loss': 0.0001, 'grad_norm': 0.0025980097707360983, 'learning_rate': 1.896551724137931e-05, 'epoch': 1.86}\n",
"{'loss': 0.0001, 'grad_norm': 0.002387199318036437, 'learning_rate': 1.867816091954023e-05, 'epoch': 1.88}\n",
"{'loss': 0.0001, 'grad_norm': 0.0023106117732822895, 'learning_rate': 1.839080459770115e-05, 'epoch': 1.9}\n",
"{'loss': 0.0001, 'grad_norm': 0.0023344189394265413, 'learning_rate': 1.810344827586207e-05, 'epoch': 1.91}\n",
"{'loss': 0.0001, 'grad_norm': 0.0023740960750728846, 'learning_rate': 1.781609195402299e-05, 'epoch': 1.93}\n",
"{'loss': 0.0001, 'grad_norm': 0.002346088644117117, 'learning_rate': 1.752873563218391e-05, 'epoch': 1.95}\n",
"{'loss': 0.0001, 'grad_norm': 0.002391340211033821, 'learning_rate': 1.7241379310344828e-05, 'epoch': 1.97}\n",
"{'loss': 0.0001, 'grad_norm': 0.002250733319669962, 'learning_rate': 1.6954022988505746e-05, 'epoch': 1.98}\n",
"{'loss': 0.0001, 'grad_norm': 0.002164299599826336, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed9239dd5b4346509d243cbc378bcf44",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/580 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 6.0841484810225666e-05, 'eval_runtime': 581.8939, 'eval_samples_per_second': 7.964, 'eval_steps_per_second': 0.997, 'epoch': 2.0}\n",
"{'loss': 0.0001, 'grad_norm': 0.0022026619408279657, 'learning_rate': 1.6379310344827585e-05, 'epoch': 2.02}\n",
"{'loss': 0.0001, 'grad_norm': 0.002242449903860688, 'learning_rate': 1.6091954022988507e-05, 'epoch': 2.03}\n",
"{'loss': 0.0001, 'grad_norm': 0.002464097458869219, 'learning_rate': 1.5804597701149425e-05, 'epoch': 2.05}\n",
"{'loss': 0.0001, 'grad_norm': 0.0022003022022545338, 'learning_rate': 1.5517241379310346e-05, 'epoch': 2.07}\n",
"{'loss': 0.0001, 'grad_norm': 0.0021785416174679995, 'learning_rate': 1.5229885057471265e-05, 'epoch': 2.09}\n",
"{'loss': 0.0001, 'grad_norm': 0.0021638190373778343, 'learning_rate': 1.4942528735632185e-05, 'epoch': 2.1}\n",
"{'loss': 0.0001, 'grad_norm': 0.0022439502645283937, 'learning_rate': 1.4655172413793103e-05, 'epoch': 2.12}\n",
"{'loss': 0.0001, 'grad_norm': 0.0020717435982078314, 'learning_rate': 1.4367816091954022e-05, 'epoch': 2.14}\n",
"{'loss': 0.0001, 'grad_norm': 0.0020531516056507826, 'learning_rate': 1.4080459770114942e-05, 'epoch': 2.16}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019899189937859774, 'learning_rate': 1.3793103448275863e-05, 'epoch': 2.17}\n",
"{'loss': 0.0001, 'grad_norm': 0.0020303332712501287, 'learning_rate': 1.3505747126436783e-05, 'epoch': 2.19}\n",
"{'loss': 0.0001, 'grad_norm': 0.0021102086175233126, 'learning_rate': 1.3218390804597702e-05, 'epoch': 2.21}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019932142458856106, 'learning_rate': 1.2931034482758622e-05, 'epoch': 2.22}\n",
"{'loss': 0.0001, 'grad_norm': 0.002080111298710108, 'learning_rate': 1.2643678160919542e-05, 'epoch': 2.24}\n",
"{'loss': 0.0001, 'grad_norm': 0.0020179597195237875, 'learning_rate': 1.2356321839080461e-05, 'epoch': 2.26}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019549003336578608, 'learning_rate': 1.206896551724138e-05, 'epoch': 2.28}\n",
"{'loss': 0.0001, 'grad_norm': 0.0020865327678620815, 'learning_rate': 1.1781609195402299e-05, 'epoch': 2.29}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018828624160960317, 'learning_rate': 1.1494252873563218e-05, 'epoch': 2.31}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018662698566913605, 'learning_rate': 1.1206896551724138e-05, 'epoch': 2.33}\n",
"{'loss': 0.0001, 'grad_norm': 0.001857285387814045, 'learning_rate': 1.091954022988506e-05, 'epoch': 2.34}\n",
"{'loss': 0.0001, 'grad_norm': 0.001844724640250206, 'learning_rate': 1.0632183908045977e-05, 'epoch': 2.36}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017886353889480233, 'learning_rate': 1.0344827586206897e-05, 'epoch': 2.38}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019668207969516516, 'learning_rate': 1.0057471264367816e-05, 'epoch': 2.4}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018605877412483096, 'learning_rate': 9.770114942528738e-06, 'epoch': 2.41}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018027386395260692, 'learning_rate': 9.482758620689655e-06, 'epoch': 2.43}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018370413454249501, 'learning_rate': 9.195402298850575e-06, 'epoch': 2.45}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019249517936259508, 'learning_rate': 8.908045977011495e-06, 'epoch': 2.47}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019102703081443906, 'learning_rate': 8.620689655172414e-06, 'epoch': 2.48}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019130830187350512, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}\n",
"{'loss': 0.0001, 'grad_norm': 0.0019449306419119239, 'learning_rate': 8.045977011494253e-06, 'epoch': 2.52}\n",
"{'loss': 0.0001, 'grad_norm': 0.001796119031496346, 'learning_rate': 7.758620689655173e-06, 'epoch': 2.53}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017440468072891235, 'learning_rate': 7.4712643678160925e-06, 'epoch': 2.55}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017786856042221189, 'learning_rate': 7.183908045977011e-06, 'epoch': 2.57}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018597355810925364, 'learning_rate': 6.896551724137932e-06, 'epoch': 2.59}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017648187931627035, 'learning_rate': 6.609195402298851e-06, 'epoch': 2.6}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017601278377696872, 'learning_rate': 6.321839080459771e-06, 'epoch': 2.62}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017502185655757785, 'learning_rate': 6.03448275862069e-06, 'epoch': 2.64}\n",
"{'loss': 0.0001, 'grad_norm': 0.001753892400301993, 'learning_rate': 5.747126436781609e-06, 'epoch': 2.66}\n",
"{'loss': 0.0001, 'grad_norm': 0.0016946644755080342, 'learning_rate': 5.45977011494253e-06, 'epoch': 2.67}\n",
"{'loss': 0.0001, 'grad_norm': 0.001783599378541112, 'learning_rate': 5.172413793103448e-06, 'epoch': 2.69}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017759180627763271, 'learning_rate': 4.885057471264369e-06, 'epoch': 2.71}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017218819120898843, 'learning_rate': 4.5977011494252875e-06, 'epoch': 2.72}\n",
"{'loss': 0.0001, 'grad_norm': 0.0016811942914500833, 'learning_rate': 4.310344827586207e-06, 'epoch': 2.74}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017582618165761232, 'learning_rate': 4.022988505747127e-06, 'epoch': 2.76}\n",
"{'loss': 0.0001, 'grad_norm': 0.001848816522397101, 'learning_rate': 3.7356321839080462e-06, 'epoch': 2.78}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017523870337754488, 'learning_rate': 3.448275862068966e-06, 'epoch': 2.79}\n",
"{'loss': 0.0001, 'grad_norm': 0.001707645715214312, 'learning_rate': 3.1609195402298854e-06, 'epoch': 2.81}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017925987485796213, 'learning_rate': 2.8735632183908046e-06, 'epoch': 2.83}\n",
"{'loss': 0.0001, 'grad_norm': 0.001785592525266111, 'learning_rate': 2.586206896551724e-06, 'epoch': 2.84}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017369745764881372, 'learning_rate': 2.2988505747126437e-06, 'epoch': 2.86}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017363271908834577, 'learning_rate': 2.0114942528735633e-06, 'epoch': 2.88}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017762900097295642, 'learning_rate': 1.724137931034483e-06, 'epoch': 2.9}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017800360219553113, 'learning_rate': 1.4367816091954023e-06, 'epoch': 2.91}\n",
"{'loss': 0.0001, 'grad_norm': 0.0016894094878807664, 'learning_rate': 1.1494252873563219e-06, 'epoch': 2.93}\n",
"{'loss': 0.0001, 'grad_norm': 0.0016883889911696315, 'learning_rate': 8.620689655172415e-07, 'epoch': 2.95}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017332383431494236, 'learning_rate': 5.747126436781609e-07, 'epoch': 2.97}\n",
"{'loss': 0.0001, 'grad_norm': 0.0018206291133537889, 'learning_rate': 2.8735632183908047e-07, 'epoch': 2.98}\n",
"{'loss': 0.0001, 'grad_norm': 0.0017392894951626658, 'learning_rate': 0.0, 'epoch': 3.0}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de595459d7b642babd74e21ce354d064",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/580 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 4.472154250834137e-05, 'eval_runtime': 541.083, 'eval_samples_per_second': 8.564, 'eval_steps_per_second': 1.072, 'epoch': 3.0}\n",
"{'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'train_loss': 0.01232232698998762, 'epoch': 3.0}\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=1740, training_loss=0.01232232698998762, metrics={'train_runtime': 10729.5411, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'total_flos': 2.5228134820702045e+17, 'train_loss': 0.01232232698998762, 'epoch': 3.0})"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save The Model"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the trained model and processor\n",
"trainer.save_model(\"./trained_model\") # Saves the model to the specified directory\n",
"processor.save_pretrained(\"./trained_model\") # Saves the processor as well\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae19dee506db46f4b39d402e43d16276",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/580 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation Metrics:\n",
"eval_loss: 4.472154250834137e-05\n",
"eval_runtime: 527.4045\n",
"eval_samples_per_second: 8.786\n",
"eval_steps_per_second: 1.1\n",
"epoch: 3.0\n"
]
}
],
"source": [
"# Evaluate the model on the validation dataset\n",
"evaluation_metrics = trainer.evaluate()\n",
"\n",
"# Print the evaluation metrics\n",
"print(\"Evaluation Metrics:\")\n",
"for metric, value in evaluation_metrics.items():\n",
" print(f\"{metric}: {value}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoProcessor, AutoModelForAudioClassification\n",
"\n",
"# Load the trained model and processor\n",
"model_path = \"./trained_model\" # Path to your saved model\n",
"model = AutoModelForAudioClassification.from_pretrained(model_path)\n",
"processor = AutoProcessor.from_pretrained(model_path)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single Audio Testing"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"def prepare_audio(file_path, sampling_rate=16000, duration=10):\n",
" \"\"\"\n",
" Prepares audio by loading, resampling, and returning it in manageable chunks.\n",
" \n",
" Parameters:\n",
" - file_path: Path to the audio file.\n",
" - sampling_rate: Target sampling rate for the audio.\n",
" - duration: Duration in seconds for each chunk.\n",
" \n",
" Returns:\n",
" - A list of audio chunks, each as a numpy array.\n",
" \"\"\"\n",
" # Load and resample the audio file\n",
" waveform, original_sampling_rate = torchaudio.load(file_path)\n",
" \n",
" # Convert stereo to mono if necessary\n",
" if waveform.shape[0] > 1: # More than 1 channel\n",
" waveform = torch.mean(waveform, dim=0, keepdim=True)\n",
" \n",
" # Resample if needed\n",
" if original_sampling_rate != sampling_rate:\n",
" resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)\n",
" waveform = resampler(waveform)\n",
" \n",
" # Calculate chunk size in samples\n",
" chunk_size = sampling_rate * duration\n",
" audio_chunks = []\n",
"\n",
" # Split the audio into chunks\n",
" for start in range(0, waveform.shape[1], chunk_size):\n",
" chunk = waveform[:, start:start + chunk_size]\n",
" \n",
" # Pad the last chunk if it's shorter than the chunk size\n",
" if chunk.shape[1] < chunk_size:\n",
" padding = chunk_size - chunk.shape[1]\n",
" chunk = torch.nn.functional.pad(chunk, (0, padding))\n",
" \n",
" audio_chunks.append(chunk.squeeze().numpy())\n",
" \n",
" return audio_chunks\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[ 4.6742, -5.1778]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 2: tensor([[ 4.7219, -5.2332]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 3: tensor([[ 4.7545, -5.2641]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 4: tensor([[ 4.6714, -5.1740]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 5: tensor([[ 4.7660, -5.2743]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 6: tensor([[ 4.7724, -5.2836]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 7: tensor([[ 4.7268, -5.2362]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 8: tensor([[ 4.6898, -5.1898]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 9: tensor([[ 4.6646, -5.1708]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 10: tensor([[ 4.5948, -5.0867]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 11: tensor([[ 4.7512, -5.2579]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 12: tensor([[-4.5599, 5.0363]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 13: tensor([[-0.4980, 0.5546]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 14: tensor([[ 4.7295, -5.2358]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 15: tensor([[ 4.7426, -5.2534]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 16: tensor([[ 1.9405, -2.1493]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 17: tensor([[ 4.7168, -5.2235]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 18: tensor([[ 4.6801, -5.1907]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 19: tensor([[ 4.7454, -5.2568]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 20: tensor([[ 4.7642, -5.2723]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 21: tensor([[ 4.7868, -5.2969]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 22: tensor([[ 4.7600, -5.2690]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 23: tensor([[ 4.7337, -5.2411]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 24: tensor([[ 4.7835, -5.2943]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 25: tensor([[ 4.7572, -5.2647]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 26: tensor([[ 4.7485, -5.2581]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 27: tensor([[ 4.6874, -5.2023]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 28: tensor([[ 4.6877, -5.1922]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 29: tensor([[ 4.7474, -5.2561]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 30: tensor([[-4.3064, 4.7629]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 31: tensor([[-3.8067, 4.2312]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 32: tensor([[ 4.7217, -5.2325]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 33: tensor([[ 4.7798, -5.2913]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 34: tensor([[ 4.7214, -5.2355]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 35: tensor([[ 4.7116, -5.2192]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 36: tensor([[ 4.6687, -5.1812]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 37: tensor([[-0.8128, 0.9402]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 38: tensor([[ 4.7259, -5.2333]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 39: tensor([[ 4.5698, -5.0731]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 40: tensor([[ 4.7467, -5.2544]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 41: tensor([[ 4.7781, -5.2884]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 42: tensor([[ 4.7243, -5.2365]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 43: tensor([[ 3.9325, -4.3570]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 44: tensor([[-3.8786, 4.3105]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 45: tensor([[ 3.3633, -3.6958]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 46: tensor([[ 4.7127, -5.2213]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 47: tensor([[ 0.0519, -0.0359]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 48: tensor([[ 4.7457, -5.2535]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 49: tensor([[ 3.4856, -3.8528]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 50: tensor([[ 4.6485, -5.1538]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 51: tensor([[ 4.6274, -5.1355]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 52: tensor([[ 4.6852, -5.1872]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 53: tensor([[ 4.7341, -5.2452]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 54: tensor([[-4.5378, 5.0152]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 55: tensor([[ 4.6822, -5.1887]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 56: tensor([[ 4.7186, -5.2252]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 57: tensor([[ 4.7688, -5.2787]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 58: tensor([[ 4.7285, -5.2342]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 59: tensor([[ 4.7447, -5.2550]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 60: tensor([[ 4.5292, -5.0253]])\n",
"Predicted Class: Fake\n"
]
}
],
"source": [
"def predict_audio(file_path):\n",
" \"\"\"\n",
" Predicts the class of an audio file by aggregating predictions from chunks.\n",
" \n",
" Args:\n",
" file_path (str): Path to the audio file.\n",
"\n",
" Returns:\n",
" str: Predicted class label.\n",
" \"\"\"\n",
" # Prepare audio chunks\n",
" audio_chunks = prepare_audio(file_path)\n",
" predictions = []\n",
"\n",
" for i, chunk in enumerate(audio_chunks):\n",
" # Prepare input for the model\n",
" print(f\"Chunk shape: {chunk.shape}\")\n",
" inputs = processor(\n",
" chunk, sampling_rate=16000, return_tensors=\"pt\", padding=True\n",
" )\n",
" \n",
" # Perform inference\n",
" with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" logits = outputs.logits\n",
" print(f\"Logits for chunk {i + 1}: {logits}\") # Print the logits\n",
" predicted_class = torch.argmax(logits, dim=1).item()\n",
" predictions.append(predicted_class)\n",
" \n",
" # Aggregate predictions (e.g., majority voting)\n",
" aggregated_prediction = max(set(predictions), key=predictions.count)\n",
" \n",
" # Convert class ID to label\n",
" return model.config.id2label[aggregated_prediction]\n",
"\n",
"# Example: Test a single audio file\n",
"file_path = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\KAGGLE\\AUDIO\\FAKE\\biden-to-linus.wav\" # Replace with your audio file path\n",
"predicted_class = predict_audio(file_path)\n",
"print(f\"Predicted Class: {predicted_class}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Batch Testing"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-3.3933, 3.7590]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-1.5531, 1.7190]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-1.5917, 1.7620]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[ 4.7569, -5.2631]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[ 4.7569, -5.2630]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-4.5033, 4.9768]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[-4.5029, 4.9765]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
"Chunk shape: (160000,)\n",
"Logits for chunk 1: tensor([[ 4.7639, -5.2653]])\n",
"{'file': 'human voice 1 to mr beast.mp3', 'predicted_class': 'Real'}\n",
"{'file': 'human voice 1 to mr beast.wav', 'predicted_class': 'Real'}\n",
"{'file': 'human voice 1.mp3', 'predicted_class': 'Real'}\n",
"{'file': 'human voice 1.wav', 'predicted_class': 'Real'}\n",
"{'file': 'human voice 2 to Jett.mp3', 'predicted_class': 'Fake'}\n",
"{'file': 'human voice 2 to Jett.wav', 'predicted_class': 'Fake'}\n",
"{'file': 'human voice 2.mp3', 'predicted_class': 'Real'}\n",
"{'file': 'human voice 2.wav', 'predicted_class': 'Real'}\n",
"{'file': 'text to audio jett.mp3', 'predicted_class': 'Fake'}\n",
"{'file': 'text to audio jett.wav', 'predicted_class': 'Fake'}\n"
]
}
],
"source": [
"import os\n",
"\n",
"def batch_predict(test_folder, limit=10):\n",
" \"\"\"\n",
" Batch processes audio files for predictions.\n",
"\n",
" Args:\n",
" test_folder (str): Path to the folder containing audio files.\n",
" limit (int): Maximum number of files to process. Set to None for all files.\n",
"\n",
" Returns:\n",
" list: A list of dictionaries containing file names and predicted classes.\n",
" \"\"\"\n",
" results = []\n",
" files = os.listdir(test_folder)\n",
"\n",
" # Limit the number of files processed if a limit is provided\n",
" if limit is not None:\n",
" files = files[:limit]\n",
"\n",
" # Process each file in the folder\n",
" for file_name in files:\n",
" file_path = os.path.join(test_folder, file_name)\n",
" try:\n",
" predicted_class = predict_audio(file_path) # Use the predict_audio function\n",
" results.append({\"file\": file_name, \"predicted_class\": predicted_class})\n",
" except Exception as e:\n",
" print(f\"Error processing {file_name}: {e}\")\n",
" \n",
" return results\n",
"\n",
"# Specify the folder path and limit\n",
"test_folder = r\"D:\\Year 3 Sem 2\\Godamlah\\Deepfake\\deepfake model ver3\\data\\real life test audio\" # Replace with your test folder path\n",
"results = batch_predict(test_folder, limit=10)\n",
"\n",
"# Print results\n",
"for result in results:\n",
" print(result)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "479899631f95453e9f82355f7511cff3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/580 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 4.472154250834137e-05, 'eval_model_preparation_time': 0.0032, 'eval_accuracy': 1.0, 'eval_runtime': 1168.3696, 'eval_samples_per_second': 3.966, 'eval_steps_per_second': 0.496}\n"
]
}
],
"source": [
"import evaluate\n",
"\n",
"# Load the accuracy metric\n",
"accuracy_metric = evaluate.load(\"accuracy\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = logits.argmax(axis=-1) # Get the predicted class\n",
" accuracy = accuracy_metric.compute(predictions=predictions, references=labels)\n",
" return {\"accuracy\": accuracy[\"accuracy\"]}\n",
"\n",
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=processed_dataset[\"train\"],\n",
" eval_dataset=processed_dataset[\"validation\"],\n",
" tokenizer=processor, # Required for padding\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics, # Add this line\n",
")\n",
"\n",
"# Evaluate the model\n",
"metrics = trainer.evaluate()\n",
"print(metrics)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}