duongve commited on
Commit
04040e2
·
verified ·
1 Parent(s): 59de709

Upload 4 files

Browse files
Export_to_ONNX_TFLite/Export_model_and_testmodel.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"markdown","source":["#PT to ONNX"],"metadata":{"id":"qZcPxnj6n6kl"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"mRy88TxHd7KE","collapsed":true,"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040415700,"user_tz":-420,"elapsed":1754,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"983ce468-07ee-45ac-9d66-cab24e71691c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Cloning into 'EfficientAT'...\n","remote: Enumerating objects: 395, done.\u001b[K\n","remote: Counting objects: 100% (152/152), done.\u001b[K\n","remote: Compressing objects: 100% (37/37), done.\u001b[K\n","remote: Total 395 (delta 134), reused 115 (delta 115), pack-reused 243 (from 2)\u001b[K\n","Receiving objects: 100% (395/395), 2.64 MiB | 7.12 MiB/s, done.\n","Resolving deltas: 100% (241/241), done.\n"]}],"source":["!git clone https://github.com/fschmid56/EfficientAT.git"]},{"cell_type":"code","execution_count":2,"metadata":{"id":"_SG1Qcgz4zsy","colab":{"base_uri":"https://localhost:8080/"},"outputId":"9d9a9b8c-cd15-4a32-d9cb-b10c91ce4628","executionInfo":{"status":"ok","timestamp":1747040417249,"user_tz":-420,"elapsed":59,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/EfficientAT\n"]}],"source":["%cd EfficientAT"]},{"cell_type":"code","source":["!pip install -r requirements.txt"],"metadata":{"id":"0yBzntF5sRtg","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040438423,"user_tz":-420,"elapsed":19609,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"fa2adbab-4d7a-4cb0-9cfb-9a10e45d2310"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting av==10.0.0 (from -r requirements.txt (line 1))\n"," Downloading av-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)\n","Collecting h5py==3.7.0 (from -r requirements.txt (line 2))\n"," Downloading h5py-3.7.0.tar.gz (392 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m392.4/392.4 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n","Collecting librosa==0.9.2 (from -r requirements.txt (line 3))\n"," Downloading librosa-0.9.2-py3-none-any.whl.metadata (8.2 kB)\n","Collecting numpy==1.23.3 (from -r requirements.txt (line 4))\n"," Downloading numpy-1.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)\n","Collecting scikit_learn==1.1.3 (from -r requirements.txt (line 5))\n"," Downloading scikit_learn-1.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n","Collecting torch==1.13.0 (from -r requirements.txt (line 6))\n"," Downloading torch-1.13.0-cp311-cp311-manylinux1_x86_64.whl.metadata (24 kB)\n","\u001b[31mERROR: Ignored the following yanked versions: 2.0.0\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: Ignored the following versions that require a different python version: 1.21.2 Requires-Python >=3.7,<3.11; 1.21.3 Requires-Python >=3.7,<3.11; 1.21.4 Requires-Python >=3.7,<3.11; 1.21.5 Requires-Python >=3.7,<3.11; 1.21.6 Requires-Python >=3.7,<3.11\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: Could not find a version that satisfies the requirement torchaudio==0.13.0 (from versions: 2.0.1, 2.0.2, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0)\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: No matching distribution found for torchaudio==0.13.0\u001b[0m\u001b[31m\n","\u001b[0m"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"szVfWuNA47pi"},"outputs":[],"source":["!python inference.py --model_name=/content/EfficientAT/resources/dymn04_im04_emotion.pt --audio_path=\"/content/EfficientAT/resources/bea_Angry_anger_1-28_0009.mp3\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"01JmuQGb64rW","colab":{"base_uri":"https://localhost:8080/"},"outputId":"09ccc631-5fb5-4226-c523-22fe623d4373"},"outputs":[{"output_type":"stream","name":"stdout","text":["python3: can't open file '/content/EfficientAT/inference_numpy.py': [Errno 2] No such file or directory\n"]}],"source":["!python inference_numpy.py --model_name=dymn04_as --audio_path=\"/content/EfficientAT/resources/recording_1741088005371.wav\""]},{"cell_type":"markdown","source":["#Export to onnx"],"metadata":{"id":"1r5BgfkNA5rc"}},{"cell_type":"code","execution_count":4,"metadata":{"id":"e4GFRBPU7j-b","executionInfo":{"status":"ok","timestamp":1747040464648,"user_tz":-420,"elapsed":12639,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["import torch\n","from models.dymn.model import get_model as get_dymn\n","from helpers.utils import NAME_TO_WIDTH, labels\n","\n","# model = get_dymn(width_mult=NAME_TO_WIDTH('dymn04_as'), pretrained_name='dymn04_as', strides=[2, 2, 2, 2])\n","# model.to(torch.device('cpu'))\n","model = torch.load('/content/dymn04_im04_acc_0.7408.pt', map_location='cpu', weights_only=False)"]},{"cell_type":"code","source":["!pip install onnx onnxruntime"],"metadata":{"id":"01D0XYhkPfCL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040473170,"user_tz":-420,"elapsed":8523,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"fdf1845d-6c4c-4171-f12f-d55505bff4bd"},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting onnx\n"," Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n","Collecting onnxruntime\n"," Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)\n","Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.11/dist-packages (from onnx) (2.0.2)\n","Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.11/dist-packages (from onnx) (5.29.4)\n","Collecting coloredlogs (from onnxruntime)\n"," Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n","Requirement already satisfied: flatbuffers in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (25.2.10)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (24.2)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (1.13.1)\n","Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n"," Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->onnxruntime) (1.3.0)\n","Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.0/16.0 MB\u001b[0m \u001b[31m67.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.4 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.4/16.4 MB\u001b[0m \u001b[31m70.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: onnx, humanfriendly, coloredlogs, onnxruntime\n","Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnx-1.17.0 onnxruntime-1.22.0\n"]}]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"z_m2eIxQ3mnx","outputId":"ddc7d438-9d4c-4121-fe23-1e3f3b71a60d","executionInfo":{"status":"ok","timestamp":1747040487470,"user_tz":-420,"elapsed":9851,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["/content/EfficientAT/models/dymn/model.py:188: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n"," if return_fmaps:\n","/content/EfficientAT/models/dymn/dy_block.py:173: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n"," assert x.shape[1] == self.channels\n"]}],"source":["dummy_input = torch.randn(1, 1, 128, 400, device=\"cpu\")\n","\n","# Define dynamic axes for the input\n","# Here, we specify that the first dimension (batch size) is dynamic\n","dynamic_axes = {\n"," \"input\": {0: \"batch_size\", 3: \"audio_length\"}, # Input shape: (B, C, H, W)\n","}\n","\n","# Define input and output names\n","input_names = [\"input\"]\n","output_names = [\"output\"]\n","\n","# Export the model\n","torch.onnx.export(\n"," model,\n"," dummy_input,\n"," \"emotion.onnx\",\n"," opset_version=11\n",")"]},{"cell_type":"code","execution_count":10,"metadata":{"id":"mDyu47stDSbg","colab":{"base_uri":"https://localhost:8080/"},"outputId":"2c006904-0ebf-4cbc-c3f5-c6a3cc234076","executionInfo":{"status":"ok","timestamp":1747040650463,"user_tz":-420,"elapsed":19978,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["x shape before (1, 121728)\n","x shape after (1, 121727)\n","spec_x before is (1, 381, 513, 2)\n","spec_x shape is (1, 381, 513)\n","(1, 1, 128, 400)\n","[array([[ 3.432613 , -0.5243076, -0.6463419, -0.5853195, -0.5354723,\n"," -0.5868242, -0.6281568]], dtype=float32)]\n"]}],"source":["import onnxruntime as ort\n","import numpy as np\n","import librosa\n","# from preprocess import MelSTFT\n","\n","# Load the ONNX model\n","session = ort.InferenceSession(\"emotion.onnx\")\n","\n","# Get input and output details\n","input_name = session.get_inputs()[0].name\n","output_name = session.get_outputs()[0].name\n","\n","# preprocess input\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","(waveform, _) = librosa.core.load('/content/03-01-05-01-01-02-05_0.mp3', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","if spec.shape[-1] > 400:\n"," spec = spec[:, :, :400]\n","else:\n"," spec = np.pad(spec, ((0, 0), (0, 0), (0, 400 - spec.shape[-1])), mode='constant')\n","spec = np.expand_dims(spec, axis=0)\n","spec = spec.astype(np.float32)\n","print(spec.shape)\n","\n","# Run inference\n","output = session.run([output_name], {input_name: spec})\n","\n","# Print the output\n","print(output)"]},{"cell_type":"code","execution_count":12,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1GnlQkWu5_Fc","outputId":"a13d8c8b-896a-431c-a9bf-790226dea5b5","executionInfo":{"status":"ok","timestamp":1747040692360,"user_tz":-420,"elapsed":12,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["0\n"]}],"source":["len(output[0][0])\n","print(np.argmax(output[0][0]))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Gf__SsN55_yY"},"outputs":[],"source":["import numpy as np\n","\n","def sigmoid(x):\n"," return 1 / (1 + np.exp(-x))\n","\n","def softmax(x):\n"," exp_x = np.exp(x - np.max(x)) # Subtract max(x) for numerical stability\n"," return exp_x / np.sum(exp_x, axis=-1, keepdims=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RVPMJbnc6roX"},"outputs":[],"source":["preds = softmax(output[0][0])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OfR1epuQ6xPR"},"outputs":[],"source":["sorted_indexes = np.argsort(preds)[::-1]"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gCEf2tdJ6x6Q","outputId":"1133b892-d5dd-488e-c987-fc82bcba2812"},"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Angry: 0.914\n","sad: 0.015\n","Fear: 0.015\n","neutral: 0.014\n","Disgust: 0.014\n","Happy: 0.014\n","surprise: 0.014\n","********************************************************\n"]}],"source":["# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(7):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"]},{"cell_type":"markdown","source":["#Load sang onnx fix typepo"],"metadata":{"id":"zBJk6xG7A1e4"}},{"cell_type":"code","execution_count":2,"metadata":{"id":"uHKxnQ-w5ng4","executionInfo":{"status":"ok","timestamp":1747045236944,"user_tz":-420,"elapsed":41,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["import onnx\n","\n","# Load ONNX model\n","onnx_model = onnx.load(\"/content/emotion.onnx\")"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"u1vlWtAF6Le9","outputId":"1a6c6507-7703-4732-9ac8-d4cfc487f813","executionInfo":{"status":"ok","timestamp":1747045239246,"user_tz":-420,"elapsed":129,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Model is valid.\n"]}],"source":["# Rename inputs\n","for i, input_node in enumerate(onnx_model.graph.input):\n"," old_name = input_node.name\n"," new_name = f\"input_{i}\"\n"," input_node.name = new_name\n"," for node in onnx_model.graph.node:\n"," for j, input_name in enumerate(node.input):\n"," if input_name == old_name:\n"," node.input[j] = new_name\n","onnx.checker.check_model(onnx_model)\n","print(\"Model is valid.\")"]},{"cell_type":"code","execution_count":4,"metadata":{"id":"RyYwyJq2LbOz","executionInfo":{"status":"ok","timestamp":1747045251036,"user_tz":-420,"elapsed":88,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["onnx.save(onnx_model, \"fixed_emotion.onnx\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8_7au43LMHns"},"outputs":[],"source":["import onnxruntime as ort\n","import numpy as np\n","import librosa\n","# from preprocess import MelSTFT\n","\n","# Load the ONNX model\n","session = ort.InferenceSession(\"fixed_dymn04_as_4s.onnx\")\n","\n","# Get input and output details\n","input_name = session.get_inputs()[0].name\n","output_name = session.get_outputs()[0].name\n","\n","# preprocess input\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","(waveform, _) = librosa.core.load('/content/EfficientAT/resources/machine_gun_4_8.wav', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","spec = np.expand_dims(spec, axis=0)\n","spec = spec.astype(np.float32)\n","\n","# Run inference\n","output = session.run([output_name], {input_name: spec})"]},{"cell_type":"code","source":["preds = sigmoid(output[0][0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(10):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"yI4BVCzZu043","outputId":"2e5ba910-2a6e-42b5-ed89-72c6c6fe5d50"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Gunshot, gunfire: 0.646\n","Artillery fire: 0.574\n","Machine gun: 0.296\n","Speech: 0.276\n","Music: 0.058\n","Ding: 0.039\n","Outside, rural or natural: 0.029\n","Clang: 0.021\n","Outside, urban or manmade: 0.014\n","Sound effect: 0.013\n","********************************************************\n"]}]},{"cell_type":"code","source":["import numpy as np\n","import tensorflow as tf\n","\n","# Load the TFLite model\n","interpreter = tf.lite.Interpreter(model_path=\"/content/EfficientAT/fixed_dymn04_as_4s.tflite\")\n","interpreter.allocate_tensors()\n","\n","# Get input and output details\n","input_details = interpreter.get_input_details()\n","output_details = interpreter.get_output_details()"],"metadata":{"id":"tYAA2SXg4rAm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%%time\n","interpreter.set_tensor(input_details[0]['index'], spec)\n","\n","# Run inference\n","interpreter.invoke()\n","output_data = interpreter.get_tensor(output_details[0]['index'])\n","print(output_data)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lCURHCdb50NM","outputId":"b90fd17d-6435-47cd-bce4-d504667ba3a5"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[[ -0.84756243 -6.042206 -7.31538 -6.587422 -8.137321\n"," -6.8817334 -10.2073965 -7.0277896 -9.497076 -9.076806\n"," -7.9717975 -9.098734 -9.190367 -10.533342 -9.261131\n"," -9.558402 -8.0612755 -9.170078 -9.971562 -9.306829\n"," -10.51197 -9.227371 -8.791254 -7.1896567 -6.929867\n"," -6.8730307 -8.732161 -6.8715105 -7.095218 -9.634821\n"," -7.3266606 -6.496367 -6.867931 -7.2777967 -7.84522\n"," -9.982277 -9.040248 -9.472129 -6.970452 -6.7633314\n"," -8.596198 -8.716958 -11.37014 -9.757639 -9.384382\n"," -11.648908 -9.068704 -9.676451 -9.729941 -10.329362\n"," -11.040093 -7.60864 -9.644373 -7.8703647 -10.20223\n"," -11.044352 -11.488615 -11.821376 -8.239668 -9.515796\n"," -10.249389 -11.395409 -12.753719 -9.719765 -7.815147\n"," -8.667467 -8.783001 -10.783493 -9.077807 -6.502072\n"," -6.641275 -9.622646 -5.2108536 -6.345402 -7.260418\n"," -9.74338 -9.060996 -9.3590765 -9.125456 -9.142018\n"," -8.84309 -7.5067797 -10.2783375 -7.242247 -8.52377\n"," -9.487844 -7.1616087 -8.470291 -7.9428587 -10.482567\n"," -7.625914 -8.034529 -6.576645 -8.566756 -7.922304\n"," -7.995202 -6.649238 -6.5045075 -6.472725 -8.503324\n"," -9.17288 -9.43606 -8.100371 -8.338065 -6.9929743\n"," -6.991773 -7.0988426 -6.6684513 -7.944296 -8.298891\n"," -7.6659217 -6.260497 -7.467764 -7.7779393 -12.417544\n"," -8.141451 -8.614699 -8.081804 -8.368467 -9.225835\n"," -11.658823 -11.246972 -8.921834 -10.21366 -10.768544\n"," -11.065884 -8.270028 -11.009446 -9.79741 -8.229642\n"," -9.637291 -8.985344 -8.959486 -10.116044 -9.153588\n"," -8.426435 -10.440107 -2.8482046 -5.147984 -7.8000393\n"," -7.3953137 -8.8387575 -8.483301 -8.403142 -9.870607\n"," -10.543358 -8.418191 -11.51122 -11.29439 -11.705905\n"," -9.793888 -10.239387 -7.895875 -8.369129 -9.151202\n"," -8.415587 -9.385785 -8.787096 -8.60712 -9.168413\n"," -10.053816 -6.7564883 -6.962956 -8.736223 -6.6655746\n"," -7.655622 -7.969426 -8.463739 -7.021253 -9.278132\n"," -11.818462 -8.614843 -9.003474 -8.835054 -9.851002\n"," -11.832719 -12.2729225 -9.901597 -7.088213 -6.365678\n"," -5.9258265 -6.109941 -9.413572 -10.997083 -8.5274315\n"," -8.2669525 -9.395208 -9.664686 -9.7653475 -8.847117\n"," -10.921641 -8.482299 -8.912869 -9.325035 -9.669159\n"," -8.549017 -9.760146 -10.429775 -10.085978 -10.271475\n"," -6.654699 -8.879931 -7.6676474 -6.643718 -9.249097\n"," -7.136822 -7.749438 -10.257306 -10.486742 -11.123094\n"," -11.035789 -9.469211 -12.157661 -10.533059 -8.657654\n"," -9.568329 -8.964076 -8.243627 -7.680642 -8.45268\n"," -9.282619 -9.683128 -10.821175 -9.947752 -8.673843\n"," -10.237195 -8.564869 -8.619034 -9.909412 -9.446312\n"," -10.431713 -10.449768 -8.876116 -9.539524 -10.091715\n"," -9.048174 -8.702717 -8.101336 -8.736906 -6.4872026\n"," -8.30099 -7.705913 -8.126543 -9.554791 -7.7878036\n"," -9.843629 -8.260341 -9.702665 -9.453328 -11.923137\n"," -10.03031 -8.647557 -9.144465 -8.42806 -6.997603\n"," -9.046401 -9.894837 -10.35937 -8.140752 -8.345273\n"," -9.8691225 -10.458608 -10.322532 -10.374186 -9.572269\n"," -8.6276865 -8.302613 -8.308632 -9.004797 -9.458067\n"," -7.9622707 -6.8402877 -9.689645 -8.055686 -8.758229\n"," -9.216762 -9.273148 -9.217585 -8.875041 -8.127402\n"," -8.534075 -11.847649 -10.116147 -6.0094028 -8.0422325\n"," -5.8051434 -8.198613 -7.914464 -7.742989 -8.479749\n"," -9.025269 -8.967785 -8.4304 -8.63103 -7.5055\n"," -7.8533816 -9.185931 -7.454551 -8.6782465 -7.8421903\n"," -4.6939163 -5.254039 -6.968486 -7.659274 -6.2444606\n"," -7.420964 -8.172086 -6.747325 -8.602886 -10.759201\n"," -9.839126 -10.491294 -9.961649 -10.368932 -8.411766\n"," -8.598401 -8.4753475 -9.222321 -10.77521 -10.591006\n"," -11.217412 -8.678829 -9.774086 -9.107769 -10.6897135\n"," -10.674824 -8.642229 -10.717334 -7.3090887 -6.9824114\n"," -10.106813 -9.793407 -7.4900055 -9.854651 -8.700833\n"," -8.348331 -10.050314 -9.854686 -9.308167 -8.760602\n"," -8.947077 -8.997547 -9.107168 -7.7415004 -10.546346\n"," -12.421447 -10.762346 -9.853303 -8.710389 -9.031233\n"," -11.481606 -9.634319 -8.992513 -7.9649596 -7.861982\n"," -8.120212 -8.570669 -8.901415 -7.302056 -7.901267\n"," -8.765028 -10.421144 -13.3499775 -11.8135605 -8.224922\n"," -9.03608 -11.085783 -9.946224 -9.469637 -11.543834\n"," -9.681482 -10.195737 -9.514964 -12.860728 -11.368202\n"," -14.56341 -12.711758 -11.350553 -13.18879 -9.407719\n"," -9.764156 -10.671074 -11.646937 -12.571701 -9.258234\n"," -10.133987 -9.215451 -11.699546 -9.683449 -8.102212\n"," -8.280151 -7.463168 -10.443597 -9.931871 -9.029537\n"," -10.809522 -9.242908 -11.067041 -8.609023 -9.37059\n"," -9.753563 -9.2785225 -8.860644 -9.524628 -8.6628685\n"," -10.255548 -9.580282 -8.809338 -8.740526 -9.584197\n"," -11.52479 -11.302017 -11.242475 -10.918733 -9.713407\n"," -9.895052 -10.511606 -9.031795 -9.218709 -8.112253\n"," -10.801938 -10.991635 -10.003891 -13.910624 -10.056546\n"," -11.168777 -5.049853 0.36061487 -0.94348156 -4.829131\n"," -0.08386183 -6.3951907 -4.9066935 -6.694729 -5.321249\n"," -6.964137 -5.1618257 -10.232554 -9.469683 -13.131125\n"," -8.580127 -8.248132 -6.7724633 -10.05866 -9.264765\n"," -9.219953 -8.475747 -11.707945 -9.846174 -11.320035\n"," -8.705415 -10.847191 -10.340349 -10.036801 -9.806983\n"," -10.147164 -11.190186 -9.811713 -9.941156 -7.9153924\n"," -6.1716485 -8.66741 -8.987112 -8.954214 -10.06595\n"," -8.577883 -7.141983 -8.485489 -6.889525 -7.2201576\n"," -7.8431225 -10.0173435 -10.403385 -7.2444086 -11.104176\n"," -8.219688 -10.097628 -8.695146 -12.666377 -11.033119\n"," -11.94974 -8.298492 -7.05132 -3.5644822 -4.064981\n"," -11.08213 -12.455801 -7.176717 -10.135718 -7.242347\n"," -10.261881 -9.935315 -8.438743 -8.370187 -10.179698\n"," -7.336482 -8.550757 -12.201432 -9.485798 -11.314145\n"," -7.3142076 -9.465659 -7.8540864 -10.189713 -4.3626466\n"," -12.027302 -6.617275 -5.9098096 -6.6777973 -4.2796392\n"," -3.5423725 -6.6892195 -6.3318806 -11.957001 -9.272463\n"," -9.623184 -11.435066 -9.314948 -10.678869 -9.70005\n"," -10.211389 -10.757169 -8.058382 -9.976052 -8.796942\n"," -10.181307 -9.936882 ]]\n","CPU times: user 62.3 ms, sys: 6.82 ms, total: 69.1 ms\n","Wall time: 74.6 ms\n"]}]},{"cell_type":"code","source":["output_data.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5JbesH4C6E3F","outputId":"56307e6b-e58e-4847-b91a-55a40252e188"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(1, 527)"]},"metadata":{},"execution_count":45}]},{"cell_type":"code","source":["sorted_indexes = np.argsort(preds)[::-1]"],"metadata":{"id":"OAd_yDEN6bK7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["sorted_indexes.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"RGlx7Tir6iQj","outputId":"dc0dd89b-09fd-4b99-f7f3-714e679eac88"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(527,)"]},"metadata":{},"execution_count":47}]},{"cell_type":"code","source":["preds = sigmoid(output_data[0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(10):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PjcLYfAk6Gwj","outputId":"1d42c015-3e73-4f1c-bc83-e03c41175ff7"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Gunshot, gunfire: 0.589\n","Artillery fire: 0.479\n","Speech: 0.300\n","Machine gun: 0.280\n","Music: 0.055\n","Outside, rural or natural: 0.028\n","Ding: 0.028\n","Clang: 0.017\n","Outside, urban or manmade: 0.014\n","Sound effect: 0.013\n","********************************************************\n"]}]},{"cell_type":"code","source":["!pip install coremltools"],"metadata":{"id":"FYeXiU1FLKP1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import coremltools as ct\n","\n","# Load the ONNX model\n","model_path = \"/content/EfficientAT/resources/dymn04_as.pt\"\n","\n","# Define dynamic input shape\n","inputs = [ct.TensorType(name=\"input_1\", shape=(1, 1, 128, ct.RangeDim(1, 2000)))]\n","\n","# Convert to Core ML\n","coreml_model = ct.convert(model_path, inputs=inputs, source='pytorch')\n","\n","# Save the Core ML model\n","coreml_model.save(\"dymn04_as.mlmodel\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"yKps_yXKLM90","outputId":"aa6b87fb-4d2d-467e-ba28-f65401510e67"},"execution_count":null,"outputs":[{"output_type":"error","ename":"RuntimeError","evalue":"PytorchStreamReader failed locating file constants.pkl: file not found","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-2-9e995d17096f>\u001b[0m in \u001b[0;36m<cell line: 10>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# Convert to Core ML\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mcoreml_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'pytorch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;31m# Save the Core ML model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/_converters_entry.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0mspecification_version\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_set_default_specification_version\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexact_target\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 444\u001b[0;31m mlmodel = mil_convert(\n\u001b[0m\u001b[1;32m 445\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexact_source\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36mmil_convert\u001b[0;34m(model, convert_from, convert_to, compute_units, **kwargs)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0mSee\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mcoremltools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconverters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \"\"\"\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_mil_convert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert_to\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mConverterRegistry\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMLModel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompute_units\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36m_mil_convert\u001b[0;34m(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"weights_dir\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mweights_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m proto, mil_program = mil_convert_to_proto(\n\u001b[0m\u001b[1;32m 212\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36mmil_convert_to_proto\u001b[0;34m(model, convert_from, convert_to, converter_registry, **kwargs)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[0mfrontend_converter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfrontend_converter_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 281\u001b[0;31m \u001b[0mprog\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfrontend_converter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 282\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 283\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m\"neuralnetwork\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mfrontend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/frontend/torch/load.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0monly\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \"\"\"\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0mtorchscript\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_torchscript_from_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorchscript\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'training'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mtorchscript\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/frontend/torch/load.py\u001b[0m in \u001b[0;36m_torchscript_from_model\u001b[0;34m(model_spec)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".pt\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".pth\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_os_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_torch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_torch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mScriptModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/jit/_serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, _extra_files)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mcu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCompilationUnit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpathlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0mcpp_module\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_ir_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_extra_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m cpp_module = torch._C.import_ir_module_from_buffer(\n","\u001b[0;31mRuntimeError\u001b[0m: PytorchStreamReader failed locating file constants.pkl: file not found"]}]},{"cell_type":"code","source":[],"metadata":{"id":"gSMIyJxwg4D-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#ONNX to tflite\n","Chạy trên kaggle"],"metadata":{"id":"gBGCjfV8n9aW"}},{"cell_type":"code","source":["!pip install onnx_tf tensorflow==2.14.0"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"JDsmanV6vn4n","executionInfo":{"status":"ok","timestamp":1747046031221,"user_tz":-420,"elapsed":67201,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"d1214964-c566-4766-ef7f-c44850520221"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: onnx_tf in /usr/local/lib/python3.11/dist-packages (1.10.0)\n","Collecting tensorflow==2.14.0\n"," Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\n","Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.4.0)\n","Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.6.3)\n","Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (25.2.10)\n","Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.4.0)\n","Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.2.0)\n","Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.13.0)\n","Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (18.1.1)\n","Collecting ml-dtypes==0.2.0 (from tensorflow==2.14.0)\n"," Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n","Requirement already satisfied: numpy>=1.23.5 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.23.5)\n","Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.4.0)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (24.2)\n","Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (4.25.7)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (75.2.0)\n","Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.17.0)\n","Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.1.0)\n","Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (4.13.2)\n","Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.14.1)\n","Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.37.1)\n","Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.71.0)\n","Collecting tensorboard<2.15,>=2.14 (from tensorflow==2.14.0)\n"," Downloading tensorboard-2.14.1-py3-none-any.whl.metadata (1.7 kB)\n","Collecting tensorflow-estimator<2.15,>=2.14.0 (from tensorflow==2.14.0)\n"," Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl.metadata (1.3 kB)\n","Collecting keras<2.15,>=2.14.0 (from tensorflow==2.14.0)\n"," Downloading keras-2.14.0-py3-none-any.whl.metadata (2.4 kB)\n","Requirement already satisfied: onnx>=1.10.2 in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (1.17.0)\n","Requirement already satisfied: PyYAML in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (6.0.2)\n","Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (0.23.0)\n","Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from astunparse>=1.6.0->tensorflow==2.14.0) (0.45.1)\n","Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.38.0)\n","Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (1.0.0)\n","Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.8)\n","Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.32.3)\n","Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.7.2)\n","Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.1.3)\n","Requirement already satisfied: typeguard<3.0.0,>=2.7 in /usr/local/lib/python3.11/dist-packages (from tensorflow-addons->onnx_tf) (2.13.3)\n","Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (5.5.2)\n","Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.4.2)\n","Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (4.9.1)\n","Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.0.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.4.1)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.4.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2025.4.26)\n","Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.0.2)\n","Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.11/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.6.1)\n","Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.2.2)\n","Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (489.9 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m489.9/489.9 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading keras-2.14.0-py3-none-any.whl (1.7 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tensorboard-2.14.1-py3-none-any.whl (5.5 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl (440 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m440.7/440.7 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: tensorflow-estimator, ml-dtypes, keras, tensorboard, tensorflow\n"," Attempting uninstall: tensorflow-estimator\n"," Found existing installation: tensorflow-estimator 2.12.0\n"," Uninstalling tensorflow-estimator-2.12.0:\n"," Successfully uninstalled tensorflow-estimator-2.12.0\n"," Attempting uninstall: ml-dtypes\n"," Found existing installation: ml-dtypes 0.4.1\n"," Uninstalling ml-dtypes-0.4.1:\n"," Successfully uninstalled ml-dtypes-0.4.1\n"," Attempting uninstall: keras\n"," Found existing installation: keras 2.12.0\n"," Uninstalling keras-2.12.0:\n"," Successfully uninstalled keras-2.12.0\n"," Attempting uninstall: tensorboard\n"," Found existing installation: tensorboard 2.12.3\n"," Uninstalling tensorboard-2.12.3:\n"," Successfully uninstalled tensorboard-2.12.3\n"," Attempting uninstall: tensorflow\n"," Found existing installation: tensorflow 2.12.0\n"," Uninstalling tensorflow-2.12.0:\n"," Successfully uninstalled tensorflow-2.12.0\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","tensorflow-text 2.18.1 requires tensorflow<2.19,>=2.18.0, but you have tensorflow 2.14.0 which is incompatible.\n","tf-keras 2.18.0 requires tensorflow<2.19,>=2.18, but you have tensorflow 2.14.0 which is incompatible.\n","flax 0.10.6 requires jax>=0.5.1, but you have jax 0.4.30 which is incompatible.\n","orbax-checkpoint 0.11.13 requires jax>=0.5.0, but you have jax 0.4.30 which is incompatible.\n","tensorflow-decision-forests 1.11.0 requires tensorflow==2.18.0, but you have tensorflow 2.14.0 which is incompatible.\n","chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.\n","tensorstore 0.1.74 requires ml_dtypes>=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed keras-2.14.0 ml-dtypes-0.2.0 tensorboard-2.14.1 tensorflow-2.14.0 tensorflow-estimator-2.14.0\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["keras","ml_dtypes","tensorboard","tensorflow"]},"id":"61b023af355f4283aa45d7f6298e8099"}},"metadata":{}}]},{"cell_type":"code","source":["from onnx_tf.backend import prepare\n","import onnx\n","import tensorflow as tf\n","\n","# Load ONNX model\n","onnx_model = onnx.load(\"/content/fixed_emotion.onnx\")\n","\n","# Convert to TensorFlow format\n","tf_rep = prepare(onnx_model)\n","print(tf_rep.inputs)\n","\n","# Export the TensorFlow SavedModel\n","tf_rep.export_graph(\"saved_model_4s_emotion\")\n","\n","# Convert SavedModel to TFLite\n","converter = tf.lite.TFLiteConverter.from_saved_model(\"saved_model_4s_emotion\")\n","converter.optimizations = [tf.lite.Optimize.DEFAULT]\n","tflite_model = converter.convert()\n","\n","# Save the TFLite model\n","with open(\"emotion_model_2025_03_28.tflite\", \"wb\") as f:\n"," f.write(tflite_model)"],"metadata":{"id":"DZDTl5Zqn_Fv","colab":{"base_uri":"https://localhost:8080/","height":547},"executionInfo":{"status":"error","timestamp":1747046056018,"user_tz":-420,"elapsed":11643,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"509c8a16-d670-4709-c14a-a4c15bb156ee"},"execution_count":1,"outputs":[{"output_type":"error","ename":"ImportError","evalue":"This version of TensorFlow Probability requires TensorFlow version >= 2.18; Detected an installation of version 2.14.0. Please upgrade TensorFlow to proceed.","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-1-3ba768d958af>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mprepare\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# Load ONNX model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mversion\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/backend.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_unique_suffix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msupports_device\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcommon_supports_device\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandler_helper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_all_backend_handlers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpb_wrapper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mOnnxNode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_tf_module\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendTFModule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTFModule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/common/handler_helper.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdefs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_handler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendHandler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcommon\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/handlers/backend/bernoulli.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdistributions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtfd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_handler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendHandler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m# from tensorflow_probability.google import staging # DisableOnExport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# from tensorflow_probability.google import tfp_google # DisableOnExport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpython\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;31m# pylint: disable=wildcard-import\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpython\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;31m# Non-lazy load of packages that register with tensorflow or keras.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpkg_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_maybe_nonlazy_load\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 152\u001b[0;31m \u001b[0mdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mglobals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpkg_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Forces loading the package from its lazy loader.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/internal/lazy_loader.py\u001b[0m in \u001b[0;36m__dir__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__dir__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/internal/lazy_loader.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\"\"\"Load the module and insert it into the parent's globals.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;31m# Import the target module and insert it into the parent's namespace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/__init__.py\u001b[0m in \u001b[0;36m_validate_tf_environment\u001b[0;34m(package)\u001b[0m\n\u001b[1;32m 57\u001b[0m if (distutils.version.LooseVersion(tf.__version__) <\n\u001b[1;32m 58\u001b[0m distutils.version.LooseVersion(required_tensorflow_version)):\n\u001b[0;32m---> 59\u001b[0;31m raise ImportError(\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m'This version of TensorFlow Probability requires TensorFlow '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m'version >= {required}; Detected an installation of version {present}. '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mImportError\u001b[0m: This version of TensorFlow Probability requires TensorFlow version >= 2.18; Detected an installation of version 2.14.0. Please upgrade TensorFlow to proceed.","","\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"],"errorDetails":{"actions":[{"action":"open_url","actionText":"Open Examples","url":"/notebooks/snippets/importing_libraries.ipynb"}]}}]},{"cell_type":"markdown","source":["#test cái tflite"],"metadata":{"id":"nnML9fwUoQWg"}},{"cell_type":"markdown","source":["Read label from csv file"],"metadata":{"id":"64ShAXw5M9lS"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"JTLwJstrEbSj"},"outputs":[],"source":["import csv\n","\n","with open('/content/class_labels_indices.csv', 'r') as f:\n"," reader = csv.reader(f, delimiter=',')\n"," lines = list(reader)\n","\n","labels = []\n","ids = [] # Each label has a unique id such as \"/m/068hy\"\n","for i1 in range(1, len(lines)):\n"," id = lines[i1][1]\n"," label = lines[i1][2]\n"," ids.append(id)\n"," labels.append(label)\n","\n","classes_num = len(labels)"]},{"cell_type":"markdown","source":["Preprocess audio"],"metadata":{"id":"ebLOZ2JXNF0w"}},{"cell_type":"code","source":["import numpy as np\n","from numpy.lib.stride_tricks import as_strided\n","from typing import Tuple, Optional, Union\n","\n","def mel_scale_scalar(freq: float) -> float:\n"," \"\"\"Convert frequency to mel scale\"\"\"\n"," return 1127.0 * np.log(1.0 + freq / 700.0)\n","\n","def mel_scale(freq: np.ndarray) -> np.ndarray:\n"," \"\"\"Vector version of mel scale conversion\"\"\"\n"," return 1127.0 * np.log(1.0 + freq / 700.0)\n","\n","def inverse_mel_scale(mel: np.ndarray) -> np.ndarray:\n"," \"\"\"Convert mel scale to frequency\"\"\"\n"," return 700.0 * (np.exp(mel / 1127.0) - 1.0)\n","\n","def vtln_warp_mel_freq(vtln_low: float,\n"," vtln_high: float,\n"," low_freq: float,\n"," high_freq: float,\n"," vtln_warp_factor: float,\n"," mel_freq: np.ndarray) -> np.ndarray:\n"," \"\"\"\n"," Implements VTLN warping for mel frequencies\n"," \"\"\"\n"," return mel_freq # Placeholder - implement if VTLN warping is needed\n","\n","def get_mel_banks(\n"," num_bins: int,\n"," window_length_padded: int,\n"," sample_freq: float,\n"," low_freq: float,\n"," high_freq: float,\n"," vtln_low: float,\n"," vtln_high: float,\n"," vtln_warp_factor: float,\n",") -> Tuple[np.ndarray, np.ndarray]:\n"," \"\"\"\n"," Get mel filterbank matrices following Kaldi's implementation.\n","\n"," Args:\n"," num_bins: Number of triangular bins\n"," window_length_padded: Padded window length\n"," sample_freq: Sampling frequency\n"," low_freq: Lowest frequency to consider\n"," high_freq: Highest frequency to consider\n"," vtln_low: Lower frequency for VTLN\n"," vtln_high: Higher frequency for VTLN\n"," vtln_warp_factor: Warping factor for VTLN\n","\n"," Returns:\n"," Tuple[np.ndarray, np.ndarray]: (bins, center_freqs)\n"," - bins: melbank matrix of shape (num_bins, num_fft_bins)\n"," - center_freqs: center frequencies of bins of shape (num_bins,)\n"," \"\"\"\n"," assert num_bins > 3, \"Must have at least 3 mel bins\"\n"," assert window_length_padded % 2 == 0\n"," num_fft_bins = window_length_padded // 2\n"," nyquist = 0.5 * sample_freq\n","\n"," if high_freq <= 0.0:\n"," high_freq += nyquist\n","\n"," assert (\n"," (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)\n"," ), f\"Bad values in options: low-freq {low_freq} and high-freq {high_freq} vs. nyquist {nyquist}\"\n","\n"," # fft-bin width [think of it as Nyquist-freq / half-window-length]\n"," fft_bin_width = sample_freq / window_length_padded\n","\n"," mel_low_freq = mel_scale_scalar(low_freq)\n"," mel_high_freq = mel_scale_scalar(high_freq)\n","\n"," # divide by num_bins+1 in next line because of end-effects where the bins\n"," # spread out to the sides.\n"," mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)\n","\n"," if vtln_high < 0.0:\n"," vtln_high += nyquist\n","\n"," assert vtln_warp_factor == 1.0 or (\n"," (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)\n"," ), f\"Bad values in options: vtln-low {vtln_low} and vtln-high {vtln_high}, versus low-freq {low_freq} and high-freq {high_freq}\"\n","\n"," bin = np.arange(num_bins)[:, np.newaxis]\n"," left_mel = mel_low_freq + bin * mel_freq_delta # shape (num_bins, 1)\n"," center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # shape (num_bins, 1)\n"," right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # shape (num_bins, 1)\n","\n"," if vtln_warp_factor != 1.0:\n"," left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)\n"," center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)\n"," right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)\n","\n"," center_freqs = inverse_mel_scale(center_mel).squeeze(-1) # shape (num_bins)\n","\n"," # shape (1, num_fft_bins)\n"," mel = mel_scale(fft_bin_width * np.arange(num_fft_bins))[np.newaxis, :]\n","\n"," # shape (num_bins, num_fft_bins)\n"," up_slope = (mel - left_mel) / (center_mel - left_mel)\n"," down_slope = (right_mel - mel) / (right_mel - center_mel)\n","\n"," if vtln_warp_factor == 1.0:\n"," # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values\n"," bins = np.maximum(0.0, np.minimum(up_slope, down_slope))\n"," else:\n"," # warping can move the order of left_mel, center_mel, right_mel anywhere\n"," bins = np.zeros_like(up_slope)\n"," up_idx = (mel > left_mel) & (mel <= center_mel) # left_mel < mel <= center_mel\n"," down_idx = (mel > center_mel) & (mel < right_mel) # center_mel < mel < right_mel\n"," bins[up_idx] = up_slope[up_idx]\n"," bins[down_idx] = down_slope[down_idx]\n","\n"," return bins, center_freqs\n","\n","def stft(\n"," input: np.ndarray,\n"," n_fft: int,\n"," hop_length: Optional[int] = None,\n"," win_length: Optional[int] = None,\n"," window: Optional[np.ndarray] = None,\n"," center: bool = True,\n"," pad_mode: str = \"reflect\",\n"," normalized: bool = False,\n"," onesided: bool = True,\n"," return_complex: bool = True\n",") -> np.ndarray:\n"," \"\"\"\n"," NumPy implementation of Short-time Fourier transform (STFT).\n","\n"," Args:\n"," input: Input signal (B?, L) where B? is optional batch dimension\n"," n_fft: Size of Fourier transform\n"," hop_length: Distance between neighboring frames (default: n_fft//4)\n"," win_length: Window size (default: n_fft)\n"," window: Window function (1D array)\n"," center: Whether to pad input for centered frames\n"," pad_mode: Padding mode ('reflect', 'constant', 'edge')\n"," normalized: Whether to normalize the STFT\n"," onesided: Whether to return only positive frequencies\n"," return_complex: Whether to return complex array\n"," \"\"\"\n"," # Set default values\n"," if hop_length is None:\n"," hop_length = n_fft // 4\n"," if win_length is None:\n"," win_length = n_fft\n","\n"," # Prepare window\n"," if window is None:\n"," window = np.ones(win_length)\n"," if len(window) < n_fft:\n"," # Pad window to n_fft if needed\n"," pad_width = (n_fft - len(window)) // 2\n"," window = np.pad(window, (pad_width, n_fft - len(window) - pad_width))\n","\n"," # Handle input dimensions\n"," input = np.asarray(input)\n"," if input.ndim == 1:\n"," input = input[np.newaxis, :] # Add batch dimension\n"," squeeze_batch = True\n"," else:\n"," squeeze_batch = False\n","\n"," # Pad signal for centered frames\n"," if center:\n"," pad_width = int(n_fft // 2)\n"," if pad_mode == 'reflect':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='reflect')\n"," elif pad_mode == 'constant':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='constant')\n"," elif pad_mode == 'edge':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='edge')\n","\n"," # Calculate number of frames\n"," n_frames = 1 + (input.shape[-1] - n_fft) // hop_length\n","\n"," # Create frame matrix using stride tricks\n"," frame_length = n_fft\n"," frame_step = hop_length\n"," frame_stride = input.strides[-1]\n","\n"," shape = (input.shape[0], n_frames, frame_length)\n"," strides = (input.strides[0], frame_step * frame_stride, frame_stride)\n","\n"," frames = as_strided(input, shape=shape, strides=strides, writeable=False)\n","\n"," # Apply window\n"," frames = frames * window\n","\n"," # Compute FFT\n"," stft_matrix = np.fft.fft(frames, n=n_fft, axis=-1)\n","\n"," # Normalize if requested\n"," if normalized:\n"," stft_matrix = stft_matrix / np.sqrt(n_fft)\n","\n"," # Handle onesided output\n"," if onesided:\n"," stft_matrix = stft_matrix[..., :(n_fft // 2) + 1]\n","\n"," # Handle return format\n"," if return_complex:\n"," result = stft_matrix\n"," else:\n"," result = np.stack((stft_matrix.real, stft_matrix.imag), axis=-1)\n","\n"," # Remove batch dimension if input was 1D\n"," if squeeze_batch:\n"," result = result[0]\n","\n"," return result\n","\n","class MelSTFT:\n"," def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024,\n"," fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):\n"," \"\"\"\n"," Initialize MelSTFT for audio feature extraction using only NumPy.\n"," \"\"\"\n"," self.n_mels = n_mels\n"," self.sr = sr\n"," self.win_length = win_length\n"," self.hopsize = hopsize\n"," self.n_fft = n_fft\n"," self.fmin = fmin\n"," if fmax is None:\n"," fmax = sr // 2 - fmax_aug_range // 2\n"," self.fmax = fmax\n","\n"," # Create Hann window\n"," self.window = np.hanning(win_length)\n","\n"," # Create mel filterbank following Kaldi's implementation\n"," self.mel_basis, _ = get_mel_banks(\n"," self.n_mels,\n"," self.n_fft,\n"," self.sr,\n"," self.fmin,\n"," self.fmax,\n"," 100.0,\n"," -500.,\n"," 1.0\n"," )\n"," self.mel_basis = np.pad(self.mel_basis, ((0, 0), (0, 1)), mode='constant', constant_values=0)\n","\n"," # Pre-emphasis filter coefficients\n"," self.preemphasis_coefficient = np.array([-.97, 1]).reshape(1, 1, 2)\n","\n"," def preemphasis(self, x):\n"," \"\"\"Apply pre-emphasis filter using conv1d equivalent\"\"\"\n"," # Reshape input to match conv1d input shape (batch, channels, length)\n"," x = x.reshape(1, 1, -1)\n","\n"," # Implement conv1d manually\n"," output_size = x.shape[2] - self.preemphasis_coefficient.shape[2] + 1\n"," result = np.zeros((1, 1, output_size))\n","\n"," for i in range(output_size):\n"," result[0, 0, i] = np.sum(x[0, 0, i:i+2] * self.preemphasis_coefficient[0, 0])\n","\n"," return result[0]\n","\n"," def __call__(self, x):\n"," \"\"\"Convert audio to log-Mel spectrogram.\"\"\"\n"," # Apply pre-emphasis\n"," print(f'x shape before {x.shape}')\n"," x = self.preemphasis(x)\n"," print(f'x shape after {x.shape}')\n","\n"," # Compute STFT\n"," spec_x = stft(\n"," input=x,\n"," n_fft=self.n_fft,\n"," hop_length=self.hopsize,\n"," win_length=self.win_length,\n"," window=self.window,\n"," center=True,\n"," pad_mode='reflect',\n"," normalized=False,\n"," return_complex=False\n"," )\n","\n"," # Convert to power spectrogram\n"," print(f'spec_x before is {spec_x.shape}')\n"," spec_x = np.sum(spec_x ** 2, axis=-1)\n","\n"," print(f'spec_x shape is {spec_x.shape}')\n","\n"," # Apply Mel filterbank (ensuring shapes match)\n"," melspec = np.dot(self.mel_basis, spec_x.transpose(0,2,1)).transpose(1,0,2)\n","\n"," # Log-scale and normalize\n"," melspec = np.log(melspec + 1e-5)\n"," melspec = (melspec + 4.5) / 5.\n","\n"," return melspec"],"metadata":{"id":"Ris3ajNhGEkQ","executionInfo":{"status":"ok","timestamp":1747040595424,"user_tz":-420,"elapsed":27,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["Load model and handle input and output\n","\n","The input shape of model is `[1,1,128, any]`. Note the last axis of input shape is dynamic. It depends on the length of audio.\n","\n","The output shape of model is `[1, 527]`. Currently, we're using the pretrained model and exporting it. Note: this shape might be changed in the future."],"metadata":{"id":"CqXIPms2NMcv"}},{"cell_type":"code","source":["import numpy as np\n","import tensorflow as tf\n","\n","# Load the TFLite model\n","interpreter = tf.lite.Interpreter(model_path=\"/content/emotion_model_2025_03_28.tflite\")\n","interpreter.allocate_tensors()\n","\n","# Get input and output details, can print to see details of input and output\n","input_details = interpreter.get_input_details() # The shape of input is [1, 1, 128, None] the last axis depends on the length of audio, data type = float32\n","output_details = interpreter.get_output_details() # The shape of output is [1, 527], data type = float32"],"metadata":{"id":"VlCUYth6GT0f"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# preprocess audio\n","import librosa\n","\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","\n","# convert to waveform\n","(waveform, _) = librosa.core.load('/content/angry_4s.mp3', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","if spec.shape[-1] > 400:\n"," spec = spec[:, :, :400]\n","else:\n"," spec = np.pad(spec, ((0, 0), (0, 0), (0, 400 - spec.shape[-1])), mode='constant')\n","spec = np.expand_dims(spec, axis=0)\n","print(f'the shape of spec is {spec.shape}')\n","spec = spec.astype(np.float32)"],"metadata":{"id":"PtexE6_HGlDW","colab":{"base_uri":"https://localhost:8080/"},"outputId":"cde2909e-7b58-4546-8e3a-e6f5779d8d04"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["x shape before (1, 130223)\n","x shape after (1, 130222)\n","spec_x before is (1, 407, 513, 2)\n","spec_x shape is (1, 407, 513)\n","the shape of spec is (1, 1, 128, 400)\n"]}]},{"cell_type":"code","source":["# Run inference\n","interpreter.set_tensor(input_details[0]['index'], spec)\n","interpreter.invoke()\n","output_data = interpreter.get_tensor(output_details[0]['index'])"],"metadata":{"id":"xyy_O-_vHsYy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Post process output to get the result"],"metadata":{"id":"eeh_JqiyN0zo"}},{"cell_type":"code","source":["# Post processing\n","import numpy as np\n","\n","def sigmoid(x):\n"," return 1 / (1 + np.exp(-x))\n","\n","def softmax(x):\n"," exp_x = np.exp(x - np.max(x)) # Subtract max(x) for numerical stability\n"," return exp_x / np.sum(exp_x, axis=-1, keepdims=True)\n","\n","preds = softmax(output_data[0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"* Acoustic Event Detected: *\")\n","for k in range(7):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"**\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pPLXt-FYH1Nz","outputId":"f620d201-8bfb-497d-d583-ef73db24cc46"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["* Acoustic Event Detected: *\n","Angry: 0.913\n","sad: 0.015\n","neutral: 0.015\n","Fear: 0.015\n","Happy: 0.014\n","surprise: 0.014\n","Disgust: 0.014\n","**\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"-MMULJzEH3xe"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
Preprocessing/Convert_to_HDF5.ipynb ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "#convert file audio to mp3 32k"
21
+ ],
22
+ "metadata": {
23
+ "id": "XE8k_JyY5eBb"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "id": "Wyv972sc5Lb4"
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "import os\n",
35
+ "from multiprocessing import Pool, cpu_count\n",
36
+ "from tqdm import tqdm\n",
37
+ "import subprocess\n",
38
+ "\n",
39
+ "def process_file_ffmpeg(args):\n",
40
+ " file_path, input_folder, output_folder = args\n",
41
+ " rel_path = os.path.relpath(file_path, input_folder)\n",
42
+ " rel_path = os.path.splitext(rel_path)[0] + \".mp3\" # luôn xuất mp3\n",
43
+ " out_path = os.path.join(output_folder, rel_path)\n",
44
+ " os.makedirs(os.path.dirname(out_path), exist_ok=True)\n",
45
+ "\n",
46
+ " cmd = [\n",
47
+ " \"ffmpeg\",\n",
48
+ " \"-y\", # overwrite nếu đã tồn tại\n",
49
+ " \"-i\", file_path,\n",
50
+ " \"-ar\", \"32000\", # sample rate 32kHz\n",
51
+ " \"-ac\", \"1\", # stereo set thành 2 còn mono set thành 1\n",
52
+ " \"-b:a\", \"192k\", # bitrate\n",
53
+ " out_path\n",
54
+ " ]\n",
55
+ "\n",
56
+ " try:\n",
57
+ " subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)\n",
58
+ " return True\n",
59
+ " except subprocess.CalledProcessError:\n",
60
+ " print(f\"❌ Lỗi khi xử lý {file_path}\")\n",
61
+ " return False\n",
62
+ "\n",
63
+ "def convert_audio_ffmpeg_multiprocessing(input_folder, output_folder, num_workers=None):\n",
64
+ " audio_exts = ('.mp3', '.wav', '.flac', '.m4a', '.ogg')\n",
65
+ "\n",
66
+ " # Lấy danh sách tất cả file audio\n",
67
+ " all_files = []\n",
68
+ " for root, _, files in os.walk(input_folder):\n",
69
+ " for f in files:\n",
70
+ " if f.lower().endswith(audio_exts):\n",
71
+ " all_files.append(os.path.join(root, f))\n",
72
+ "\n",
73
+ " if num_workers is None:\n",
74
+ " num_workers = cpu_count()\n",
75
+ "\n",
76
+ " args_list = [(f, input_folder, output_folder) for f in all_files]\n",
77
+ "\n",
78
+ " # Multiprocessing + tqdm\n",
79
+ " with Pool(num_workers) as pool:\n",
80
+ " for _ in tqdm(pool.imap_unordered(process_file_ffmpeg, args_list),\n",
81
+ " total=len(args_list), desc=\"Converting to 32kHz MP3\"):\n",
82
+ " pass\n",
83
+ "\n",
84
+ "# --- Ví dụ sử dụng ---\n",
85
+ "input_dir = \"/content/dataset\"\n",
86
+ "output_dir = \"/content/dataset_process\"\n",
87
+ "convert_audio_ffmpeg_multiprocessing(input_dir, output_dir)\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "source": [
93
+ "#mp3 to hdf5"
94
+ ],
95
+ "metadata": {
96
+ "id": "kIoHloKr5ky7"
97
+ }
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "source": [
102
+ "##Audioset"
103
+ ],
104
+ "metadata": {
105
+ "id": "pctubCgR5sli"
106
+ }
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "source": [
111
+ "import h5py\n",
112
+ "import pandas as pd\n",
113
+ "import numpy as np\n",
114
+ "import csv\n",
115
+ "import os\n",
116
+ "import io\n",
117
+ "import av\n",
118
+ "\n",
119
+ "def decode_mp3(mp3_arr):\n",
120
+ " \"\"\"\n",
121
+ " Giải mã một mảng uint8 đại diện cho một file MP3.\n",
122
+ " :rtype: np.array\n",
123
+ " \"\"\"\n",
124
+ " container = av.open(io.BytesIO(mp3_arr.tobytes())) # Đọc dữ liệu MP3\n",
125
+ " stream = next(s for s in container.streams if s.type == 'audio') # Lấy stream âm thanh\n",
126
+ " a = []\n",
127
+ " for i, packet in enumerate(container.demux(stream)): # Demux các gói dữ liệu âm thanh\n",
128
+ " for frame in packet.decode(): # Giải mã frame\n",
129
+ " a.append(frame.to_ndarray().reshape(-1)) # Chuyển đổi frame thành mảng numpy\n",
130
+ " waveform = np.concatenate(a) # Kết nối tất cả các frame lại\n",
131
+ " if waveform.dtype != 'float32': # Kiểm tra loại dữ liệu\n",
132
+ " raise RuntimeError(\"Unexpected wave type\")\n",
133
+ " return waveform\n",
134
+ "\n",
135
+ "# %%\n",
136
+ "base_dir = \"/content/output_\"\n",
137
+ "balanced_csv= '/content/new_updated_balanced_train_segments.csv'\n",
138
+ "eval_csv= '/content/new_eval_segments.csv'\n",
139
+ "mp3_path = \"/content/dataset/\"\n",
140
+ "\n",
141
+ "\n",
142
+ "# %%\n",
143
+ "\n",
144
+ "\n",
145
+ "def read_metadata(csv_path, classes_num, id_to_ix):\n",
146
+ " \"\"\"Read metadata of AudioSet from a csv file.\"\"\"\n",
147
+ "\n",
148
+ " audio_names = []\n",
149
+ " targets = []\n",
150
+ "\n",
151
+ " with open(csv_path, 'r') as fr:\n",
152
+ " reader = csv.reader(fr)\n",
153
+ " next(reader) # Skip header line if exists\n",
154
+ " next(reader) # Skip another potential header line\n",
155
+ " next(reader) # Skip another potential header line\n",
156
+ "\n",
157
+ " for line in reader:\n",
158
+ " if len(line) < 4:\n",
159
+ " continue # Skip malformed lines\n",
160
+ "\n",
161
+ " audio_name = 'Y{}.mp3'.format(line[0]) # Assumed naming convention\n",
162
+ " label_ids = line[3].strip('\"').split(',')\n",
163
+ "\n",
164
+ " audio_names.append(audio_name)\n",
165
+ " target = np.zeros(classes_num, dtype=bool)\n",
166
+ "\n",
167
+ " for label_id in label_ids:\n",
168
+ " if label_id in id_to_ix:\n",
169
+ " ix = id_to_ix[label_id]\n",
170
+ " target[ix] = 1\n",
171
+ " else:\n",
172
+ " print(f\"Warning: Label ID {label_id} not found in id_to_ix.\")\n",
173
+ "\n",
174
+ " targets.append(target)\n",
175
+ "\n",
176
+ " meta_dict = {'audio_name': np.array(audio_names), 'target': np.array(targets)}\n",
177
+ " print(meta_dict)\n",
178
+ " return meta_dict\n",
179
+ "\n",
180
+ "# Load label\n",
181
+ "with open('/content/new_class_labels_indices_filter_discard.csv', 'r') as f:\n",
182
+ " reader = csv.reader(f, delimiter=',')\n",
183
+ " lines = list(reader)\n",
184
+ "\n",
185
+ "labels = []\n",
186
+ "ids = [] # Each label has a unique id such as \"/m/068hy\"\n",
187
+ "for i1 in range(1, len(lines)):\n",
188
+ " id = lines[i1][1]\n",
189
+ " label = lines[i1][2]\n",
190
+ " ids.append(id)\n",
191
+ " labels.append(label)\n",
192
+ "\n",
193
+ "classes_num = len(labels)\n",
194
+ "\n",
195
+ "lb_to_ix = {label : i for i, label in enumerate(labels)}\n",
196
+ "ix_to_lb = {i : label for i, label in enumerate(labels)}\n",
197
+ "\n",
198
+ "id_to_ix = {id : i for i, id in enumerate(ids)}\n",
199
+ "ix_to_id = {i : id for i, id in enumerate(ids)}\n",
200
+ "\n",
201
+ "# %%\n",
202
+ "\n",
203
+ "def check_available(balanced_csv,balanced_audio_path,prefix=None):\n",
204
+ " meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)\n",
205
+ " #print(meta_csv)\n",
206
+ " audios_num = len(meta_csv['audio_name'])\n",
207
+ " found=0\n",
208
+ " notfound=0\n",
209
+ " available_files=[]\n",
210
+ " available_targets=[]\n",
211
+ " if prefix is None:\n",
212
+ " prefix = os.path.basename(balanced_csv)[:-4]\n",
213
+ " for n in range(audios_num):\n",
214
+ " audio_path = meta_csv['audio_name'][n]\n",
215
+ " #print(balanced_audio_path + f\"{prefix}/{audio_path}\")\n",
216
+ " if os.path.isfile(balanced_audio_path + f\"{prefix}/{audio_path}\" ):\n",
217
+ " found+=1\n",
218
+ " available_files.append(meta_csv['audio_name'][n])\n",
219
+ " available_targets.append(meta_csv['target'][n])\n",
220
+ " else:\n",
221
+ " notfound+=1\n",
222
+ " print(f\"Found {found} . not found {notfound}\")\n",
223
+ " return available_files,available_targets\n",
224
+ "# %%\n",
225
+ "\n",
226
+ "# %%\n",
227
+ "\n",
228
+ "# %%\n",
229
+ "\n",
230
+ "\n",
231
+ "os.makedirs(os.path.dirname(base_dir + \"mp3\"), exist_ok=True)\n",
232
+ "\n",
233
+ "for read_file,prefix in [(balanced_csv,\"balanced_train_segments/\"), (eval_csv,\"eval_segments/\"),]:\n",
234
+ " print(\"now working on \",read_file,prefix)\n",
235
+ " #files, y = torch.load(read_file+\".pth\")\n",
236
+ " files, y = check_available(read_file, mp3_path, prefix)\n",
237
+ " y = np.packbits(y, axis=-1)\n",
238
+ " packed_len = y.shape[1]\n",
239
+ " print(files[0], \"classes: \",packed_len, y.dtype)\n",
240
+ " available_size = len(files)\n",
241
+ " f = files[0][:-3]+\"mp3\"\n",
242
+ " a = np.fromfile(mp3_path+prefix + \"/\"+f, dtype='uint8')\n",
243
+ "\n",
244
+ " dt = h5py.vlen_dtype(np.dtype('uint8'))\n",
245
+ " save_file = prefix.split(\"/\")[0]\n",
246
+ " os.makedirs(os.path.dirname(base_dir + \"mp3/\" ), exist_ok=True)\n",
247
+ " with h5py.File(base_dir+ \"mp3/\" + save_file+\"_mp3.hdf\", 'w') as hf:\n",
248
+ " audio_name = hf.create_dataset('audio_name', shape=(0,), maxshape=(None,), dtype='S20')\n",
249
+ " waveform = hf.create_dataset('mp3', shape=(0,), maxshape=(None,), dtype=dt)\n",
250
+ " target = hf.create_dataset('target', shape=(0, packed_len), maxshape=(None, packed_len), dtype=y.dtype)\n",
251
+ " for i,file in enumerate(files):\n",
252
+ " if i%1000==0:\n",
253
+ " print(f\"{i}/{available_size}\")\n",
254
+ " f = file[:-3] + \"mp3\"\n",
255
+ " a = np.fromfile(mp3_path + prefix + f, dtype='uint8')\n",
256
+ " try:\n",
257
+ " # Kiểm tra xem file audio có đọc được không\n",
258
+ " decode_mp3(a) # Dùng hàm decode_mp3 của bạn\n",
259
+ "\n",
260
+ " audio_name.resize((i + 1,))\n",
261
+ " waveform.resize((i + 1,))\n",
262
+ " target.resize((i + 1, packed_len))\n",
263
+ "\n",
264
+ " audio_name[i]=f\n",
265
+ " waveform[i] = a\n",
266
+ " target[i] = y[i]\n",
267
+ " except Exception as e:\n",
268
+ " print(f\"File lỗi tại index {i} với file {file}: {e}\")\n",
269
+ "\n",
270
+ " print(a.shape)\n",
271
+ " print(\"Done!\" , prefix)"
272
+ ],
273
+ "metadata": {
274
+ "id": "8oFKEbtb5mzr"
275
+ },
276
+ "execution_count": null,
277
+ "outputs": []
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "source": [
282
+ "##For this structure folder/train/folder(class)/file"
283
+ ],
284
+ "metadata": {
285
+ "id": "mCwcSx8y5v7q"
286
+ }
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "source": [
291
+ "![image.png]()"
292
+ ],
293
+ "metadata": {
294
+ "id": "cRAr5tkn566K"
295
+ }
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "source": [
300
+ "import h5py\n",
301
+ "import numpy as np\n",
302
+ "import os\n",
303
+ "import io\n",
304
+ "import av\n",
305
+ "from pathlib import Path\n",
306
+ "from tqdm import tqdm\n",
307
+ "\n",
308
+ "def decode_mp3(mp3_arr):\n",
309
+ " \"\"\"\n",
310
+ " Giải mã một mảng uint8 đại diện cho một file MP3.\n",
311
+ " :rtype: np.array\n",
312
+ " \"\"\"\n",
313
+ " try:\n",
314
+ " container = av.open(io.BytesIO(mp3_arr.tobytes()))\n",
315
+ " stream = next(s for s in container.streams if s.type == 'audio')\n",
316
+ " a = []\n",
317
+ " for packet in container.demux(stream):\n",
318
+ " for frame in packet.decode():\n",
319
+ " a.append(frame.to_ndarray().reshape(-1))\n",
320
+ " waveform = np.concatenate(a)\n",
321
+ " if waveform.dtype != 'float32':\n",
322
+ " raise RuntimeError(\"Unexpected wave type\")\n",
323
+ " return waveform\n",
324
+ " except Exception as e:\n",
325
+ " raise RuntimeError(f\"Cannot decode MP3: {e}\")\n",
326
+ "\n",
327
+ "def scan_dataset_structure(dataset_path):\n",
328
+ " \"\"\"\n",
329
+ " Quét cấu trúc thư mục dataset và tạo mapping cho classes\n",
330
+ " Structure: dataset_path/train(or test)/class_name/*.mp3\n",
331
+ " \"\"\"\n",
332
+ " dataset_path = Path(dataset_path)\n",
333
+ "\n",
334
+ " # Lấy tất cả các class names từ thư mục train\n",
335
+ " train_path = dataset_path / \"train\"\n",
336
+ " if not train_path.exists():\n",
337
+ " raise ValueError(f\"Train folder not found: {train_path}\")\n",
338
+ "\n",
339
+ " classes = sorted([d.name for d in train_path.iterdir() if d.is_dir()])\n",
340
+ " classes_num = len(classes)\n",
341
+ "\n",
342
+ " # Tạo mapping\n",
343
+ " lb_to_ix = {label: i for i, label in enumerate(classes)}\n",
344
+ " ix_to_lb = {i: label for i, label in enumerate(classes)}\n",
345
+ "\n",
346
+ " print(f\"Found {classes_num} classes: {classes[:7]}...\" if len(classes) > 7 else f\"Found {classes_num} classes: {classes}\")\n",
347
+ "\n",
348
+ " return classes, classes_num, lb_to_ix, ix_to_lb\n",
349
+ "\n",
350
+ "def collect_audio_files(dataset_path, split='train', shuffle=True, random_seed=42):\n",
351
+ " \"\"\"\n",
352
+ " Thu thập tất cả audio files từ structure thư mục và shuffle để tránh grouping theo class\n",
353
+ " \"\"\"\n",
354
+ " dataset_path = Path(dataset_path)\n",
355
+ " split_path = dataset_path / split\n",
356
+ "\n",
357
+ " if not split_path.exists():\n",
358
+ " raise ValueError(f\"{split} folder not found: {split_path}\")\n",
359
+ "\n",
360
+ " audio_files = []\n",
361
+ " labels = []\n",
362
+ " class_counts = {}\n",
363
+ "\n",
364
+ " class_dirs = [d for d in split_path.iterdir() if d.is_dir()]\n",
365
+ "\n",
366
+ " print(f\"📁 Scanning {split} folder...\")\n",
367
+ " for class_dir in tqdm(class_dirs, desc=f\"Scanning classes\"):\n",
368
+ " class_name = class_dir.name\n",
369
+ " mp3_files = list(class_dir.glob(\"*.mp3\"))\n",
370
+ " class_counts[class_name] = len(mp3_files)\n",
371
+ "\n",
372
+ " for mp3_file in mp3_files:\n",
373
+ " audio_files.append(str(mp3_file))\n",
374
+ " labels.append(class_name)\n",
375
+ "\n",
376
+ " # Shuffle để tránh việc group theo class trong HDF5\n",
377
+ " if shuffle:\n",
378
+ " import random\n",
379
+ " random.seed(random_seed)\n",
380
+ "\n",
381
+ " # Zip files và labels lại, sau đó shuffle\n",
382
+ " combined = list(zip(audio_files, labels))\n",
383
+ " random.shuffle(combined)\n",
384
+ "\n",
385
+ " # Unpack lại\n",
386
+ " audio_files, labels = zip(*combined)\n",
387
+ " audio_files = list(audio_files)\n",
388
+ " labels = list(labels)\n",
389
+ "\n",
390
+ " print(f\"🔀 Files shuffled with seed={random_seed}\")\n",
391
+ "\n",
392
+ " # In class distribution\n",
393
+ " print(f\"✅ Found {len(audio_files)} audio files in {split} set\")\n",
394
+ " print(f\"📊 Class distribution:\")\n",
395
+ " for class_name, count in sorted(class_counts.items()):\n",
396
+ " percentage = count / len(audio_files) * 100\n",
397
+ " print(f\" {class_name}: {count} files ({percentage:.1f}%)\")\n",
398
+ "\n",
399
+ " return audio_files, labels\n",
400
+ "\n",
401
+ "def create_target_array(labels, classes_num, lb_to_ix):\n",
402
+ " \"\"\"\n",
403
+ " Tạo target array từ danh sách labels\n",
404
+ " \"\"\"\n",
405
+ " targets = []\n",
406
+ " for label in labels:\n",
407
+ " target = np.zeros(classes_num, dtype=bool)\n",
408
+ " if label in lb_to_ix:\n",
409
+ " ix = lb_to_ix[label]\n",
410
+ " target[ix] = 1\n",
411
+ " targets.append(target)\n",
412
+ "\n",
413
+ " return np.array(targets)\n",
414
+ "\n",
415
+ "def convert_to_hdf5(dataset_path, output_dir):\n",
416
+ " \"\"\"\n",
417
+ " Convert audio dataset to HDF5 format\n",
418
+ " \"\"\"\n",
419
+ " # Tạo output directory\n",
420
+ " os.makedirs(output_dir, exist_ok=True)\n",
421
+ "\n",
422
+ " # Quét cấu trúc dataset\n",
423
+ " classes, classes_num, lb_to_ix, ix_to_lb = scan_dataset_structure(dataset_path)\n",
424
+ "\n",
425
+ " # Process both train and test splits\n",
426
+ " for split in ['train', 'test']:\n",
427
+ " print(f\"\\n=== Processing {split} set ===\")\n",
428
+ "\n",
429
+ " try:\n",
430
+ " # Thu thập audio files\n",
431
+ " audio_files, labels = collect_audio_files(dataset_path, split)\n",
432
+ "\n",
433
+ " if len(audio_files) == 0:\n",
434
+ " print(f\"No audio files found in {split} set, skipping...\")\n",
435
+ " continue\n",
436
+ "\n",
437
+ " # Tạo target array\n",
438
+ " targets = create_target_array(labels, classes_num, lb_to_ix)\n",
439
+ "\n",
440
+ " # Pack targets để tiết kiệm memory\n",
441
+ " packed_targets = np.packbits(targets, axis=-1)\n",
442
+ " packed_len = packed_targets.shape[1]\n",
443
+ "\n",
444
+ " print(f\"Target shape: {targets.shape} -> Packed: {packed_targets.shape}\")\n",
445
+ "\n",
446
+ " # Tạo HDF5 file\n",
447
+ " dt = h5py.vlen_dtype(np.dtype('uint8'))\n",
448
+ " hdf5_path = os.path.join(output_dir, f\"{split}_mp3.hdf5\")\n",
449
+ "\n",
450
+ " with h5py.File(hdf5_path, 'w') as hf:\n",
451
+ " # Tạo datasets\n",
452
+ " audio_name_ds = hf.create_dataset('audio_name', shape=(0,), maxshape=(None,), dtype='S200')\n",
453
+ " waveform_ds = hf.create_dataset('mp3', shape=(0,), maxshape=(None,), dtype=dt)\n",
454
+ " target_ds = hf.create_dataset('target', shape=(0, packed_len), maxshape=(None, packed_len), dtype=packed_targets.dtype)\n",
455
+ "\n",
456
+ " # Lưu class info\n",
457
+ " hf.attrs['classes'] = [c.encode('utf-8') for c in classes]\n",
458
+ " hf.attrs['classes_num'] = classes_num\n",
459
+ "\n",
460
+ " valid_count = 0\n",
461
+ " error_count = 0\n",
462
+ "\n",
463
+ " # Process từng file với tqdm\n",
464
+ " pbar = tqdm(zip(audio_files, labels),\n",
465
+ " total=len(audio_files),\n",
466
+ " desc=f\"Converting {split}\")\n",
467
+ "\n",
468
+ " for i, (audio_file, label) in enumerate(pbar):\n",
469
+ " try:\n",
470
+ " # Đọc file MP3\n",
471
+ " audio_data = np.fromfile(audio_file, dtype='uint8')\n",
472
+ "\n",
473
+ " # Kiểm tra tính hợp lệ bằng cách decode\n",
474
+ " decode_mp3(audio_data)\n",
475
+ "\n",
476
+ " # Resize datasets\n",
477
+ " audio_name_ds.resize((valid_count + 1,))\n",
478
+ " waveform_ds.resize((valid_count + 1,))\n",
479
+ " target_ds.resize((valid_count + 1, packed_len))\n",
480
+ "\n",
481
+ " # Lưu data\n",
482
+ " filename = os.path.basename(audio_file).encode('utf-8')\n",
483
+ " audio_name_ds[valid_count] = filename\n",
484
+ " waveform_ds[valid_count] = audio_data\n",
485
+ " target_ds[valid_count] = packed_targets[i]\n",
486
+ "\n",
487
+ " valid_count += 1\n",
488
+ "\n",
489
+ " # Update progress bar\n",
490
+ " pbar.set_postfix({\n",
491
+ " 'Valid': valid_count,\n",
492
+ " 'Errors': error_count,\n",
493
+ " 'Success Rate': f\"{valid_count/(i+1)*100:.1f}%\"\n",
494
+ " })\n",
495
+ "\n",
496
+ " except Exception as e:\n",
497
+ " error_count += 1\n",
498
+ " pbar.set_postfix({\n",
499
+ " 'Valid': valid_count,\n",
500
+ " 'Errors': error_count,\n",
501
+ " 'Success Rate': f\"{valid_count/(i+1)*100:.1f}%\"\n",
502
+ " })\n",
503
+ " if error_count <= 5: # Chỉ show 5 error đầu tiên\n",
504
+ " tqdm.write(f\"❌ Error processing {os.path.basename(audio_file)}: {e}\")\n",
505
+ " continue\n",
506
+ "\n",
507
+ " pbar.close()\n",
508
+ "\n",
509
+ " print(f\"Successfully processed {valid_count}/{len(audio_files)} files\")\n",
510
+ " print(f\"Saved to: {hdf5_path}\")\n",
511
+ "\n",
512
+ " except Exception as e:\n",
513
+ " print(f\"Error processing {split} set: {e}\")\n",
514
+ "\n",
515
+ "def main():\n",
516
+ " # Cấu hình paths\n",
517
+ " dataset_path = \"/content/dataset\" # Thay đổi path này\n",
518
+ " output_dir = \"/content/dataset_hdf5\" # Thay đổi path này\n",
519
+ "\n",
520
+ " # Chạy conversion\n",
521
+ " convert_to_hdf5(dataset_path, output_dir)\n",
522
+ " print(\"\\n=== Conversion completed! ===\")\n",
523
+ "\n",
524
+ "if __name__ == \"__main__\":\n",
525
+ " # Example usage:\n",
526
+ " # dataset_path = \"/content/audio_dataset\"\n",
527
+ " # output_dir = \"/content/output_hdf5\"\n",
528
+ " # convert_to_hdf5(dataset_path, output_dir)\n",
529
+ " main()"
530
+ ],
531
+ "metadata": {
532
+ "id": "lcdNaKMx59ip"
533
+ },
534
+ "execution_count": null,
535
+ "outputs": []
536
+ }
537
+ ]
538
+ }
Train_script/EfficientAT_code_train.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed1ec9ad0e90da763c5abc4a27f99fad97d073e5f694960c8981deae04705177
3
- size 22674058
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05f83c9fa5de22cc1153c30e0adff06eea1c0a11b6be6cc6ff2a466006b95fe1
3
+ size 22666093
Train_script/Train_guide.ipynb ADDED
@@ -0,0 +1,1117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "cd7ca034-54d2-434d-9ba8-9e1877e7a7c9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!unzip -q '/workspace/EfficientAT_code_train.zip' -d './train'"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 3,
16
+ "id": "2f674532-597d-4471-aef8-53c6f2b3ce7f",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "!unzip -q '/workspace/HDF5_file.zip' -d '/workspace/train/EfficientAT-main/datasets'"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "c7595bbe-a43b-4e08-a7de-1e27206a782b",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "#Please change the dataset_config (filename hdf5) in \"audioset.py\" in \"EfficientAT-main/datasets\" "
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 4,
36
+ "id": "7dd14b84-3451-423f-bc07-16fddddc2a07",
37
+ "metadata": {},
38
+ "outputs": [
39
+ {
40
+ "name": "stdout",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Collecting av (from -r /workspace/train/EfficientAT-main/requirements.txt (line 1))\n",
44
+ " Downloading av-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.6 kB)\n",
45
+ "Collecting h5py (from -r /workspace/train/EfficientAT-main/requirements.txt (line 2))\n",
46
+ " Downloading h5py-3.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)\n",
47
+ "Collecting librosa (from -r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
48
+ " Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)\n",
49
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 4)) (1.24.1)\n",
50
+ "Collecting scikit_learn (from -r /workspace/train/EfficientAT-main/requirements.txt (line 5))\n",
51
+ " Downloading scikit_learn-1.7.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n",
52
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.0+cu118)\n",
53
+ "Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 7)) (2.1.0+cu118)\n",
54
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (0.16.0+cu118)\n",
55
+ "Collecting tqdm (from -r /workspace/train/EfficientAT-main/requirements.txt (line 9))\n",
56
+ " Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)\n",
57
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
58
+ "\u001b[?25hCollecting wandb (from -r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
59
+ " Downloading wandb-0.21.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n",
60
+ "Collecting pandas (from -r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
61
+ " Downloading pandas-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n",
62
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
63
+ "\u001b[?25hCollecting seaborn (from -r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
64
+ " Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n",
65
+ "Collecting audioread>=2.1.9 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
66
+ " Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)\n",
67
+ "Collecting numba>=0.51.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
68
+ " Downloading numba-0.61.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)\n",
69
+ "Collecting scipy>=1.6.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
70
+ " Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
71
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.0/62.0 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
72
+ "\u001b[?25hCollecting joblib>=1.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
73
+ " Downloading joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)\n",
74
+ "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (5.1.1)\n",
75
+ "Collecting soundfile>=0.12.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
76
+ " Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)\n",
77
+ "Collecting pooch>=1.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
78
+ " Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)\n",
79
+ "Collecting soxr>=0.3.2 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
80
+ " Downloading soxr-0.5.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)\n",
81
+ "Requirement already satisfied: typing_extensions>=4.1.1 in /usr/local/lib/python3.10/dist-packages (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (4.4.0)\n",
82
+ "Collecting lazy_loader>=0.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
83
+ " Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)\n",
84
+ "Collecting msgpack>=1.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
85
+ " Downloading msgpack-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)\n",
86
+ "Collecting threadpoolctl>=3.1.0 (from scikit_learn->-r /workspace/train/EfficientAT-main/requirements.txt (line 5))\n",
87
+ " Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)\n",
88
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.9.0)\n",
89
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (1.12)\n",
90
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.0)\n",
91
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.1.2)\n",
92
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2023.4.0)\n",
93
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.0)\n",
94
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2.31.0)\n",
95
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (9.3.0)\n",
96
+ "Collecting click!=8.0.0,>=7.1 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
97
+ " Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)\n",
98
+ "Collecting gitpython!=3.1.29,>=1.0.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
99
+ " Downloading gitpython-3.1.45-py3-none-any.whl.metadata (13 kB)\n",
100
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (23.2)\n",
101
+ "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (3.11.0)\n",
102
+ "Collecting protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
103
+ " Downloading protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)\n",
104
+ "Collecting pydantic<3 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
105
+ " Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)\n",
106
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.0/68.0 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
107
+ "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (6.0.1)\n",
108
+ "Collecting sentry-sdk>=2.0.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
109
+ " Downloading sentry_sdk-2.35.0-py2.py3-none-any.whl.metadata (10 kB)\n",
110
+ "Collecting typing_extensions>=4.1.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
111
+ " Downloading typing_extensions-4.14.1-py3-none-any.whl.metadata (3.0 kB)\n",
112
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11)) (2.8.2)\n",
113
+ "Collecting pytz>=2020.1 (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
114
+ " Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)\n",
115
+ "Collecting tzdata>=2022.7 (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
116
+ " Downloading tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)\n",
117
+ "Collecting matplotlib!=3.6.1,>=3.4 (from seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
118
+ " Downloading matplotlib-3.10.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n",
119
+ "Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
120
+ " Downloading gitdb-4.0.12-py3-none-any.whl.metadata (1.2 kB)\n",
121
+ "Collecting contourpy>=1.0.1 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
122
+ " Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n",
123
+ "Collecting cycler>=0.10 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
124
+ " Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n",
125
+ "Collecting fonttools>=4.22.0 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
126
+ " Downloading fonttools-4.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (108 kB)\n",
127
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.9/108.9 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
128
+ "\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
129
+ " Downloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.3 kB)\n",
130
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12)) (2.4.7)\n",
131
+ "Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
132
+ " Downloading llvmlite-0.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.8 kB)\n",
133
+ "Collecting annotated-types>=0.6.0 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
134
+ " Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)\n",
135
+ "Collecting pydantic-core==2.33.2 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
136
+ " Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n",
137
+ "Collecting typing-inspection>=0.4.0 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
138
+ " Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)\n",
139
+ "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11)) (1.16.0)\n",
140
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2.1.1)\n",
141
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (3.4)\n",
142
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (1.26.13)\n",
143
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2022.12.7)\n",
144
+ "Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile>=0.12.1->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (1.16.0)\n",
145
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.2)\n",
146
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (1.3.0)\n",
147
+ "Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (2.21)\n",
148
+ "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
149
+ " Downloading smmap-5.0.2-py3-none-any.whl.metadata (4.3 kB)\n",
150
+ "Downloading av-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.2 MB)\n",
151
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.2/39.2 MB\u001b[0m \u001b[31m46.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
152
+ "\u001b[?25hDownloading h5py-3.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n",
153
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
154
+ "\u001b[?25hDownloading librosa-0.11.0-py3-none-any.whl (260 kB)\n",
155
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m260.7/260.7 kB\u001b[0m \u001b[31m60.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
156
+ "\u001b[?25hDownloading scikit_learn-1.7.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)\n",
157
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.7/9.7 MB\u001b[0m \u001b[31m95.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
158
+ "\u001b[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)\n",
159
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
160
+ "\u001b[?25hDownloading wandb-0.21.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.4 MB)\n",
161
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m22.4/22.4 MB\u001b[0m \u001b[31m105.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
162
+ "\u001b[?25hDownloading pandas-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.3 MB)\n",
163
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.3/12.3 MB\u001b[0m \u001b[31m106.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n",
164
+ "\u001b[?25hDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n",
165
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.9/294.9 kB\u001b[0m \u001b[31m51.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
166
+ "\u001b[?25hDownloading audioread-3.0.1-py3-none-any.whl (23 kB)\n",
167
+ "Downloading click-8.2.1-py3-none-any.whl (102 kB)\n",
168
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.2/102.2 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
169
+ "\u001b[?25hDownloading gitpython-3.1.45-py3-none-any.whl (208 kB)\n",
170
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.2/208.2 kB\u001b[0m \u001b[31m56.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
171
+ "\u001b[?25hDownloading joblib-1.5.1-py3-none-any.whl (307 kB)\n",
172
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.7/307.7 kB\u001b[0m \u001b[31m68.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
173
+ "\u001b[?25hDownloading lazy_loader-0.4-py3-none-any.whl (12 kB)\n",
174
+ "Downloading matplotlib-3.10.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)\n",
175
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m91.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
176
+ "\u001b[?25hDownloading msgpack-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (408 kB)\n",
177
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━���━━━\u001b[0m \u001b[32m408.6/408.6 kB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
178
+ "\u001b[?25hDownloading numba-0.61.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)\n",
179
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m138.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
180
+ "\u001b[?25hDownloading pooch-1.8.2-py3-none-any.whl (64 kB)\n",
181
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.6/64.6 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
182
+ "\u001b[?25hDownloading protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl (322 kB)\n",
183
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m322.0/322.0 kB\u001b[0m \u001b[31m49.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
184
+ "\u001b[?25hDownloading pydantic-2.11.7-py3-none-any.whl (444 kB)\n",
185
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m444.8/444.8 kB\u001b[0m \u001b[31m82.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
186
+ "\u001b[?25hDownloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n",
187
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m140.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
188
+ "\u001b[?25hDownloading pytz-2025.2-py2.py3-none-any.whl (509 kB)\n",
189
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m509.2/509.2 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n",
190
+ "\u001b[?25hDownloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n",
191
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m43.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
192
+ "\u001b[?25hDownloading sentry_sdk-2.35.0-py2.py3-none-any.whl (363 kB)\n",
193
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.8/363.8 kB\u001b[0m \u001b[31m69.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
194
+ "\u001b[?25hDownloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)\n",
195
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m128.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
196
+ "\u001b[?25hDownloading soxr-0.5.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (252 kB)\n",
197
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m44.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
198
+ "\u001b[?25hDownloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)\n",
199
+ "Downloading typing_extensions-4.14.1-py3-none-any.whl (43 kB)\n",
200
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.9/43.9 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
201
+ "\u001b[?25hDownloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)\n",
202
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m347.8/347.8 kB\u001b[0m \u001b[31m71.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
203
+ "\u001b[?25hDownloading annotated_types-0.7.0-py3-none-any.whl (13 kB)\n",
204
+ "Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)\n",
205
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m325.0/325.0 kB\u001b[0m \u001b[31m39.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
206
+ "\u001b[?25hDownloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
207
+ "Downloading fonttools-4.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (4.8 MB)\n",
208
+ "\u001b[2K \u001b[90m��━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.8/4.8 MB\u001b[0m \u001b[31m113.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
209
+ "\u001b[?25hDownloading gitdb-4.0.12-py3-none-any.whl (62 kB)\n",
210
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
211
+ "\u001b[?25hDownloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)\n",
212
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m72.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
213
+ "\u001b[?25hDownloading llvmlite-0.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.4 MB)\n",
214
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.4/42.4 MB\u001b[0m \u001b[31m47.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
215
+ "\u001b[?25hDownloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)\n",
216
+ "Downloading smmap-5.0.2-py3-none-any.whl (24 kB)\n",
217
+ "Installing collected packages: pytz, tzdata, typing_extensions, tqdm, threadpoolctl, soxr, smmap, sentry-sdk, scipy, protobuf, msgpack, llvmlite, lazy_loader, kiwisolver, joblib, h5py, fonttools, cycler, contourpy, click, av, audioread, annotated-types, typing-inspection, soundfile, scikit_learn, pydantic-core, pooch, pandas, numba, matplotlib, gitdb, seaborn, pydantic, librosa, gitpython, wandb\n",
218
+ " Attempting uninstall: typing_extensions\n",
219
+ " Found existing installation: typing_extensions 4.4.0\n",
220
+ " Uninstalling typing_extensions-4.4.0:\n",
221
+ " Successfully uninstalled typing_extensions-4.4.0\n",
222
+ "Successfully installed annotated-types-0.7.0 audioread-3.0.1 av-15.0.0 click-8.2.1 contourpy-1.3.2 cycler-0.12.1 fonttools-4.59.1 gitdb-4.0.12 gitpython-3.1.45 h5py-3.14.0 joblib-1.5.1 kiwisolver-1.4.9 lazy_loader-0.4 librosa-0.11.0 llvmlite-0.44.0 matplotlib-3.10.5 msgpack-1.1.1 numba-0.61.2 pandas-2.3.1 pooch-1.8.2 protobuf-6.32.0 pydantic-2.11.7 pydantic-core-2.33.2 pytz-2025.2 scikit_learn-1.7.1 scipy-1.15.3 seaborn-0.13.2 sentry-sdk-2.35.0 smmap-5.0.2 soundfile-0.13.1 soxr-0.5.0.post1 threadpoolctl-3.6.0 tqdm-4.67.1 typing-inspection-0.4.1 typing_extensions-4.14.1 tzdata-2025.2 wandb-0.21.1\n",
223
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
224
+ "\u001b[0m\n",
225
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.2\u001b[0m\n",
226
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
227
+ ]
228
+ }
229
+ ],
230
+ "source": [
231
+ "!pip install -r /workspace/train/EfficientAT-main/requirements.txt"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 5,
237
+ "id": "5a7a42e9-3f79-4f63-a6ea-98db2206ca96",
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "/workspace/train/EfficientAT-main\n"
245
+ ]
246
+ },
247
+ {
248
+ "name": "stderr",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n",
252
+ " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
253
+ ]
254
+ }
255
+ ],
256
+ "source": [
257
+ "%cd /workspace/train/EfficientAT-main"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "d0abf14b-46a1-4f92-9f70-e0de40dff46a",
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "INFO: FMAX is None, setting to 15000 \n",
271
+ "The number of classes in the loaded state dict (=527) and the current model (=7) is not the same. Dropping final fully-connected layer and loading weights in non-strict mode!\n",
272
+ "DyMN(\n",
273
+ " (layers): ModuleList(\n",
274
+ " (0): DY_Block(\n",
275
+ " (exp_conv): DynamicWrapper(\n",
276
+ " (module): Identity()\n",
277
+ " )\n",
278
+ " (exp_norm): Identity()\n",
279
+ " (exp_act): DynamicWrapper(\n",
280
+ " (module): Identity()\n",
281
+ " )\n",
282
+ " (depth_conv): DynamicConv(\n",
283
+ " (residuals): Sequential(\n",
284
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
285
+ " )\n",
286
+ " )\n",
287
+ " (depth_norm): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
288
+ " (depth_act): DyReLUB(\n",
289
+ " (coef_net): Sequential(\n",
290
+ " (0): Linear(in_features=16, out_features=32, bias=True)\n",
291
+ " )\n",
292
+ " (sigmoid): Sigmoid()\n",
293
+ " )\n",
294
+ " (ca): CoordAtt()\n",
295
+ " (proj_conv): DynamicConv(\n",
296
+ " (residuals): Sequential(\n",
297
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
298
+ " )\n",
299
+ " )\n",
300
+ " (proj_norm): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
301
+ " (context_gen): ContextGen(\n",
302
+ " (joint_conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
303
+ " (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
304
+ " (joint_act): Hardswish()\n",
305
+ " (conv_f): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))\n",
306
+ " (conv_t): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))\n",
307
+ " (pool_f): Sequential()\n",
308
+ " (pool_t): Sequential()\n",
309
+ " )\n",
310
+ " )\n",
311
+ " (1): DY_Block(\n",
312
+ " (exp_conv): DynamicConv(\n",
313
+ " (residuals): Sequential(\n",
314
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
315
+ " )\n",
316
+ " )\n",
317
+ " (exp_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
318
+ " (exp_act): DynamicWrapper(\n",
319
+ " (module): ReLU(inplace=True)\n",
320
+ " )\n",
321
+ " (depth_conv): DynamicConv(\n",
322
+ " (residuals): Sequential(\n",
323
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
324
+ " )\n",
325
+ " )\n",
326
+ " (depth_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
327
+ " (depth_act): DyReLUB(\n",
328
+ " (coef_net): Sequential(\n",
329
+ " (0): Linear(in_features=16, out_features=96, bias=True)\n",
330
+ " )\n",
331
+ " (sigmoid): Sigmoid()\n",
332
+ " )\n",
333
+ " (ca): CoordAtt()\n",
334
+ " (proj_conv): DynamicConv(\n",
335
+ " (residuals): Sequential(\n",
336
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
337
+ " )\n",
338
+ " )\n",
339
+ " (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
340
+ " (context_gen): ContextGen(\n",
341
+ " (joint_conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
342
+ " (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
343
+ " (joint_act): Hardswish()\n",
344
+ " (conv_f): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1))\n",
345
+ " (conv_t): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1))\n",
346
+ " (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
347
+ " (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (2): DY_Block(\n",
351
+ " (exp_conv): DynamicConv(\n",
352
+ " (residuals): Sequential(\n",
353
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
354
+ " )\n",
355
+ " )\n",
356
+ " (exp_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
357
+ " (exp_act): DynamicWrapper(\n",
358
+ " (module): ReLU(inplace=True)\n",
359
+ " )\n",
360
+ " (depth_conv): DynamicConv(\n",
361
+ " (residuals): Sequential(\n",
362
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
363
+ " )\n",
364
+ " )\n",
365
+ " (depth_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
366
+ " (depth_act): DyReLUB(\n",
367
+ " (coef_net): Sequential(\n",
368
+ " (0): Linear(in_features=16, out_features=128, bias=True)\n",
369
+ " )\n",
370
+ " (sigmoid): Sigmoid()\n",
371
+ " )\n",
372
+ " (ca): CoordAtt()\n",
373
+ " (proj_conv): DynamicConv(\n",
374
+ " (residuals): Sequential(\n",
375
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
376
+ " )\n",
377
+ " )\n",
378
+ " (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
379
+ " (context_gen): ContextGen(\n",
380
+ " (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
381
+ " (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
382
+ " (joint_act): Hardswish()\n",
383
+ " (conv_f): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
384
+ " (conv_t): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
385
+ " (pool_f): Sequential()\n",
386
+ " (pool_t): Sequential()\n",
387
+ " )\n",
388
+ " )\n",
389
+ " (3): DY_Block(\n",
390
+ " (exp_conv): DynamicConv(\n",
391
+ " (residuals): Sequential(\n",
392
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
393
+ " )\n",
394
+ " )\n",
395
+ " (exp_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
396
+ " (exp_act): DynamicWrapper(\n",
397
+ " (module): ReLU(inplace=True)\n",
398
+ " )\n",
399
+ " (depth_conv): DynamicConv(\n",
400
+ " (residuals): Sequential(\n",
401
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
402
+ " )\n",
403
+ " )\n",
404
+ " (depth_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
405
+ " (depth_act): DyReLUB(\n",
406
+ " (coef_net): Sequential(\n",
407
+ " (0): Linear(in_features=16, out_features=128, bias=True)\n",
408
+ " )\n",
409
+ " (sigmoid): Sigmoid()\n",
410
+ " )\n",
411
+ " (ca): CoordAtt()\n",
412
+ " (proj_conv): DynamicConv(\n",
413
+ " (residuals): Sequential(\n",
414
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
415
+ " )\n",
416
+ " )\n",
417
+ " (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
418
+ " (context_gen): ContextGen(\n",
419
+ " (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
420
+ " (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
421
+ " (joint_act): Hardswish()\n",
422
+ " (conv_f): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
423
+ " (conv_t): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
424
+ " (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
425
+ " (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
426
+ " )\n",
427
+ " )\n",
428
+ " (4-5): 2 x DY_Block(\n",
429
+ " (exp_conv): DynamicConv(\n",
430
+ " (residuals): Sequential(\n",
431
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
432
+ " )\n",
433
+ " )\n",
434
+ " (exp_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
435
+ " (exp_act): DynamicWrapper(\n",
436
+ " (module): ReLU(inplace=True)\n",
437
+ " )\n",
438
+ " (depth_conv): DynamicConv(\n",
439
+ " (residuals): Sequential(\n",
440
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
441
+ " )\n",
442
+ " )\n",
443
+ " (depth_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
444
+ " (depth_act): DyReLUB(\n",
445
+ " (coef_net): Sequential(\n",
446
+ " (0): Linear(in_features=16, out_features=192, bias=True)\n",
447
+ " )\n",
448
+ " (sigmoid): Sigmoid()\n",
449
+ " )\n",
450
+ " (ca): CoordAtt()\n",
451
+ " (proj_conv): DynamicConv(\n",
452
+ " (residuals): Sequential(\n",
453
+ " (0): Linear(in_features=16, out_features=4, bias=True)\n",
454
+ " )\n",
455
+ " )\n",
456
+ " (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
457
+ " (context_gen): ContextGen(\n",
458
+ " (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
459
+ " (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
460
+ " (joint_act): Hardswish()\n",
461
+ " (conv_f): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1))\n",
462
+ " (conv_t): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1))\n",
463
+ " (pool_f): Sequential()\n",
464
+ " (pool_t): Sequential()\n",
465
+ " )\n",
466
+ " )\n",
467
+ " (6): DY_Block(\n",
468
+ " (exp_conv): DynamicConv(\n",
469
+ " (residuals): Sequential(\n",
470
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
471
+ " )\n",
472
+ " )\n",
473
+ " (exp_norm): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
474
+ " (exp_act): DynamicWrapper(\n",
475
+ " (module): Hardswish()\n",
476
+ " )\n",
477
+ " (depth_conv): DynamicConv(\n",
478
+ " (residuals): Sequential(\n",
479
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
480
+ " )\n",
481
+ " )\n",
482
+ " (depth_norm): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
483
+ " (depth_act): DyReLUB(\n",
484
+ " (coef_net): Sequential(\n",
485
+ " (0): Linear(in_features=24, out_features=384, bias=True)\n",
486
+ " )\n",
487
+ " (sigmoid): Sigmoid()\n",
488
+ " )\n",
489
+ " (ca): CoordAtt()\n",
490
+ " (proj_conv): DynamicConv(\n",
491
+ " (residuals): Sequential(\n",
492
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
493
+ " )\n",
494
+ " )\n",
495
+ " (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
496
+ " (context_gen): ContextGen(\n",
497
+ " (joint_conv): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
498
+ " (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
499
+ " (joint_act): Hardswish()\n",
500
+ " (conv_f): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
501
+ " (conv_t): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
502
+ " (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
503
+ " (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
504
+ " )\n",
505
+ " )\n",
506
+ " (7): DY_Block(\n",
507
+ " (exp_conv): DynamicConv(\n",
508
+ " (residuals): Sequential(\n",
509
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
510
+ " )\n",
511
+ " )\n",
512
+ " (exp_norm): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
513
+ " (exp_act): DynamicWrapper(\n",
514
+ " (module): Hardswish()\n",
515
+ " )\n",
516
+ " (depth_conv): DynamicConv(\n",
517
+ " (residuals): Sequential(\n",
518
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
519
+ " )\n",
520
+ " )\n",
521
+ " (depth_norm): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
522
+ " (depth_act): DyReLUB(\n",
523
+ " (coef_net): Sequential(\n",
524
+ " (0): Linear(in_features=24, out_features=320, bias=True)\n",
525
+ " )\n",
526
+ " (sigmoid): Sigmoid()\n",
527
+ " )\n",
528
+ " (ca): CoordAtt()\n",
529
+ " (proj_conv): DynamicConv(\n",
530
+ " (residuals): Sequential(\n",
531
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
532
+ " )\n",
533
+ " )\n",
534
+ " (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
535
+ " (context_gen): ContextGen(\n",
536
+ " (joint_conv): Conv2d(32, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
537
+ " (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
538
+ " (joint_act): Hardswish()\n",
539
+ " (conv_f): Conv2d(24, 80, kernel_size=(1, 1), stride=(1, 1))\n",
540
+ " (conv_t): Conv2d(24, 80, kernel_size=(1, 1), stride=(1, 1))\n",
541
+ " (pool_f): Sequential()\n",
542
+ " (pool_t): Sequential()\n",
543
+ " )\n",
544
+ " )\n",
545
+ " (8-9): 2 x DY_Block(\n",
546
+ " (exp_conv): DynamicConv(\n",
547
+ " (residuals): Sequential(\n",
548
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
549
+ " )\n",
550
+ " )\n",
551
+ " (exp_norm): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
552
+ " (exp_act): DynamicWrapper(\n",
553
+ " (module): Hardswish()\n",
554
+ " )\n",
555
+ " (depth_conv): DynamicConv(\n",
556
+ " (residuals): Sequential(\n",
557
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
558
+ " )\n",
559
+ " )\n",
560
+ " (depth_norm): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
561
+ " (depth_act): DyReLUB(\n",
562
+ " (coef_net): Sequential(\n",
563
+ " (0): Linear(in_features=24, out_features=288, bias=True)\n",
564
+ " )\n",
565
+ " (sigmoid): Sigmoid()\n",
566
+ " )\n",
567
+ " (ca): CoordAtt()\n",
568
+ " (proj_conv): DynamicConv(\n",
569
+ " (residuals): Sequential(\n",
570
+ " (0): Linear(in_features=24, out_features=4, bias=True)\n",
571
+ " )\n",
572
+ " )\n",
573
+ " (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
574
+ " (context_gen): ContextGen(\n",
575
+ " (joint_conv): Conv2d(32, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
576
+ " (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
577
+ " (joint_act): Hardswish()\n",
578
+ " (conv_f): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))\n",
579
+ " (conv_t): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))\n",
580
+ " (pool_f): Sequential()\n",
581
+ " (pool_t): Sequential()\n",
582
+ " )\n",
583
+ " )\n",
584
+ " (10): DY_Block(\n",
585
+ " (exp_conv): DynamicConv(\n",
586
+ " (residuals): Sequential(\n",
587
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
588
+ " )\n",
589
+ " )\n",
590
+ " (exp_norm): BatchNorm2d(192, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
591
+ " (exp_act): DynamicWrapper(\n",
592
+ " (module): Hardswish()\n",
593
+ " )\n",
594
+ " (depth_conv): DynamicConv(\n",
595
+ " (residuals): Sequential(\n",
596
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
597
+ " )\n",
598
+ " )\n",
599
+ " (depth_norm): BatchNorm2d(192, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
600
+ " (depth_act): DyReLUB(\n",
601
+ " (coef_net): Sequential(\n",
602
+ " (0): Linear(in_features=48, out_features=768, bias=True)\n",
603
+ " )\n",
604
+ " (sigmoid): Sigmoid()\n",
605
+ " )\n",
606
+ " (ca): CoordAtt()\n",
607
+ " (proj_conv): DynamicConv(\n",
608
+ " (residuals): Sequential(\n",
609
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
610
+ " )\n",
611
+ " )\n",
612
+ " (proj_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
613
+ " (context_gen): ContextGen(\n",
614
+ " (joint_conv): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
615
+ " (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
616
+ " (joint_act): Hardswish()\n",
617
+ " (conv_f): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))\n",
618
+ " (conv_t): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))\n",
619
+ " (pool_f): Sequential()\n",
620
+ " (pool_t): Sequential()\n",
621
+ " )\n",
622
+ " )\n",
623
+ " (11): DY_Block(\n",
624
+ " (exp_conv): DynamicConv(\n",
625
+ " (residuals): Sequential(\n",
626
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
627
+ " )\n",
628
+ " )\n",
629
+ " (exp_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
630
+ " (exp_act): DynamicWrapper(\n",
631
+ " (module): Hardswish()\n",
632
+ " )\n",
633
+ " (depth_conv): DynamicConv(\n",
634
+ " (residuals): Sequential(\n",
635
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
636
+ " )\n",
637
+ " )\n",
638
+ " (depth_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
639
+ " (depth_act): DyReLUB(\n",
640
+ " (coef_net): Sequential(\n",
641
+ " (0): Linear(in_features=48, out_features=1088, bias=True)\n",
642
+ " )\n",
643
+ " (sigmoid): Sigmoid()\n",
644
+ " )\n",
645
+ " (ca): CoordAtt()\n",
646
+ " (proj_conv): DynamicConv(\n",
647
+ " (residuals): Sequential(\n",
648
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
649
+ " )\n",
650
+ " )\n",
651
+ " (proj_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
652
+ " (context_gen): ContextGen(\n",
653
+ " (joint_conv): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
654
+ " (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
655
+ " (joint_act): Hardswish()\n",
656
+ " (conv_f): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
657
+ " (conv_t): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
658
+ " (pool_f): Sequential()\n",
659
+ " (pool_t): Sequential()\n",
660
+ " )\n",
661
+ " )\n",
662
+ " (12): DY_Block(\n",
663
+ " (exp_conv): DynamicConv(\n",
664
+ " (residuals): Sequential(\n",
665
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
666
+ " )\n",
667
+ " )\n",
668
+ " (exp_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
669
+ " (exp_act): DynamicWrapper(\n",
670
+ " (module): Hardswish()\n",
671
+ " )\n",
672
+ " (depth_conv): DynamicConv(\n",
673
+ " (residuals): Sequential(\n",
674
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
675
+ " )\n",
676
+ " )\n",
677
+ " (depth_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
678
+ " (depth_act): DyReLUB(\n",
679
+ " (coef_net): Sequential(\n",
680
+ " (0): Linear(in_features=48, out_features=1088, bias=True)\n",
681
+ " )\n",
682
+ " (sigmoid): Sigmoid()\n",
683
+ " )\n",
684
+ " (ca): CoordAtt()\n",
685
+ " (proj_conv): DynamicConv(\n",
686
+ " (residuals): Sequential(\n",
687
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
688
+ " )\n",
689
+ " )\n",
690
+ " (proj_norm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
691
+ " (context_gen): ContextGen(\n",
692
+ " (joint_conv): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
693
+ " (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
694
+ " (joint_act): Hardswish()\n",
695
+ " (conv_f): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
696
+ " (conv_t): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
697
+ " (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
698
+ " (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
699
+ " )\n",
700
+ " )\n",
701
+ " (13-14): 2 x DY_Block(\n",
702
+ " (exp_conv): DynamicConv(\n",
703
+ " (residuals): Sequential(\n",
704
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
705
+ " )\n",
706
+ " )\n",
707
+ " (exp_norm): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
708
+ " (exp_act): DynamicWrapper(\n",
709
+ " (module): Hardswish()\n",
710
+ " )\n",
711
+ " (depth_conv): DynamicConv(\n",
712
+ " (residuals): Sequential(\n",
713
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
714
+ " )\n",
715
+ " )\n",
716
+ " (depth_norm): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
717
+ " (depth_act): DyReLUB(\n",
718
+ " (coef_net): Sequential(\n",
719
+ " (0): Linear(in_features=48, out_features=1536, bias=True)\n",
720
+ " )\n",
721
+ " (sigmoid): Sigmoid()\n",
722
+ " )\n",
723
+ " (ca): CoordAtt()\n",
724
+ " (proj_conv): DynamicConv(\n",
725
+ " (residuals): Sequential(\n",
726
+ " (0): Linear(in_features=48, out_features=4, bias=True)\n",
727
+ " )\n",
728
+ " )\n",
729
+ " (proj_norm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
730
+ " (context_gen): ContextGen(\n",
731
+ " (joint_conv): Conv2d(64, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
732
+ " (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
733
+ " (joint_act): Hardswish()\n",
734
+ " (conv_f): Conv2d(48, 384, kernel_size=(1, 1), stride=(1, 1))\n",
735
+ " (conv_t): Conv2d(48, 384, kernel_size=(1, 1), stride=(1, 1))\n",
736
+ " (pool_f): Sequential()\n",
737
+ " (pool_t): Sequential()\n",
738
+ " )\n",
739
+ " )\n",
740
+ " )\n",
741
+ " (in_c): Conv2dNormActivation(\n",
742
+ " (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
743
+ " (1): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
744
+ " (2): Hardswish()\n",
745
+ " )\n",
746
+ " (out_c): Conv2dNormActivation(\n",
747
+ " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
748
+ " (1): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
749
+ " (2): Hardswish()\n",
750
+ " )\n",
751
+ " (classifier): Sequential(\n",
752
+ " (0): AdaptiveAvgPool2d(output_size=1)\n",
753
+ " (1): Flatten(start_dim=1, end_dim=-1)\n",
754
+ " (2): Linear(in_features=384, out_features=512, bias=True)\n",
755
+ " (3): Hardswish()\n",
756
+ " (4): Dropout(p=0.2, inplace=True)\n",
757
+ " (5): Linear(in_features=512, out_features=7, bias=True)\n",
758
+ " )\n",
759
+ ")\n",
760
+ "Dataset from ./datasets/train_mp3.hdf5 with length 36878.\n",
761
+ "Warning: sample_weight_offset=100 minnow=1355.0\n",
762
+ "Dataset from ./datasets/test_mp3.hdf5 with length 9261.\n",
763
+ "Epoch 1/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 30.0\n",
764
+ "Setting temperature for attention over kernels to 30.0\n",
765
+ "Setting temperature for attention over kernels to 30.0\n",
766
+ "Setting temperature for attention over kernels to 30.0\n",
767
+ "Setting temperature for attention over kernels to 30.0\n",
768
+ "Setting temperature for attention over kernels to 30.0\n",
769
+ "Setting temperature for attention over kernels to 30.0\n",
770
+ "Setting temperature for attention over kernels to 30.0\n",
771
+ "Setting temperature for attention over kernels to 30.0\n",
772
+ "Setting temperature for attention over kernels to 30.0\n",
773
+ "Setting temperature for attention over kernels to 30.0\n",
774
+ "Setting temperature for attention over kernels to 30.0\n",
775
+ "Setting temperature for attention over kernels to 30.0\n",
776
+ "Setting temperature for attention over kernels to 30.0\n",
777
+ "Setting temperature for attention over kernels to 30.0\n",
778
+ "Setting temperature for attention over kernels to 30.0\n",
779
+ "Setting temperature for attention over kernels to 30.0\n",
780
+ "Setting temperature for attention over kernels to 30.0\n",
781
+ "Setting temperature for attention over kernels to 30.0\n",
782
+ "Setting temperature for attention over kernels to 30.0\n",
783
+ "Setting temperature for attention over kernels to 30.0\n",
784
+ "Setting temperature for attention over kernels to 30.0\n",
785
+ "Setting temperature for attention over kernels to 30.0\n",
786
+ "Setting temperature for attention over kernels to 30.0\n",
787
+ "Setting temperature for attention over kernels to 30.0\n",
788
+ "Setting temperature for attention over kernels to 30.0\n",
789
+ "Setting temperature for attention over kernels to 30.0\n",
790
+ "Setting temperature for attention over kernels to 30.0\n",
791
+ "Setting temperature for attention over kernels to 30.0\n",
792
+ "Setting temperature for attention over kernels to 30.0\n",
793
+ "Setting temperature for attention over kernels to 30.0\n",
794
+ "Setting temperature for attention over kernels to 30.0\n",
795
+ "Setting temperature for attention over kernels to 30.0\n",
796
+ "Setting temperature for attention over kernels to 30.0\n",
797
+ "Setting temperature for attention over kernels to 30.0\n",
798
+ "Setting temperature for attention over kernels to 30.0\n",
799
+ "Setting temperature for attention over kernels to 30.0\n",
800
+ "Setting temperature for attention over kernels to 30.0\n",
801
+ "Setting temperature for attention over kernels to 30.0\n",
802
+ "Setting temperature for attention over kernels to 30.0\n",
803
+ "Setting temperature for attention over kernels to 30.0\n",
804
+ "Setting temperature for attention over kernels to 30.0\n",
805
+ "Setting temperature for attention over kernels to 30.0\n",
806
+ "Setting temperature for attention over kernels to 30.0\n",
807
+ "Epoch 1/200: 100%|████████████| 145/145 [01:15<00:00, 1.93it/s, train_loss=1.8]\n",
808
+ "Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.87it/s]\n",
809
+ "Confusion Matrix:\n",
810
+ "[[ 529 0 0 76 656 19 0]\n",
811
+ " [ 70 2 0 6 799 5 0]\n",
812
+ " [ 99 0 0 1 311 11 0]\n",
813
+ " [ 440 0 0 104 1255 22 0]\n",
814
+ " [ 115 0 1 34 2903 12 0]\n",
815
+ " [ 195 0 2 51 1063 150 0]\n",
816
+ " [ 69 0 0 4 239 18 0]]\n",
817
+ "Epoch 1/200, Train Loss: 1.8010, Validation Loss: 1.7283, Validation Accuracy: 0.3982, LR: 0.000012, Time: 84.82s\n",
818
+ "Epoch 2/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 29.0\n",
819
+ "Setting temperature for attention over kernels to 29.0\n",
820
+ "Setting temperature for attention over kernels to 29.0\n",
821
+ "Setting temperature for attention over kernels to 29.0\n",
822
+ "Setting temperature for attention over kernels to 29.0\n",
823
+ "Setting temperature for attention over kernels to 29.0\n",
824
+ "Setting temperature for attention over kernels to 29.0\n",
825
+ "Setting temperature for attention over kernels to 29.0\n",
826
+ "Setting temperature for attention over kernels to 29.0\n",
827
+ "Setting temperature for attention over kernels to 29.0\n",
828
+ "Setting temperature for attention over kernels to 29.0\n",
829
+ "Setting temperature for attention over kernels to 29.0\n",
830
+ "Setting temperature for attention over kernels to 29.0\n",
831
+ "Setting temperature for attention over kernels to 29.0\n",
832
+ "Setting temperature for attention over kernels to 29.0\n",
833
+ "Setting temperature for attention over kernels to 29.0\n",
834
+ "Setting temperature for attention over kernels to 29.0\n",
835
+ "Setting temperature for attention over kernels to 29.0\n",
836
+ "Setting temperature for attention over kernels to 29.0\n",
837
+ "Setting temperature for attention over kernels to 29.0\n",
838
+ "Setting temperature for attention over kernels to 29.0\n",
839
+ "Setting temperature for attention over kernels to 29.0\n",
840
+ "Setting temperature for attention over kernels to 29.0\n",
841
+ "Setting temperature for attention over kernels to 29.0\n",
842
+ "Setting temperature for attention over kernels to 29.0\n",
843
+ "Setting temperature for attention over kernels to 29.0\n",
844
+ "Setting temperature for attention over kernels to 29.0\n",
845
+ "Setting temperature for attention over kernels to 29.0\n",
846
+ "Setting temperature for attention over kernels to 29.0\n",
847
+ "Setting temperature for attention over kernels to 29.0\n",
848
+ "Setting temperature for attention over kernels to 29.0\n",
849
+ "Setting temperature for attention over kernels to 29.0\n",
850
+ "Setting temperature for attention over kernels to 29.0\n",
851
+ "Setting temperature for attention over kernels to 29.0\n",
852
+ "Setting temperature for attention over kernels to 29.0\n",
853
+ "Setting temperature for attention over kernels to 29.0\n",
854
+ "Setting temperature for attention over kernels to 29.0\n",
855
+ "Setting temperature for attention over kernels to 29.0\n",
856
+ "Setting temperature for attention over kernels to 29.0\n",
857
+ "Setting temperature for attention over kernels to 29.0\n",
858
+ "Setting temperature for attention over kernels to 29.0\n",
859
+ "Setting temperature for attention over kernels to 29.0\n",
860
+ "Setting temperature for attention over kernels to 29.0\n",
861
+ "Setting temperature for attention over kernels to 29.0\n",
862
+ "Epoch 2/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.52]\n",
863
+ "Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.87it/s]\n",
864
+ "Confusion Matrix:\n",
865
+ "[[ 438 3 0 324 484 31 0]\n",
866
+ " [ 41 38 0 51 744 8 0]\n",
867
+ " [ 80 7 0 26 294 15 0]\n",
868
+ " [ 158 5 0 519 1084 55 0]\n",
869
+ " [ 30 4 0 54 2936 41 0]\n",
870
+ " [ 52 10 0 140 956 303 0]\n",
871
+ " [ 37 11 0 21 232 29 0]]\n",
872
+ "Epoch 2/200, Train Loss: 1.5245, Validation Loss: 1.6068, Validation Accuracy: 0.4572, LR: 0.000022, Time: 84.03s\n",
873
+ "Epoch 3/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 28.0\n",
874
+ "Setting temperature for attention over kernels to 28.0\n",
875
+ "Setting temperature for attention over kernels to 28.0\n",
876
+ "Setting temperature for attention over kernels to 28.0\n",
877
+ "Setting temperature for attention over kernels to 28.0\n",
878
+ "Setting temperature for attention over kernels to 28.0\n",
879
+ "Setting temperature for attention over kernels to 28.0\n",
880
+ "Setting temperature for attention over kernels to 28.0\n",
881
+ "Setting temperature for attention over kernels to 28.0\n",
882
+ "Setting temperature for attention over kernels to 28.0\n",
883
+ "Setting temperature for attention over kernels to 28.0\n",
884
+ "Setting temperature for attention over kernels to 28.0\n",
885
+ "Setting temperature for attention over kernels to 28.0\n",
886
+ "Setting temperature for attention over kernels to 28.0\n",
887
+ "Setting temperature for attention over kernels to 28.0\n",
888
+ "Setting temperature for attention over kernels to 28.0\n",
889
+ "Setting temperature for attention over kernels to 28.0\n",
890
+ "Setting temperature for attention over kernels to 28.0\n",
891
+ "Setting temperature for attention over kernels to 28.0\n",
892
+ "Setting temperature for attention over kernels to 28.0\n",
893
+ "Setting temperature for attention over kernels to 28.0\n",
894
+ "Setting temperature for attention over kernels to 28.0\n",
895
+ "Setting temperature for attention over kernels to 28.0\n",
896
+ "Setting temperature for attention over kernels to 28.0\n",
897
+ "Setting temperature for attention over kernels to 28.0\n",
898
+ "Setting temperature for attention over kernels to 28.0\n",
899
+ "Setting temperature for attention over kernels to 28.0\n",
900
+ "Setting temperature for attention over kernels to 28.0\n",
901
+ "Setting temperature for attention over kernels to 28.0\n",
902
+ "Setting temperature for attention over kernels to 28.0\n",
903
+ "Setting temperature for attention over kernels to 28.0\n",
904
+ "Setting temperature for attention over kernels to 28.0\n",
905
+ "Setting temperature for attention over kernels to 28.0\n",
906
+ "Setting temperature for attention over kernels to 28.0\n",
907
+ "Setting temperature for attention over kernels to 28.0\n",
908
+ "Setting temperature for attention over kernels to 28.0\n",
909
+ "Setting temperature for attention over kernels to 28.0\n",
910
+ "Setting temperature for attention over kernels to 28.0\n",
911
+ "Setting temperature for attention over kernels to 28.0\n",
912
+ "Setting temperature for attention over kernels to 28.0\n",
913
+ "Setting temperature for attention over kernels to 28.0\n",
914
+ "Setting temperature for attention over kernels to 28.0\n",
915
+ "Setting temperature for attention over kernels to 28.0\n",
916
+ "Setting temperature for attention over kernels to 28.0\n",
917
+ "Epoch 3/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.36]\n",
918
+ "Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.88it/s]\n",
919
+ "Confusion Matrix:\n",
920
+ "[[ 112 0 0 723 324 121 0]\n",
921
+ " [ 13 139 0 160 518 52 0]\n",
922
+ " [ 9 0 0 148 199 66 0]\n",
923
+ " [ 5 1 0 950 765 100 0]\n",
924
+ " [ 6 0 0 222 2709 128 0]\n",
925
+ " [ 5 1 0 224 567 664 0]\n",
926
+ " [ 5 14 0 68 149 94 0]]\n",
927
+ "Epoch 3/200, Train Loss: 1.3576, Validation Loss: 1.6087, Validation Accuracy: 0.4939, LR: 0.000060, Time: 83.90s\n",
928
+ "Epoch 4/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 27.0\n",
929
+ "Setting temperature for attention over kernels to 27.0\n",
930
+ "Setting temperature for attention over kernels to 27.0\n",
931
+ "Setting temperature for attention over kernels to 27.0\n",
932
+ "Setting temperature for attention over kernels to 27.0\n",
933
+ "Setting temperature for attention over kernels to 27.0\n",
934
+ "Setting temperature for attention over kernels to 27.0\n",
935
+ "Setting temperature for attention over kernels to 27.0\n",
936
+ "Setting temperature for attention over kernels to 27.0\n",
937
+ "Setting temperature for attention over kernels to 27.0\n",
938
+ "Setting temperature for attention over kernels to 27.0\n",
939
+ "Setting temperature for attention over kernels to 27.0\n",
940
+ "Setting temperature for attention over kernels to 27.0\n",
941
+ "Setting temperature for attention over kernels to 27.0\n",
942
+ "Setting temperature for attention over kernels to 27.0\n",
943
+ "Setting temperature for attention over kernels to 27.0\n",
944
+ "Setting temperature for attention over kernels to 27.0\n",
945
+ "Setting temperature for attention over kernels to 27.0\n",
946
+ "Setting temperature for attention over kernels to 27.0\n",
947
+ "Setting temperature for attention over kernels to 27.0\n",
948
+ "Setting temperature for attention over kernels to 27.0\n",
949
+ "Setting temperature for attention over kernels to 27.0\n",
950
+ "Setting temperature for attention over kernels to 27.0\n",
951
+ "Setting temperature for attention over kernels to 27.0\n",
952
+ "Setting temperature for attention over kernels to 27.0\n",
953
+ "Setting temperature for attention over kernels to 27.0\n",
954
+ "Setting temperature for attention over kernels to 27.0\n",
955
+ "Setting temperature for attention over kernels to 27.0\n",
956
+ "Setting temperature for attention over kernels to 27.0\n",
957
+ "Setting temperature for attention over kernels to 27.0\n",
958
+ "Setting temperature for attention over kernels to 27.0\n",
959
+ "Setting temperature for attention over kernels to 27.0\n",
960
+ "Setting temperature for attention over kernels to 27.0\n",
961
+ "Setting temperature for attention over kernels to 27.0\n",
962
+ "Setting temperature for attention over kernels to 27.0\n",
963
+ "Setting temperature for attention over kernels to 27.0\n",
964
+ "Setting temperature for attention over kernels to 27.0\n",
965
+ "Setting temperature for attention over kernels to 27.0\n",
966
+ "Setting temperature for attention over kernels to 27.0\n",
967
+ "Setting temperature for attention over kernels to 27.0\n",
968
+ "Setting temperature for attention over kernels to 27.0\n",
969
+ "Setting temperature for attention over kernels to 27.0\n",
970
+ "Setting temperature for attention over kernels to 27.0\n",
971
+ "Setting temperature for attention over kernels to 27.0\n",
972
+ "Epoch 4/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.26]\n",
973
+ "Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.85it/s]\n",
974
+ "Confusion Matrix:\n",
975
+ "[[ 28 2 0 472 590 188 0]\n",
976
+ " [ 0 198 0 60 603 21 0]\n",
977
+ " [ 0 0 0 37 338 47 0]\n",
978
+ " [ 0 1 0 659 1077 84 0]\n",
979
+ " [ 0 0 0 21 3000 44 0]\n",
980
+ " [ 0 0 0 42 1005 414 0]\n",
981
+ " [ 0 15 0 7 213 95 0]]\n",
982
+ "Epoch 4/200, Train Loss: 1.2573, Validation Loss: 1.8212, Validation Accuracy: 0.4642, LR: 0.000142, Time: 84.06s\n",
983
+ "Epoch 5/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 26.0\n",
984
+ "Setting temperature for attention over kernels to 26.0\n",
985
+ "Setting temperature for attention over kernels to 26.0\n",
986
+ "Setting temperature for attention over kernels to 26.0\n",
987
+ "Setting temperature for attention over kernels to 26.0\n",
988
+ "Setting temperature for attention over kernels to 26.0\n",
989
+ "Setting temperature for attention over kernels to 26.0\n",
990
+ "Setting temperature for attention over kernels to 26.0\n",
991
+ "Setting temperature for attention over kernels to 26.0\n",
992
+ "Setting temperature for attention over kernels to 26.0\n",
993
+ "Setting temperature for attention over kernels to 26.0\n",
994
+ "Setting temperature for attention over kernels to 26.0\n",
995
+ "Setting temperature for attention over kernels to 26.0\n",
996
+ "Setting temperature for attention over kernels to 26.0\n",
997
+ "Setting temperature for attention over kernels to 26.0\n",
998
+ "Setting temperature for attention over kernels to 26.0\n",
999
+ "Setting temperature for attention over kernels to 26.0\n",
1000
+ "Setting temperature for attention over kernels to 26.0\n",
1001
+ "Setting temperature for attention over kernels to 26.0\n",
1002
+ "Setting temperature for attention over kernels to 26.0\n",
1003
+ "Setting temperature for attention over kernels to 26.0\n",
1004
+ "Setting temperature for attention over kernels to 26.0\n",
1005
+ "Setting temperature for attention over kernels to 26.0\n",
1006
+ "Setting temperature for attention over kernels to 26.0\n",
1007
+ "Setting temperature for attention over kernels to 26.0\n",
1008
+ "Setting temperature for attention over kernels to 26.0\n",
1009
+ "Setting temperature for attention over kernels to 26.0\n",
1010
+ "Setting temperature for attention over kernels to 26.0\n",
1011
+ "Setting temperature for attention over kernels to 26.0\n",
1012
+ "Setting temperature for attention over kernels to 26.0\n",
1013
+ "Setting temperature for attention over kernels to 26.0\n",
1014
+ "Setting temperature for attention over kernels to 26.0\n",
1015
+ "Setting temperature for attention over kernels to 26.0\n",
1016
+ "Setting temperature for attention over kernels to 26.0\n",
1017
+ "Setting temperature for attention over kernels to 26.0\n",
1018
+ "Setting temperature for attention over kernels to 26.0\n",
1019
+ "Setting temperature for attention over kernels to 26.0\n",
1020
+ "Setting temperature for attention over kernels to 26.0\n",
1021
+ "Setting temperature for attention over kernels to 26.0\n",
1022
+ "Setting temperature for attention over kernels to 26.0\n",
1023
+ "Setting temperature for attention over kernels to 26.0\n",
1024
+ "Setting temperature for attention over kernels to 26.0\n",
1025
+ "Setting temperature for attention over kernels to 26.0\n",
1026
+ "Setting temperature for attention over kernels to 26.0\n",
1027
+ "Epoch 5/200: 73%|████████ | 106/145 [00:55<00:19, 2.01it/s, train_loss=1.27]"
1028
+ ]
1029
+ }
1030
+ ],
1031
+ "source": [
1032
+ "!python ex_train_audio_classification_mode.py --cuda --train --pretrained --gain_augment=5 --n_epochs=200 --model_name=dymn04_im --batch_size=256 --max_lr=0.001 --pretrain_final_temp=30 --adamw"
1033
+ ]
1034
+ },
1035
+ {
1036
+ "cell_type": "code",
1037
+ "execution_count": 15,
1038
+ "id": "d3b0a027-1774-4725-9e29-0112001369fe",
1039
+ "metadata": {},
1040
+ "outputs": [],
1041
+ "source": [
1042
+ "!rm -rf '/workspace/train/EfficientAT-main/training_results_dymn04_im_20250816_031533'"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": null,
1048
+ "id": "fe6df804-8e48-41fc-9dab-5c409d591379",
1049
+ "metadata": {},
1050
+ "outputs": [],
1051
+ "source": [
1052
+ "!python inference_classification.py --cuda --pre_trained_model_path='/workspace/main/EfficientAT-main/output_model/best_checkpoint/dymn04_im04_acc_0.8923.pt' --audio_path=\"/workspace/main/EfficientAT-main/03-01-01-01-02-01-24.mp3\""
1053
+ ]
1054
+ },
1055
+ {
1056
+ "cell_type": "code",
1057
+ "execution_count": null,
1058
+ "id": "3c8074fe-eedf-4cf6-8ab3-1815e733e6b9",
1059
+ "metadata": {},
1060
+ "outputs": [],
1061
+ "source": [
1062
+ "import os\n",
1063
+ "import shutil\n",
1064
+ "\n",
1065
+ "# Mount Google Drive nếu cần\n",
1066
+ "\n",
1067
+ "# Đường dẫn tới thư mục cần zip\n",
1068
+ "folder_to_zip = '/workspace/main/EfficientAT-main'\n",
1069
+ "output_zip = '/workspace/main/EfficientAT-main.zip'\n",
1070
+ "\n",
1071
+ "# Tạo file zip\n",
1072
+ "shutil.make_archive(output_zip.replace('.zip', ''), 'zip', folder_to_zip)\n",
1073
+ "\n",
1074
+ "print(f'Thư mục đã được nén thành công thành file zip: {output_zip}')\n"
1075
+ ]
1076
+ },
1077
+ {
1078
+ "cell_type": "code",
1079
+ "execution_count": 12,
1080
+ "id": "3c819525-cde6-4c3b-a24d-0bfb5e5c2b01",
1081
+ "metadata": {},
1082
+ "outputs": [],
1083
+ "source": [
1084
+ "import numpy as np"
1085
+ ]
1086
+ },
1087
+ {
1088
+ "cell_type": "code",
1089
+ "execution_count": null,
1090
+ "id": "f3edf2d9-07d7-4472-86ec-81b800dcecbd",
1091
+ "metadata": {},
1092
+ "outputs": [],
1093
+ "source": []
1094
+ }
1095
+ ],
1096
+ "metadata": {
1097
+ "kernelspec": {
1098
+ "display_name": "Python 3 (ipykernel)",
1099
+ "language": "python",
1100
+ "name": "python3"
1101
+ },
1102
+ "language_info": {
1103
+ "codemirror_mode": {
1104
+ "name": "ipython",
1105
+ "version": 3
1106
+ },
1107
+ "file_extension": ".py",
1108
+ "mimetype": "text/x-python",
1109
+ "name": "python",
1110
+ "nbconvert_exporter": "python",
1111
+ "pygments_lexer": "ipython3",
1112
+ "version": "3.10.12"
1113
+ }
1114
+ },
1115
+ "nbformat": 4,
1116
+ "nbformat_minor": 5
1117
+ }