{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BarVox API - Kid Vocabulary Test\n", "\n", "Tests the BarVox API using a real kid's vocabulary:\n", "- **Bank-12**: 12 words (+ `_unknown`), each with 3-5 audio samples from the kid\n", "- **test-12**: Test recordings to predict which word the kid is saying\n", "\n", "**Flow:**\n", "1. Extract embeddings from every sample in Bank-12 → build a dictionary\n", "2. For each test file → extract embeddings → compare against the dictionary\n", "3. See if the top-ranked match is the correct word\n", "\n", "**Start the server first:**\n", "```bash\n", "py -3.10 -m uvicorn app:app --host 0.0.0.0 --port 8000\n", "```" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: requests in c:\\users\\97254\\desktop\\ronen\\jobs\\gcp\\.venv\\lib\\site-packages (2.32.5)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in c:\\users\\97254\\desktop\\ronen\\jobs\\gcp\\.venv\\lib\\site-packages (from requests) (3.4.4)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\97254\\desktop\\ronen\\jobs\\gcp\\.venv\\lib\\site-packages (from requests) (3.11)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\97254\\desktop\\ronen\\jobs\\gcp\\.venv\\lib\\site-packages (from requests) (2.6.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\97254\\desktop\\ronen\\jobs\\gcp\\.venv\\lib\\site-packages (from requests) (2026.2.25)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 23.2.1 -> 26.0.1\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "source": [ "!pip install requests\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import requests\n", "import json\n", "import os\n", "from pathlib import Path\n", "from collections import defaultdict\n", "import time\n", "\n", "BASE_URL = \"http://localhost:8000\"\n", "\n", "# === PATHS - adjust if needed ===\n", "BANK_DIR = r\"C:\\Users\\97254\\Desktop\\RONEN\\Jobs\\Dror\\Bank-12\"\n", "TEST_DIR = r\"C:\\Users\\97254\\Desktop\\RONEN\\Jobs\\Dror\\test-12\"\n", "\n", "# Which embedding model to use for similarity matching\n", "EMBEDDING_MODEL = \"hubert_embedding\" # options: hubert_embedding, wav2vec2_embedding, trill" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Health Check" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Server status: 200\n", "{\n", " \"status\": \"online\",\n", " \"models_loaded\": {\n", " \"hubert_ctc\": true,\n", " \"hubert_base\": true,\n", " \"wav2vec2_ctc\": true,\n", " \"wav2vec2_base\": true,\n", " \"trill\": false,\n", " \"silero_vad\": true,\n", " \"allosaurus\": false\n", " }\n", "}\n" ] } ], "source": [ "resp = requests.get(f\"{BASE_URL}/status\")\n", "print(f\"Server status: {resp.status_code}\")\n", "print(json.dumps(resp.json(), indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Scan the Dataset\n", "\n", "Discover all words in Bank-12 and map test files to their expected word." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " _unknown -> 9 samples\n", " ארוחת צהריים -> 5 samples\n", " בוכה -> 3 samples\n", " ופל -> 3 samples\n", " חולצה -> 3 samples\n", " טיול -> 2 samples\n", " לגרד -> 3 samples\n", " מקלחת -> 4 samples\n", " סבתא -> 4 samples\n", " עיתון -> 4 samples\n", " פצע -> 4 samples\n", " קרקר -> 5 samples\n", " תשע -> 9 samples\n", "\n", "Total words in bank: 13\n", "Total samples: 58\n" ] } ], "source": [ "# Scan Bank-12: each subfolder = one word\n", "bank_words = {}\n", "for word_folder in sorted(Path(BANK_DIR).iterdir()):\n", " if word_folder.is_dir():\n", " wav_files = sorted(word_folder.glob(\"*.wav\"))\n", " bank_words[word_folder.name] = wav_files\n", " print(f\" {word_folder.name:20s} -> {len(wav_files)} samples\")\n", "\n", "print(f\"\\nTotal words in bank: {len(bank_words)}\")\n", "print(f\"Total samples: {sum(len(v) for v in bank_words.values())}\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " beautiful.wav -> expected: _unknown\n", " black.wav -> expected: _unknown\n", " ארוחת צהריים 6.wav -> expected: ארוחת צהריים\n", " בוכה 4.wav -> expected: בוכה\n", " ופל 3.wav -> expected: ופל\n", " חולצה 4.wav -> expected: חולצה\n", " טיול 3.wav -> expected: טיול\n", " לגרד 4.wav -> expected: לגרד\n", " מקלחת 5.wav -> expected: מקלחת\n", " סבתא 5.wav -> expected: סבתא\n", " עיתון 5.wav -> expected: עיתון\n", " עשר 1.wav -> expected: _unknown\n", " פצע 5.wav -> expected: פצע\n", " קרקר 6.wav -> expected: קרקר\n", " קרקר 7.wav -> expected: קרקר\n", " תשע 11.wav -> expected: תשע\n", " תשע 12.wav -> expected: תשע\n", " תשע 13.wav -> expected: תשע\n", "\n", "Total test files: 18\n" ] } ], "source": [ "# Scan test-12 and figure out expected word for each test file\n", "# The test filename starts with the Hebrew word name (e.g. \"סבתא 5.wav\" -> expected = \"סבתא\")\n", "test_files = sorted(Path(TEST_DIR).glob(\"*.wav\"))\n", "\n", "# Map each test file to its expected word (or \"_unknown\" if no match)\n", "test_entries = []\n", "for tf in test_files:\n", " expected = \"_unknown\"\n", " for word_name in bank_words:\n", " if word_name != \"_unknown\" and tf.stem.startswith(word_name.split()[0]):\n", " expected = word_name\n", " break\n", " test_entries.append({\"path\": tf, \"expected\": expected})\n", " print(f\" {tf.name:30s} -> expected: {expected}\")\n", "\n", "print(f\"\\nTotal test files: {len(test_entries)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Extract Features from Bank (Dictionary)\n", "\n", "Send each audio sample in Bank-12 to `/extract` and collect embeddings." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Helper function ready.\n" ] } ], "source": [ "def extract_features(audio_path, use_vad=True):\n", " \"\"\"Extract features from one audio file via the API.\"\"\"\n", " params = {\n", " \"use_silero_vad\": use_vad,\n", " \"threshold\": 0.35,\n", " \"min_speech_ms\": 60,\n", " \"min_silence_ms\": 650,\n", " \"pad_ms\": 250,\n", " \"pad_ms_after\": 250,\n", " \"chunk_selection\": \"longest\",\n", " \"use_noise_reduction\": False,\n", " \"use_normalization\": False,\n", " \"selected_transcription_models\": [\"hubert\"],\n", " \"selected_embedding_models\": [EMBEDDING_MODEL],\n", " \"selected_acoustic_models\": [],\n", " \"use_beam_search\": False,\n", " \"use_n_top_ctc\": False,\n", " \"hubert_layer\": 12,\n", " \"wav2vec2_layer\": 12\n", " }\n", " with open(audio_path, \"rb\") as f:\n", " resp = requests.post(\n", " f\"{BASE_URL}/extract\",\n", " files={\"file\": (Path(audio_path).name, f, \"audio/wav\")},\n", " data={\"silero_params\": json.dumps(params)}\n", " )\n", " result = resp.json()\n", " if result.get(\"success\"):\n", " features = result[\"features\"]\n", " \n", " # Remap keys to match what compute_similarity() expects:\n", " # \"embedding\" (mean vector)\n", " # \"embedding_sequence\" (frame sequence)\n", " # The API returns e.g. \"hubert_embedding_mean\" / \"hubert_embedding_sequence\"\n", " remapped = {}\n", " for key, val in features.items():\n", " if key.endswith(\"_embedding_mean\"):\n", " remapped[\"embedding\"] = val\n", " elif key.endswith(\"_embedding_sequence\"):\n", " remapped[\"embedding_sequence\"] = val\n", " else:\n", " remapped[key] = val\n", " \n", " return remapped\n", " else:\n", " print(f\" ERROR extracting {audio_path}: {result.get('error')}\")\n", " return None\n", "\n", "print(\"Helper function ready.\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Processing word: _unknown (9 samples)\n", " + beautiful.wav\n", " + black.wav\n", " + blessing.wav\n", " + lion.wav\n", " + mama.wav\n", " + mouse.wav\n", " + orange.wav\n", " + papa.wav\n", " + pink.wav\n", "\n", "Processing word: ארוחת צהריים (5 samples)\n", " + ארוחת צהריים 1.wav\n", " + ארוחת צהריים 2.wav\n", " + ארוחת צהריים 3.wav\n", " + ארוחת צהריים 4.wav\n", " + ארוחת צהריים 5.wav\n", "\n", "Processing word: בוכה (3 samples)\n", " + בוכה 1.wav\n", " + בוכה 2.wav\n", " + בוכה 3.wav\n", "\n", "Processing word: ופל (3 samples)\n", " + ופל 1(1).wav\n", " + ופל 1.wav\n", " + ופל 2.wav\n", "\n", "Processing word: חולצה (3 samples)\n", " + חולצה 1.wav\n", " + חולצה 2.wav\n", " + חולצה 3.wav\n", "\n", "Processing word: טיול (2 samples)\n", " + טיול 1.wav\n", " + טיול 2.wav\n", "\n", "Processing word: לגרד (3 samples)\n", " + לגרד 1.wav\n", " + לגרד 2.wav\n", " + לגרד 3.wav\n", "\n", "Processing word: מקלחת (4 samples)\n", " + מקלחת 1.wav\n", " + מקלחת 2 .wav\n", " + מקלחת 3.wav\n", " + מקלחת 4.wav\n", "\n", "Processing word: סבתא (4 samples)\n", " + סבתא 1.wav\n", " + סבתא 2.wav\n", " + סבתא 3.wav\n", " + סבתא 4.wav\n", "\n", "Processing word: עיתון (4 samples)\n", " + עיתון 1.wav\n", " + עיתון 2.wav\n", " + עיתון 3.wav\n", " + עיתון 4.wav\n", "\n", "Processing word: פצע (4 samples)\n", " + פצע 1.wav\n", " + פצע 2.wav\n", " + פצע 3.wav\n", " + פצע 4.wav\n", "\n", "Processing word: קרקר (5 samples)\n", " + קרקר 1.wav\n", " + קרקר 2.wav\n", " + קרקר 3.wav\n", " + קרקר 4.wav\n", " + קרקר 5.wav\n", "\n", "Processing word: תשע (9 samples)\n", " + תשע 1.wav\n", " + תשע 10.wav\n", " + תשע 2.wav\n", " + תשע 3.wav\n", " + תשע 4.wav\n", " + תשע 5.wav\n", " + תשע 6.wav\n", " + תשע 7.wav\n", " + תשע 8.wav\n", "\n", "Done! Built dictionary with 13 words, 58 total recordings in 232.1s\n" ] } ], "source": [ "# Build the dictionary: extract features for every sample in Bank-12\n", "dictionary = {} # word_name -> list of {\"features\": {...}}\n", "\n", "start = time.time()\n", "for word_name, wav_files in bank_words.items():\n", " print(f\"\\nProcessing word: {word_name} ({len(wav_files)} samples)\")\n", " recordings = []\n", " for wav_path in wav_files:\n", " features = extract_features(wav_path)\n", " if features:\n", " recordings.append({\"features\": features})\n", " print(f\" + {wav_path.name}\")\n", " else:\n", " print(f\" x {wav_path.name} (failed)\")\n", " dictionary[word_name] = recordings\n", "\n", "elapsed = time.time() - start\n", "total_recordings = sum(len(v) for v in dictionary.values())\n", "print(f\"\\nDone! Built dictionary with {len(dictionary)} words, {total_recordings} total recordings in {elapsed:.1f}s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Run Predictions on Test Files\n", "\n", "For each test file:\n", "1. Extract features\n", "2. Send to `/compute_similarities` against the full dictionary\n", "3. Check if top-1 prediction matches expected word" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dictionary entries for similarity: 13 words\n", " _unknown -> 9 recordings\n", " ארוחת צהריים -> 5 recordings\n", " בוכה -> 3 recordings\n", " ופל -> 3 recordings\n", " חולצה -> 3 recordings\n", " טיול -> 2 recordings\n", " לגרד -> 3 recordings\n", " מקלחת -> 4 recordings\n", " סבתא -> 4 recordings\n", " עיתון -> 4 recordings\n", " פצע -> 4 recordings\n", " קרקר -> 5 recordings\n", " תשע -> 9 recordings\n" ] } ], "source": [ "# Build the dictionary_entries payload (same for all test files)\n", "dictionary_entries = []\n", "for word_name, recordings in dictionary.items():\n", " if recordings: # skip empty\n", " dictionary_entries.append({\n", " \"id\": word_name,\n", " \"label\": word_name,\n", " \"recordings\": recordings\n", " })\n", "\n", "print(f\"Dictionary entries for similarity: {len(dictionary_entries)} words\")\n", "for entry in dictionary_entries:\n", " print(f\" {entry['label']:20s} -> {len(entry['recordings'])} recordings\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "[1/18] Testing: beautiful.wav (expected: _unknown)\n", " Top 3 predictions:\n", " 1. _unknown score: 1.0000 <<<\n", " 2. ארוחת צהריים score: 0.9381 \n", " 3. תשע score: 0.9308 \n", " => CORRECT (predicted: _unknown, expected: _unknown)\n", "\n", "[2/18] Testing: black.wav (expected: _unknown)\n", " Top 3 predictions:\n", " 1. _unknown score: 1.0000 <<<\n", " 2. מקלחת score: 0.9565 \n", " 3. תשע score: 0.9560 \n", " => CORRECT (predicted: _unknown, expected: _unknown)\n", "\n", "[3/18] Testing: ארוחת צהריים 6.wav (expected: ארוחת צהריים)\n", " Top 3 predictions:\n", " 1. ארוחת צהריים score: 0.9814 <<<\n", " 2. פצע score: 0.9784 \n", " 3. בוכה score: 0.9739 \n", " => CORRECT (predicted: ארוחת צהריים, expected: ארוחת צהריים)\n", "\n", "[4/18] Testing: בוכה 4.wav (expected: בוכה)\n", " Top 3 predictions:\n", " 1. בוכה score: 0.9794 <<<\n", " 2. ופל score: 0.9782 \n", " 3. ארוחת צהריים score: 0.9712 \n", " => CORRECT (predicted: בוכה, expected: בוכה)\n", "\n", "[5/18] Testing: ופל 3.wav (expected: ופל)\n", " Top 3 predictions:\n", " 1. ופל score: 0.9819 <<<\n", " 2. בוכה score: 0.9776 \n", " 3. לגרד score: 0.9765 \n", " => CORRECT (predicted: ופל, expected: ופל)\n", "\n", "[6/18] Testing: חולצה 4.wav (expected: חולצה)\n", " Top 3 predictions:\n", " 1. חולצה score: 0.9836 <<<\n", " 2. עיתון score: 0.9675 \n", " 3. סבתא score: 0.9670 \n", " => CORRECT (predicted: חולצה, expected: חולצה)\n", "\n", "[7/18] Testing: טיול 3.wav (expected: טיול)\n", " Top 3 predictions:\n", " 1. טיול score: 0.9795 <<<\n", " 2. קרקר score: 0.9742 \n", " 3. פצע score: 0.9729 \n", " => CORRECT (predicted: טיול, expected: טיול)\n", "\n", "[8/18] Testing: לגרד 4.wav (expected: לגרד)\n", " Top 3 predictions:\n", " 1. ארוחת צהריים score: 0.9912 \n", " 2. לגרד score: 0.9838 <<<\n", " 3. סבתא score: 0.9720 \n", " => WRONG (predicted: ארוחת צהריים, expected: לגרד)\n", "\n", "[9/18] Testing: מקלחת 5.wav (expected: מקלחת)\n", " Top 3 predictions:\n", " 1. מקלחת score: 0.9847 <<<\n", " 2. ארוחת צהריים score: 0.9743 \n", " 3. עיתון score: 0.9728 \n", " => CORRECT (predicted: מקלחת, expected: מקלחת)\n", "\n", "[10/18] Testing: סבתא 5.wav (expected: סבתא)\n", " Top 3 predictions:\n", " 1. עיתון score: 0.9815 \n", " 2. פצע score: 0.9750 \n", " 3. ארוחת צהריים score: 0.9748 \n", " => WRONG (predicted: עיתון, expected: סבתא)\n", "\n", "[11/18] Testing: עיתון 5.wav (expected: עיתון)\n", " Top 3 predictions:\n", " 1. עיתון score: 0.9811 <<<\n", " 2. ארוחת צהריים score: 0.9789 \n", " 3. קרקר score: 0.9772 \n", " => CORRECT (predicted: עיתון, expected: עיתון)\n", "\n", "[12/18] Testing: עשר 1.wav (expected: _unknown)\n", " Top 3 predictions:\n", " 1. פצע score: 0.9712 \n", " 2. תשע score: 0.9710 \n", " 3. עיתון score: 0.9692 \n", " => WRONG (predicted: פצע, expected: _unknown)\n", "\n", "[13/18] Testing: פצע 5.wav (expected: פצע)\n", " Top 3 predictions:\n", " 1. ארוחת צהריים score: 0.9832 \n", " 2. פצע score: 0.9806 <<<\n", " 3. קרקר score: 0.9804 \n", " => WRONG (predicted: ארוחת צהריים, expected: פצע)\n", "\n", "[14/18] Testing: קרקר 6.wav (expected: קרקר)\n", " Top 3 predictions:\n", " 1. קרקר score: 0.9951 <<<\n", " 2. סבתא score: 0.9852 \n", " 3. ארוחת צהריים score: 0.9817 \n", " => CORRECT (predicted: קרקר, expected: קרקר)\n", "\n", "[15/18] Testing: קרקר 7.wav (expected: קרקר)\n", " Top 3 predictions:\n", " 1. קרקר score: 0.9850 <<<\n", " 2. תשע score: 0.9760 \n", " 3. סבתא score: 0.9755 \n", " => CORRECT (predicted: קרקר, expected: קרקר)\n", "\n", "[16/18] Testing: תשע 11.wav (expected: תשע)\n", " Top 3 predictions:\n", " 1. תשע score: 0.9831 <<<\n", " 2. עיתון score: 0.9723 \n", " 3. קרקר score: 0.9705 \n", " => CORRECT (predicted: תשע, expected: תשע)\n", "\n", "[17/18] Testing: תשע 12.wav (expected: תשע)\n", " Top 3 predictions:\n", " 1. תשע score: 0.9779 <<<\n", " 2. עיתון score: 0.9742 \n", " 3. ופל score: 0.9710 \n", " => CORRECT (predicted: תשע, expected: תשע)\n", "\n", "[18/18] Testing: תשע 13.wav (expected: תשע)\n" ] } ], "source": [ "# Run predictions\n", "results = []\n", "start = time.time()\n", "\n", "for i, entry in enumerate(test_entries):\n", " test_path = entry[\"path\"]\n", " expected = entry[\"expected\"]\n", " \n", " print(f\"\\n[{i+1}/{len(test_entries)}] Testing: {test_path.name} (expected: {expected})\")\n", " \n", " # Step 1: extract features from test file\n", " test_features = extract_features(test_path)\n", " if not test_features:\n", " results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", " continue\n", " \n", " # Step 2: compute similarity against dictionary\n", " sim_request = {\n", " \"test_features\": test_features,\n", " \"dictionary_entries\": dictionary_entries,\n", " \"dtw_params\": {\n", " \"distance_metric\": \"cosine\",\n", " \"step_pattern\": \"symmetric2\",\n", " \"window_type\": \"sakoe_chiba\",\n", " \"window_size\": 10,\n", " \"normalization\": \"path_length\"\n", " }\n", " }\n", " \n", " resp = requests.post(f\"{BASE_URL}/compute_similarities\", json=sim_request)\n", " sim_result = resp.json()\n", " \n", " if sim_result.get(\"success\") and sim_result[\"results\"]:\n", " top = sim_result[\"results\"][0]\n", " predicted = top[\"label\"]\n", " score = top[\"score\"]\n", " correct = (predicted == expected)\n", " \n", " # Show top 3\n", " print(f\" Top 3 predictions:\")\n", " for rank, r in enumerate(sim_result[\"results\"][:3], 1):\n", " marker = \"<<<\" if r[\"label\"] == expected else \"\"\n", " print(f\" {rank}. {r['label']:20s} score: {r['score']:.4f} {marker}\")\n", " \n", " status = \"CORRECT\" if correct else \"WRONG\"\n", " print(f\" => {status} (predicted: {predicted}, expected: {expected})\")\n", " \n", " results.append({\n", " \"file\": test_path.name,\n", " \"expected\": expected,\n", " \"predicted\": predicted,\n", " \"score\": score,\n", " \"correct\": correct,\n", " \"all_scores\": sim_result[\"results\"]\n", " })\n", " else:\n", " print(f\" ERROR: {sim_result.get('error')}\")\n", " results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", "\n", "elapsed = time.time() - start\n", "print(f\"\\nDone! Tested {len(results)} files in {elapsed:.1f}s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Results Summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Overall accuracy\n", "correct_count = sum(1 for r in results if r[\"correct\"])\n", "total = len(results)\n", "known_results = [r for r in results if r[\"expected\"] != \"_unknown\"]\n", "known_correct = sum(1 for r in known_results if r[\"correct\"])\n", "\n", "print(\"=\" * 70)\n", "print(f\"OVERALL ACCURACY: {correct_count}/{total} ({100*correct_count/total:.1f}%)\")\n", "print(f\"KNOWN WORDS ONLY: {known_correct}/{len(known_results)} ({100*known_correct/len(known_results):.1f}%)\" if known_results else \"\")\n", "print(\"=\" * 70)\n", "\n", "# Per-file breakdown\n", "print(f\"\\n{'File':<30s} {'Expected':<15s} {'Predicted':<15s} {'Score':>8s} {'Result'}\")\n", "print(\"-\" * 80)\n", "for r in results:\n", " status = \"OK\" if r[\"correct\"] else \"MISS\"\n", " print(f\"{r['file']:<30s} {r['expected']:<15s} {r['predicted']:<15s} {r['score']:>8.4f} {status}\")\n", "\n", "# Per-word accuracy\n", "print(f\"\\n{'Word':<20s} {'Correct':<10s} {'Total':<10s} {'Accuracy'}\")\n", "print(\"-\" * 55)\n", "word_stats = defaultdict(lambda: {\"correct\": 0, \"total\": 0})\n", "for r in results:\n", " word_stats[r[\"expected\"]][\"total\"] += 1\n", " if r[\"correct\"]:\n", " word_stats[r[\"expected\"]][\"correct\"] += 1\n", "\n", "for word, stats in sorted(word_stats.items()):\n", " acc = 100 * stats[\"correct\"] / stats[\"total\"] if stats[\"total\"] > 0 else 0\n", " print(f\"{word:<20s} {stats['correct']:<10d} {stats['total']:<10d} {acc:.0f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Confusion Analysis\n", "\n", "For wrong predictions, see what the model confused the word with." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "misses = [r for r in results if not r[\"correct\"]]\n", "\n", "if not misses:\n", " print(\"No misses! Perfect accuracy.\")\n", "else:\n", " print(f\"{len(misses)} misclassified files:\\n\")\n", " for r in misses:\n", " print(f\"File: {r['file']}\")\n", " print(f\" Expected: {r['expected']}\")\n", " print(f\" Predicted: {r['predicted']} (score: {r['score']:.4f})\")\n", " if \"all_scores\" in r:\n", " print(f\" Full ranking:\")\n", " for rank, s in enumerate(r[\"all_scores\"][:5], 1):\n", " marker = \"<<<\" if s[\"label\"] == r[\"expected\"] else \"\"\n", " print(f\" {rank}. {s['label']:<20s} {s['score']:.4f} {marker}\")\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Overfitting Check\n", "\n", "**How do we know it's not overfitting?**\n", "\n", "This system does NOT train on your data — it uses pre-trained models (HuBERT) that were trained on general English speech by Meta. The \"dictionary\" is just stored embeddings, not learned weights. So traditional overfitting (memorizing training data) doesn't apply here.\n", "\n", "However, there are still things to watch for:\n", "1. **Score gap** — If the margin between #1 and #2 is tiny, the model is guessing, not confident\n", "2. **Leave-one-out** — Use one bank sample as \"test\" and the rest as dictionary. If bank-on-bank accuracy is much higher than real test accuracy, the bank samples might be too similar to each other\n", "3. **Cross-dictionary** — Test against a completely different dictionary (Section 8 below)\n", "\n", "The cell below runs a leave-one-out cross-validation on the bank itself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 7a. Score gap analysis — are the correct predictions confident?\n", "print(\"=== SCORE GAP ANALYSIS ===\")\n", "print(f\"{'File':<25s} {'#1 Score':>10s} {'#2 Score':>10s} {'Gap':>8s} {'Confident?'}\")\n", "print(\"-\" * 70)\n", "\n", "for r in results:\n", " if \"all_scores\" not in r:\n", " continue\n", " scores = r[\"all_scores\"]\n", " s1 = scores[0][\"score\"]\n", " s2 = scores[1][\"score\"] if len(scores) > 1 else 0\n", " gap = s1 - s2\n", " confident = \"YES\" if gap > 0.005 else \"WEAK\" if gap > 0.002 else \"NO\"\n", " marker = \"\" if r[\"correct\"] else \" <-- MISS\"\n", " print(f\"{r['file']:<25s} {s1:>10.4f} {s2:>10.4f} {gap:>8.4f} {confident}{marker}\")\n", "\n", "# Summary\n", "gaps = []\n", "for r in results:\n", " if \"all_scores\" in r and len(r[\"all_scores\"]) >= 2:\n", " gaps.append(r[\"all_scores\"][0][\"score\"] - r[\"all_scores\"][1][\"score\"])\n", "\n", "print(f\"\\nAverage gap: {sum(gaps)/len(gaps):.4f}\")\n", "print(f\"Min gap: {min(gaps):.4f}\")\n", "print(f\"Max gap: {max(gaps):.4f}\")\n", "print(f\"\\nIf gaps are very small (<0.002), the model is uncertain and scores are close — not a sign of overfitting, but of low discriminability.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 7b. Leave-One-Out cross-validation on the bank\n", "# For each sample in the bank, hold it out and use the rest as dictionary.\n", "# This tells us how well the bank samples match *each other*.\n", "# If LOO accuracy >> test accuracy, the bank samples may be too homogeneous.\n", "\n", "print(\"=== LEAVE-ONE-OUT CROSS-VALIDATION (on Bank) ===\\n\")\n", "print(\"Hold out one bank sample, use the rest as dictionary, predict the held-out sample.\\n\")\n", "\n", "loo_results = []\n", "loo_start = time.time()\n", "\n", "for word_name, recordings in dictionary.items():\n", " if word_name == \"_unknown\":\n", " continue # skip unknown for LOO\n", " \n", " for i, held_out in enumerate(recordings):\n", " # Build dictionary WITHOUT the held-out sample\n", " loo_entries = []\n", " for w, recs in dictionary.items():\n", " if w == word_name:\n", " # Same word, exclude held-out sample\n", " remaining = [r for j, r in enumerate(recs) if j != i]\n", " if remaining:\n", " loo_entries.append({\"id\": w, \"label\": w, \"recordings\": remaining})\n", " else:\n", " if recs:\n", " loo_entries.append({\"id\": w, \"label\": w, \"recordings\": recs})\n", " \n", " # Compute similarity\n", " sim_request = {\n", " \"test_features\": held_out[\"features\"],\n", " \"dictionary_entries\": loo_entries,\n", " \"dtw_params\": {\n", " \"distance_metric\": \"cosine\",\n", " \"step_pattern\": \"symmetric2\",\n", " \"window_type\": \"sakoe_chiba\",\n", " \"window_size\": 10,\n", " \"normalization\": \"path_length\"\n", " }\n", " }\n", " resp = requests.post(f\"{BASE_URL}/compute_similarities\", json=sim_request)\n", " sim = resp.json()\n", " \n", " if sim.get(\"success\") and sim[\"results\"]:\n", " predicted = sim[\"results\"][0][\"label\"]\n", " correct = (predicted == word_name)\n", " loo_results.append({\"word\": word_name, \"predicted\": predicted, \"correct\": correct})\n", " else:\n", " loo_results.append({\"word\": word_name, \"predicted\": \"ERROR\", \"correct\": False})\n", "\n", "loo_elapsed = time.time() - loo_start\n", "\n", "loo_correct = sum(1 for r in loo_results if r[\"correct\"])\n", "loo_total = len(loo_results)\n", "print(f\"LOO Accuracy: {loo_correct}/{loo_total} ({100*loo_correct/loo_total:.1f}%)\")\n", "print(f\"Test Accuracy (known words): {known_correct}/{len(known_results)} ({100*known_correct/len(known_results):.1f}%)\")\n", "print(f\"Time: {loo_elapsed:.1f}s\")\n", "\n", "gap = (100*loo_correct/loo_total) - (100*known_correct/len(known_results))\n", "if gap > 15:\n", " print(f\"\\nWARNING: LOO is {gap:.0f}% higher than test — bank samples may be too similar to each other.\")\n", " print(\"Consider adding more diverse recordings to the bank.\")\n", "elif gap > 5:\n", " print(f\"\\nNote: LOO is {gap:.0f}% higher than test — slight gap, fairly normal.\")\n", "else:\n", " print(f\"\\nGood: LOO and test accuracy are close — no sign of overfitting.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## 8. Test a New Bank + Test Directory\n", "\n", "Point to a different bank and test directory to run the same evaluation. \n", "The bank directory should have subfolders (one per word), each containing `.wav` samples. \n", "The test directory should have `.wav` files named with the word prefix." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# === CHANGE THESE PATHS to your new bank and test directories ===\n", "NEW_BANK_DIR = \"Bank_New\" # <-- change this\n", "NEW_TEST_DIR = \"Test_New\" # <-- change this\n", "\n", "# Scan new bank\n", "new_bank_words = {}\n", "for word_folder in sorted(Path(NEW_BANK_DIR).iterdir()):\n", " if word_folder.is_dir():\n", " wav_files = sorted(word_folder.glob(\"*.wav\"))\n", " new_bank_words[word_folder.name] = wav_files\n", " print(f\" {word_folder.name:20s} -> {len(wav_files)} samples\")\n", "\n", "print(f\"\\nTotal words in new bank: {len(new_bank_words)}\")\n", "print(f\"Total samples: {sum(len(v) for v in new_bank_words.values())}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract features for the new bank\n", "new_dictionary = {}\n", "\n", "start = time.time()\n", "for word_name, wav_files in new_bank_words.items():\n", " print(f\"\\nProcessing word: {word_name} ({len(wav_files)} samples)\")\n", " recordings = []\n", " for wav_path in wav_files:\n", " features = extract_features(wav_path)\n", " if features:\n", " recordings.append({\"features\": features})\n", " print(f\" + {wav_path.name}\")\n", " else:\n", " print(f\" x {wav_path.name} (failed)\")\n", " new_dictionary[word_name] = recordings\n", "\n", "elapsed = time.time() - start\n", "total_recordings = sum(len(v) for v in new_dictionary.values())\n", "print(f\"\\nDone! Built new dictionary with {len(new_dictionary)} words, {total_recordings} total recordings in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scan new test files and map to expected words\n", "new_test_files = sorted(Path(NEW_TEST_DIR).glob(\"*.wav\"))\n", "\n", "new_test_entries = []\n", "for tf in new_test_files:\n", " expected = \"_unknown\"\n", " for word_name in new_bank_words:\n", " if word_name != \"_unknown\" and tf.stem.startswith(word_name.split()[0]):\n", " expected = word_name\n", " break\n", " new_test_entries.append({\"path\": tf, \"expected\": expected})\n", " print(f\" {tf.name:30s} -> expected: {expected}\")\n", "\n", "print(f\"\\nTotal new test files: {len(new_test_entries)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run predictions on new test set\n", "new_dict_entries = []\n", "for word_name, recordings in new_dictionary.items():\n", " if recordings:\n", " new_dict_entries.append({\"id\": word_name, \"label\": word_name, \"recordings\": recordings})\n", "\n", "new_results = []\n", "start = time.time()\n", "\n", "for i, entry in enumerate(new_test_entries):\n", " test_path = entry[\"path\"]\n", " expected = entry[\"expected\"]\n", " \n", " print(f\"\\n[{i+1}/{len(new_test_entries)}] Testing: {test_path.name} (expected: {expected})\")\n", " \n", " test_features = extract_features(test_path)\n", " if not test_features:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", " continue\n", " \n", " sim_request = {\n", " \"test_features\": test_features,\n", " \"dictionary_entries\": new_dict_entries,\n", " \"dtw_params\": {\n", " \"distance_metric\": \"cosine\",\n", " \"step_pattern\": \"symmetric2\",\n", " \"window_type\": \"sakoe_chiba\",\n", " \"window_size\": 10,\n", " \"normalization\": \"path_length\"\n", " }\n", " }\n", " \n", " resp = requests.post(f\"{BASE_URL}/compute_similarities\", json=sim_request)\n", " sim_result = resp.json()\n", " \n", " if sim_result.get(\"success\") and sim_result[\"results\"]:\n", " top = sim_result[\"results\"][0]\n", " predicted = top[\"label\"]\n", " score = top[\"score\"]\n", " correct = (predicted == expected)\n", " \n", " print(f\" Top 3:\")\n", " for rank, r in enumerate(sim_result[\"results\"][:3], 1):\n", " marker = \"<<<\" if r[\"label\"] == expected else \"\"\n", " print(f\" {rank}. {r['label']:20s} score: {r['score']:.4f} {marker}\")\n", " \n", " status = \"CORRECT\" if correct else \"WRONG\"\n", " print(f\" => {status}\")\n", " \n", " new_results.append({\n", " \"file\": test_path.name, \"expected\": expected, \"predicted\": predicted,\n", " \"score\": score, \"correct\": correct, \"all_scores\": sim_result[\"results\"]\n", " })\n", " else:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", "\n", "elapsed = time.time() - start\n", "print(f\"\\nDone! Tested {len(new_results)} files in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# New dictionary results summary\n", "new_correct = sum(1 for r in new_results if r[\"correct\"])\n", "new_total = len(new_results)\n", "new_known = [r for r in new_results if r[\"expected\"] != \"_unknown\"]\n", "new_known_correct = sum(1 for r in new_known if r[\"correct\"])\n", "\n", "print(\"=\" * 70)\n", "print(f\"NEW BANK — OVERALL ACCURACY: {new_correct}/{new_total} ({100*new_correct/new_total:.1f}%)\")\n", "if new_known:\n", " print(f\"NEW BANK — KNOWN WORDS ONLY: {new_known_correct}/{len(new_known)} ({100*new_known_correct/len(new_known):.1f}%)\")\n", "print(\"=\" * 70)\n", "\n", "# Compare with Bank-12 results\n", "print(f\"\\n--- Comparison ---\")\n", "print(f\"Bank-12 accuracy (known): {100*known_correct/len(known_results):.1f}%\")\n", "if new_known:\n", " print(f\"New bank accuracy (known): {100*new_known_correct/len(new_known):.1f}%\")\n", "\n", "# Per-file breakdown\n", "print(f\"\\n{'File':<30s} {'Expected':<15s} {'Predicted':<15s} {'Score':>8s} {'Result'}\")\n", "print(\"-\" * 80)\n", "for r in new_results:\n", " status = \"OK\" if r[\"correct\"] else \"MISS\"\n", " print(f\"{r['file']:<30s} {r['expected']:<15s} {r['predicted']:<15s} {r['score']:>8.4f} {status}\")\n", "\n", "# Per-word accuracy\n", "print(f\"\\n{'Word':<20s} {'Correct':<10s} {'Total':<10s} {'Accuracy'}\")\n", "print(\"-\" * 55)\n", "new_word_stats = defaultdict(lambda: {\"correct\": 0, \"total\": 0})\n", "for r in new_results:\n", " new_word_stats[r[\"expected\"]][\"total\"] += 1\n", " if r[\"correct\"]:\n", " new_word_stats[r[\"expected\"]][\"correct\"] += 1\n", "\n", "for word, stats in sorted(new_word_stats.items()):\n", " acc = 100 * stats[\"correct\"] / stats[\"total\"] if stats[\"total\"] > 0 else 0\n", " print(f\"{word:<20s} {stats['correct']:<10d} {stats['total']:<10d} {acc:.0f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Second TRy with Bank__New and _unknown" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# === CHANGE THESE PATHS to your new bank and test directories ===\n", "NEW_BANK_DIR = \"Bank_New\" # <-- change this\n", "NEW_TEST_DIR = \"Test_New\" # <-- change this\n", "\n", "# Scan new bank\n", "new_bank_words = {}\n", "for word_folder in sorted(Path(NEW_BANK_DIR).iterdir()):\n", " if word_folder.is_dir():\n", " wav_files = sorted(word_folder.glob(\"*.wav\"))\n", " new_bank_words[word_folder.name] = wav_files\n", " print(f\" {word_folder.name:20s} -> {len(wav_files)} samples\")\n", "\n", "print(f\"\\nTotal words in new bank: {len(new_bank_words)}\")\n", "print(f\"Total samples: {sum(len(v) for v in new_bank_words.values())}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract features for the new bank\n", "new_dictionary = {}\n", "\n", "start = time.time()\n", "for word_name, wav_files in new_bank_words.items():\n", " print(f\"\\nProcessing word: {word_name} ({len(wav_files)} samples)\")\n", " recordings = []\n", " for wav_path in wav_files:\n", " features = extract_features(wav_path)\n", " if features:\n", " recordings.append({\"features\": features})\n", " print(f\" + {wav_path.name}\")\n", " else:\n", " print(f\" x {wav_path.name} (failed)\")\n", " new_dictionary[word_name] = recordings\n", "\n", "elapsed = time.time() - start\n", "total_recordings = sum(len(v) for v in new_dictionary.values())\n", "print(f\"\\nDone! Built new dictionary with {len(new_dictionary)} words, {total_recordings} total recordings in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scan new test files and map to expected words\n", "new_test_files = sorted(Path(NEW_TEST_DIR).glob(\"*.wav\"))\n", "\n", "new_test_entries = []\n", "for tf in new_test_files:\n", " expected = \"_unknown\"\n", " for word_name in new_bank_words:\n", " if word_name != \"_unknown\" and tf.stem.startswith(word_name.split()[0]):\n", " expected = word_name\n", " break\n", " new_test_entries.append({\"path\": tf, \"expected\": expected})\n", " print(f\" {tf.name:30s} -> expected: {expected}\")\n", "\n", "print(f\"\\nTotal new test files: {len(new_test_entries)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run predictions on new test set\n", "new_dict_entries = []\n", "for word_name, recordings in new_dictionary.items():\n", " if recordings:\n", " new_dict_entries.append({\"id\": word_name, \"label\": word_name, \"recordings\": recordings})\n", "\n", "new_results = []\n", "start = time.time()\n", "\n", "for i, entry in enumerate(new_test_entries):\n", " test_path = entry[\"path\"]\n", " expected = entry[\"expected\"]\n", " \n", " print(f\"\\n[{i+1}/{len(new_test_entries)}] Testing: {test_path.name} (expected: {expected})\")\n", " \n", " test_features = extract_features(test_path)\n", " if not test_features:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", " continue\n", " \n", " sim_request = {\n", " \"test_features\": test_features,\n", " \"dictionary_entries\": new_dict_entries,\n", " \"dtw_params\": {\n", " \"distance_metric\": \"cosine\",\n", " \"step_pattern\": \"symmetric2\",\n", " \"window_type\": \"sakoe_chiba\",\n", " \"window_size\": 10,\n", " \"normalization\": \"path_length\"\n", " }\n", " }\n", " \n", " resp = requests.post(f\"{BASE_URL}/compute_similarities\", json=sim_request)\n", " sim_result = resp.json()\n", " \n", " if sim_result.get(\"success\") and sim_result[\"results\"]:\n", " top = sim_result[\"results\"][0]\n", " predicted = top[\"label\"]\n", " score = top[\"score\"]\n", " correct = (predicted == expected)\n", " \n", " print(f\" Top 3:\")\n", " for rank, r in enumerate(sim_result[\"results\"][:3], 1):\n", " marker = \"<<<\" if r[\"label\"] == expected else \"\"\n", " print(f\" {rank}. {r['label']:20s} score: {r['score']:.4f} {marker}\")\n", " \n", " status = \"CORRECT\" if correct else \"WRONG\"\n", " print(f\" => {status}\")\n", " \n", " new_results.append({\n", " \"file\": test_path.name, \"expected\": expected, \"predicted\": predicted,\n", " \"score\": score, \"correct\": correct, \"all_scores\": sim_result[\"results\"]\n", " })\n", " else:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", "\n", "elapsed = time.time() - start\n", "print(f\"\\nDone! Tested {len(new_results)} files in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Third Try — Real Speech _unknown + Score-Gap Rejection\n", "\n", "**Fixes applied:**\n", "1. Replaced synthetic sounds (white noise, claps, etc.) with **real spoken words** in `Bank_New/_unknown/`\n", "2. Added **threshold-based rejection** — if the score gap between #1 and #2 is too small, predict `_unknown`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# === Re-scan Bank_New (now with real speech _unknown samples) ===\n", "NEW_BANK_DIR = \"Bank_New\"\n", "NEW_TEST_DIR = \"Test_New\"\n", "\n", "new_bank_words = {}\n", "for word_folder in sorted(Path(NEW_BANK_DIR).iterdir()):\n", " if word_folder.is_dir():\n", " wav_files = sorted(word_folder.glob(\"*.wav\"))\n", " new_bank_words[word_folder.name] = wav_files\n", " print(f\" {word_folder.name:20s} -> {len(wav_files)} samples\")\n", "\n", "print(f\"\\nTotal words: {len(new_bank_words)}\")\n", "print(f\"Total samples: {sum(len(v) for v in new_bank_words.values())}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract features for the updated bank\n", "new_dictionary = {}\n", "\n", "start = time.time()\n", "for word_name, wav_files in new_bank_words.items():\n", " print(f\"\\nProcessing word: {word_name} ({len(wav_files)} samples)\")\n", " recordings = []\n", " for wav_path in wav_files:\n", " features = extract_features(wav_path)\n", " if features:\n", " recordings.append({\"features\": features})\n", " print(f\" + {wav_path.name}\")\n", " else:\n", " print(f\" x {wav_path.name} (failed)\")\n", " new_dictionary[word_name] = recordings\n", "\n", "elapsed = time.time() - start\n", "total_recordings = sum(len(v) for v in new_dictionary.values())\n", "print(f\"\\nDone! Built dictionary with {len(new_dictionary)} words, {total_recordings} recordings in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scan test files\n", "new_test_files = sorted(Path(NEW_TEST_DIR).glob(\"*.wav\"))\n", "\n", "new_test_entries = []\n", "for tf in new_test_files:\n", " expected = \"_unknown\"\n", " for word_name in new_bank_words:\n", " if word_name != \"_unknown\" and tf.stem.startswith(word_name.split()[0]):\n", " expected = word_name\n", " break\n", " new_test_entries.append({\"path\": tf, \"expected\": expected})\n", " print(f\" {tf.name:30s} -> expected: {expected}\")\n", "\n", "print(f\"\\nTotal test files: {len(new_test_entries)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run predictions WITH score-gap rejection\n", "# Tune these thresholds based on the score-gap analysis from section 7a\n", "UNKNOWN_MIN_GAP = 0.005 # If gap between #1 and #2 < 0.5%, reject to _unknown\n", "\n", "new_dict_entries = []\n", "for word_name, recordings in new_dictionary.items():\n", " if recordings:\n", " new_dict_entries.append({\"id\": word_name, \"label\": word_name, \"recordings\": recordings})\n", "\n", "new_results = []\n", "start = time.time()\n", "\n", "for i, entry in enumerate(new_test_entries):\n", " test_path = entry[\"path\"]\n", " expected = entry[\"expected\"]\n", " \n", " print(f\"\\n[{i+1}/{len(new_test_entries)}] Testing: {test_path.name} (expected: {expected})\")\n", " \n", " test_features = extract_features(test_path)\n", " if not test_features:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", " continue\n", " \n", " sim_request = {\n", " \"test_features\": test_features,\n", " \"dictionary_entries\": new_dict_entries,\n", " \"dtw_params\": {\n", " \"distance_metric\": \"cosine\",\n", " \"step_pattern\": \"symmetric2\",\n", " \"window_type\": \"sakoe_chiba\",\n", " \"window_size\": 10,\n", " \"normalization\": \"path_length\"\n", " },\n", " \"unknown_min_gap\": UNKNOWN_MIN_GAP\n", " }\n", " \n", " resp = requests.post(f\"{BASE_URL}/compute_similarities\", json=sim_request)\n", " sim_result = resp.json()\n", " \n", " if sim_result.get(\"success\") and sim_result[\"results\"]:\n", " top = sim_result[\"results\"][0]\n", " rejected = sim_result.get(\"rejected_to_unknown\", False)\n", " \n", " # If the API flagged it as rejected AND top prediction isn't already _unknown\n", " if rejected and top[\"label\"] != \"_unknown\":\n", " predicted = \"_unknown\"\n", " score = top[\"score\"]\n", " reason = \"(rejected: small gap)\"\n", " else:\n", " predicted = top[\"label\"]\n", " score = top[\"score\"]\n", " reason = \"\"\n", " \n", " correct = (predicted == expected)\n", " \n", " top_scores = sim_result[\"results\"][:3]\n", " gap = top_scores[0][\"score\"] - top_scores[1][\"score\"] if len(top_scores) > 1 else 999\n", " print(f\" Top 3 (gap={gap:.4f}):\")\n", " for rank, r in enumerate(top_scores, 1):\n", " marker = \"<<<\" if r[\"label\"] == expected else \"\"\n", " print(f\" {rank}. {r['label']:20s} score: {r['score']:.4f} {marker}\")\n", " \n", " status = \"CORRECT\" if correct else \"WRONG\"\n", " print(f\" => {status} (predicted: {predicted}) {reason}\")\n", " \n", " new_results.append({\n", " \"file\": test_path.name, \"expected\": expected, \"predicted\": predicted,\n", " \"score\": score, \"correct\": correct, \"all_scores\": sim_result[\"results\"],\n", " \"rejected\": rejected\n", " })\n", " else:\n", " new_results.append({\"file\": test_path.name, \"expected\": expected, \"predicted\": \"ERROR\", \"score\": 0, \"correct\": False})\n", "\n", "elapsed = time.time() - start\n", "print(f\"\\nDone! Tested {len(new_results)} files in {elapsed:.1f}s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Results summary\n", "correct_count = sum(1 for r in new_results if r[\"correct\"])\n", "total = len(new_results)\n", "known_results = [r for r in new_results if r[\"expected\"] != \"_unknown\"]\n", "known_correct = sum(1 for r in known_results if r[\"correct\"])\n", "unknown_results = [r for r in new_results if r[\"expected\"] == \"_unknown\"]\n", "unknown_correct = sum(1 for r in unknown_results if r[\"correct\"])\n", "\n", "print(\"=\" * 70)\n", "print(f\"OVERALL ACCURACY: {correct_count}/{total} ({100*correct_count/total:.1f}%)\")\n", "print(f\"KNOWN WORDS ONLY: {known_correct}/{len(known_results)} ({100*known_correct/len(known_results):.1f}%)\" if known_results else \"\")\n", "print(f\"UNKNOWN DETECTION: {unknown_correct}/{len(unknown_results)} ({100*unknown_correct/len(unknown_results):.1f}%)\" if unknown_results else \"\")\n", "rejected_count = sum(1 for r in new_results if r.get(\"rejected\"))\n", "print(f\"Rejected to _unknown: {rejected_count} files\")\n", "print(\"=\" * 70)\n", "\n", "print(f\"\\n{'File':<30s} {'Expected':<18s} {'Predicted':<18s} {'Score':>8s} {'Result'}\")\n", "print(\"-\" * 90)\n", "for r in new_results:\n", " status = \"OK\" if r[\"correct\"] else \"MISS\"\n", " rej = \" [REJ]\" if r.get(\"rejected\") else \"\"\n", " print(f\"{r['file']:<30s} {r['expected']:<18s} {r['predicted']:<18s} {r['score']:8.4f} {status}{rej}\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.12.1)", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 4 }