Upload 4 files
Browse files
Export_to_ONNX_TFLite/Export_model_and_testmodel.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"markdown","source":["#PT to ONNX"],"metadata":{"id":"qZcPxnj6n6kl"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"mRy88TxHd7KE","collapsed":true,"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040415700,"user_tz":-420,"elapsed":1754,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"983ce468-07ee-45ac-9d66-cab24e71691c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Cloning into 'EfficientAT'...\n","remote: Enumerating objects: 395, done.\u001b[K\n","remote: Counting objects: 100% (152/152), done.\u001b[K\n","remote: Compressing objects: 100% (37/37), done.\u001b[K\n","remote: Total 395 (delta 134), reused 115 (delta 115), pack-reused 243 (from 2)\u001b[K\n","Receiving objects: 100% (395/395), 2.64 MiB | 7.12 MiB/s, done.\n","Resolving deltas: 100% (241/241), done.\n"]}],"source":["!git clone https://github.com/fschmid56/EfficientAT.git"]},{"cell_type":"code","execution_count":2,"metadata":{"id":"_SG1Qcgz4zsy","colab":{"base_uri":"https://localhost:8080/"},"outputId":"9d9a9b8c-cd15-4a32-d9cb-b10c91ce4628","executionInfo":{"status":"ok","timestamp":1747040417249,"user_tz":-420,"elapsed":59,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/EfficientAT\n"]}],"source":["%cd EfficientAT"]},{"cell_type":"code","source":["!pip install -r requirements.txt"],"metadata":{"id":"0yBzntF5sRtg","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040438423,"user_tz":-420,"elapsed":19609,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"fa2adbab-4d7a-4cb0-9cfb-9a10e45d2310"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting av==10.0.0 (from -r requirements.txt (line 1))\n"," Downloading av-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)\n","Collecting h5py==3.7.0 (from -r requirements.txt (line 2))\n"," Downloading h5py-3.7.0.tar.gz (392 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m392.4/392.4 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n","Collecting librosa==0.9.2 (from -r requirements.txt (line 3))\n"," Downloading librosa-0.9.2-py3-none-any.whl.metadata (8.2 kB)\n","Collecting numpy==1.23.3 (from -r requirements.txt (line 4))\n"," Downloading numpy-1.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)\n","Collecting scikit_learn==1.1.3 (from -r requirements.txt (line 5))\n"," Downloading scikit_learn-1.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n","Collecting torch==1.13.0 (from -r requirements.txt (line 6))\n"," Downloading torch-1.13.0-cp311-cp311-manylinux1_x86_64.whl.metadata (24 kB)\n","\u001b[31mERROR: Ignored the following yanked versions: 2.0.0\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: Ignored the following versions that require a different python version: 1.21.2 Requires-Python >=3.7,<3.11; 1.21.3 Requires-Python >=3.7,<3.11; 1.21.4 Requires-Python >=3.7,<3.11; 1.21.5 Requires-Python >=3.7,<3.11; 1.21.6 Requires-Python >=3.7,<3.11\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: Could not find a version that satisfies the requirement torchaudio==0.13.0 (from versions: 2.0.1, 2.0.2, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0)\u001b[0m\u001b[31m\n","\u001b[0m\u001b[31mERROR: No matching distribution found for torchaudio==0.13.0\u001b[0m\u001b[31m\n","\u001b[0m"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"szVfWuNA47pi"},"outputs":[],"source":["!python inference.py --model_name=/content/EfficientAT/resources/dymn04_im04_emotion.pt --audio_path=\"/content/EfficientAT/resources/bea_Angry_anger_1-28_0009.mp3\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"01JmuQGb64rW","colab":{"base_uri":"https://localhost:8080/"},"outputId":"09ccc631-5fb5-4226-c523-22fe623d4373"},"outputs":[{"output_type":"stream","name":"stdout","text":["python3: can't open file '/content/EfficientAT/inference_numpy.py': [Errno 2] No such file or directory\n"]}],"source":["!python inference_numpy.py --model_name=dymn04_as --audio_path=\"/content/EfficientAT/resources/recording_1741088005371.wav\""]},{"cell_type":"markdown","source":["#Export to onnx"],"metadata":{"id":"1r5BgfkNA5rc"}},{"cell_type":"code","execution_count":4,"metadata":{"id":"e4GFRBPU7j-b","executionInfo":{"status":"ok","timestamp":1747040464648,"user_tz":-420,"elapsed":12639,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["import torch\n","from models.dymn.model import get_model as get_dymn\n","from helpers.utils import NAME_TO_WIDTH, labels\n","\n","# model = get_dymn(width_mult=NAME_TO_WIDTH('dymn04_as'), pretrained_name='dymn04_as', strides=[2, 2, 2, 2])\n","# model.to(torch.device('cpu'))\n","model = torch.load('/content/dymn04_im04_acc_0.7408.pt', map_location='cpu', weights_only=False)"]},{"cell_type":"code","source":["!pip install onnx onnxruntime"],"metadata":{"id":"01D0XYhkPfCL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1747040473170,"user_tz":-420,"elapsed":8523,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"fdf1845d-6c4c-4171-f12f-d55505bff4bd"},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting onnx\n"," Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n","Collecting onnxruntime\n"," Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)\n","Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.11/dist-packages (from onnx) (2.0.2)\n","Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.11/dist-packages (from onnx) (5.29.4)\n","Collecting coloredlogs (from onnxruntime)\n"," Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n","Requirement already satisfied: flatbuffers in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (25.2.10)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (24.2)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (1.13.1)\n","Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n"," Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->onnxruntime) (1.3.0)\n","Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.0/16.0 MB\u001b[0m \u001b[31m67.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.4 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.4/16.4 MB\u001b[0m \u001b[31m70.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: onnx, humanfriendly, coloredlogs, onnxruntime\n","Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnx-1.17.0 onnxruntime-1.22.0\n"]}]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"z_m2eIxQ3mnx","outputId":"ddc7d438-9d4c-4121-fe23-1e3f3b71a60d","executionInfo":{"status":"ok","timestamp":1747040487470,"user_tz":-420,"elapsed":9851,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stderr","text":["/content/EfficientAT/models/dymn/model.py:188: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n"," if return_fmaps:\n","/content/EfficientAT/models/dymn/dy_block.py:173: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n"," assert x.shape[1] == self.channels\n"]}],"source":["dummy_input = torch.randn(1, 1, 128, 400, device=\"cpu\")\n","\n","# Define dynamic axes for the input\n","# Here, we specify that the first dimension (batch size) is dynamic\n","dynamic_axes = {\n"," \"input\": {0: \"batch_size\", 3: \"audio_length\"}, # Input shape: (B, C, H, W)\n","}\n","\n","# Define input and output names\n","input_names = [\"input\"]\n","output_names = [\"output\"]\n","\n","# Export the model\n","torch.onnx.export(\n"," model,\n"," dummy_input,\n"," \"emotion.onnx\",\n"," opset_version=11\n",")"]},{"cell_type":"code","execution_count":10,"metadata":{"id":"mDyu47stDSbg","colab":{"base_uri":"https://localhost:8080/"},"outputId":"2c006904-0ebf-4cbc-c3f5-c6a3cc234076","executionInfo":{"status":"ok","timestamp":1747040650463,"user_tz":-420,"elapsed":19978,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["x shape before (1, 121728)\n","x shape after (1, 121727)\n","spec_x before is (1, 381, 513, 2)\n","spec_x shape is (1, 381, 513)\n","(1, 1, 128, 400)\n","[array([[ 3.432613 , -0.5243076, -0.6463419, -0.5853195, -0.5354723,\n"," -0.5868242, -0.6281568]], dtype=float32)]\n"]}],"source":["import onnxruntime as ort\n","import numpy as np\n","import librosa\n","# from preprocess import MelSTFT\n","\n","# Load the ONNX model\n","session = ort.InferenceSession(\"emotion.onnx\")\n","\n","# Get input and output details\n","input_name = session.get_inputs()[0].name\n","output_name = session.get_outputs()[0].name\n","\n","# preprocess input\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","(waveform, _) = librosa.core.load('/content/03-01-05-01-01-02-05_0.mp3', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","if spec.shape[-1] > 400:\n"," spec = spec[:, :, :400]\n","else:\n"," spec = np.pad(spec, ((0, 0), (0, 0), (0, 400 - spec.shape[-1])), mode='constant')\n","spec = np.expand_dims(spec, axis=0)\n","spec = spec.astype(np.float32)\n","print(spec.shape)\n","\n","# Run inference\n","output = session.run([output_name], {input_name: spec})\n","\n","# Print the output\n","print(output)"]},{"cell_type":"code","execution_count":12,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1GnlQkWu5_Fc","outputId":"a13d8c8b-896a-431c-a9bf-790226dea5b5","executionInfo":{"status":"ok","timestamp":1747040692360,"user_tz":-420,"elapsed":12,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["0\n"]}],"source":["len(output[0][0])\n","print(np.argmax(output[0][0]))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Gf__SsN55_yY"},"outputs":[],"source":["import numpy as np\n","\n","def sigmoid(x):\n"," return 1 / (1 + np.exp(-x))\n","\n","def softmax(x):\n"," exp_x = np.exp(x - np.max(x)) # Subtract max(x) for numerical stability\n"," return exp_x / np.sum(exp_x, axis=-1, keepdims=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RVPMJbnc6roX"},"outputs":[],"source":["preds = softmax(output[0][0])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OfR1epuQ6xPR"},"outputs":[],"source":["sorted_indexes = np.argsort(preds)[::-1]"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gCEf2tdJ6x6Q","outputId":"1133b892-d5dd-488e-c987-fc82bcba2812"},"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Angry: 0.914\n","sad: 0.015\n","Fear: 0.015\n","neutral: 0.014\n","Disgust: 0.014\n","Happy: 0.014\n","surprise: 0.014\n","********************************************************\n"]}],"source":["# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(7):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"]},{"cell_type":"markdown","source":["#Load sang onnx fix typepo"],"metadata":{"id":"zBJk6xG7A1e4"}},{"cell_type":"code","execution_count":2,"metadata":{"id":"uHKxnQ-w5ng4","executionInfo":{"status":"ok","timestamp":1747045236944,"user_tz":-420,"elapsed":41,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["import onnx\n","\n","# Load ONNX model\n","onnx_model = onnx.load(\"/content/emotion.onnx\")"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"u1vlWtAF6Le9","outputId":"1a6c6507-7703-4732-9ac8-d4cfc487f813","executionInfo":{"status":"ok","timestamp":1747045239246,"user_tz":-420,"elapsed":129,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Model is valid.\n"]}],"source":["# Rename inputs\n","for i, input_node in enumerate(onnx_model.graph.input):\n"," old_name = input_node.name\n"," new_name = f\"input_{i}\"\n"," input_node.name = new_name\n"," for node in onnx_model.graph.node:\n"," for j, input_name in enumerate(node.input):\n"," if input_name == old_name:\n"," node.input[j] = new_name\n","onnx.checker.check_model(onnx_model)\n","print(\"Model is valid.\")"]},{"cell_type":"code","execution_count":4,"metadata":{"id":"RyYwyJq2LbOz","executionInfo":{"status":"ok","timestamp":1747045251036,"user_tz":-420,"elapsed":88,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"outputs":[],"source":["onnx.save(onnx_model, \"fixed_emotion.onnx\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8_7au43LMHns"},"outputs":[],"source":["import onnxruntime as ort\n","import numpy as np\n","import librosa\n","# from preprocess import MelSTFT\n","\n","# Load the ONNX model\n","session = ort.InferenceSession(\"fixed_dymn04_as_4s.onnx\")\n","\n","# Get input and output details\n","input_name = session.get_inputs()[0].name\n","output_name = session.get_outputs()[0].name\n","\n","# preprocess input\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","(waveform, _) = librosa.core.load('/content/EfficientAT/resources/machine_gun_4_8.wav', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","spec = np.expand_dims(spec, axis=0)\n","spec = spec.astype(np.float32)\n","\n","# Run inference\n","output = session.run([output_name], {input_name: spec})"]},{"cell_type":"code","source":["preds = sigmoid(output[0][0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(10):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"yI4BVCzZu043","outputId":"2e5ba910-2a6e-42b5-ed89-72c6c6fe5d50"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Gunshot, gunfire: 0.646\n","Artillery fire: 0.574\n","Machine gun: 0.296\n","Speech: 0.276\n","Music: 0.058\n","Ding: 0.039\n","Outside, rural or natural: 0.029\n","Clang: 0.021\n","Outside, urban or manmade: 0.014\n","Sound effect: 0.013\n","********************************************************\n"]}]},{"cell_type":"code","source":["import numpy as np\n","import tensorflow as tf\n","\n","# Load the TFLite model\n","interpreter = tf.lite.Interpreter(model_path=\"/content/EfficientAT/fixed_dymn04_as_4s.tflite\")\n","interpreter.allocate_tensors()\n","\n","# Get input and output details\n","input_details = interpreter.get_input_details()\n","output_details = interpreter.get_output_details()"],"metadata":{"id":"tYAA2SXg4rAm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%%time\n","interpreter.set_tensor(input_details[0]['index'], spec)\n","\n","# Run inference\n","interpreter.invoke()\n","output_data = interpreter.get_tensor(output_details[0]['index'])\n","print(output_data)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lCURHCdb50NM","outputId":"b90fd17d-6435-47cd-bce4-d504667ba3a5"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[[ -0.84756243 -6.042206 -7.31538 -6.587422 -8.137321\n"," -6.8817334 -10.2073965 -7.0277896 -9.497076 -9.076806\n"," -7.9717975 -9.098734 -9.190367 -10.533342 -9.261131\n"," -9.558402 -8.0612755 -9.170078 -9.971562 -9.306829\n"," -10.51197 -9.227371 -8.791254 -7.1896567 -6.929867\n"," -6.8730307 -8.732161 -6.8715105 -7.095218 -9.634821\n"," -7.3266606 -6.496367 -6.867931 -7.2777967 -7.84522\n"," -9.982277 -9.040248 -9.472129 -6.970452 -6.7633314\n"," -8.596198 -8.716958 -11.37014 -9.757639 -9.384382\n"," -11.648908 -9.068704 -9.676451 -9.729941 -10.329362\n"," -11.040093 -7.60864 -9.644373 -7.8703647 -10.20223\n"," -11.044352 -11.488615 -11.821376 -8.239668 -9.515796\n"," -10.249389 -11.395409 -12.753719 -9.719765 -7.815147\n"," -8.667467 -8.783001 -10.783493 -9.077807 -6.502072\n"," -6.641275 -9.622646 -5.2108536 -6.345402 -7.260418\n"," -9.74338 -9.060996 -9.3590765 -9.125456 -9.142018\n"," -8.84309 -7.5067797 -10.2783375 -7.242247 -8.52377\n"," -9.487844 -7.1616087 -8.470291 -7.9428587 -10.482567\n"," -7.625914 -8.034529 -6.576645 -8.566756 -7.922304\n"," -7.995202 -6.649238 -6.5045075 -6.472725 -8.503324\n"," -9.17288 -9.43606 -8.100371 -8.338065 -6.9929743\n"," -6.991773 -7.0988426 -6.6684513 -7.944296 -8.298891\n"," -7.6659217 -6.260497 -7.467764 -7.7779393 -12.417544\n"," -8.141451 -8.614699 -8.081804 -8.368467 -9.225835\n"," -11.658823 -11.246972 -8.921834 -10.21366 -10.768544\n"," -11.065884 -8.270028 -11.009446 -9.79741 -8.229642\n"," -9.637291 -8.985344 -8.959486 -10.116044 -9.153588\n"," -8.426435 -10.440107 -2.8482046 -5.147984 -7.8000393\n"," -7.3953137 -8.8387575 -8.483301 -8.403142 -9.870607\n"," -10.543358 -8.418191 -11.51122 -11.29439 -11.705905\n"," -9.793888 -10.239387 -7.895875 -8.369129 -9.151202\n"," -8.415587 -9.385785 -8.787096 -8.60712 -9.168413\n"," -10.053816 -6.7564883 -6.962956 -8.736223 -6.6655746\n"," -7.655622 -7.969426 -8.463739 -7.021253 -9.278132\n"," -11.818462 -8.614843 -9.003474 -8.835054 -9.851002\n"," -11.832719 -12.2729225 -9.901597 -7.088213 -6.365678\n"," -5.9258265 -6.109941 -9.413572 -10.997083 -8.5274315\n"," -8.2669525 -9.395208 -9.664686 -9.7653475 -8.847117\n"," -10.921641 -8.482299 -8.912869 -9.325035 -9.669159\n"," -8.549017 -9.760146 -10.429775 -10.085978 -10.271475\n"," -6.654699 -8.879931 -7.6676474 -6.643718 -9.249097\n"," -7.136822 -7.749438 -10.257306 -10.486742 -11.123094\n"," -11.035789 -9.469211 -12.157661 -10.533059 -8.657654\n"," -9.568329 -8.964076 -8.243627 -7.680642 -8.45268\n"," -9.282619 -9.683128 -10.821175 -9.947752 -8.673843\n"," -10.237195 -8.564869 -8.619034 -9.909412 -9.446312\n"," -10.431713 -10.449768 -8.876116 -9.539524 -10.091715\n"," -9.048174 -8.702717 -8.101336 -8.736906 -6.4872026\n"," -8.30099 -7.705913 -8.126543 -9.554791 -7.7878036\n"," -9.843629 -8.260341 -9.702665 -9.453328 -11.923137\n"," -10.03031 -8.647557 -9.144465 -8.42806 -6.997603\n"," -9.046401 -9.894837 -10.35937 -8.140752 -8.345273\n"," -9.8691225 -10.458608 -10.322532 -10.374186 -9.572269\n"," -8.6276865 -8.302613 -8.308632 -9.004797 -9.458067\n"," -7.9622707 -6.8402877 -9.689645 -8.055686 -8.758229\n"," -9.216762 -9.273148 -9.217585 -8.875041 -8.127402\n"," -8.534075 -11.847649 -10.116147 -6.0094028 -8.0422325\n"," -5.8051434 -8.198613 -7.914464 -7.742989 -8.479749\n"," -9.025269 -8.967785 -8.4304 -8.63103 -7.5055\n"," -7.8533816 -9.185931 -7.454551 -8.6782465 -7.8421903\n"," -4.6939163 -5.254039 -6.968486 -7.659274 -6.2444606\n"," -7.420964 -8.172086 -6.747325 -8.602886 -10.759201\n"," -9.839126 -10.491294 -9.961649 -10.368932 -8.411766\n"," -8.598401 -8.4753475 -9.222321 -10.77521 -10.591006\n"," -11.217412 -8.678829 -9.774086 -9.107769 -10.6897135\n"," -10.674824 -8.642229 -10.717334 -7.3090887 -6.9824114\n"," -10.106813 -9.793407 -7.4900055 -9.854651 -8.700833\n"," -8.348331 -10.050314 -9.854686 -9.308167 -8.760602\n"," -8.947077 -8.997547 -9.107168 -7.7415004 -10.546346\n"," -12.421447 -10.762346 -9.853303 -8.710389 -9.031233\n"," -11.481606 -9.634319 -8.992513 -7.9649596 -7.861982\n"," -8.120212 -8.570669 -8.901415 -7.302056 -7.901267\n"," -8.765028 -10.421144 -13.3499775 -11.8135605 -8.224922\n"," -9.03608 -11.085783 -9.946224 -9.469637 -11.543834\n"," -9.681482 -10.195737 -9.514964 -12.860728 -11.368202\n"," -14.56341 -12.711758 -11.350553 -13.18879 -9.407719\n"," -9.764156 -10.671074 -11.646937 -12.571701 -9.258234\n"," -10.133987 -9.215451 -11.699546 -9.683449 -8.102212\n"," -8.280151 -7.463168 -10.443597 -9.931871 -9.029537\n"," -10.809522 -9.242908 -11.067041 -8.609023 -9.37059\n"," -9.753563 -9.2785225 -8.860644 -9.524628 -8.6628685\n"," -10.255548 -9.580282 -8.809338 -8.740526 -9.584197\n"," -11.52479 -11.302017 -11.242475 -10.918733 -9.713407\n"," -9.895052 -10.511606 -9.031795 -9.218709 -8.112253\n"," -10.801938 -10.991635 -10.003891 -13.910624 -10.056546\n"," -11.168777 -5.049853 0.36061487 -0.94348156 -4.829131\n"," -0.08386183 -6.3951907 -4.9066935 -6.694729 -5.321249\n"," -6.964137 -5.1618257 -10.232554 -9.469683 -13.131125\n"," -8.580127 -8.248132 -6.7724633 -10.05866 -9.264765\n"," -9.219953 -8.475747 -11.707945 -9.846174 -11.320035\n"," -8.705415 -10.847191 -10.340349 -10.036801 -9.806983\n"," -10.147164 -11.190186 -9.811713 -9.941156 -7.9153924\n"," -6.1716485 -8.66741 -8.987112 -8.954214 -10.06595\n"," -8.577883 -7.141983 -8.485489 -6.889525 -7.2201576\n"," -7.8431225 -10.0173435 -10.403385 -7.2444086 -11.104176\n"," -8.219688 -10.097628 -8.695146 -12.666377 -11.033119\n"," -11.94974 -8.298492 -7.05132 -3.5644822 -4.064981\n"," -11.08213 -12.455801 -7.176717 -10.135718 -7.242347\n"," -10.261881 -9.935315 -8.438743 -8.370187 -10.179698\n"," -7.336482 -8.550757 -12.201432 -9.485798 -11.314145\n"," -7.3142076 -9.465659 -7.8540864 -10.189713 -4.3626466\n"," -12.027302 -6.617275 -5.9098096 -6.6777973 -4.2796392\n"," -3.5423725 -6.6892195 -6.3318806 -11.957001 -9.272463\n"," -9.623184 -11.435066 -9.314948 -10.678869 -9.70005\n"," -10.211389 -10.757169 -8.058382 -9.976052 -8.796942\n"," -10.181307 -9.936882 ]]\n","CPU times: user 62.3 ms, sys: 6.82 ms, total: 69.1 ms\n","Wall time: 74.6 ms\n"]}]},{"cell_type":"code","source":["output_data.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5JbesH4C6E3F","outputId":"56307e6b-e58e-4847-b91a-55a40252e188"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(1, 527)"]},"metadata":{},"execution_count":45}]},{"cell_type":"code","source":["sorted_indexes = np.argsort(preds)[::-1]"],"metadata":{"id":"OAd_yDEN6bK7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["sorted_indexes.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"RGlx7Tir6iQj","outputId":"dc0dd89b-09fd-4b99-f7f3-714e679eac88"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(527,)"]},"metadata":{},"execution_count":47}]},{"cell_type":"code","source":["preds = sigmoid(output_data[0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"************* Acoustic Event Detected: *****************\")\n","for k in range(10):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"********************************************************\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PjcLYfAk6Gwj","outputId":"1d42c015-3e73-4f1c-bc83-e03c41175ff7"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["************* Acoustic Event Detected: *****************\n","Gunshot, gunfire: 0.589\n","Artillery fire: 0.479\n","Speech: 0.300\n","Machine gun: 0.280\n","Music: 0.055\n","Outside, rural or natural: 0.028\n","Ding: 0.028\n","Clang: 0.017\n","Outside, urban or manmade: 0.014\n","Sound effect: 0.013\n","********************************************************\n"]}]},{"cell_type":"code","source":["!pip install coremltools"],"metadata":{"id":"FYeXiU1FLKP1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import coremltools as ct\n","\n","# Load the ONNX model\n","model_path = \"/content/EfficientAT/resources/dymn04_as.pt\"\n","\n","# Define dynamic input shape\n","inputs = [ct.TensorType(name=\"input_1\", shape=(1, 1, 128, ct.RangeDim(1, 2000)))]\n","\n","# Convert to Core ML\n","coreml_model = ct.convert(model_path, inputs=inputs, source='pytorch')\n","\n","# Save the Core ML model\n","coreml_model.save(\"dymn04_as.mlmodel\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"yKps_yXKLM90","outputId":"aa6b87fb-4d2d-467e-ba28-f65401510e67"},"execution_count":null,"outputs":[{"output_type":"error","ename":"RuntimeError","evalue":"PytorchStreamReader failed locating file constants.pkl: file not found","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-2-9e995d17096f>\u001b[0m in \u001b[0;36m<cell line: 10>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# Convert to Core ML\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mcoreml_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'pytorch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;31m# Save the Core ML model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/_converters_entry.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0mspecification_version\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_set_default_specification_version\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexact_target\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 444\u001b[0;31m mlmodel = mil_convert(\n\u001b[0m\u001b[1;32m 445\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexact_source\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36mmil_convert\u001b[0;34m(model, convert_from, convert_to, compute_units, **kwargs)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0mSee\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mcoremltools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconverters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \"\"\"\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_mil_convert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert_to\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mConverterRegistry\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMLModel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompute_units\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36m_mil_convert\u001b[0;34m(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"weights_dir\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mweights_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m proto, mil_program = mil_convert_to_proto(\n\u001b[0m\u001b[1;32m 212\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mconvert_from\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36mmil_convert_to_proto\u001b[0;34m(model, convert_from, convert_to, converter_registry, **kwargs)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[0mfrontend_converter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfrontend_converter_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 281\u001b[0;31m \u001b[0mprog\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfrontend_converter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 282\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 283\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m\"neuralnetwork\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/converter.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mfrontend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/frontend/torch/load.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0monly\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \"\"\"\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0mtorchscript\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_torchscript_from_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorchscript\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'training'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mtorchscript\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/coremltools/converters/mil/frontend/torch/load.py\u001b[0m in \u001b[0;36m_torchscript_from_model\u001b[0;34m(model_spec)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".pt\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".pth\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_os_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_torch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_torch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mScriptModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/jit/_serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, _extra_files)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mcu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCompilationUnit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpathlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0mcpp_module\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_ir_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_extra_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m cpp_module = torch._C.import_ir_module_from_buffer(\n","\u001b[0;31mRuntimeError\u001b[0m: PytorchStreamReader failed locating file constants.pkl: file not found"]}]},{"cell_type":"code","source":[],"metadata":{"id":"gSMIyJxwg4D-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#ONNX to tflite\n","Chạy trên kaggle"],"metadata":{"id":"gBGCjfV8n9aW"}},{"cell_type":"code","source":["!pip install onnx_tf tensorflow==2.14.0"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"JDsmanV6vn4n","executionInfo":{"status":"ok","timestamp":1747046031221,"user_tz":-420,"elapsed":67201,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"d1214964-c566-4766-ef7f-c44850520221"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: onnx_tf in /usr/local/lib/python3.11/dist-packages (1.10.0)\n","Collecting tensorflow==2.14.0\n"," Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\n","Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.4.0)\n","Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.6.3)\n","Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (25.2.10)\n","Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.4.0)\n","Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.2.0)\n","Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.13.0)\n","Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (18.1.1)\n","Collecting ml-dtypes==0.2.0 (from tensorflow==2.14.0)\n"," Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n","Requirement already satisfied: numpy>=1.23.5 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.23.5)\n","Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.4.0)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (24.2)\n","Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (4.25.7)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (75.2.0)\n","Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.17.0)\n","Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (3.1.0)\n","Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (4.13.2)\n","Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.14.1)\n","Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (0.37.1)\n","Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.14.0) (1.71.0)\n","Collecting tensorboard<2.15,>=2.14 (from tensorflow==2.14.0)\n"," Downloading tensorboard-2.14.1-py3-none-any.whl.metadata (1.7 kB)\n","Collecting tensorflow-estimator<2.15,>=2.14.0 (from tensorflow==2.14.0)\n"," Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl.metadata (1.3 kB)\n","Collecting keras<2.15,>=2.14.0 (from tensorflow==2.14.0)\n"," Downloading keras-2.14.0-py3-none-any.whl.metadata (2.4 kB)\n","Requirement already satisfied: onnx>=1.10.2 in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (1.17.0)\n","Requirement already satisfied: PyYAML in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (6.0.2)\n","Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.11/dist-packages (from onnx_tf) (0.23.0)\n","Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from astunparse>=1.6.0->tensorflow==2.14.0) (0.45.1)\n","Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.38.0)\n","Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (1.0.0)\n","Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.8)\n","Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.32.3)\n","Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.7.2)\n","Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.1.3)\n","Requirement already satisfied: typeguard<3.0.0,>=2.7 in /usr/local/lib/python3.11/dist-packages (from tensorflow-addons->onnx_tf) (2.13.3)\n","Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (5.5.2)\n","Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.4.2)\n","Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (4.9.1)\n","Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.0.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.4.1)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2.4.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (2025.4.26)\n","Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.0.2)\n","Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.11/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (0.6.1)\n","Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow==2.14.0) (3.2.2)\n","Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (489.9 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m489.9/489.9 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading keras-2.14.0-py3-none-any.whl (1.7 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tensorboard-2.14.1-py3-none-any.whl (5.5 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl (440 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m440.7/440.7 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: tensorflow-estimator, ml-dtypes, keras, tensorboard, tensorflow\n"," Attempting uninstall: tensorflow-estimator\n"," Found existing installation: tensorflow-estimator 2.12.0\n"," Uninstalling tensorflow-estimator-2.12.0:\n"," Successfully uninstalled tensorflow-estimator-2.12.0\n"," Attempting uninstall: ml-dtypes\n"," Found existing installation: ml-dtypes 0.4.1\n"," Uninstalling ml-dtypes-0.4.1:\n"," Successfully uninstalled ml-dtypes-0.4.1\n"," Attempting uninstall: keras\n"," Found existing installation: keras 2.12.0\n"," Uninstalling keras-2.12.0:\n"," Successfully uninstalled keras-2.12.0\n"," Attempting uninstall: tensorboard\n"," Found existing installation: tensorboard 2.12.3\n"," Uninstalling tensorboard-2.12.3:\n"," Successfully uninstalled tensorboard-2.12.3\n"," Attempting uninstall: tensorflow\n"," Found existing installation: tensorflow 2.12.0\n"," Uninstalling tensorflow-2.12.0:\n"," Successfully uninstalled tensorflow-2.12.0\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","tensorflow-text 2.18.1 requires tensorflow<2.19,>=2.18.0, but you have tensorflow 2.14.0 which is incompatible.\n","tf-keras 2.18.0 requires tensorflow<2.19,>=2.18, but you have tensorflow 2.14.0 which is incompatible.\n","flax 0.10.6 requires jax>=0.5.1, but you have jax 0.4.30 which is incompatible.\n","orbax-checkpoint 0.11.13 requires jax>=0.5.0, but you have jax 0.4.30 which is incompatible.\n","tensorflow-decision-forests 1.11.0 requires tensorflow==2.18.0, but you have tensorflow 2.14.0 which is incompatible.\n","chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.\n","tensorstore 0.1.74 requires ml_dtypes>=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed keras-2.14.0 ml-dtypes-0.2.0 tensorboard-2.14.1 tensorflow-2.14.0 tensorflow-estimator-2.14.0\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["keras","ml_dtypes","tensorboard","tensorflow"]},"id":"61b023af355f4283aa45d7f6298e8099"}},"metadata":{}}]},{"cell_type":"code","source":["from onnx_tf.backend import prepare\n","import onnx\n","import tensorflow as tf\n","\n","# Load ONNX model\n","onnx_model = onnx.load(\"/content/fixed_emotion.onnx\")\n","\n","# Convert to TensorFlow format\n","tf_rep = prepare(onnx_model)\n","print(tf_rep.inputs)\n","\n","# Export the TensorFlow SavedModel\n","tf_rep.export_graph(\"saved_model_4s_emotion\")\n","\n","# Convert SavedModel to TFLite\n","converter = tf.lite.TFLiteConverter.from_saved_model(\"saved_model_4s_emotion\")\n","converter.optimizations = [tf.lite.Optimize.DEFAULT]\n","tflite_model = converter.convert()\n","\n","# Save the TFLite model\n","with open(\"emotion_model_2025_03_28.tflite\", \"wb\") as f:\n"," f.write(tflite_model)"],"metadata":{"id":"DZDTl5Zqn_Fv","colab":{"base_uri":"https://localhost:8080/","height":547},"executionInfo":{"status":"error","timestamp":1747046056018,"user_tz":-420,"elapsed":11643,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}},"outputId":"509c8a16-d670-4709-c14a-a4c15bb156ee"},"execution_count":1,"outputs":[{"output_type":"error","ename":"ImportError","evalue":"This version of TensorFlow Probability requires TensorFlow version >= 2.18; Detected an installation of version 2.14.0. Please upgrade TensorFlow to proceed.","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-1-3ba768d958af>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mprepare\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# Load ONNX model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mversion\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/backend.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_unique_suffix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msupports_device\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcommon_supports_device\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandler_helper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_all_backend_handlers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpb_wrapper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mOnnxNode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_tf_module\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendTFModule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTFModule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/common/handler_helper.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdefs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_handler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendHandler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcommon\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnx_tf/handlers/backend/bernoulli.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdistributions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtfd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_handler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBackendHandler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0monnx_tf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0monnx_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m# from tensorflow_probability.google import staging # DisableOnExport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# from tensorflow_probability.google import tfp_google # DisableOnExport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpython\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;31m# pylint: disable=wildcard-import\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow_probability\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpython\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;31m# Non-lazy load of packages that register with tensorflow or keras.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpkg_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_maybe_nonlazy_load\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 152\u001b[0;31m \u001b[0mdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mglobals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpkg_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Forces loading the package from its lazy loader.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/internal/lazy_loader.py\u001b[0m in \u001b[0;36m__dir__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__dir__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/internal/lazy_loader.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\"\"\"Load the module and insert it into the parent's globals.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_first_access\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;31m# Import the target module and insert it into the parent's namespace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/tensorflow_probability/python/__init__.py\u001b[0m in \u001b[0;36m_validate_tf_environment\u001b[0;34m(package)\u001b[0m\n\u001b[1;32m 57\u001b[0m if (distutils.version.LooseVersion(tf.__version__) <\n\u001b[1;32m 58\u001b[0m distutils.version.LooseVersion(required_tensorflow_version)):\n\u001b[0;32m---> 59\u001b[0;31m raise ImportError(\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m'This version of TensorFlow Probability requires TensorFlow '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m'version >= {required}; Detected an installation of version {present}. '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mImportError\u001b[0m: This version of TensorFlow Probability requires TensorFlow version >= 2.18; Detected an installation of version 2.14.0. Please upgrade TensorFlow to proceed.","","\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"],"errorDetails":{"actions":[{"action":"open_url","actionText":"Open Examples","url":"/notebooks/snippets/importing_libraries.ipynb"}]}}]},{"cell_type":"markdown","source":["#test cái tflite"],"metadata":{"id":"nnML9fwUoQWg"}},{"cell_type":"markdown","source":["Read label from csv file"],"metadata":{"id":"64ShAXw5M9lS"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"JTLwJstrEbSj"},"outputs":[],"source":["import csv\n","\n","with open('/content/class_labels_indices.csv', 'r') as f:\n"," reader = csv.reader(f, delimiter=',')\n"," lines = list(reader)\n","\n","labels = []\n","ids = [] # Each label has a unique id such as \"/m/068hy\"\n","for i1 in range(1, len(lines)):\n"," id = lines[i1][1]\n"," label = lines[i1][2]\n"," ids.append(id)\n"," labels.append(label)\n","\n","classes_num = len(labels)"]},{"cell_type":"markdown","source":["Preprocess audio"],"metadata":{"id":"ebLOZ2JXNF0w"}},{"cell_type":"code","source":["import numpy as np\n","from numpy.lib.stride_tricks import as_strided\n","from typing import Tuple, Optional, Union\n","\n","def mel_scale_scalar(freq: float) -> float:\n"," \"\"\"Convert frequency to mel scale\"\"\"\n"," return 1127.0 * np.log(1.0 + freq / 700.0)\n","\n","def mel_scale(freq: np.ndarray) -> np.ndarray:\n"," \"\"\"Vector version of mel scale conversion\"\"\"\n"," return 1127.0 * np.log(1.0 + freq / 700.0)\n","\n","def inverse_mel_scale(mel: np.ndarray) -> np.ndarray:\n"," \"\"\"Convert mel scale to frequency\"\"\"\n"," return 700.0 * (np.exp(mel / 1127.0) - 1.0)\n","\n","def vtln_warp_mel_freq(vtln_low: float,\n"," vtln_high: float,\n"," low_freq: float,\n"," high_freq: float,\n"," vtln_warp_factor: float,\n"," mel_freq: np.ndarray) -> np.ndarray:\n"," \"\"\"\n"," Implements VTLN warping for mel frequencies\n"," \"\"\"\n"," return mel_freq # Placeholder - implement if VTLN warping is needed\n","\n","def get_mel_banks(\n"," num_bins: int,\n"," window_length_padded: int,\n"," sample_freq: float,\n"," low_freq: float,\n"," high_freq: float,\n"," vtln_low: float,\n"," vtln_high: float,\n"," vtln_warp_factor: float,\n",") -> Tuple[np.ndarray, np.ndarray]:\n"," \"\"\"\n"," Get mel filterbank matrices following Kaldi's implementation.\n","\n"," Args:\n"," num_bins: Number of triangular bins\n"," window_length_padded: Padded window length\n"," sample_freq: Sampling frequency\n"," low_freq: Lowest frequency to consider\n"," high_freq: Highest frequency to consider\n"," vtln_low: Lower frequency for VTLN\n"," vtln_high: Higher frequency for VTLN\n"," vtln_warp_factor: Warping factor for VTLN\n","\n"," Returns:\n"," Tuple[np.ndarray, np.ndarray]: (bins, center_freqs)\n"," - bins: melbank matrix of shape (num_bins, num_fft_bins)\n"," - center_freqs: center frequencies of bins of shape (num_bins,)\n"," \"\"\"\n"," assert num_bins > 3, \"Must have at least 3 mel bins\"\n"," assert window_length_padded % 2 == 0\n"," num_fft_bins = window_length_padded // 2\n"," nyquist = 0.5 * sample_freq\n","\n"," if high_freq <= 0.0:\n"," high_freq += nyquist\n","\n"," assert (\n"," (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)\n"," ), f\"Bad values in options: low-freq {low_freq} and high-freq {high_freq} vs. nyquist {nyquist}\"\n","\n"," # fft-bin width [think of it as Nyquist-freq / half-window-length]\n"," fft_bin_width = sample_freq / window_length_padded\n","\n"," mel_low_freq = mel_scale_scalar(low_freq)\n"," mel_high_freq = mel_scale_scalar(high_freq)\n","\n"," # divide by num_bins+1 in next line because of end-effects where the bins\n"," # spread out to the sides.\n"," mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)\n","\n"," if vtln_high < 0.0:\n"," vtln_high += nyquist\n","\n"," assert vtln_warp_factor == 1.0 or (\n"," (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)\n"," ), f\"Bad values in options: vtln-low {vtln_low} and vtln-high {vtln_high}, versus low-freq {low_freq} and high-freq {high_freq}\"\n","\n"," bin = np.arange(num_bins)[:, np.newaxis]\n"," left_mel = mel_low_freq + bin * mel_freq_delta # shape (num_bins, 1)\n"," center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # shape (num_bins, 1)\n"," right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # shape (num_bins, 1)\n","\n"," if vtln_warp_factor != 1.0:\n"," left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)\n"," center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)\n"," right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)\n","\n"," center_freqs = inverse_mel_scale(center_mel).squeeze(-1) # shape (num_bins)\n","\n"," # shape (1, num_fft_bins)\n"," mel = mel_scale(fft_bin_width * np.arange(num_fft_bins))[np.newaxis, :]\n","\n"," # shape (num_bins, num_fft_bins)\n"," up_slope = (mel - left_mel) / (center_mel - left_mel)\n"," down_slope = (right_mel - mel) / (right_mel - center_mel)\n","\n"," if vtln_warp_factor == 1.0:\n"," # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values\n"," bins = np.maximum(0.0, np.minimum(up_slope, down_slope))\n"," else:\n"," # warping can move the order of left_mel, center_mel, right_mel anywhere\n"," bins = np.zeros_like(up_slope)\n"," up_idx = (mel > left_mel) & (mel <= center_mel) # left_mel < mel <= center_mel\n"," down_idx = (mel > center_mel) & (mel < right_mel) # center_mel < mel < right_mel\n"," bins[up_idx] = up_slope[up_idx]\n"," bins[down_idx] = down_slope[down_idx]\n","\n"," return bins, center_freqs\n","\n","def stft(\n"," input: np.ndarray,\n"," n_fft: int,\n"," hop_length: Optional[int] = None,\n"," win_length: Optional[int] = None,\n"," window: Optional[np.ndarray] = None,\n"," center: bool = True,\n"," pad_mode: str = \"reflect\",\n"," normalized: bool = False,\n"," onesided: bool = True,\n"," return_complex: bool = True\n",") -> np.ndarray:\n"," \"\"\"\n"," NumPy implementation of Short-time Fourier transform (STFT).\n","\n"," Args:\n"," input: Input signal (B?, L) where B? is optional batch dimension\n"," n_fft: Size of Fourier transform\n"," hop_length: Distance between neighboring frames (default: n_fft//4)\n"," win_length: Window size (default: n_fft)\n"," window: Window function (1D array)\n"," center: Whether to pad input for centered frames\n"," pad_mode: Padding mode ('reflect', 'constant', 'edge')\n"," normalized: Whether to normalize the STFT\n"," onesided: Whether to return only positive frequencies\n"," return_complex: Whether to return complex array\n"," \"\"\"\n"," # Set default values\n"," if hop_length is None:\n"," hop_length = n_fft // 4\n"," if win_length is None:\n"," win_length = n_fft\n","\n"," # Prepare window\n"," if window is None:\n"," window = np.ones(win_length)\n"," if len(window) < n_fft:\n"," # Pad window to n_fft if needed\n"," pad_width = (n_fft - len(window)) // 2\n"," window = np.pad(window, (pad_width, n_fft - len(window) - pad_width))\n","\n"," # Handle input dimensions\n"," input = np.asarray(input)\n"," if input.ndim == 1:\n"," input = input[np.newaxis, :] # Add batch dimension\n"," squeeze_batch = True\n"," else:\n"," squeeze_batch = False\n","\n"," # Pad signal for centered frames\n"," if center:\n"," pad_width = int(n_fft // 2)\n"," if pad_mode == 'reflect':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='reflect')\n"," elif pad_mode == 'constant':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='constant')\n"," elif pad_mode == 'edge':\n"," input = np.pad(input, ((0, 0), (pad_width, pad_width)), mode='edge')\n","\n"," # Calculate number of frames\n"," n_frames = 1 + (input.shape[-1] - n_fft) // hop_length\n","\n"," # Create frame matrix using stride tricks\n"," frame_length = n_fft\n"," frame_step = hop_length\n"," frame_stride = input.strides[-1]\n","\n"," shape = (input.shape[0], n_frames, frame_length)\n"," strides = (input.strides[0], frame_step * frame_stride, frame_stride)\n","\n"," frames = as_strided(input, shape=shape, strides=strides, writeable=False)\n","\n"," # Apply window\n"," frames = frames * window\n","\n"," # Compute FFT\n"," stft_matrix = np.fft.fft(frames, n=n_fft, axis=-1)\n","\n"," # Normalize if requested\n"," if normalized:\n"," stft_matrix = stft_matrix / np.sqrt(n_fft)\n","\n"," # Handle onesided output\n"," if onesided:\n"," stft_matrix = stft_matrix[..., :(n_fft // 2) + 1]\n","\n"," # Handle return format\n"," if return_complex:\n"," result = stft_matrix\n"," else:\n"," result = np.stack((stft_matrix.real, stft_matrix.imag), axis=-1)\n","\n"," # Remove batch dimension if input was 1D\n"," if squeeze_batch:\n"," result = result[0]\n","\n"," return result\n","\n","class MelSTFT:\n"," def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024,\n"," fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):\n"," \"\"\"\n"," Initialize MelSTFT for audio feature extraction using only NumPy.\n"," \"\"\"\n"," self.n_mels = n_mels\n"," self.sr = sr\n"," self.win_length = win_length\n"," self.hopsize = hopsize\n"," self.n_fft = n_fft\n"," self.fmin = fmin\n"," if fmax is None:\n"," fmax = sr // 2 - fmax_aug_range // 2\n"," self.fmax = fmax\n","\n"," # Create Hann window\n"," self.window = np.hanning(win_length)\n","\n"," # Create mel filterbank following Kaldi's implementation\n"," self.mel_basis, _ = get_mel_banks(\n"," self.n_mels,\n"," self.n_fft,\n"," self.sr,\n"," self.fmin,\n"," self.fmax,\n"," 100.0,\n"," -500.,\n"," 1.0\n"," )\n"," self.mel_basis = np.pad(self.mel_basis, ((0, 0), (0, 1)), mode='constant', constant_values=0)\n","\n"," # Pre-emphasis filter coefficients\n"," self.preemphasis_coefficient = np.array([-.97, 1]).reshape(1, 1, 2)\n","\n"," def preemphasis(self, x):\n"," \"\"\"Apply pre-emphasis filter using conv1d equivalent\"\"\"\n"," # Reshape input to match conv1d input shape (batch, channels, length)\n"," x = x.reshape(1, 1, -1)\n","\n"," # Implement conv1d manually\n"," output_size = x.shape[2] - self.preemphasis_coefficient.shape[2] + 1\n"," result = np.zeros((1, 1, output_size))\n","\n"," for i in range(output_size):\n"," result[0, 0, i] = np.sum(x[0, 0, i:i+2] * self.preemphasis_coefficient[0, 0])\n","\n"," return result[0]\n","\n"," def __call__(self, x):\n"," \"\"\"Convert audio to log-Mel spectrogram.\"\"\"\n"," # Apply pre-emphasis\n"," print(f'x shape before {x.shape}')\n"," x = self.preemphasis(x)\n"," print(f'x shape after {x.shape}')\n","\n"," # Compute STFT\n"," spec_x = stft(\n"," input=x,\n"," n_fft=self.n_fft,\n"," hop_length=self.hopsize,\n"," win_length=self.win_length,\n"," window=self.window,\n"," center=True,\n"," pad_mode='reflect',\n"," normalized=False,\n"," return_complex=False\n"," )\n","\n"," # Convert to power spectrogram\n"," print(f'spec_x before is {spec_x.shape}')\n"," spec_x = np.sum(spec_x ** 2, axis=-1)\n","\n"," print(f'spec_x shape is {spec_x.shape}')\n","\n"," # Apply Mel filterbank (ensuring shapes match)\n"," melspec = np.dot(self.mel_basis, spec_x.transpose(0,2,1)).transpose(1,0,2)\n","\n"," # Log-scale and normalize\n"," melspec = np.log(melspec + 1e-5)\n"," melspec = (melspec + 4.5) / 5.\n","\n"," return melspec"],"metadata":{"id":"Ris3ajNhGEkQ","executionInfo":{"status":"ok","timestamp":1747040595424,"user_tz":-420,"elapsed":27,"user":{"displayName":"Nguha Duong","userId":"08400527066055899740"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["Load model and handle input and output\n","\n","The input shape of model is `[1,1,128, any]`. Note the last axis of input shape is dynamic. It depends on the length of audio.\n","\n","The output shape of model is `[1, 527]`. Currently, we're using the pretrained model and exporting it. Note: this shape might be changed in the future."],"metadata":{"id":"CqXIPms2NMcv"}},{"cell_type":"code","source":["import numpy as np\n","import tensorflow as tf\n","\n","# Load the TFLite model\n","interpreter = tf.lite.Interpreter(model_path=\"/content/emotion_model_2025_03_28.tflite\")\n","interpreter.allocate_tensors()\n","\n","# Get input and output details, can print to see details of input and output\n","input_details = interpreter.get_input_details() # The shape of input is [1, 1, 128, None] the last axis depends on the length of audio, data type = float32\n","output_details = interpreter.get_output_details() # The shape of output is [1, 527], data type = float32"],"metadata":{"id":"VlCUYth6GT0f"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# preprocess audio\n","import librosa\n","\n","mel = MelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=320)\n","\n","# convert to waveform\n","(waveform, _) = librosa.core.load('/content/angry_4s.mp3', sr=32000, mono=True)\n","waveform = np.stack([waveform])\n","spec = mel(waveform)\n","if spec.shape[-1] > 400:\n"," spec = spec[:, :, :400]\n","else:\n"," spec = np.pad(spec, ((0, 0), (0, 0), (0, 400 - spec.shape[-1])), mode='constant')\n","spec = np.expand_dims(spec, axis=0)\n","print(f'the shape of spec is {spec.shape}')\n","spec = spec.astype(np.float32)"],"metadata":{"id":"PtexE6_HGlDW","colab":{"base_uri":"https://localhost:8080/"},"outputId":"cde2909e-7b58-4546-8e3a-e6f5779d8d04"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["x shape before (1, 130223)\n","x shape after (1, 130222)\n","spec_x before is (1, 407, 513, 2)\n","spec_x shape is (1, 407, 513)\n","the shape of spec is (1, 1, 128, 400)\n"]}]},{"cell_type":"code","source":["# Run inference\n","interpreter.set_tensor(input_details[0]['index'], spec)\n","interpreter.invoke()\n","output_data = interpreter.get_tensor(output_details[0]['index'])"],"metadata":{"id":"xyy_O-_vHsYy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Post process output to get the result"],"metadata":{"id":"eeh_JqiyN0zo"}},{"cell_type":"code","source":["# Post processing\n","import numpy as np\n","\n","def sigmoid(x):\n"," return 1 / (1 + np.exp(-x))\n","\n","def softmax(x):\n"," exp_x = np.exp(x - np.max(x)) # Subtract max(x) for numerical stability\n"," return exp_x / np.sum(exp_x, axis=-1, keepdims=True)\n","\n","preds = softmax(output_data[0])\n","sorted_indexes = np.argsort(preds)[::-1]\n","# Print audio tagging top probabilities\n","print(\"* Acoustic Event Detected: *\")\n","for k in range(7):\n"," print('{}: {:.3f}'.format(labels[sorted_indexes[k]], preds[sorted_indexes[k]]))\n","print(\"**\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pPLXt-FYH1Nz","outputId":"f620d201-8bfb-497d-d583-ef73db24cc46"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["* Acoustic Event Detected: *\n","Angry: 0.913\n","sad: 0.015\n","neutral: 0.015\n","Fear: 0.015\n","Happy: 0.014\n","surprise: 0.014\n","Disgust: 0.014\n","**\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"-MMULJzEH3xe"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
|
Preprocessing/Convert_to_HDF5.ipynb
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": []
|
| 7 |
+
},
|
| 8 |
+
"kernelspec": {
|
| 9 |
+
"name": "python3",
|
| 10 |
+
"display_name": "Python 3"
|
| 11 |
+
},
|
| 12 |
+
"language_info": {
|
| 13 |
+
"name": "python"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"cells": [
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"source": [
|
| 20 |
+
"#convert file audio to mp3 32k"
|
| 21 |
+
],
|
| 22 |
+
"metadata": {
|
| 23 |
+
"id": "XE8k_JyY5eBb"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"metadata": {
|
| 30 |
+
"id": "Wyv972sc5Lb4"
|
| 31 |
+
},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"import os\n",
|
| 35 |
+
"from multiprocessing import Pool, cpu_count\n",
|
| 36 |
+
"from tqdm import tqdm\n",
|
| 37 |
+
"import subprocess\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"def process_file_ffmpeg(args):\n",
|
| 40 |
+
" file_path, input_folder, output_folder = args\n",
|
| 41 |
+
" rel_path = os.path.relpath(file_path, input_folder)\n",
|
| 42 |
+
" rel_path = os.path.splitext(rel_path)[0] + \".mp3\" # luôn xuất mp3\n",
|
| 43 |
+
" out_path = os.path.join(output_folder, rel_path)\n",
|
| 44 |
+
" os.makedirs(os.path.dirname(out_path), exist_ok=True)\n",
|
| 45 |
+
"\n",
|
| 46 |
+
" cmd = [\n",
|
| 47 |
+
" \"ffmpeg\",\n",
|
| 48 |
+
" \"-y\", # overwrite nếu đã tồn tại\n",
|
| 49 |
+
" \"-i\", file_path,\n",
|
| 50 |
+
" \"-ar\", \"32000\", # sample rate 32kHz\n",
|
| 51 |
+
" \"-ac\", \"1\", # stereo set thành 2 còn mono set thành 1\n",
|
| 52 |
+
" \"-b:a\", \"192k\", # bitrate\n",
|
| 53 |
+
" out_path\n",
|
| 54 |
+
" ]\n",
|
| 55 |
+
"\n",
|
| 56 |
+
" try:\n",
|
| 57 |
+
" subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)\n",
|
| 58 |
+
" return True\n",
|
| 59 |
+
" except subprocess.CalledProcessError:\n",
|
| 60 |
+
" print(f\"❌ Lỗi khi xử lý {file_path}\")\n",
|
| 61 |
+
" return False\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"def convert_audio_ffmpeg_multiprocessing(input_folder, output_folder, num_workers=None):\n",
|
| 64 |
+
" audio_exts = ('.mp3', '.wav', '.flac', '.m4a', '.ogg')\n",
|
| 65 |
+
"\n",
|
| 66 |
+
" # Lấy danh sách tất cả file audio\n",
|
| 67 |
+
" all_files = []\n",
|
| 68 |
+
" for root, _, files in os.walk(input_folder):\n",
|
| 69 |
+
" for f in files:\n",
|
| 70 |
+
" if f.lower().endswith(audio_exts):\n",
|
| 71 |
+
" all_files.append(os.path.join(root, f))\n",
|
| 72 |
+
"\n",
|
| 73 |
+
" if num_workers is None:\n",
|
| 74 |
+
" num_workers = cpu_count()\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" args_list = [(f, input_folder, output_folder) for f in all_files]\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" # Multiprocessing + tqdm\n",
|
| 79 |
+
" with Pool(num_workers) as pool:\n",
|
| 80 |
+
" for _ in tqdm(pool.imap_unordered(process_file_ffmpeg, args_list),\n",
|
| 81 |
+
" total=len(args_list), desc=\"Converting to 32kHz MP3\"):\n",
|
| 82 |
+
" pass\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# --- Ví dụ sử dụng ---\n",
|
| 85 |
+
"input_dir = \"/content/dataset\"\n",
|
| 86 |
+
"output_dir = \"/content/dataset_process\"\n",
|
| 87 |
+
"convert_audio_ffmpeg_multiprocessing(input_dir, output_dir)\n"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "markdown",
|
| 92 |
+
"source": [
|
| 93 |
+
"#mp3 to hdf5"
|
| 94 |
+
],
|
| 95 |
+
"metadata": {
|
| 96 |
+
"id": "kIoHloKr5ky7"
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "markdown",
|
| 101 |
+
"source": [
|
| 102 |
+
"##Audioset"
|
| 103 |
+
],
|
| 104 |
+
"metadata": {
|
| 105 |
+
"id": "pctubCgR5sli"
|
| 106 |
+
}
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"source": [
|
| 111 |
+
"import h5py\n",
|
| 112 |
+
"import pandas as pd\n",
|
| 113 |
+
"import numpy as np\n",
|
| 114 |
+
"import csv\n",
|
| 115 |
+
"import os\n",
|
| 116 |
+
"import io\n",
|
| 117 |
+
"import av\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"def decode_mp3(mp3_arr):\n",
|
| 120 |
+
" \"\"\"\n",
|
| 121 |
+
" Giải mã một mảng uint8 đại diện cho một file MP3.\n",
|
| 122 |
+
" :rtype: np.array\n",
|
| 123 |
+
" \"\"\"\n",
|
| 124 |
+
" container = av.open(io.BytesIO(mp3_arr.tobytes())) # Đọc dữ liệu MP3\n",
|
| 125 |
+
" stream = next(s for s in container.streams if s.type == 'audio') # Lấy stream âm thanh\n",
|
| 126 |
+
" a = []\n",
|
| 127 |
+
" for i, packet in enumerate(container.demux(stream)): # Demux các gói dữ liệu âm thanh\n",
|
| 128 |
+
" for frame in packet.decode(): # Giải mã frame\n",
|
| 129 |
+
" a.append(frame.to_ndarray().reshape(-1)) # Chuyển đổi frame thành mảng numpy\n",
|
| 130 |
+
" waveform = np.concatenate(a) # Kết nối tất cả các frame lại\n",
|
| 131 |
+
" if waveform.dtype != 'float32': # Kiểm tra loại dữ liệu\n",
|
| 132 |
+
" raise RuntimeError(\"Unexpected wave type\")\n",
|
| 133 |
+
" return waveform\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# %%\n",
|
| 136 |
+
"base_dir = \"/content/output_\"\n",
|
| 137 |
+
"balanced_csv= '/content/new_updated_balanced_train_segments.csv'\n",
|
| 138 |
+
"eval_csv= '/content/new_eval_segments.csv'\n",
|
| 139 |
+
"mp3_path = \"/content/dataset/\"\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"# %%\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"def read_metadata(csv_path, classes_num, id_to_ix):\n",
|
| 146 |
+
" \"\"\"Read metadata of AudioSet from a csv file.\"\"\"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" audio_names = []\n",
|
| 149 |
+
" targets = []\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" with open(csv_path, 'r') as fr:\n",
|
| 152 |
+
" reader = csv.reader(fr)\n",
|
| 153 |
+
" next(reader) # Skip header line if exists\n",
|
| 154 |
+
" next(reader) # Skip another potential header line\n",
|
| 155 |
+
" next(reader) # Skip another potential header line\n",
|
| 156 |
+
"\n",
|
| 157 |
+
" for line in reader:\n",
|
| 158 |
+
" if len(line) < 4:\n",
|
| 159 |
+
" continue # Skip malformed lines\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" audio_name = 'Y{}.mp3'.format(line[0]) # Assumed naming convention\n",
|
| 162 |
+
" label_ids = line[3].strip('\"').split(',')\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" audio_names.append(audio_name)\n",
|
| 165 |
+
" target = np.zeros(classes_num, dtype=bool)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" for label_id in label_ids:\n",
|
| 168 |
+
" if label_id in id_to_ix:\n",
|
| 169 |
+
" ix = id_to_ix[label_id]\n",
|
| 170 |
+
" target[ix] = 1\n",
|
| 171 |
+
" else:\n",
|
| 172 |
+
" print(f\"Warning: Label ID {label_id} not found in id_to_ix.\")\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" targets.append(target)\n",
|
| 175 |
+
"\n",
|
| 176 |
+
" meta_dict = {'audio_name': np.array(audio_names), 'target': np.array(targets)}\n",
|
| 177 |
+
" print(meta_dict)\n",
|
| 178 |
+
" return meta_dict\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"# Load label\n",
|
| 181 |
+
"with open('/content/new_class_labels_indices_filter_discard.csv', 'r') as f:\n",
|
| 182 |
+
" reader = csv.reader(f, delimiter=',')\n",
|
| 183 |
+
" lines = list(reader)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"labels = []\n",
|
| 186 |
+
"ids = [] # Each label has a unique id such as \"/m/068hy\"\n",
|
| 187 |
+
"for i1 in range(1, len(lines)):\n",
|
| 188 |
+
" id = lines[i1][1]\n",
|
| 189 |
+
" label = lines[i1][2]\n",
|
| 190 |
+
" ids.append(id)\n",
|
| 191 |
+
" labels.append(label)\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"classes_num = len(labels)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"lb_to_ix = {label : i for i, label in enumerate(labels)}\n",
|
| 196 |
+
"ix_to_lb = {i : label for i, label in enumerate(labels)}\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"id_to_ix = {id : i for i, id in enumerate(ids)}\n",
|
| 199 |
+
"ix_to_id = {i : id for i, id in enumerate(ids)}\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"# %%\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"def check_available(balanced_csv,balanced_audio_path,prefix=None):\n",
|
| 204 |
+
" meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)\n",
|
| 205 |
+
" #print(meta_csv)\n",
|
| 206 |
+
" audios_num = len(meta_csv['audio_name'])\n",
|
| 207 |
+
" found=0\n",
|
| 208 |
+
" notfound=0\n",
|
| 209 |
+
" available_files=[]\n",
|
| 210 |
+
" available_targets=[]\n",
|
| 211 |
+
" if prefix is None:\n",
|
| 212 |
+
" prefix = os.path.basename(balanced_csv)[:-4]\n",
|
| 213 |
+
" for n in range(audios_num):\n",
|
| 214 |
+
" audio_path = meta_csv['audio_name'][n]\n",
|
| 215 |
+
" #print(balanced_audio_path + f\"{prefix}/{audio_path}\")\n",
|
| 216 |
+
" if os.path.isfile(balanced_audio_path + f\"{prefix}/{audio_path}\" ):\n",
|
| 217 |
+
" found+=1\n",
|
| 218 |
+
" available_files.append(meta_csv['audio_name'][n])\n",
|
| 219 |
+
" available_targets.append(meta_csv['target'][n])\n",
|
| 220 |
+
" else:\n",
|
| 221 |
+
" notfound+=1\n",
|
| 222 |
+
" print(f\"Found {found} . not found {notfound}\")\n",
|
| 223 |
+
" return available_files,available_targets\n",
|
| 224 |
+
"# %%\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"# %%\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"# %%\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"os.makedirs(os.path.dirname(base_dir + \"mp3\"), exist_ok=True)\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"for read_file,prefix in [(balanced_csv,\"balanced_train_segments/\"), (eval_csv,\"eval_segments/\"),]:\n",
|
| 234 |
+
" print(\"now working on \",read_file,prefix)\n",
|
| 235 |
+
" #files, y = torch.load(read_file+\".pth\")\n",
|
| 236 |
+
" files, y = check_available(read_file, mp3_path, prefix)\n",
|
| 237 |
+
" y = np.packbits(y, axis=-1)\n",
|
| 238 |
+
" packed_len = y.shape[1]\n",
|
| 239 |
+
" print(files[0], \"classes: \",packed_len, y.dtype)\n",
|
| 240 |
+
" available_size = len(files)\n",
|
| 241 |
+
" f = files[0][:-3]+\"mp3\"\n",
|
| 242 |
+
" a = np.fromfile(mp3_path+prefix + \"/\"+f, dtype='uint8')\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" dt = h5py.vlen_dtype(np.dtype('uint8'))\n",
|
| 245 |
+
" save_file = prefix.split(\"/\")[0]\n",
|
| 246 |
+
" os.makedirs(os.path.dirname(base_dir + \"mp3/\" ), exist_ok=True)\n",
|
| 247 |
+
" with h5py.File(base_dir+ \"mp3/\" + save_file+\"_mp3.hdf\", 'w') as hf:\n",
|
| 248 |
+
" audio_name = hf.create_dataset('audio_name', shape=(0,), maxshape=(None,), dtype='S20')\n",
|
| 249 |
+
" waveform = hf.create_dataset('mp3', shape=(0,), maxshape=(None,), dtype=dt)\n",
|
| 250 |
+
" target = hf.create_dataset('target', shape=(0, packed_len), maxshape=(None, packed_len), dtype=y.dtype)\n",
|
| 251 |
+
" for i,file in enumerate(files):\n",
|
| 252 |
+
" if i%1000==0:\n",
|
| 253 |
+
" print(f\"{i}/{available_size}\")\n",
|
| 254 |
+
" f = file[:-3] + \"mp3\"\n",
|
| 255 |
+
" a = np.fromfile(mp3_path + prefix + f, dtype='uint8')\n",
|
| 256 |
+
" try:\n",
|
| 257 |
+
" # Kiểm tra xem file audio có đọc được không\n",
|
| 258 |
+
" decode_mp3(a) # Dùng hàm decode_mp3 của bạn\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" audio_name.resize((i + 1,))\n",
|
| 261 |
+
" waveform.resize((i + 1,))\n",
|
| 262 |
+
" target.resize((i + 1, packed_len))\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" audio_name[i]=f\n",
|
| 265 |
+
" waveform[i] = a\n",
|
| 266 |
+
" target[i] = y[i]\n",
|
| 267 |
+
" except Exception as e:\n",
|
| 268 |
+
" print(f\"File lỗi tại index {i} với file {file}: {e}\")\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" print(a.shape)\n",
|
| 271 |
+
" print(\"Done!\" , prefix)"
|
| 272 |
+
],
|
| 273 |
+
"metadata": {
|
| 274 |
+
"id": "8oFKEbtb5mzr"
|
| 275 |
+
},
|
| 276 |
+
"execution_count": null,
|
| 277 |
+
"outputs": []
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"cell_type": "markdown",
|
| 281 |
+
"source": [
|
| 282 |
+
"##For this structure folder/train/folder(class)/file"
|
| 283 |
+
],
|
| 284 |
+
"metadata": {
|
| 285 |
+
"id": "mCwcSx8y5v7q"
|
| 286 |
+
}
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"cell_type": "markdown",
|
| 290 |
+
"source": [
|
| 291 |
+
""
|
| 292 |
+
],
|
| 293 |
+
"metadata": {
|
| 294 |
+
"id": "cRAr5tkn566K"
|
| 295 |
+
}
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"source": [
|
| 300 |
+
"import h5py\n",
|
| 301 |
+
"import numpy as np\n",
|
| 302 |
+
"import os\n",
|
| 303 |
+
"import io\n",
|
| 304 |
+
"import av\n",
|
| 305 |
+
"from pathlib import Path\n",
|
| 306 |
+
"from tqdm import tqdm\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"def decode_mp3(mp3_arr):\n",
|
| 309 |
+
" \"\"\"\n",
|
| 310 |
+
" Giải mã một mảng uint8 đại diện cho một file MP3.\n",
|
| 311 |
+
" :rtype: np.array\n",
|
| 312 |
+
" \"\"\"\n",
|
| 313 |
+
" try:\n",
|
| 314 |
+
" container = av.open(io.BytesIO(mp3_arr.tobytes()))\n",
|
| 315 |
+
" stream = next(s for s in container.streams if s.type == 'audio')\n",
|
| 316 |
+
" a = []\n",
|
| 317 |
+
" for packet in container.demux(stream):\n",
|
| 318 |
+
" for frame in packet.decode():\n",
|
| 319 |
+
" a.append(frame.to_ndarray().reshape(-1))\n",
|
| 320 |
+
" waveform = np.concatenate(a)\n",
|
| 321 |
+
" if waveform.dtype != 'float32':\n",
|
| 322 |
+
" raise RuntimeError(\"Unexpected wave type\")\n",
|
| 323 |
+
" return waveform\n",
|
| 324 |
+
" except Exception as e:\n",
|
| 325 |
+
" raise RuntimeError(f\"Cannot decode MP3: {e}\")\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"def scan_dataset_structure(dataset_path):\n",
|
| 328 |
+
" \"\"\"\n",
|
| 329 |
+
" Quét cấu trúc thư mục dataset và tạo mapping cho classes\n",
|
| 330 |
+
" Structure: dataset_path/train(or test)/class_name/*.mp3\n",
|
| 331 |
+
" \"\"\"\n",
|
| 332 |
+
" dataset_path = Path(dataset_path)\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" # Lấy tất cả các class names từ thư mục train\n",
|
| 335 |
+
" train_path = dataset_path / \"train\"\n",
|
| 336 |
+
" if not train_path.exists():\n",
|
| 337 |
+
" raise ValueError(f\"Train folder not found: {train_path}\")\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" classes = sorted([d.name for d in train_path.iterdir() if d.is_dir()])\n",
|
| 340 |
+
" classes_num = len(classes)\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" # Tạo mapping\n",
|
| 343 |
+
" lb_to_ix = {label: i for i, label in enumerate(classes)}\n",
|
| 344 |
+
" ix_to_lb = {i: label for i, label in enumerate(classes)}\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" print(f\"Found {classes_num} classes: {classes[:7]}...\" if len(classes) > 7 else f\"Found {classes_num} classes: {classes}\")\n",
|
| 347 |
+
"\n",
|
| 348 |
+
" return classes, classes_num, lb_to_ix, ix_to_lb\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"def collect_audio_files(dataset_path, split='train', shuffle=True, random_seed=42):\n",
|
| 351 |
+
" \"\"\"\n",
|
| 352 |
+
" Thu thập tất cả audio files từ structure thư mục và shuffle để tránh grouping theo class\n",
|
| 353 |
+
" \"\"\"\n",
|
| 354 |
+
" dataset_path = Path(dataset_path)\n",
|
| 355 |
+
" split_path = dataset_path / split\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" if not split_path.exists():\n",
|
| 358 |
+
" raise ValueError(f\"{split} folder not found: {split_path}\")\n",
|
| 359 |
+
"\n",
|
| 360 |
+
" audio_files = []\n",
|
| 361 |
+
" labels = []\n",
|
| 362 |
+
" class_counts = {}\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" class_dirs = [d for d in split_path.iterdir() if d.is_dir()]\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" print(f\"📁 Scanning {split} folder...\")\n",
|
| 367 |
+
" for class_dir in tqdm(class_dirs, desc=f\"Scanning classes\"):\n",
|
| 368 |
+
" class_name = class_dir.name\n",
|
| 369 |
+
" mp3_files = list(class_dir.glob(\"*.mp3\"))\n",
|
| 370 |
+
" class_counts[class_name] = len(mp3_files)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" for mp3_file in mp3_files:\n",
|
| 373 |
+
" audio_files.append(str(mp3_file))\n",
|
| 374 |
+
" labels.append(class_name)\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" # Shuffle để tránh việc group theo class trong HDF5\n",
|
| 377 |
+
" if shuffle:\n",
|
| 378 |
+
" import random\n",
|
| 379 |
+
" random.seed(random_seed)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" # Zip files và labels lại, sau đó shuffle\n",
|
| 382 |
+
" combined = list(zip(audio_files, labels))\n",
|
| 383 |
+
" random.shuffle(combined)\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" # Unpack lại\n",
|
| 386 |
+
" audio_files, labels = zip(*combined)\n",
|
| 387 |
+
" audio_files = list(audio_files)\n",
|
| 388 |
+
" labels = list(labels)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" print(f\"🔀 Files shuffled with seed={random_seed}\")\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" # In class distribution\n",
|
| 393 |
+
" print(f\"✅ Found {len(audio_files)} audio files in {split} set\")\n",
|
| 394 |
+
" print(f\"📊 Class distribution:\")\n",
|
| 395 |
+
" for class_name, count in sorted(class_counts.items()):\n",
|
| 396 |
+
" percentage = count / len(audio_files) * 100\n",
|
| 397 |
+
" print(f\" {class_name}: {count} files ({percentage:.1f}%)\")\n",
|
| 398 |
+
"\n",
|
| 399 |
+
" return audio_files, labels\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"def create_target_array(labels, classes_num, lb_to_ix):\n",
|
| 402 |
+
" \"\"\"\n",
|
| 403 |
+
" Tạo target array từ danh sách labels\n",
|
| 404 |
+
" \"\"\"\n",
|
| 405 |
+
" targets = []\n",
|
| 406 |
+
" for label in labels:\n",
|
| 407 |
+
" target = np.zeros(classes_num, dtype=bool)\n",
|
| 408 |
+
" if label in lb_to_ix:\n",
|
| 409 |
+
" ix = lb_to_ix[label]\n",
|
| 410 |
+
" target[ix] = 1\n",
|
| 411 |
+
" targets.append(target)\n",
|
| 412 |
+
"\n",
|
| 413 |
+
" return np.array(targets)\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"def convert_to_hdf5(dataset_path, output_dir):\n",
|
| 416 |
+
" \"\"\"\n",
|
| 417 |
+
" Convert audio dataset to HDF5 format\n",
|
| 418 |
+
" \"\"\"\n",
|
| 419 |
+
" # Tạo output directory\n",
|
| 420 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
| 421 |
+
"\n",
|
| 422 |
+
" # Quét cấu trúc dataset\n",
|
| 423 |
+
" classes, classes_num, lb_to_ix, ix_to_lb = scan_dataset_structure(dataset_path)\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" # Process both train and test splits\n",
|
| 426 |
+
" for split in ['train', 'test']:\n",
|
| 427 |
+
" print(f\"\\n=== Processing {split} set ===\")\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" try:\n",
|
| 430 |
+
" # Thu thập audio files\n",
|
| 431 |
+
" audio_files, labels = collect_audio_files(dataset_path, split)\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" if len(audio_files) == 0:\n",
|
| 434 |
+
" print(f\"No audio files found in {split} set, skipping...\")\n",
|
| 435 |
+
" continue\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" # Tạo target array\n",
|
| 438 |
+
" targets = create_target_array(labels, classes_num, lb_to_ix)\n",
|
| 439 |
+
"\n",
|
| 440 |
+
" # Pack targets để tiết kiệm memory\n",
|
| 441 |
+
" packed_targets = np.packbits(targets, axis=-1)\n",
|
| 442 |
+
" packed_len = packed_targets.shape[1]\n",
|
| 443 |
+
"\n",
|
| 444 |
+
" print(f\"Target shape: {targets.shape} -> Packed: {packed_targets.shape}\")\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" # Tạo HDF5 file\n",
|
| 447 |
+
" dt = h5py.vlen_dtype(np.dtype('uint8'))\n",
|
| 448 |
+
" hdf5_path = os.path.join(output_dir, f\"{split}_mp3.hdf5\")\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" with h5py.File(hdf5_path, 'w') as hf:\n",
|
| 451 |
+
" # Tạo datasets\n",
|
| 452 |
+
" audio_name_ds = hf.create_dataset('audio_name', shape=(0,), maxshape=(None,), dtype='S200')\n",
|
| 453 |
+
" waveform_ds = hf.create_dataset('mp3', shape=(0,), maxshape=(None,), dtype=dt)\n",
|
| 454 |
+
" target_ds = hf.create_dataset('target', shape=(0, packed_len), maxshape=(None, packed_len), dtype=packed_targets.dtype)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" # Lưu class info\n",
|
| 457 |
+
" hf.attrs['classes'] = [c.encode('utf-8') for c in classes]\n",
|
| 458 |
+
" hf.attrs['classes_num'] = classes_num\n",
|
| 459 |
+
"\n",
|
| 460 |
+
" valid_count = 0\n",
|
| 461 |
+
" error_count = 0\n",
|
| 462 |
+
"\n",
|
| 463 |
+
" # Process từng file với tqdm\n",
|
| 464 |
+
" pbar = tqdm(zip(audio_files, labels),\n",
|
| 465 |
+
" total=len(audio_files),\n",
|
| 466 |
+
" desc=f\"Converting {split}\")\n",
|
| 467 |
+
"\n",
|
| 468 |
+
" for i, (audio_file, label) in enumerate(pbar):\n",
|
| 469 |
+
" try:\n",
|
| 470 |
+
" # Đọc file MP3\n",
|
| 471 |
+
" audio_data = np.fromfile(audio_file, dtype='uint8')\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" # Kiểm tra tính hợp lệ bằng cách decode\n",
|
| 474 |
+
" decode_mp3(audio_data)\n",
|
| 475 |
+
"\n",
|
| 476 |
+
" # Resize datasets\n",
|
| 477 |
+
" audio_name_ds.resize((valid_count + 1,))\n",
|
| 478 |
+
" waveform_ds.resize((valid_count + 1,))\n",
|
| 479 |
+
" target_ds.resize((valid_count + 1, packed_len))\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" # Lưu data\n",
|
| 482 |
+
" filename = os.path.basename(audio_file).encode('utf-8')\n",
|
| 483 |
+
" audio_name_ds[valid_count] = filename\n",
|
| 484 |
+
" waveform_ds[valid_count] = audio_data\n",
|
| 485 |
+
" target_ds[valid_count] = packed_targets[i]\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" valid_count += 1\n",
|
| 488 |
+
"\n",
|
| 489 |
+
" # Update progress bar\n",
|
| 490 |
+
" pbar.set_postfix({\n",
|
| 491 |
+
" 'Valid': valid_count,\n",
|
| 492 |
+
" 'Errors': error_count,\n",
|
| 493 |
+
" 'Success Rate': f\"{valid_count/(i+1)*100:.1f}%\"\n",
|
| 494 |
+
" })\n",
|
| 495 |
+
"\n",
|
| 496 |
+
" except Exception as e:\n",
|
| 497 |
+
" error_count += 1\n",
|
| 498 |
+
" pbar.set_postfix({\n",
|
| 499 |
+
" 'Valid': valid_count,\n",
|
| 500 |
+
" 'Errors': error_count,\n",
|
| 501 |
+
" 'Success Rate': f\"{valid_count/(i+1)*100:.1f}%\"\n",
|
| 502 |
+
" })\n",
|
| 503 |
+
" if error_count <= 5: # Chỉ show 5 error đầu tiên\n",
|
| 504 |
+
" tqdm.write(f\"❌ Error processing {os.path.basename(audio_file)}: {e}\")\n",
|
| 505 |
+
" continue\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" pbar.close()\n",
|
| 508 |
+
"\n",
|
| 509 |
+
" print(f\"Successfully processed {valid_count}/{len(audio_files)} files\")\n",
|
| 510 |
+
" print(f\"Saved to: {hdf5_path}\")\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" except Exception as e:\n",
|
| 513 |
+
" print(f\"Error processing {split} set: {e}\")\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"def main():\n",
|
| 516 |
+
" # Cấu hình paths\n",
|
| 517 |
+
" dataset_path = \"/content/dataset\" # Thay đổi path này\n",
|
| 518 |
+
" output_dir = \"/content/dataset_hdf5\" # Thay đổi path này\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" # Chạy conversion\n",
|
| 521 |
+
" convert_to_hdf5(dataset_path, output_dir)\n",
|
| 522 |
+
" print(\"\\n=== Conversion completed! ===\")\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"if __name__ == \"__main__\":\n",
|
| 525 |
+
" # Example usage:\n",
|
| 526 |
+
" # dataset_path = \"/content/audio_dataset\"\n",
|
| 527 |
+
" # output_dir = \"/content/output_hdf5\"\n",
|
| 528 |
+
" # convert_to_hdf5(dataset_path, output_dir)\n",
|
| 529 |
+
" main()"
|
| 530 |
+
],
|
| 531 |
+
"metadata": {
|
| 532 |
+
"id": "lcdNaKMx59ip"
|
| 533 |
+
},
|
| 534 |
+
"execution_count": null,
|
| 535 |
+
"outputs": []
|
| 536 |
+
}
|
| 537 |
+
]
|
| 538 |
+
}
|
Train_script/EfficientAT_code_train.zip
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05f83c9fa5de22cc1153c30e0adff06eea1c0a11b6be6cc6ff2a466006b95fe1
|
| 3 |
+
size 22666093
|
Train_script/Train_guide.ipynb
ADDED
|
@@ -0,0 +1,1117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "cd7ca034-54d2-434d-9ba8-9e1877e7a7c9",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"!unzip -q '/workspace/EfficientAT_code_train.zip' -d './train'"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 3,
|
| 16 |
+
"id": "2f674532-597d-4471-aef8-53c6f2b3ce7f",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"!unzip -q '/workspace/HDF5_file.zip' -d '/workspace/train/EfficientAT-main/datasets'"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "c7595bbe-a43b-4e08-a7de-1e27206a782b",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"#Please change the dataset_config (filename hdf5) in \"audioset.py\" in \"EfficientAT-main/datasets\" "
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 4,
|
| 36 |
+
"id": "7dd14b84-3451-423f-bc07-16fddddc2a07",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [
|
| 39 |
+
{
|
| 40 |
+
"name": "stdout",
|
| 41 |
+
"output_type": "stream",
|
| 42 |
+
"text": [
|
| 43 |
+
"Collecting av (from -r /workspace/train/EfficientAT-main/requirements.txt (line 1))\n",
|
| 44 |
+
" Downloading av-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.6 kB)\n",
|
| 45 |
+
"Collecting h5py (from -r /workspace/train/EfficientAT-main/requirements.txt (line 2))\n",
|
| 46 |
+
" Downloading h5py-3.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)\n",
|
| 47 |
+
"Collecting librosa (from -r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 48 |
+
" Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)\n",
|
| 49 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 4)) (1.24.1)\n",
|
| 50 |
+
"Collecting scikit_learn (from -r /workspace/train/EfficientAT-main/requirements.txt (line 5))\n",
|
| 51 |
+
" Downloading scikit_learn-1.7.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n",
|
| 52 |
+
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.0+cu118)\n",
|
| 53 |
+
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 7)) (2.1.0+cu118)\n",
|
| 54 |
+
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from -r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (0.16.0+cu118)\n",
|
| 55 |
+
"Collecting tqdm (from -r /workspace/train/EfficientAT-main/requirements.txt (line 9))\n",
|
| 56 |
+
" Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)\n",
|
| 57 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 58 |
+
"\u001b[?25hCollecting wandb (from -r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 59 |
+
" Downloading wandb-0.21.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n",
|
| 60 |
+
"Collecting pandas (from -r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
|
| 61 |
+
" Downloading pandas-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n",
|
| 62 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 63 |
+
"\u001b[?25hCollecting seaborn (from -r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 64 |
+
" Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n",
|
| 65 |
+
"Collecting audioread>=2.1.9 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 66 |
+
" Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)\n",
|
| 67 |
+
"Collecting numba>=0.51.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 68 |
+
" Downloading numba-0.61.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)\n",
|
| 69 |
+
"Collecting scipy>=1.6.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 70 |
+
" Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
|
| 71 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.0/62.0 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 72 |
+
"\u001b[?25hCollecting joblib>=1.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 73 |
+
" Downloading joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)\n",
|
| 74 |
+
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (5.1.1)\n",
|
| 75 |
+
"Collecting soundfile>=0.12.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 76 |
+
" Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)\n",
|
| 77 |
+
"Collecting pooch>=1.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 78 |
+
" Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)\n",
|
| 79 |
+
"Collecting soxr>=0.3.2 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 80 |
+
" Downloading soxr-0.5.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)\n",
|
| 81 |
+
"Requirement already satisfied: typing_extensions>=4.1.1 in /usr/local/lib/python3.10/dist-packages (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (4.4.0)\n",
|
| 82 |
+
"Collecting lazy_loader>=0.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 83 |
+
" Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)\n",
|
| 84 |
+
"Collecting msgpack>=1.0 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 85 |
+
" Downloading msgpack-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)\n",
|
| 86 |
+
"Collecting threadpoolctl>=3.1.0 (from scikit_learn->-r /workspace/train/EfficientAT-main/requirements.txt (line 5))\n",
|
| 87 |
+
" Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)\n",
|
| 88 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.9.0)\n",
|
| 89 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (1.12)\n",
|
| 90 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.0)\n",
|
| 91 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (3.1.2)\n",
|
| 92 |
+
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2023.4.0)\n",
|
| 93 |
+
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.0)\n",
|
| 94 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2.31.0)\n",
|
| 95 |
+
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (9.3.0)\n",
|
| 96 |
+
"Collecting click!=8.0.0,>=7.1 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 97 |
+
" Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)\n",
|
| 98 |
+
"Collecting gitpython!=3.1.29,>=1.0.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 99 |
+
" Downloading gitpython-3.1.45-py3-none-any.whl.metadata (13 kB)\n",
|
| 100 |
+
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (23.2)\n",
|
| 101 |
+
"Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (3.11.0)\n",
|
| 102 |
+
"Collecting protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 103 |
+
" Downloading protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)\n",
|
| 104 |
+
"Collecting pydantic<3 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 105 |
+
" Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)\n",
|
| 106 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.0/68.0 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 107 |
+
"\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10)) (6.0.1)\n",
|
| 108 |
+
"Collecting sentry-sdk>=2.0.0 (from wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 109 |
+
" Downloading sentry_sdk-2.35.0-py2.py3-none-any.whl.metadata (10 kB)\n",
|
| 110 |
+
"Collecting typing_extensions>=4.1.1 (from librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 111 |
+
" Downloading typing_extensions-4.14.1-py3-none-any.whl.metadata (3.0 kB)\n",
|
| 112 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11)) (2.8.2)\n",
|
| 113 |
+
"Collecting pytz>=2020.1 (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
|
| 114 |
+
" Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)\n",
|
| 115 |
+
"Collecting tzdata>=2022.7 (from pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11))\n",
|
| 116 |
+
" Downloading tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)\n",
|
| 117 |
+
"Collecting matplotlib!=3.6.1,>=3.4 (from seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 118 |
+
" Downloading matplotlib-3.10.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n",
|
| 119 |
+
"Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 120 |
+
" Downloading gitdb-4.0.12-py3-none-any.whl.metadata (1.2 kB)\n",
|
| 121 |
+
"Collecting contourpy>=1.0.1 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 122 |
+
" Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n",
|
| 123 |
+
"Collecting cycler>=0.10 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 124 |
+
" Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n",
|
| 125 |
+
"Collecting fonttools>=4.22.0 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 126 |
+
" Downloading fonttools-4.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (108 kB)\n",
|
| 127 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.9/108.9 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 128 |
+
"\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12))\n",
|
| 129 |
+
" Downloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.3 kB)\n",
|
| 130 |
+
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn->-r /workspace/train/EfficientAT-main/requirements.txt (line 12)) (2.4.7)\n",
|
| 131 |
+
"Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3))\n",
|
| 132 |
+
" Downloading llvmlite-0.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.8 kB)\n",
|
| 133 |
+
"Collecting annotated-types>=0.6.0 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 134 |
+
" Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)\n",
|
| 135 |
+
"Collecting pydantic-core==2.33.2 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 136 |
+
" Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n",
|
| 137 |
+
"Collecting typing-inspection>=0.4.0 (from pydantic<3->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 138 |
+
" Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)\n",
|
| 139 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->-r /workspace/train/EfficientAT-main/requirements.txt (line 11)) (1.16.0)\n",
|
| 140 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2.1.1)\n",
|
| 141 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (3.4)\n",
|
| 142 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (1.26.13)\n",
|
| 143 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r /workspace/train/EfficientAT-main/requirements.txt (line 8)) (2022.12.7)\n",
|
| 144 |
+
"Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile>=0.12.1->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (1.16.0)\n",
|
| 145 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (2.1.2)\n",
|
| 146 |
+
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->-r /workspace/train/EfficientAT-main/requirements.txt (line 6)) (1.3.0)\n",
|
| 147 |
+
"Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa->-r /workspace/train/EfficientAT-main/requirements.txt (line 3)) (2.21)\n",
|
| 148 |
+
"Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb->-r /workspace/train/EfficientAT-main/requirements.txt (line 10))\n",
|
| 149 |
+
" Downloading smmap-5.0.2-py3-none-any.whl.metadata (4.3 kB)\n",
|
| 150 |
+
"Downloading av-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.2 MB)\n",
|
| 151 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.2/39.2 MB\u001b[0m \u001b[31m46.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
|
| 152 |
+
"\u001b[?25hDownloading h5py-3.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n",
|
| 153 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
| 154 |
+
"\u001b[?25hDownloading librosa-0.11.0-py3-none-any.whl (260 kB)\n",
|
| 155 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m260.7/260.7 kB\u001b[0m \u001b[31m60.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 156 |
+
"\u001b[?25hDownloading scikit_learn-1.7.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)\n",
|
| 157 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.7/9.7 MB\u001b[0m \u001b[31m95.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
|
| 158 |
+
"\u001b[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)\n",
|
| 159 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 160 |
+
"\u001b[?25hDownloading wandb-0.21.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.4 MB)\n",
|
| 161 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m22.4/22.4 MB\u001b[0m \u001b[31m105.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
| 162 |
+
"\u001b[?25hDownloading pandas-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.3 MB)\n",
|
| 163 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.3/12.3 MB\u001b[0m \u001b[31m106.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n",
|
| 164 |
+
"\u001b[?25hDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n",
|
| 165 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.9/294.9 kB\u001b[0m \u001b[31m51.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 166 |
+
"\u001b[?25hDownloading audioread-3.0.1-py3-none-any.whl (23 kB)\n",
|
| 167 |
+
"Downloading click-8.2.1-py3-none-any.whl (102 kB)\n",
|
| 168 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.2/102.2 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 169 |
+
"\u001b[?25hDownloading gitpython-3.1.45-py3-none-any.whl (208 kB)\n",
|
| 170 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.2/208.2 kB\u001b[0m \u001b[31m56.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 171 |
+
"\u001b[?25hDownloading joblib-1.5.1-py3-none-any.whl (307 kB)\n",
|
| 172 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.7/307.7 kB\u001b[0m \u001b[31m68.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 173 |
+
"\u001b[?25hDownloading lazy_loader-0.4-py3-none-any.whl (12 kB)\n",
|
| 174 |
+
"Downloading matplotlib-3.10.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)\n",
|
| 175 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m91.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
|
| 176 |
+
"\u001b[?25hDownloading msgpack-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (408 kB)\n",
|
| 177 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━���━━━\u001b[0m \u001b[32m408.6/408.6 kB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 178 |
+
"\u001b[?25hDownloading numba-0.61.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)\n",
|
| 179 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m138.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 180 |
+
"\u001b[?25hDownloading pooch-1.8.2-py3-none-any.whl (64 kB)\n",
|
| 181 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.6/64.6 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 182 |
+
"\u001b[?25hDownloading protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl (322 kB)\n",
|
| 183 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m322.0/322.0 kB\u001b[0m \u001b[31m49.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 184 |
+
"\u001b[?25hDownloading pydantic-2.11.7-py3-none-any.whl (444 kB)\n",
|
| 185 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m444.8/444.8 kB\u001b[0m \u001b[31m82.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 186 |
+
"\u001b[?25hDownloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n",
|
| 187 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m140.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 188 |
+
"\u001b[?25hDownloading pytz-2025.2-py2.py3-none-any.whl (509 kB)\n",
|
| 189 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m509.2/509.2 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n",
|
| 190 |
+
"\u001b[?25hDownloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n",
|
| 191 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m43.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
|
| 192 |
+
"\u001b[?25hDownloading sentry_sdk-2.35.0-py2.py3-none-any.whl (363 kB)\n",
|
| 193 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.8/363.8 kB\u001b[0m \u001b[31m69.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 194 |
+
"\u001b[?25hDownloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)\n",
|
| 195 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m128.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 196 |
+
"\u001b[?25hDownloading soxr-0.5.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (252 kB)\n",
|
| 197 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m44.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 198 |
+
"\u001b[?25hDownloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)\n",
|
| 199 |
+
"Downloading typing_extensions-4.14.1-py3-none-any.whl (43 kB)\n",
|
| 200 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.9/43.9 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 201 |
+
"\u001b[?25hDownloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)\n",
|
| 202 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m347.8/347.8 kB\u001b[0m \u001b[31m71.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 203 |
+
"\u001b[?25hDownloading annotated_types-0.7.0-py3-none-any.whl (13 kB)\n",
|
| 204 |
+
"Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)\n",
|
| 205 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m325.0/325.0 kB\u001b[0m \u001b[31m39.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 206 |
+
"\u001b[?25hDownloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
|
| 207 |
+
"Downloading fonttools-4.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (4.8 MB)\n",
|
| 208 |
+
"\u001b[2K \u001b[90m��━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.8/4.8 MB\u001b[0m \u001b[31m113.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
|
| 209 |
+
"\u001b[?25hDownloading gitdb-4.0.12-py3-none-any.whl (62 kB)\n",
|
| 210 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 211 |
+
"\u001b[?25hDownloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)\n",
|
| 212 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m72.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 213 |
+
"\u001b[?25hDownloading llvmlite-0.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.4 MB)\n",
|
| 214 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.4/42.4 MB\u001b[0m \u001b[31m47.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
|
| 215 |
+
"\u001b[?25hDownloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)\n",
|
| 216 |
+
"Downloading smmap-5.0.2-py3-none-any.whl (24 kB)\n",
|
| 217 |
+
"Installing collected packages: pytz, tzdata, typing_extensions, tqdm, threadpoolctl, soxr, smmap, sentry-sdk, scipy, protobuf, msgpack, llvmlite, lazy_loader, kiwisolver, joblib, h5py, fonttools, cycler, contourpy, click, av, audioread, annotated-types, typing-inspection, soundfile, scikit_learn, pydantic-core, pooch, pandas, numba, matplotlib, gitdb, seaborn, pydantic, librosa, gitpython, wandb\n",
|
| 218 |
+
" Attempting uninstall: typing_extensions\n",
|
| 219 |
+
" Found existing installation: typing_extensions 4.4.0\n",
|
| 220 |
+
" Uninstalling typing_extensions-4.4.0:\n",
|
| 221 |
+
" Successfully uninstalled typing_extensions-4.4.0\n",
|
| 222 |
+
"Successfully installed annotated-types-0.7.0 audioread-3.0.1 av-15.0.0 click-8.2.1 contourpy-1.3.2 cycler-0.12.1 fonttools-4.59.1 gitdb-4.0.12 gitpython-3.1.45 h5py-3.14.0 joblib-1.5.1 kiwisolver-1.4.9 lazy_loader-0.4 librosa-0.11.0 llvmlite-0.44.0 matplotlib-3.10.5 msgpack-1.1.1 numba-0.61.2 pandas-2.3.1 pooch-1.8.2 protobuf-6.32.0 pydantic-2.11.7 pydantic-core-2.33.2 pytz-2025.2 scikit_learn-1.7.1 scipy-1.15.3 seaborn-0.13.2 sentry-sdk-2.35.0 smmap-5.0.2 soundfile-0.13.1 soxr-0.5.0.post1 threadpoolctl-3.6.0 tqdm-4.67.1 typing-inspection-0.4.1 typing_extensions-4.14.1 tzdata-2025.2 wandb-0.21.1\n",
|
| 223 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 224 |
+
"\u001b[0m\n",
|
| 225 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.2\u001b[0m\n",
|
| 226 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
|
| 227 |
+
]
|
| 228 |
+
}
|
| 229 |
+
],
|
| 230 |
+
"source": [
|
| 231 |
+
"!pip install -r /workspace/train/EfficientAT-main/requirements.txt"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "code",
|
| 236 |
+
"execution_count": 5,
|
| 237 |
+
"id": "5a7a42e9-3f79-4f63-a6ea-98db2206ca96",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [
|
| 240 |
+
{
|
| 241 |
+
"name": "stdout",
|
| 242 |
+
"output_type": "stream",
|
| 243 |
+
"text": [
|
| 244 |
+
"/workspace/train/EfficientAT-main\n"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"name": "stderr",
|
| 249 |
+
"output_type": "stream",
|
| 250 |
+
"text": [
|
| 251 |
+
"/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n",
|
| 252 |
+
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
|
| 253 |
+
]
|
| 254 |
+
}
|
| 255 |
+
],
|
| 256 |
+
"source": [
|
| 257 |
+
"%cd /workspace/train/EfficientAT-main"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"id": "d0abf14b-46a1-4f92-9f70-e0de40dff46a",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [
|
| 266 |
+
{
|
| 267 |
+
"name": "stdout",
|
| 268 |
+
"output_type": "stream",
|
| 269 |
+
"text": [
|
| 270 |
+
"INFO: FMAX is None, setting to 15000 \n",
|
| 271 |
+
"The number of classes in the loaded state dict (=527) and the current model (=7) is not the same. Dropping final fully-connected layer and loading weights in non-strict mode!\n",
|
| 272 |
+
"DyMN(\n",
|
| 273 |
+
" (layers): ModuleList(\n",
|
| 274 |
+
" (0): DY_Block(\n",
|
| 275 |
+
" (exp_conv): DynamicWrapper(\n",
|
| 276 |
+
" (module): Identity()\n",
|
| 277 |
+
" )\n",
|
| 278 |
+
" (exp_norm): Identity()\n",
|
| 279 |
+
" (exp_act): DynamicWrapper(\n",
|
| 280 |
+
" (module): Identity()\n",
|
| 281 |
+
" )\n",
|
| 282 |
+
" (depth_conv): DynamicConv(\n",
|
| 283 |
+
" (residuals): Sequential(\n",
|
| 284 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 285 |
+
" )\n",
|
| 286 |
+
" )\n",
|
| 287 |
+
" (depth_norm): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 288 |
+
" (depth_act): DyReLUB(\n",
|
| 289 |
+
" (coef_net): Sequential(\n",
|
| 290 |
+
" (0): Linear(in_features=16, out_features=32, bias=True)\n",
|
| 291 |
+
" )\n",
|
| 292 |
+
" (sigmoid): Sigmoid()\n",
|
| 293 |
+
" )\n",
|
| 294 |
+
" (ca): CoordAtt()\n",
|
| 295 |
+
" (proj_conv): DynamicConv(\n",
|
| 296 |
+
" (residuals): Sequential(\n",
|
| 297 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 298 |
+
" )\n",
|
| 299 |
+
" )\n",
|
| 300 |
+
" (proj_norm): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 301 |
+
" (context_gen): ContextGen(\n",
|
| 302 |
+
" (joint_conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 303 |
+
" (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 304 |
+
" (joint_act): Hardswish()\n",
|
| 305 |
+
" (conv_f): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 306 |
+
" (conv_t): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 307 |
+
" (pool_f): Sequential()\n",
|
| 308 |
+
" (pool_t): Sequential()\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" )\n",
|
| 311 |
+
" (1): DY_Block(\n",
|
| 312 |
+
" (exp_conv): DynamicConv(\n",
|
| 313 |
+
" (residuals): Sequential(\n",
|
| 314 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 315 |
+
" )\n",
|
| 316 |
+
" )\n",
|
| 317 |
+
" (exp_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 318 |
+
" (exp_act): DynamicWrapper(\n",
|
| 319 |
+
" (module): ReLU(inplace=True)\n",
|
| 320 |
+
" )\n",
|
| 321 |
+
" (depth_conv): DynamicConv(\n",
|
| 322 |
+
" (residuals): Sequential(\n",
|
| 323 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 324 |
+
" )\n",
|
| 325 |
+
" )\n",
|
| 326 |
+
" (depth_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 327 |
+
" (depth_act): DyReLUB(\n",
|
| 328 |
+
" (coef_net): Sequential(\n",
|
| 329 |
+
" (0): Linear(in_features=16, out_features=96, bias=True)\n",
|
| 330 |
+
" )\n",
|
| 331 |
+
" (sigmoid): Sigmoid()\n",
|
| 332 |
+
" )\n",
|
| 333 |
+
" (ca): CoordAtt()\n",
|
| 334 |
+
" (proj_conv): DynamicConv(\n",
|
| 335 |
+
" (residuals): Sequential(\n",
|
| 336 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 337 |
+
" )\n",
|
| 338 |
+
" )\n",
|
| 339 |
+
" (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 340 |
+
" (context_gen): ContextGen(\n",
|
| 341 |
+
" (joint_conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 342 |
+
" (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 343 |
+
" (joint_act): Hardswish()\n",
|
| 344 |
+
" (conv_f): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 345 |
+
" (conv_t): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 346 |
+
" (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
|
| 347 |
+
" (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
|
| 348 |
+
" )\n",
|
| 349 |
+
" )\n",
|
| 350 |
+
" (2): DY_Block(\n",
|
| 351 |
+
" (exp_conv): DynamicConv(\n",
|
| 352 |
+
" (residuals): Sequential(\n",
|
| 353 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 354 |
+
" )\n",
|
| 355 |
+
" )\n",
|
| 356 |
+
" (exp_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 357 |
+
" (exp_act): DynamicWrapper(\n",
|
| 358 |
+
" (module): ReLU(inplace=True)\n",
|
| 359 |
+
" )\n",
|
| 360 |
+
" (depth_conv): DynamicConv(\n",
|
| 361 |
+
" (residuals): Sequential(\n",
|
| 362 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 363 |
+
" )\n",
|
| 364 |
+
" )\n",
|
| 365 |
+
" (depth_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 366 |
+
" (depth_act): DyReLUB(\n",
|
| 367 |
+
" (coef_net): Sequential(\n",
|
| 368 |
+
" (0): Linear(in_features=16, out_features=128, bias=True)\n",
|
| 369 |
+
" )\n",
|
| 370 |
+
" (sigmoid): Sigmoid()\n",
|
| 371 |
+
" )\n",
|
| 372 |
+
" (ca): CoordAtt()\n",
|
| 373 |
+
" (proj_conv): DynamicConv(\n",
|
| 374 |
+
" (residuals): Sequential(\n",
|
| 375 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 376 |
+
" )\n",
|
| 377 |
+
" )\n",
|
| 378 |
+
" (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 379 |
+
" (context_gen): ContextGen(\n",
|
| 380 |
+
" (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 381 |
+
" (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 382 |
+
" (joint_act): Hardswish()\n",
|
| 383 |
+
" (conv_f): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 384 |
+
" (conv_t): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 385 |
+
" (pool_f): Sequential()\n",
|
| 386 |
+
" (pool_t): Sequential()\n",
|
| 387 |
+
" )\n",
|
| 388 |
+
" )\n",
|
| 389 |
+
" (3): DY_Block(\n",
|
| 390 |
+
" (exp_conv): DynamicConv(\n",
|
| 391 |
+
" (residuals): Sequential(\n",
|
| 392 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 393 |
+
" )\n",
|
| 394 |
+
" )\n",
|
| 395 |
+
" (exp_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 396 |
+
" (exp_act): DynamicWrapper(\n",
|
| 397 |
+
" (module): ReLU(inplace=True)\n",
|
| 398 |
+
" )\n",
|
| 399 |
+
" (depth_conv): DynamicConv(\n",
|
| 400 |
+
" (residuals): Sequential(\n",
|
| 401 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 402 |
+
" )\n",
|
| 403 |
+
" )\n",
|
| 404 |
+
" (depth_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 405 |
+
" (depth_act): DyReLUB(\n",
|
| 406 |
+
" (coef_net): Sequential(\n",
|
| 407 |
+
" (0): Linear(in_features=16, out_features=128, bias=True)\n",
|
| 408 |
+
" )\n",
|
| 409 |
+
" (sigmoid): Sigmoid()\n",
|
| 410 |
+
" )\n",
|
| 411 |
+
" (ca): CoordAtt()\n",
|
| 412 |
+
" (proj_conv): DynamicConv(\n",
|
| 413 |
+
" (residuals): Sequential(\n",
|
| 414 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 415 |
+
" )\n",
|
| 416 |
+
" )\n",
|
| 417 |
+
" (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 418 |
+
" (context_gen): ContextGen(\n",
|
| 419 |
+
" (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 420 |
+
" (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 421 |
+
" (joint_act): Hardswish()\n",
|
| 422 |
+
" (conv_f): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 423 |
+
" (conv_t): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 424 |
+
" (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
|
| 425 |
+
" (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
|
| 426 |
+
" )\n",
|
| 427 |
+
" )\n",
|
| 428 |
+
" (4-5): 2 x DY_Block(\n",
|
| 429 |
+
" (exp_conv): DynamicConv(\n",
|
| 430 |
+
" (residuals): Sequential(\n",
|
| 431 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 432 |
+
" )\n",
|
| 433 |
+
" )\n",
|
| 434 |
+
" (exp_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 435 |
+
" (exp_act): DynamicWrapper(\n",
|
| 436 |
+
" (module): ReLU(inplace=True)\n",
|
| 437 |
+
" )\n",
|
| 438 |
+
" (depth_conv): DynamicConv(\n",
|
| 439 |
+
" (residuals): Sequential(\n",
|
| 440 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 441 |
+
" )\n",
|
| 442 |
+
" )\n",
|
| 443 |
+
" (depth_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 444 |
+
" (depth_act): DyReLUB(\n",
|
| 445 |
+
" (coef_net): Sequential(\n",
|
| 446 |
+
" (0): Linear(in_features=16, out_features=192, bias=True)\n",
|
| 447 |
+
" )\n",
|
| 448 |
+
" (sigmoid): Sigmoid()\n",
|
| 449 |
+
" )\n",
|
| 450 |
+
" (ca): CoordAtt()\n",
|
| 451 |
+
" (proj_conv): DynamicConv(\n",
|
| 452 |
+
" (residuals): Sequential(\n",
|
| 453 |
+
" (0): Linear(in_features=16, out_features=4, bias=True)\n",
|
| 454 |
+
" )\n",
|
| 455 |
+
" )\n",
|
| 456 |
+
" (proj_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 457 |
+
" (context_gen): ContextGen(\n",
|
| 458 |
+
" (joint_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 459 |
+
" (joint_norm): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 460 |
+
" (joint_act): Hardswish()\n",
|
| 461 |
+
" (conv_f): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 462 |
+
" (conv_t): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 463 |
+
" (pool_f): Sequential()\n",
|
| 464 |
+
" (pool_t): Sequential()\n",
|
| 465 |
+
" )\n",
|
| 466 |
+
" )\n",
|
| 467 |
+
" (6): DY_Block(\n",
|
| 468 |
+
" (exp_conv): DynamicConv(\n",
|
| 469 |
+
" (residuals): Sequential(\n",
|
| 470 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 471 |
+
" )\n",
|
| 472 |
+
" )\n",
|
| 473 |
+
" (exp_norm): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 474 |
+
" (exp_act): DynamicWrapper(\n",
|
| 475 |
+
" (module): Hardswish()\n",
|
| 476 |
+
" )\n",
|
| 477 |
+
" (depth_conv): DynamicConv(\n",
|
| 478 |
+
" (residuals): Sequential(\n",
|
| 479 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 480 |
+
" )\n",
|
| 481 |
+
" )\n",
|
| 482 |
+
" (depth_norm): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 483 |
+
" (depth_act): DyReLUB(\n",
|
| 484 |
+
" (coef_net): Sequential(\n",
|
| 485 |
+
" (0): Linear(in_features=24, out_features=384, bias=True)\n",
|
| 486 |
+
" )\n",
|
| 487 |
+
" (sigmoid): Sigmoid()\n",
|
| 488 |
+
" )\n",
|
| 489 |
+
" (ca): CoordAtt()\n",
|
| 490 |
+
" (proj_conv): DynamicConv(\n",
|
| 491 |
+
" (residuals): Sequential(\n",
|
| 492 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 493 |
+
" )\n",
|
| 494 |
+
" )\n",
|
| 495 |
+
" (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 496 |
+
" (context_gen): ContextGen(\n",
|
| 497 |
+
" (joint_conv): Conv2d(16, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 498 |
+
" (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 499 |
+
" (joint_act): Hardswish()\n",
|
| 500 |
+
" (conv_f): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 501 |
+
" (conv_t): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 502 |
+
" (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
|
| 503 |
+
" (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
|
| 504 |
+
" )\n",
|
| 505 |
+
" )\n",
|
| 506 |
+
" (7): DY_Block(\n",
|
| 507 |
+
" (exp_conv): DynamicConv(\n",
|
| 508 |
+
" (residuals): Sequential(\n",
|
| 509 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 510 |
+
" )\n",
|
| 511 |
+
" )\n",
|
| 512 |
+
" (exp_norm): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 513 |
+
" (exp_act): DynamicWrapper(\n",
|
| 514 |
+
" (module): Hardswish()\n",
|
| 515 |
+
" )\n",
|
| 516 |
+
" (depth_conv): DynamicConv(\n",
|
| 517 |
+
" (residuals): Sequential(\n",
|
| 518 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 519 |
+
" )\n",
|
| 520 |
+
" )\n",
|
| 521 |
+
" (depth_norm): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 522 |
+
" (depth_act): DyReLUB(\n",
|
| 523 |
+
" (coef_net): Sequential(\n",
|
| 524 |
+
" (0): Linear(in_features=24, out_features=320, bias=True)\n",
|
| 525 |
+
" )\n",
|
| 526 |
+
" (sigmoid): Sigmoid()\n",
|
| 527 |
+
" )\n",
|
| 528 |
+
" (ca): CoordAtt()\n",
|
| 529 |
+
" (proj_conv): DynamicConv(\n",
|
| 530 |
+
" (residuals): Sequential(\n",
|
| 531 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 532 |
+
" )\n",
|
| 533 |
+
" )\n",
|
| 534 |
+
" (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 535 |
+
" (context_gen): ContextGen(\n",
|
| 536 |
+
" (joint_conv): Conv2d(32, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 537 |
+
" (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 538 |
+
" (joint_act): Hardswish()\n",
|
| 539 |
+
" (conv_f): Conv2d(24, 80, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 540 |
+
" (conv_t): Conv2d(24, 80, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 541 |
+
" (pool_f): Sequential()\n",
|
| 542 |
+
" (pool_t): Sequential()\n",
|
| 543 |
+
" )\n",
|
| 544 |
+
" )\n",
|
| 545 |
+
" (8-9): 2 x DY_Block(\n",
|
| 546 |
+
" (exp_conv): DynamicConv(\n",
|
| 547 |
+
" (residuals): Sequential(\n",
|
| 548 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 549 |
+
" )\n",
|
| 550 |
+
" )\n",
|
| 551 |
+
" (exp_norm): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 552 |
+
" (exp_act): DynamicWrapper(\n",
|
| 553 |
+
" (module): Hardswish()\n",
|
| 554 |
+
" )\n",
|
| 555 |
+
" (depth_conv): DynamicConv(\n",
|
| 556 |
+
" (residuals): Sequential(\n",
|
| 557 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 558 |
+
" )\n",
|
| 559 |
+
" )\n",
|
| 560 |
+
" (depth_norm): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 561 |
+
" (depth_act): DyReLUB(\n",
|
| 562 |
+
" (coef_net): Sequential(\n",
|
| 563 |
+
" (0): Linear(in_features=24, out_features=288, bias=True)\n",
|
| 564 |
+
" )\n",
|
| 565 |
+
" (sigmoid): Sigmoid()\n",
|
| 566 |
+
" )\n",
|
| 567 |
+
" (ca): CoordAtt()\n",
|
| 568 |
+
" (proj_conv): DynamicConv(\n",
|
| 569 |
+
" (residuals): Sequential(\n",
|
| 570 |
+
" (0): Linear(in_features=24, out_features=4, bias=True)\n",
|
| 571 |
+
" )\n",
|
| 572 |
+
" )\n",
|
| 573 |
+
" (proj_norm): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 574 |
+
" (context_gen): ContextGen(\n",
|
| 575 |
+
" (joint_conv): Conv2d(32, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 576 |
+
" (joint_norm): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 577 |
+
" (joint_act): Hardswish()\n",
|
| 578 |
+
" (conv_f): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 579 |
+
" (conv_t): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 580 |
+
" (pool_f): Sequential()\n",
|
| 581 |
+
" (pool_t): Sequential()\n",
|
| 582 |
+
" )\n",
|
| 583 |
+
" )\n",
|
| 584 |
+
" (10): DY_Block(\n",
|
| 585 |
+
" (exp_conv): DynamicConv(\n",
|
| 586 |
+
" (residuals): Sequential(\n",
|
| 587 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 588 |
+
" )\n",
|
| 589 |
+
" )\n",
|
| 590 |
+
" (exp_norm): BatchNorm2d(192, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 591 |
+
" (exp_act): DynamicWrapper(\n",
|
| 592 |
+
" (module): Hardswish()\n",
|
| 593 |
+
" )\n",
|
| 594 |
+
" (depth_conv): DynamicConv(\n",
|
| 595 |
+
" (residuals): Sequential(\n",
|
| 596 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 597 |
+
" )\n",
|
| 598 |
+
" )\n",
|
| 599 |
+
" (depth_norm): BatchNorm2d(192, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 600 |
+
" (depth_act): DyReLUB(\n",
|
| 601 |
+
" (coef_net): Sequential(\n",
|
| 602 |
+
" (0): Linear(in_features=48, out_features=768, bias=True)\n",
|
| 603 |
+
" )\n",
|
| 604 |
+
" (sigmoid): Sigmoid()\n",
|
| 605 |
+
" )\n",
|
| 606 |
+
" (ca): CoordAtt()\n",
|
| 607 |
+
" (proj_conv): DynamicConv(\n",
|
| 608 |
+
" (residuals): Sequential(\n",
|
| 609 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 610 |
+
" )\n",
|
| 611 |
+
" )\n",
|
| 612 |
+
" (proj_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 613 |
+
" (context_gen): ContextGen(\n",
|
| 614 |
+
" (joint_conv): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 615 |
+
" (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 616 |
+
" (joint_act): Hardswish()\n",
|
| 617 |
+
" (conv_f): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 618 |
+
" (conv_t): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 619 |
+
" (pool_f): Sequential()\n",
|
| 620 |
+
" (pool_t): Sequential()\n",
|
| 621 |
+
" )\n",
|
| 622 |
+
" )\n",
|
| 623 |
+
" (11): DY_Block(\n",
|
| 624 |
+
" (exp_conv): DynamicConv(\n",
|
| 625 |
+
" (residuals): Sequential(\n",
|
| 626 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 627 |
+
" )\n",
|
| 628 |
+
" )\n",
|
| 629 |
+
" (exp_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 630 |
+
" (exp_act): DynamicWrapper(\n",
|
| 631 |
+
" (module): Hardswish()\n",
|
| 632 |
+
" )\n",
|
| 633 |
+
" (depth_conv): DynamicConv(\n",
|
| 634 |
+
" (residuals): Sequential(\n",
|
| 635 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 636 |
+
" )\n",
|
| 637 |
+
" )\n",
|
| 638 |
+
" (depth_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 639 |
+
" (depth_act): DyReLUB(\n",
|
| 640 |
+
" (coef_net): Sequential(\n",
|
| 641 |
+
" (0): Linear(in_features=48, out_features=1088, bias=True)\n",
|
| 642 |
+
" )\n",
|
| 643 |
+
" (sigmoid): Sigmoid()\n",
|
| 644 |
+
" )\n",
|
| 645 |
+
" (ca): CoordAtt()\n",
|
| 646 |
+
" (proj_conv): DynamicConv(\n",
|
| 647 |
+
" (residuals): Sequential(\n",
|
| 648 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 649 |
+
" )\n",
|
| 650 |
+
" )\n",
|
| 651 |
+
" (proj_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 652 |
+
" (context_gen): ContextGen(\n",
|
| 653 |
+
" (joint_conv): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 654 |
+
" (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 655 |
+
" (joint_act): Hardswish()\n",
|
| 656 |
+
" (conv_f): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 657 |
+
" (conv_t): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 658 |
+
" (pool_f): Sequential()\n",
|
| 659 |
+
" (pool_t): Sequential()\n",
|
| 660 |
+
" )\n",
|
| 661 |
+
" )\n",
|
| 662 |
+
" (12): DY_Block(\n",
|
| 663 |
+
" (exp_conv): DynamicConv(\n",
|
| 664 |
+
" (residuals): Sequential(\n",
|
| 665 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 666 |
+
" )\n",
|
| 667 |
+
" )\n",
|
| 668 |
+
" (exp_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 669 |
+
" (exp_act): DynamicWrapper(\n",
|
| 670 |
+
" (module): Hardswish()\n",
|
| 671 |
+
" )\n",
|
| 672 |
+
" (depth_conv): DynamicConv(\n",
|
| 673 |
+
" (residuals): Sequential(\n",
|
| 674 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 675 |
+
" )\n",
|
| 676 |
+
" )\n",
|
| 677 |
+
" (depth_norm): BatchNorm2d(272, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 678 |
+
" (depth_act): DyReLUB(\n",
|
| 679 |
+
" (coef_net): Sequential(\n",
|
| 680 |
+
" (0): Linear(in_features=48, out_features=1088, bias=True)\n",
|
| 681 |
+
" )\n",
|
| 682 |
+
" (sigmoid): Sigmoid()\n",
|
| 683 |
+
" )\n",
|
| 684 |
+
" (ca): CoordAtt()\n",
|
| 685 |
+
" (proj_conv): DynamicConv(\n",
|
| 686 |
+
" (residuals): Sequential(\n",
|
| 687 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 688 |
+
" )\n",
|
| 689 |
+
" )\n",
|
| 690 |
+
" (proj_norm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 691 |
+
" (context_gen): ContextGen(\n",
|
| 692 |
+
" (joint_conv): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 693 |
+
" (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 694 |
+
" (joint_act): Hardswish()\n",
|
| 695 |
+
" (conv_f): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 696 |
+
" (conv_t): Conv2d(48, 272, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 697 |
+
" (pool_f): AvgPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
|
| 698 |
+
" (pool_t): AvgPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))\n",
|
| 699 |
+
" )\n",
|
| 700 |
+
" )\n",
|
| 701 |
+
" (13-14): 2 x DY_Block(\n",
|
| 702 |
+
" (exp_conv): DynamicConv(\n",
|
| 703 |
+
" (residuals): Sequential(\n",
|
| 704 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 705 |
+
" )\n",
|
| 706 |
+
" )\n",
|
| 707 |
+
" (exp_norm): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 708 |
+
" (exp_act): DynamicWrapper(\n",
|
| 709 |
+
" (module): Hardswish()\n",
|
| 710 |
+
" )\n",
|
| 711 |
+
" (depth_conv): DynamicConv(\n",
|
| 712 |
+
" (residuals): Sequential(\n",
|
| 713 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 714 |
+
" )\n",
|
| 715 |
+
" )\n",
|
| 716 |
+
" (depth_norm): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 717 |
+
" (depth_act): DyReLUB(\n",
|
| 718 |
+
" (coef_net): Sequential(\n",
|
| 719 |
+
" (0): Linear(in_features=48, out_features=1536, bias=True)\n",
|
| 720 |
+
" )\n",
|
| 721 |
+
" (sigmoid): Sigmoid()\n",
|
| 722 |
+
" )\n",
|
| 723 |
+
" (ca): CoordAtt()\n",
|
| 724 |
+
" (proj_conv): DynamicConv(\n",
|
| 725 |
+
" (residuals): Sequential(\n",
|
| 726 |
+
" (0): Linear(in_features=48, out_features=4, bias=True)\n",
|
| 727 |
+
" )\n",
|
| 728 |
+
" )\n",
|
| 729 |
+
" (proj_norm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 730 |
+
" (context_gen): ContextGen(\n",
|
| 731 |
+
" (joint_conv): Conv2d(64, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 732 |
+
" (joint_norm): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 733 |
+
" (joint_act): Hardswish()\n",
|
| 734 |
+
" (conv_f): Conv2d(48, 384, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 735 |
+
" (conv_t): Conv2d(48, 384, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 736 |
+
" (pool_f): Sequential()\n",
|
| 737 |
+
" (pool_t): Sequential()\n",
|
| 738 |
+
" )\n",
|
| 739 |
+
" )\n",
|
| 740 |
+
" )\n",
|
| 741 |
+
" (in_c): Conv2dNormActivation(\n",
|
| 742 |
+
" (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
| 743 |
+
" (1): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 744 |
+
" (2): Hardswish()\n",
|
| 745 |
+
" )\n",
|
| 746 |
+
" (out_c): Conv2dNormActivation(\n",
|
| 747 |
+
" (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
| 748 |
+
" (1): BatchNorm2d(384, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
|
| 749 |
+
" (2): Hardswish()\n",
|
| 750 |
+
" )\n",
|
| 751 |
+
" (classifier): Sequential(\n",
|
| 752 |
+
" (0): AdaptiveAvgPool2d(output_size=1)\n",
|
| 753 |
+
" (1): Flatten(start_dim=1, end_dim=-1)\n",
|
| 754 |
+
" (2): Linear(in_features=384, out_features=512, bias=True)\n",
|
| 755 |
+
" (3): Hardswish()\n",
|
| 756 |
+
" (4): Dropout(p=0.2, inplace=True)\n",
|
| 757 |
+
" (5): Linear(in_features=512, out_features=7, bias=True)\n",
|
| 758 |
+
" )\n",
|
| 759 |
+
")\n",
|
| 760 |
+
"Dataset from ./datasets/train_mp3.hdf5 with length 36878.\n",
|
| 761 |
+
"Warning: sample_weight_offset=100 minnow=1355.0\n",
|
| 762 |
+
"Dataset from ./datasets/test_mp3.hdf5 with length 9261.\n",
|
| 763 |
+
"Epoch 1/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 30.0\n",
|
| 764 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 765 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 766 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 767 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 768 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 769 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 770 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 771 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 772 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 773 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 774 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 775 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 776 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 777 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 778 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 779 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 780 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 781 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 782 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 783 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 784 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 785 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 786 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 787 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 788 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 789 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 790 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 791 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 792 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 793 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 794 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 795 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 796 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 797 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 798 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 799 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 800 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 801 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 802 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 803 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 804 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 805 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 806 |
+
"Setting temperature for attention over kernels to 30.0\n",
|
| 807 |
+
"Epoch 1/200: 100%|████████████| 145/145 [01:15<00:00, 1.93it/s, train_loss=1.8]\n",
|
| 808 |
+
"Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.87it/s]\n",
|
| 809 |
+
"Confusion Matrix:\n",
|
| 810 |
+
"[[ 529 0 0 76 656 19 0]\n",
|
| 811 |
+
" [ 70 2 0 6 799 5 0]\n",
|
| 812 |
+
" [ 99 0 0 1 311 11 0]\n",
|
| 813 |
+
" [ 440 0 0 104 1255 22 0]\n",
|
| 814 |
+
" [ 115 0 1 34 2903 12 0]\n",
|
| 815 |
+
" [ 195 0 2 51 1063 150 0]\n",
|
| 816 |
+
" [ 69 0 0 4 239 18 0]]\n",
|
| 817 |
+
"Epoch 1/200, Train Loss: 1.8010, Validation Loss: 1.7283, Validation Accuracy: 0.3982, LR: 0.000012, Time: 84.82s\n",
|
| 818 |
+
"Epoch 2/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 29.0\n",
|
| 819 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 820 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 821 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 822 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 823 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 824 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 825 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 826 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 827 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 828 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 829 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 830 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 831 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 832 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 833 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 834 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 835 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 836 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 837 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 838 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 839 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 840 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 841 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 842 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 843 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 844 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 845 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 846 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 847 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 848 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 849 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 850 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 851 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 852 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 853 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 854 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 855 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 856 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 857 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 858 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 859 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 860 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 861 |
+
"Setting temperature for attention over kernels to 29.0\n",
|
| 862 |
+
"Epoch 2/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.52]\n",
|
| 863 |
+
"Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.87it/s]\n",
|
| 864 |
+
"Confusion Matrix:\n",
|
| 865 |
+
"[[ 438 3 0 324 484 31 0]\n",
|
| 866 |
+
" [ 41 38 0 51 744 8 0]\n",
|
| 867 |
+
" [ 80 7 0 26 294 15 0]\n",
|
| 868 |
+
" [ 158 5 0 519 1084 55 0]\n",
|
| 869 |
+
" [ 30 4 0 54 2936 41 0]\n",
|
| 870 |
+
" [ 52 10 0 140 956 303 0]\n",
|
| 871 |
+
" [ 37 11 0 21 232 29 0]]\n",
|
| 872 |
+
"Epoch 2/200, Train Loss: 1.5245, Validation Loss: 1.6068, Validation Accuracy: 0.4572, LR: 0.000022, Time: 84.03s\n",
|
| 873 |
+
"Epoch 3/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 28.0\n",
|
| 874 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 875 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 876 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 877 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 878 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 879 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 880 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 881 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 882 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 883 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 884 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 885 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 886 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 887 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 888 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 889 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 890 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 891 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 892 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 893 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 894 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 895 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 896 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 897 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 898 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 899 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 900 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 901 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 902 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 903 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 904 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 905 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 906 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 907 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 908 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 909 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 910 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 911 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 912 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 913 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 914 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 915 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 916 |
+
"Setting temperature for attention over kernels to 28.0\n",
|
| 917 |
+
"Epoch 3/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.36]\n",
|
| 918 |
+
"Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.88it/s]\n",
|
| 919 |
+
"Confusion Matrix:\n",
|
| 920 |
+
"[[ 112 0 0 723 324 121 0]\n",
|
| 921 |
+
" [ 13 139 0 160 518 52 0]\n",
|
| 922 |
+
" [ 9 0 0 148 199 66 0]\n",
|
| 923 |
+
" [ 5 1 0 950 765 100 0]\n",
|
| 924 |
+
" [ 6 0 0 222 2709 128 0]\n",
|
| 925 |
+
" [ 5 1 0 224 567 664 0]\n",
|
| 926 |
+
" [ 5 14 0 68 149 94 0]]\n",
|
| 927 |
+
"Epoch 3/200, Train Loss: 1.3576, Validation Loss: 1.6087, Validation Accuracy: 0.4939, LR: 0.000060, Time: 83.90s\n",
|
| 928 |
+
"Epoch 4/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 27.0\n",
|
| 929 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 930 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 931 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 932 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 933 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 934 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 935 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 936 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 937 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 938 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 939 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 940 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 941 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 942 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 943 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 944 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 945 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 946 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 947 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 948 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 949 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 950 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 951 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 952 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 953 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 954 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 955 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 956 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 957 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 958 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 959 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 960 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 961 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 962 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 963 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 964 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 965 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 966 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 967 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 968 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 969 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 970 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 971 |
+
"Setting temperature for attention over kernels to 27.0\n",
|
| 972 |
+
"Epoch 4/200: 100%|███████████| 145/145 [01:14<00:00, 1.95it/s, train_loss=1.26]\n",
|
| 973 |
+
"Validating: 100%|███████████████████████████████| 37/37 [00:09<00:00, 3.85it/s]\n",
|
| 974 |
+
"Confusion Matrix:\n",
|
| 975 |
+
"[[ 28 2 0 472 590 188 0]\n",
|
| 976 |
+
" [ 0 198 0 60 603 21 0]\n",
|
| 977 |
+
" [ 0 0 0 37 338 47 0]\n",
|
| 978 |
+
" [ 0 1 0 659 1077 84 0]\n",
|
| 979 |
+
" [ 0 0 0 21 3000 44 0]\n",
|
| 980 |
+
" [ 0 0 0 42 1005 414 0]\n",
|
| 981 |
+
" [ 0 15 0 7 213 95 0]]\n",
|
| 982 |
+
"Epoch 4/200, Train Loss: 1.2573, Validation Loss: 1.8212, Validation Accuracy: 0.4642, LR: 0.000142, Time: 84.06s\n",
|
| 983 |
+
"Epoch 5/200: 0%| | 0/145 [00:00<?, ?it/s]Setting temperature for attention over kernels to 26.0\n",
|
| 984 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 985 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 986 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 987 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 988 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 989 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 990 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 991 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 992 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 993 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 994 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 995 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 996 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 997 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 998 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 999 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1000 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1001 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1002 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1003 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1004 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1005 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1006 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1007 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1008 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1009 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1010 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1011 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1012 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1013 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1014 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1015 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1016 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1017 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1018 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1019 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1020 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1021 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1022 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1023 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1024 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1025 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1026 |
+
"Setting temperature for attention over kernels to 26.0\n",
|
| 1027 |
+
"Epoch 5/200: 73%|████████ | 106/145 [00:55<00:19, 2.01it/s, train_loss=1.27]"
|
| 1028 |
+
]
|
| 1029 |
+
}
|
| 1030 |
+
],
|
| 1031 |
+
"source": [
|
| 1032 |
+
"!python ex_train_audio_classification_mode.py --cuda --train --pretrained --gain_augment=5 --n_epochs=200 --model_name=dymn04_im --batch_size=256 --max_lr=0.001 --pretrain_final_temp=30 --adamw"
|
| 1033 |
+
]
|
| 1034 |
+
},
|
| 1035 |
+
{
|
| 1036 |
+
"cell_type": "code",
|
| 1037 |
+
"execution_count": 15,
|
| 1038 |
+
"id": "d3b0a027-1774-4725-9e29-0112001369fe",
|
| 1039 |
+
"metadata": {},
|
| 1040 |
+
"outputs": [],
|
| 1041 |
+
"source": [
|
| 1042 |
+
"!rm -rf '/workspace/train/EfficientAT-main/training_results_dymn04_im_20250816_031533'"
|
| 1043 |
+
]
|
| 1044 |
+
},
|
| 1045 |
+
{
|
| 1046 |
+
"cell_type": "code",
|
| 1047 |
+
"execution_count": null,
|
| 1048 |
+
"id": "fe6df804-8e48-41fc-9dab-5c409d591379",
|
| 1049 |
+
"metadata": {},
|
| 1050 |
+
"outputs": [],
|
| 1051 |
+
"source": [
|
| 1052 |
+
"!python inference_classification.py --cuda --pre_trained_model_path='/workspace/main/EfficientAT-main/output_model/best_checkpoint/dymn04_im04_acc_0.8923.pt' --audio_path=\"/workspace/main/EfficientAT-main/03-01-01-01-02-01-24.mp3\""
|
| 1053 |
+
]
|
| 1054 |
+
},
|
| 1055 |
+
{
|
| 1056 |
+
"cell_type": "code",
|
| 1057 |
+
"execution_count": null,
|
| 1058 |
+
"id": "3c8074fe-eedf-4cf6-8ab3-1815e733e6b9",
|
| 1059 |
+
"metadata": {},
|
| 1060 |
+
"outputs": [],
|
| 1061 |
+
"source": [
|
| 1062 |
+
"import os\n",
|
| 1063 |
+
"import shutil\n",
|
| 1064 |
+
"\n",
|
| 1065 |
+
"# Mount Google Drive nếu cần\n",
|
| 1066 |
+
"\n",
|
| 1067 |
+
"# Đường dẫn tới thư mục cần zip\n",
|
| 1068 |
+
"folder_to_zip = '/workspace/main/EfficientAT-main'\n",
|
| 1069 |
+
"output_zip = '/workspace/main/EfficientAT-main.zip'\n",
|
| 1070 |
+
"\n",
|
| 1071 |
+
"# Tạo file zip\n",
|
| 1072 |
+
"shutil.make_archive(output_zip.replace('.zip', ''), 'zip', folder_to_zip)\n",
|
| 1073 |
+
"\n",
|
| 1074 |
+
"print(f'Thư mục đã được nén thành công thành file zip: {output_zip}')\n"
|
| 1075 |
+
]
|
| 1076 |
+
},
|
| 1077 |
+
{
|
| 1078 |
+
"cell_type": "code",
|
| 1079 |
+
"execution_count": 12,
|
| 1080 |
+
"id": "3c819525-cde6-4c3b-a24d-0bfb5e5c2b01",
|
| 1081 |
+
"metadata": {},
|
| 1082 |
+
"outputs": [],
|
| 1083 |
+
"source": [
|
| 1084 |
+
"import numpy as np"
|
| 1085 |
+
]
|
| 1086 |
+
},
|
| 1087 |
+
{
|
| 1088 |
+
"cell_type": "code",
|
| 1089 |
+
"execution_count": null,
|
| 1090 |
+
"id": "f3edf2d9-07d7-4472-86ec-81b800dcecbd",
|
| 1091 |
+
"metadata": {},
|
| 1092 |
+
"outputs": [],
|
| 1093 |
+
"source": []
|
| 1094 |
+
}
|
| 1095 |
+
],
|
| 1096 |
+
"metadata": {
|
| 1097 |
+
"kernelspec": {
|
| 1098 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1099 |
+
"language": "python",
|
| 1100 |
+
"name": "python3"
|
| 1101 |
+
},
|
| 1102 |
+
"language_info": {
|
| 1103 |
+
"codemirror_mode": {
|
| 1104 |
+
"name": "ipython",
|
| 1105 |
+
"version": 3
|
| 1106 |
+
},
|
| 1107 |
+
"file_extension": ".py",
|
| 1108 |
+
"mimetype": "text/x-python",
|
| 1109 |
+
"name": "python",
|
| 1110 |
+
"nbconvert_exporter": "python",
|
| 1111 |
+
"pygments_lexer": "ipython3",
|
| 1112 |
+
"version": "3.10.12"
|
| 1113 |
+
}
|
| 1114 |
+
},
|
| 1115 |
+
"nbformat": 4,
|
| 1116 |
+
"nbformat_minor": 5
|
| 1117 |
+
}
|