Spaces:
Sleeping
Sleeping
Commit ยท
d5f8ae0
1
Parent(s): 4b061f6
base modelv3
Browse files- README.md +158 -30
- main.py +42 -10
- models/__pycache__/base.cpython-310.pyc +0 -0
- models/base.py +2 -1
- models/model_v1/__pycache__/wrapper.cpython-310.pyc +0 -0
- models/model_v3/config.py +7 -0
- models/model_v3/model.py +63 -0
- models/model_v3/processor.py +172 -0
- models/model_v3/wrapper.py +182 -0
README.md
CHANGED
|
@@ -9,21 +9,24 @@ pinned: false
|
|
| 9 |
|
| 10 |
# Alzheimer's Detection Backend API
|
| 11 |
|
| 12 |
-
This repository contains a FastAPI-based backend for detecting Alzheimer's disease from linguistic data
|
| 13 |
|
| 14 |
## ๐ Features
|
| 15 |
|
| 16 |
- **FastAPI Framework**: High-performance, easy-to-use API.
|
| 17 |
- **Support for .cha Files**: Specialized parsing for CHAT format transcripts.
|
|
|
|
| 18 |
- **Multiple AI Models**:
|
| 19 |
-
- `
|
| 20 |
-
- `
|
| 21 |
-
-
|
|
|
|
| 22 |
|
| 23 |
## ๐ ๏ธ Prerequisites
|
| 24 |
|
| 25 |
- **Python 3.8+**
|
| 26 |
- **pip** package manager
|
|
|
|
| 27 |
|
| 28 |
## ๐ฅ Installation
|
| 29 |
|
|
@@ -61,6 +64,8 @@ uvicorn main:app --reload
|
|
| 61 |
|
| 62 |
The server will start at `http://127.0.0.1:8000`.
|
| 63 |
|
|
|
|
|
|
|
| 64 |
## ๐ API Documentation
|
| 65 |
|
| 66 |
### 1. Health Check
|
|
@@ -72,7 +77,7 @@ Checks if the API is active and lists loaded models.
|
|
| 72 |
```json
|
| 73 |
{
|
| 74 |
"status": "active",
|
| 75 |
-
"loaded_models": ["
|
| 76 |
}
|
| 77 |
```
|
| 78 |
|
|
@@ -86,7 +91,7 @@ Returns a list of all available model keys that can be used for prediction.
|
|
| 86 |
**Response:**
|
| 87 |
```json
|
| 88 |
{
|
| 89 |
-
"models": ["
|
| 90 |
}
|
| 91 |
```
|
| 92 |
|
|
@@ -95,58 +100,174 @@ Returns a list of all available model keys that can be used for prediction.
|
|
| 95 |
### 3. Predict / Analyze
|
| 96 |
**Endpoint:** `POST /predict`
|
| 97 |
|
| 98 |
-
Uploads
|
|
|
|
|
|
|
| 99 |
|
| 100 |
**Parameters:**
|
| 101 |
-
- `file`: The `.cha` transcript file (Form Data).
|
| 102 |
-
- `model_name`: The key of the model to use (e.g., `hybrid_deberta` or `model_v2`) (Form Data).
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
Focuses on sequence classification and attention scores for sentences.
|
| 110 |
|
| 111 |
```json
|
| 112 |
{
|
| 113 |
"filename": "sample.cha",
|
| 114 |
-
"prediction": "DEMENTIA",
|
| 115 |
-
"confidence": 0.85,
|
| 116 |
-
"is_dementia": true,
|
| 117 |
-
"attention_map": [
|
| 118 |
{
|
| 119 |
"sentence": "I saw the cookie jar.",
|
| 120 |
"attention_score": 0.92
|
| 121 |
-
}
|
| 122 |
-
...
|
| 123 |
],
|
| 124 |
"model_used": "hybrid_deberta"
|
| 125 |
}
|
| 126 |
```
|
| 127 |
|
| 128 |
-
###
|
| 129 |
Provides a rich set of metadata and linguistic features for explainability.
|
| 130 |
|
| 131 |
```json
|
| 132 |
{
|
| 133 |
"filename": "sample.cha",
|
| 134 |
-
"prediction": "Dementia",
|
| 135 |
-
"probability_dementia": 0.78,
|
| 136 |
"metadata": {
|
| 137 |
"age": 72,
|
| 138 |
"gender": "Female",
|
| 139 |
"sentence_count": 15
|
| 140 |
},
|
| 141 |
-
"linguistic_features": {
|
| 142 |
-
"TTR": 0.45,
|
| 143 |
"fillers_ratio": 0.05,
|
| 144 |
"repetitions_ratio": 0.02,
|
| 145 |
"retracing_ratio": 0.01,
|
| 146 |
"incomplete_ratio": 0.03,
|
| 147 |
"pauses_ratio": 0.12
|
| 148 |
},
|
| 149 |
-
"key_segments": [
|
| 150 |
{
|
| 151 |
"text": "Um... checking the... the overflowing water.",
|
| 152 |
"importance": 0.88
|
|
@@ -156,14 +277,21 @@ Provides a rich set of metadata and linguistic features for explainability.
|
|
| 156 |
}
|
| 157 |
```
|
| 158 |
|
|
|
|
|
|
|
| 159 |
## ๐ Project Structure
|
| 160 |
|
| 161 |
```
|
| 162 |
.
|
| 163 |
-
โโโ main.py
|
| 164 |
-
โโโ models/
|
| 165 |
-
โ โโโ base.py
|
| 166 |
-
โ โโโ model_v1/
|
| 167 |
-
โ
|
| 168 |
-
โโโ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
```
|
|
|
|
| 9 |
|
| 10 |
# Alzheimer's Detection Backend API
|
| 11 |
|
| 12 |
+
This repository contains a FastAPI-based backend for detecting Alzheimer's disease from linguistic data. It supports multiple machine learning models, including text-only analysis from `.cha` transcripts and a **multimodal model (V3)** that can process both text and audio.
|
| 13 |
|
| 14 |
## ๐ Features
|
| 15 |
|
| 16 |
- **FastAPI Framework**: High-performance, easy-to-use API.
|
| 17 |
- **Support for .cha Files**: Specialized parsing for CHAT format transcripts.
|
| 18 |
+
- **Multimodal Audio Support (V3)**: Process raw audio files with Automatic Speech Recognition (ASR).
|
| 19 |
- **Multiple AI Models**:
|
| 20 |
+
- `Model V1`: A DeBERTa-based hybrid model focusing on semantic understanding.
|
| 21 |
+
- `Model V2`: An explainable model with rich linguistic features (TTR, fillers, pauses, etc.).
|
| 22 |
+
- `Model V3 (Multimodal)`: A multimodal fusion model combining text, audio spectrograms, and linguistic features.
|
| 23 |
+
- **CORS Support**: Configured to allow requests from frontend applications.
|
| 24 |
|
| 25 |
## ๐ ๏ธ Prerequisites
|
| 26 |
|
| 27 |
- **Python 3.8+**
|
| 28 |
- **pip** package manager
|
| 29 |
+
- **FFmpeg**: Required for audio processing in Model V3.
|
| 30 |
|
| 31 |
## ๐ฅ Installation
|
| 32 |
|
|
|
|
| 64 |
|
| 65 |
The server will start at `http://127.0.0.1:8000`.
|
| 66 |
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
## ๐ API Documentation
|
| 70 |
|
| 71 |
### 1. Health Check
|
|
|
|
| 77 |
```json
|
| 78 |
{
|
| 79 |
"status": "active",
|
| 80 |
+
"loaded_models": ["Model V1", "Model V2", "Model V3 (Multimodal)"]
|
| 81 |
}
|
| 82 |
```
|
| 83 |
|
|
|
|
| 91 |
**Response:**
|
| 92 |
```json
|
| 93 |
{
|
| 94 |
+
"models": ["Model V1", "Model V2", "Model V3 (Multimodal)"]
|
| 95 |
}
|
| 96 |
```
|
| 97 |
|
|
|
|
| 100 |
### 3. Predict / Analyze
|
| 101 |
**Endpoint:** `POST /predict`
|
| 102 |
|
| 103 |
+
Uploads files and processes them using the specified model.
|
| 104 |
+
|
| 105 |
+
**Request Type:** `multipart/form-data`
|
| 106 |
|
| 107 |
**Parameters:**
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
| Parameter | Type | Required | Description |
|
| 110 |
+
|---------------|----------------|----------|-----------------------------------------------------------------------------|
|
| 111 |
+
| `model_name` | `string` | **Yes** | The key of the model (e.g., `Model V1`, `Model V2`, `Model V3 (Multimodal)`)|
|
| 112 |
+
| `file` | `file (.cha)` | Depends | The CHAT format transcript file. |
|
| 113 |
+
| `audio_file` | `file (audio)` | No | An audio file (e.g., `.wav`, `.mp3`). **Only for Model V3.** |
|
| 114 |
+
|
| 115 |
+
#### **Input Validation Rules**
|
| 116 |
+
|
| 117 |
+
| Model | `file` (.cha) | `audio_file` | Notes |
|
| 118 |
+
|-------------------------|---------------|--------------|-------------------------------------------------------|
|
| 119 |
+
| `Model V1` | **Required** | Ignored | Text-only model. |
|
| 120 |
+
| `Model V2` | **Required** | Ignored | Text-only model. |
|
| 121 |
+
| `Model V3 (Multimodal)` | Optional | Optional | At least one file must be provided. Supports 3 modes. |
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## ๐ง Model V3 (Multimodal) - Deep Dive
|
| 126 |
+
|
| 127 |
+
Model V3 is a **multimodal fusion model** that combines three branches of information for its predictions:
|
| 128 |
+
|
| 129 |
+
1. **Text Branch**: Uses a DeBERTa transformer with an LSTM layer to encode textual semantics.
|
| 130 |
+
2. **Audio Branch**: Uses a Vision Transformer (ViT) trained on spectrograms derived from the audio.
|
| 131 |
+
3. **Linguistic Branch**: A simple feedforward network processing extracted linguistic features (TTR, filler ratio, pause ratio, etc.).
|
| 132 |
+
|
| 133 |
+
### Processing Modes
|
| 134 |
+
|
| 135 |
+
Model V3 intelligently handles three different input scenarios:
|
| 136 |
+
|
| 137 |
+
#### **Mode 1: CHA File Only**
|
| 138 |
+
- **Input:** A `.cha` transcript file.
|
| 139 |
+
- **Process:**
|
| 140 |
+
1. Parses `*PAR:` (participant) lines from the CHA file.
|
| 141 |
+
2. Cleans the text for the DeBERTa model.
|
| 142 |
+
3. Extracts a 6-dimensional linguistic feature vector (TTR, fillers, repetitions, retracing, errors, pauses).
|
| 143 |
+
4. **Audio branch receives a zero-tensor** (no audio input).
|
| 144 |
+
- **Use Case:** When you have a pre-existing transcript and no audio.
|
| 145 |
+
|
| 146 |
+
#### **Mode 2: CHA File + Audio (Segmented)**
|
| 147 |
+
- **Input:** A `.cha` transcript file AND an audio file.
|
| 148 |
+
- **Process:**
|
| 149 |
+
1. Parses the CHA file for text and linguistic features (same as Mode 1).
|
| 150 |
+
2. Extracts timestamps (e.g., `15123_456`) from the CHA file.
|
| 151 |
+
3. Uses these timestamps to **slice the corresponding audio segments** from the full audio file.
|
| 152 |
+
4. Concatenates the slices and generates a spectrogram.
|
| 153 |
+
5. Passes the spectrogram to the ViT-based audio branch.
|
| 154 |
+
- **Use Case:** For maximum accuracy when you have a professionally transcribed CHA file that is time-aligned with its source audio.
|
| 155 |
+
|
| 156 |
+
#### **Mode 3: Audio Only (ASR)**
|
| 157 |
+
- **Input:** An audio file only (no `.cha` file).
|
| 158 |
+
- **Process:**
|
| 159 |
+
1. Uses OpenAI's **Whisper** model to transcribe the audio.
|
| 160 |
+
2. Applies CHAT-like formatting rules to the transcript:
|
| 161 |
+
- Detects pauses and inserts `[PAUSE]` tokens.
|
| 162 |
+
- Detects word repetitions and inserts `[/]` markers.
|
| 163 |
+
3. Extracts linguistic features from the generated transcript.
|
| 164 |
+
4. Generates a spectrogram from the **full audio file** (up to 30s).
|
| 165 |
+
- **Use Case:** For real-world inference when you only have raw audio (e.g., a voice recording).
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
### Model V3 Response Format
|
| 170 |
+
|
| 171 |
+
```json
|
| 172 |
+
{
|
| 173 |
+
"model_version": "v3_multimodal",
|
| 174 |
+
"filename": "sample.cha",
|
| 175 |
+
"predicted_label": "AD", // "AD" (Alzheimer's Disease) or "Control"
|
| 176 |
+
"confidence": 0.8721, // Probability score (0.0 - 1.0)
|
| 177 |
+
"modalities_used": ["text", "linguistic", "audio"],
|
| 178 |
+
"generated_transcript": null // Populated only in Audio-Only mode (Mode 3)
|
| 179 |
+
}
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
**Response Fields:**
|
| 183 |
|
| 184 |
+
| Field | Type | Description |
|
| 185 |
+
|------------------------|-----------------|------------------------------------------------------------------------------------------------------|
|
| 186 |
+
| `model_version` | `string` | Always `"v3_multimodal"` for this model. |
|
| 187 |
+
| `filename` | `string` | Name of the uploaded file, or `"audio_only_upload"` if no CHA file was provided. |
|
| 188 |
+
| `predicted_label` | `string` | The classification result: `"AD"` or `"Control"`. |
|
| 189 |
+
| `confidence` | `float` | The model's confidence score for the predicted label. |
|
| 190 |
+
| `modalities_used` | `array[string]` | Lists the modalities used (`"text"`, `"linguistic"`, `"audio"`). |
|
| 191 |
+
| `generated_transcript` | `string \| null`| The transcript generated by Whisper. **Only populated in Audio-Only mode (Mode 3)**, otherwise `null`.|
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
|
| 195 |
+
## Example API Requests (cURL)
|
| 196 |
+
|
| 197 |
+
### Model V1 / V2 (CHA File Only)
|
| 198 |
+
```bash
|
| 199 |
+
curl -X POST "http://127.0.0.1:8000/predict" \
|
| 200 |
+
-F "model_name=Model V1" \
|
| 201 |
+
-F "file=@/path/to/your/transcript.cha"
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### Model V3: CHA Only
|
| 205 |
+
```bash
|
| 206 |
+
curl -X POST "http://127.0.0.1:8000/predict" \
|
| 207 |
+
-F "model_name=Model V3 (Multimodal)" \
|
| 208 |
+
-F "file=@/path/to/your/transcript.cha"
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
### Model V3: CHA + Audio
|
| 212 |
+
```bash
|
| 213 |
+
curl -X POST "http://127.0.0.1:8000/predict" \
|
| 214 |
+
-F "model_name=Model V3 (Multimodal)" \
|
| 215 |
+
-F "file=@/path/to/your/transcript.cha" \
|
| 216 |
+
-F "audio_file=@/path/to/your/audio.wav"
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### Model V3: Audio Only (ASR)
|
| 220 |
+
```bash
|
| 221 |
+
curl -X POST "http://127.0.0.1:8000/predict" \
|
| 222 |
+
-F "model_name=Model V3 (Multimodal)" \
|
| 223 |
+
-F "audio_file=@/path/to/your/audio.wav"
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
---
|
| 227 |
+
|
| 228 |
+
## Older Model Response Formats
|
| 229 |
+
|
| 230 |
+
### **A. `Model V1` Output**
|
| 231 |
Focuses on sequence classification and attention scores for sentences.
|
| 232 |
|
| 233 |
```json
|
| 234 |
{
|
| 235 |
"filename": "sample.cha",
|
| 236 |
+
"prediction": "DEMENTIA",
|
| 237 |
+
"confidence": 0.85,
|
| 238 |
+
"is_dementia": true,
|
| 239 |
+
"attention_map": [
|
| 240 |
{
|
| 241 |
"sentence": "I saw the cookie jar.",
|
| 242 |
"attention_score": 0.92
|
| 243 |
+
}
|
|
|
|
| 244 |
],
|
| 245 |
"model_used": "hybrid_deberta"
|
| 246 |
}
|
| 247 |
```
|
| 248 |
|
| 249 |
+
### **B. `Model V2` Output**
|
| 250 |
Provides a rich set of metadata and linguistic features for explainability.
|
| 251 |
|
| 252 |
```json
|
| 253 |
{
|
| 254 |
"filename": "sample.cha",
|
| 255 |
+
"prediction": "Dementia",
|
| 256 |
+
"probability_dementia": 0.78,
|
| 257 |
"metadata": {
|
| 258 |
"age": 72,
|
| 259 |
"gender": "Female",
|
| 260 |
"sentence_count": 15
|
| 261 |
},
|
| 262 |
+
"linguistic_features": {
|
| 263 |
+
"TTR": 0.45,
|
| 264 |
"fillers_ratio": 0.05,
|
| 265 |
"repetitions_ratio": 0.02,
|
| 266 |
"retracing_ratio": 0.01,
|
| 267 |
"incomplete_ratio": 0.03,
|
| 268 |
"pauses_ratio": 0.12
|
| 269 |
},
|
| 270 |
+
"key_segments": [
|
| 271 |
{
|
| 272 |
"text": "Um... checking the... the overflowing water.",
|
| 273 |
"importance": 0.88
|
|
|
|
| 277 |
}
|
| 278 |
```
|
| 279 |
|
| 280 |
+
---
|
| 281 |
+
|
| 282 |
## ๐ Project Structure
|
| 283 |
|
| 284 |
```
|
| 285 |
.
|
| 286 |
+
โโโ main.py # Entry point, API routes, and CORS config
|
| 287 |
+
โโโ models/ # Model definitions and wrappers
|
| 288 |
+
โ โโโ base.py # Base class for model wrappers
|
| 289 |
+
โ โโโ model_v1/ # Logic for 'Model V1' (DeBERTa Hybrid)
|
| 290 |
+
โ โโโ model_v2/ # Logic for 'Model V2' (Explainable + Linguistic)
|
| 291 |
+
โ โโโ model_v3/ # Logic for 'Model V3 (Multimodal)'
|
| 292 |
+
โ โโโ config.py # Model configuration (weights path, model names)
|
| 293 |
+
โ โโโ model.py # Neural network architecture (TextBranch, AudioBranch, etc.)
|
| 294 |
+
โ โโโ processor.py # Preprocessing (Linguistic features, Spectrograms, ASR)
|
| 295 |
+
โ โโโ wrapper.py # The main wrapper class integrating all components
|
| 296 |
+
โโโ requirements.txt # Project dependencies
|
| 297 |
```
|
main.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
from contextlib import asynccontextmanager
|
| 4 |
-
from typing import Dict
|
| 5 |
|
| 6 |
from models.base import BaseModelWrapper
|
| 7 |
from models.model_v1.wrapper import HybridDebertaWrapper
|
| 8 |
from models.model_v2.wrapper import ModelV2Wrapper
|
|
|
|
| 9 |
|
| 10 |
AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = {
|
| 11 |
"Model V1": HybridDebertaWrapper(),
|
| 12 |
-
"Model V2":ModelV2Wrapper()
|
|
|
|
| 13 |
}
|
| 14 |
|
| 15 |
@asynccontextmanager
|
|
@@ -41,24 +43,54 @@ def list_models():
|
|
| 41 |
|
| 42 |
@app.post("/predict")
|
| 43 |
async def predict(
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
):
|
| 47 |
if model_name not in AVAILABLE_MODELS:
|
| 48 |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
try:
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return result
|
| 58 |
except ValueError as e:
|
| 59 |
raise HTTPException(status_code=400, detail=str(e))
|
| 60 |
except Exception as e:
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
@app.get("/health")
|
| 64 |
def health_check():
|
|
|
|
| 1 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
from contextlib import asynccontextmanager
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
|
| 6 |
from models.base import BaseModelWrapper
|
| 7 |
from models.model_v1.wrapper import HybridDebertaWrapper
|
| 8 |
from models.model_v2.wrapper import ModelV2Wrapper
|
| 9 |
+
from models.model_v3.wrapper import MultimodalWrapper
|
| 10 |
|
| 11 |
AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = {
|
| 12 |
"Model V1": HybridDebertaWrapper(),
|
| 13 |
+
"Model V2": ModelV2Wrapper(),
|
| 14 |
+
"Model V3 (Multimodal)": MultimodalWrapper()
|
| 15 |
}
|
| 16 |
|
| 17 |
@asynccontextmanager
|
|
|
|
| 43 |
|
| 44 |
@app.post("/predict")
|
| 45 |
async def predict(
|
| 46 |
+
model_name: str = Form(...),
|
| 47 |
+
file: Optional[UploadFile] = File(None), # Changed to Optional
|
| 48 |
+
audio_file: Optional[UploadFile] = File(None) # Added Audio Input
|
| 49 |
):
|
| 50 |
if model_name not in AVAILABLE_MODELS:
|
| 51 |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
| 52 |
+
|
| 53 |
+
# --- Validation Logic ---
|
| 54 |
+
# Models V1 and V2 REQUIRE a .cha file
|
| 55 |
+
if model_name in ["Model V1", "Model V2"]:
|
| 56 |
+
if not file or not file.filename.endswith('.cha'):
|
| 57 |
+
raise HTTPException(status_code=400, detail=f"{model_name} requires a .cha file.")
|
| 58 |
|
| 59 |
+
# Model V3 requires AT LEAST one file
|
| 60 |
+
if not file and not audio_file:
|
| 61 |
+
raise HTTPException(status_code=400, detail="Please provide a .cha file, an audio file, or both.")
|
| 62 |
+
|
| 63 |
+
# --- Read Files ---
|
| 64 |
+
text_content = b""
|
| 65 |
+
filename = "audio_only_upload" # Default if no .cha
|
| 66 |
|
| 67 |
+
if file:
|
| 68 |
+
if not file.filename.endswith('.cha'):
|
| 69 |
+
raise HTTPException(status_code=400, detail="Text file must be a .cha file.")
|
| 70 |
+
text_content = await file.read()
|
| 71 |
+
filename = file.filename
|
| 72 |
+
|
| 73 |
+
audio_content = None
|
| 74 |
+
if audio_file:
|
| 75 |
+
audio_content = await audio_file.read()
|
| 76 |
+
|
| 77 |
+
# --- Prediction ---
|
| 78 |
try:
|
| 79 |
+
# We pass audio_content to all models.
|
| 80 |
+
# Base.py now supports it, and V1/V2 wrappers (if not updated) might need a dummy arg
|
| 81 |
+
# or we rely on Python's flexible args if they inherit correctly.
|
| 82 |
+
# Ideally, update V1/V2 predict signatures to accept **kwargs or audio_content=None too.
|
| 83 |
+
result = AVAILABLE_MODELS[model_name].predict(
|
| 84 |
+
file_content=text_content,
|
| 85 |
+
filename=filename,
|
| 86 |
+
audio_content=audio_content
|
| 87 |
+
)
|
| 88 |
return result
|
| 89 |
except ValueError as e:
|
| 90 |
raise HTTPException(status_code=400, detail=str(e))
|
| 91 |
except Exception as e:
|
| 92 |
+
print(f"Prediction Error: {e}") # Log internal errors
|
| 93 |
+
raise HTTPException(status_code=500, detail="Internal Server Error")
|
| 94 |
|
| 95 |
@app.get("/health")
|
| 96 |
def health_check():
|
models/__pycache__/base.cpython-310.pyc
CHANGED
|
Binary files a/models/__pycache__/base.cpython-310.pyc and b/models/__pycache__/base.cpython-310.pyc differ
|
|
|
models/base.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
|
|
|
| 2 |
|
| 3 |
class BaseModelWrapper(ABC):
|
| 4 |
@abstractmethod
|
|
@@ -6,5 +7,5 @@ class BaseModelWrapper(ABC):
|
|
| 6 |
pass
|
| 7 |
|
| 8 |
@abstractmethod
|
| 9 |
-
def predict(self, file_content: bytes, filename: str) -> dict:
|
| 10 |
pass
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Optional
|
| 3 |
|
| 4 |
class BaseModelWrapper(ABC):
|
| 5 |
@abstractmethod
|
|
|
|
| 7 |
pass
|
| 8 |
|
| 9 |
@abstractmethod
|
| 10 |
+
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
|
| 11 |
pass
|
models/model_v1/__pycache__/wrapper.cpython-310.pyc
CHANGED
|
Binary files a/models/model_v1/__pycache__/wrapper.cpython-310.pyc and b/models/model_v1/__pycache__/wrapper.cpython-310.pyc differ
|
|
|
models/model_v3/config.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "D:/Work/7th Sem/adtrack-v2/models/model_v3/multimodal_dementia_model.pth")
|
| 5 |
+
TEXT_MODEL_NAME = "microsoft/deberta-base"
|
| 6 |
+
MAX_LEN = 128
|
| 7 |
+
WHISPER_MODEL_SIZE = "base"
|
models/model_v3/model.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
|
| 6 |
+
class TextBranch(nn.Module):
|
| 7 |
+
def __init__(self, model_name):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.bert = AutoModel.from_pretrained(model_name)
|
| 10 |
+
self.lstm = nn.LSTM(768, 128, batch_first=True, bidirectional=True)
|
| 11 |
+
self.fc = nn.Linear(256, 64)
|
| 12 |
+
|
| 13 |
+
def forward(self, input_ids, attention_mask):
|
| 14 |
+
out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 15 |
+
_, (h_n, _) = self.lstm(out.last_hidden_state)
|
| 16 |
+
context = torch.cat((h_n[-2], h_n[-1]), dim=1)
|
| 17 |
+
return self.fc(context)
|
| 18 |
+
|
| 19 |
+
class AudioBranch(nn.Module):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# Pretrained ViT
|
| 23 |
+
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
|
| 24 |
+
self.fc = nn.Linear(768, 64) # Project ViT dim to 64
|
| 25 |
+
|
| 26 |
+
def forward(self, pixel_values):
|
| 27 |
+
features = self.vit(pixel_values)
|
| 28 |
+
return self.fc(features)
|
| 29 |
+
|
| 30 |
+
class LinguisticBranch(nn.Module):
|
| 31 |
+
def __init__(self, input_dim=6):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.net = nn.Sequential(
|
| 34 |
+
nn.Linear(input_dim, 32),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Linear(32, 16)
|
| 37 |
+
)
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
return self.net(x)
|
| 40 |
+
|
| 41 |
+
class MultimodalFusion(nn.Module):
|
| 42 |
+
def __init__(self, text_model_name='microsoft/deberta-base'):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.text_branch = TextBranch(text_model_name)
|
| 45 |
+
self.audio_branch = AudioBranch()
|
| 46 |
+
self.ling_branch = LinguisticBranch(input_dim=6)
|
| 47 |
+
|
| 48 |
+
# Fusion: 64 (Text) + 64 (Audio) + 16 (Ling) = 144
|
| 49 |
+
self.classifier = nn.Sequential(
|
| 50 |
+
nn.Linear(64 + 64 + 16, 64),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.Dropout(0.5),
|
| 53 |
+
nn.Linear(64, 2)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, input_ids, attention_mask, pixel_values, ling_features):
|
| 57 |
+
text_emb = self.text_branch(input_ids, attention_mask)
|
| 58 |
+
audio_emb = self.audio_branch(pixel_values)
|
| 59 |
+
ling_emb = self.ling_branch(ling_features)
|
| 60 |
+
|
| 61 |
+
# Concat
|
| 62 |
+
combined = torch.cat((text_emb, audio_emb, ling_emb), dim=1)
|
| 63 |
+
return self.classifier(combined)
|
models/model_v3/processor.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import librosa
|
| 6 |
+
import librosa.display
|
| 7 |
+
import matplotlib
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import whisper
|
| 12 |
+
|
| 13 |
+
# Force non-interactive backend for server environments
|
| 14 |
+
matplotlib.use('Agg')
|
| 15 |
+
|
| 16 |
+
# ==========================================
|
| 17 |
+
# 1. Linguistic Feature Extractor
|
| 18 |
+
# ==========================================
|
| 19 |
+
class LinguisticFeatureExtractor:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.patterns = {
|
| 22 |
+
'fillers': re.compile(r'&-([a-z]+)', re.IGNORECASE),
|
| 23 |
+
'repetition': re.compile(r'\[/+\]'),
|
| 24 |
+
'retracing': re.compile(r'\[//\]'),
|
| 25 |
+
'incomplete': re.compile(r'\+[\./]+'),
|
| 26 |
+
'errors': re.compile(r'\[\*.*?\]'),
|
| 27 |
+
'pauses': re.compile(r'\(\.+\)')
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
def clean_for_bert(self, raw_text):
|
| 31 |
+
text = re.sub(r'^\*PAR:\s+', '', raw_text)
|
| 32 |
+
text = re.sub(r'\x15\d+_\d+\x15', '', text)
|
| 33 |
+
text = re.sub(r'<|>', '', text)
|
| 34 |
+
text = re.sub(r'\[.*?\]', '', text)
|
| 35 |
+
text = re.sub(r'\(\.+\)', '[PAUSE]', text)
|
| 36 |
+
text = text.replace('_', ' ')
|
| 37 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 38 |
+
if text.endswith('[PAUSE]'):
|
| 39 |
+
text = text[:-7].strip()
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
def get_features(self, raw_text):
|
| 43 |
+
stats = {
|
| 44 |
+
'filler_count': len(self.patterns['fillers'].findall(raw_text)),
|
| 45 |
+
'repetition_count': len(self.patterns['repetition'].findall(raw_text)),
|
| 46 |
+
'retracing_count': len(self.patterns['retracing'].findall(raw_text)),
|
| 47 |
+
'incomplete_count': len(self.patterns['incomplete'].findall(raw_text)),
|
| 48 |
+
'error_count': len(self.patterns['errors'].findall(raw_text)),
|
| 49 |
+
'pause_count': len(self.patterns['pauses'].findall(raw_text))
|
| 50 |
+
}
|
| 51 |
+
clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
|
| 52 |
+
clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
|
| 53 |
+
clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
|
| 54 |
+
words = clean_for_stats.lower().split()
|
| 55 |
+
stats['word_count'] = len(words)
|
| 56 |
+
return stats
|
| 57 |
+
|
| 58 |
+
def get_feature_vector(self, raw_text):
|
| 59 |
+
stats = self.get_features(raw_text)
|
| 60 |
+
n = stats['word_count'] if stats['word_count'] > 0 else 1
|
| 61 |
+
|
| 62 |
+
# Calculate TTR (Type-Token Ratio)
|
| 63 |
+
clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
|
| 64 |
+
clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
|
| 65 |
+
clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
|
| 66 |
+
words = clean_for_stats.lower().split()
|
| 67 |
+
ttr = (len(set(words)) / n) if n > 0 else 0.0
|
| 68 |
+
|
| 69 |
+
return np.array([
|
| 70 |
+
ttr,
|
| 71 |
+
stats['filler_count'] / n,
|
| 72 |
+
stats['repetition_count'] / n,
|
| 73 |
+
stats['retracing_count'] / n,
|
| 74 |
+
stats['error_count'] / n,
|
| 75 |
+
stats['pause_count'] / n
|
| 76 |
+
], dtype=np.float32)
|
| 77 |
+
|
| 78 |
+
# ==========================================
|
| 79 |
+
# 2. Audio Processor
|
| 80 |
+
# ==========================================
|
| 81 |
+
class AudioProcessor:
|
| 82 |
+
def __init__(self):
|
| 83 |
+
self.vit_transforms = transforms.Compose([
|
| 84 |
+
transforms.Resize((224, 224)),
|
| 85 |
+
transforms.ToTensor(),
|
| 86 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
def create_spectrogram_tensor(self, audio_path, intervals=None):
|
| 90 |
+
"""
|
| 91 |
+
Generates spectrogram image and transforms it to Tensor.
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
fig = plt.figure(figsize=(2.24, 2.24), dpi=100)
|
| 95 |
+
ax = fig.add_subplot(1, 1, 1)
|
| 96 |
+
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
| 97 |
+
|
| 98 |
+
if intervals:
|
| 99 |
+
# Load full audio then slice based on timestamps
|
| 100 |
+
y, sr = librosa.load(audio_path, sr=None)
|
| 101 |
+
clips = []
|
| 102 |
+
for start_ms, end_ms in intervals:
|
| 103 |
+
start_sample = int(start_ms * sr / 1000)
|
| 104 |
+
end_sample = int(end_ms * sr / 1000)
|
| 105 |
+
if end_sample > len(y): end_sample = len(y)
|
| 106 |
+
if start_sample < len(y):
|
| 107 |
+
clips.append(y[start_sample:end_sample])
|
| 108 |
+
if clips:
|
| 109 |
+
y = np.concatenate(clips)
|
| 110 |
+
else:
|
| 111 |
+
y = np.zeros(int(sr*30))
|
| 112 |
+
|
| 113 |
+
# Limit to 30s
|
| 114 |
+
if len(y) > 30 * sr:
|
| 115 |
+
y = y[:30 * sr]
|
| 116 |
+
else:
|
| 117 |
+
y, sr = librosa.load(audio_path, duration=30)
|
| 118 |
+
|
| 119 |
+
ms = librosa.feature.melspectrogram(y=y, sr=sr)
|
| 120 |
+
log_ms = librosa.power_to_db(ms, ref=np.max)
|
| 121 |
+
librosa.display.specshow(log_ms, sr=sr, ax=ax)
|
| 122 |
+
|
| 123 |
+
# Save to buffer instead of file
|
| 124 |
+
from io import BytesIO
|
| 125 |
+
buf = BytesIO()
|
| 126 |
+
fig.savefig(buf, format='png')
|
| 127 |
+
plt.close(fig)
|
| 128 |
+
buf.seek(0)
|
| 129 |
+
|
| 130 |
+
image = Image.open(buf).convert('RGB')
|
| 131 |
+
return self.vit_transforms(image).unsqueeze(0)
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Spectrogram creation failed: {e}")
|
| 135 |
+
return torch.zeros((1, 3, 224, 224))
|
| 136 |
+
|
| 137 |
+
# ==========================================
|
| 138 |
+
# 3. ASR Helper (Whisper + CHAT Rules)
|
| 139 |
+
# ==========================================
|
| 140 |
+
def apply_chat_rules(transcription_result):
|
| 141 |
+
"""
|
| 142 |
+
Converts Whisper result into CHAT-like format AND inserts [PAUSE] tokens.
|
| 143 |
+
"""
|
| 144 |
+
formatted_text = []
|
| 145 |
+
segments = transcription_result.get('segments', [])
|
| 146 |
+
last_end = 0
|
| 147 |
+
|
| 148 |
+
for seg in segments:
|
| 149 |
+
gap = seg['start'] - last_end
|
| 150 |
+
# Insert [PAUSE] token + CHAT marker
|
| 151 |
+
if gap > 0.8:
|
| 152 |
+
formatted_text.append("[PAUSE] (..)")
|
| 153 |
+
elif gap > 0.3:
|
| 154 |
+
formatted_text.append("[PAUSE] (.)")
|
| 155 |
+
|
| 156 |
+
text = seg['text'].strip()
|
| 157 |
+
|
| 158 |
+
# Repetitions (Basic Detection)
|
| 159 |
+
words = text.split()
|
| 160 |
+
processed_words = []
|
| 161 |
+
for i, w in enumerate(words):
|
| 162 |
+
clean_w = re.sub(r'[^a-zA-Z]', '', w.lower())
|
| 163 |
+
if i > 0:
|
| 164 |
+
prev_clean = re.sub(r'[^a-zA-Z]', '', words[i-1].lower())
|
| 165 |
+
if clean_w == prev_clean and clean_w:
|
| 166 |
+
processed_words[-1] = f"{words[i-1]} [/]"
|
| 167 |
+
processed_words.append(w)
|
| 168 |
+
|
| 169 |
+
formatted_text.append(" ".join(processed_words))
|
| 170 |
+
last_end = seg['end']
|
| 171 |
+
|
| 172 |
+
return " ".join(formatted_text)
|
models/model_v3/wrapper.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import whisper
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
from models.base import BaseModelWrapper
|
| 11 |
+
from .model import MultimodalFusion
|
| 12 |
+
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules
|
| 13 |
+
from .config import WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN
|
| 14 |
+
|
| 15 |
+
class MultimodalWrapper(BaseModelWrapper):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
self.model = None
|
| 19 |
+
self.tokenizer = None
|
| 20 |
+
self.asr_model = None
|
| 21 |
+
self.ling_extractor = LinguisticFeatureExtractor()
|
| 22 |
+
self.audio_processor = AudioProcessor()
|
| 23 |
+
|
| 24 |
+
def load(self):
|
| 25 |
+
print("Loading Model V3 components...")
|
| 26 |
+
self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 27 |
+
self.model = MultimodalFusion(TEXT_MODEL_NAME)
|
| 28 |
+
|
| 29 |
+
# Load Weights
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
state_dict = torch.load(WEIGHTS_PATH)
|
| 32 |
+
else:
|
| 33 |
+
state_dict = torch.load(WEIGHTS_PATH, map_location=torch.device('cpu'))
|
| 34 |
+
self.model.load_state_dict(state_dict)
|
| 35 |
+
self.model.to(self.device)
|
| 36 |
+
self.model.eval()
|
| 37 |
+
|
| 38 |
+
# Load Whisper (Base model as per notebook)
|
| 39 |
+
print("Loading Whisper for Audio-Only Inference...")
|
| 40 |
+
self.asr_model = whisper.load_model("base")
|
| 41 |
+
|
| 42 |
+
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
|
| 43 |
+
"""
|
| 44 |
+
Handles 3 scenarios:
|
| 45 |
+
1. CHA only: file_content is CHA.
|
| 46 |
+
2. CHA + Audio: file_content is CHA, audio_content is Audio.
|
| 47 |
+
3. Audio only: file_content is likely empty/dummy, audio_content is Audio.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# Determine Scenario
|
| 51 |
+
is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
|
| 52 |
+
has_audio = audio_content is not None and len(audio_content) > 0
|
| 53 |
+
|
| 54 |
+
processed_text = ""
|
| 55 |
+
ling_features = None
|
| 56 |
+
audio_tensor = None
|
| 57 |
+
intervals = []
|
| 58 |
+
|
| 59 |
+
# --- SCENARIO 3: PURE AUDIO (New file, generate transcript) ---
|
| 60 |
+
if not is_cha_provided and has_audio:
|
| 61 |
+
print("Processing Mode: Audio Only (ASR)")
|
| 62 |
+
|
| 63 |
+
# Save audio to temp file for Whisper/Librosa
|
| 64 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
|
| 65 |
+
tmp_audio.write(audio_content)
|
| 66 |
+
tmp_path = tmp_audio.name
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# 1. Transcribe
|
| 70 |
+
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
|
| 71 |
+
# 2. Apply Rules
|
| 72 |
+
chat_transcript = apply_chat_rules(result)
|
| 73 |
+
processed_text = chat_transcript # No BERT cleaning needed on Whisper output usually, or minimal
|
| 74 |
+
|
| 75 |
+
# 3. Extract Features from generated text
|
| 76 |
+
# We need to manually calculating stats like the ASR notebook section does
|
| 77 |
+
# because the ASR output doesn't have the exact same format as raw CHA
|
| 78 |
+
stats = self.ling_extractor.get_features(chat_transcript)
|
| 79 |
+
pause_count = chat_transcript.count("[PAUSE]")
|
| 80 |
+
repetition_count = chat_transcript.count("[/]")
|
| 81 |
+
|
| 82 |
+
# TTR Calc
|
| 83 |
+
clean_t = re.sub(r'\[.*?\]', '', chat_transcript)
|
| 84 |
+
clean_t = re.sub(r'[^\w\s]', '', clean_t)
|
| 85 |
+
words = clean_t.lower().split()
|
| 86 |
+
n = len(words) if len(words) > 0 else 1
|
| 87 |
+
ttr = len(set(words)) / n
|
| 88 |
+
|
| 89 |
+
ling_vec = [
|
| 90 |
+
ttr,
|
| 91 |
+
stats['filler_count'] / n,
|
| 92 |
+
repetition_count / n,
|
| 93 |
+
stats['retracing_count'] / n,
|
| 94 |
+
stats['error_count'] / n,
|
| 95 |
+
pause_count / n
|
| 96 |
+
]
|
| 97 |
+
ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0)
|
| 98 |
+
|
| 99 |
+
# 4. Generate Spectrogram (Whole file, no intervals)
|
| 100 |
+
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
|
| 101 |
+
|
| 102 |
+
finally:
|
| 103 |
+
os.remove(tmp_path)
|
| 104 |
+
|
| 105 |
+
# --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
|
| 106 |
+
else:
|
| 107 |
+
# Parse Text from CHA
|
| 108 |
+
text_str = file_content.decode('utf-8', errors='replace')
|
| 109 |
+
par_lines = []
|
| 110 |
+
|
| 111 |
+
# Regex to find timestamps: 123_456
|
| 112 |
+
# Matches functionality in 'load_and_process_data' -> 'process_dir'
|
| 113 |
+
full_text_for_intervals = ""
|
| 114 |
+
|
| 115 |
+
for line in text_str.splitlines():
|
| 116 |
+
if line.startswith('*PAR:'):
|
| 117 |
+
content = line[5:].strip()
|
| 118 |
+
par_lines.append(content)
|
| 119 |
+
full_text_for_intervals += content + " "
|
| 120 |
+
|
| 121 |
+
raw_text = " ".join(par_lines)
|
| 122 |
+
processed_text = self.ling_extractor.clean_for_bert(raw_text)
|
| 123 |
+
|
| 124 |
+
# Extract Features
|
| 125 |
+
feats = self.ling_extractor.get_feature_vector(raw_text)
|
| 126 |
+
ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
# --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
|
| 129 |
+
if has_audio:
|
| 130 |
+
print("Processing Mode: CHA + Audio (Segmentation)")
|
| 131 |
+
# Extract intervals from the raw text (containing the bullets)
|
| 132 |
+
# Notebook regex: re.findall(r'\x15(\d+)_(\d+)\x15', text_content)
|
| 133 |
+
found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
|
| 134 |
+
intervals = [(int(s), int(e)) for s, e in found_intervals]
|
| 135 |
+
|
| 136 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
|
| 137 |
+
tmp_audio.write(audio_content)
|
| 138 |
+
tmp_path = tmp_audio.name
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
# Pass intervals to slice specific PAR audio
|
| 142 |
+
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
|
| 143 |
+
finally:
|
| 144 |
+
os.remove(tmp_path)
|
| 145 |
+
|
| 146 |
+
# --- SCENARIO 1: CHA ONLY ---
|
| 147 |
+
else:
|
| 148 |
+
print("Processing Mode: CHA Only")
|
| 149 |
+
audio_tensor = torch.zeros((1, 3, 224, 224))
|
| 150 |
+
|
| 151 |
+
# --- COMMON INFERENCE STEPS ---
|
| 152 |
+
encoding = self.tokenizer.encode_plus(
|
| 153 |
+
processed_text,
|
| 154 |
+
add_special_tokens=True,
|
| 155 |
+
max_length=MAX_LEN,
|
| 156 |
+
padding='max_length',
|
| 157 |
+
truncation=True,
|
| 158 |
+
return_attention_mask=True,
|
| 159 |
+
return_tensors='pt'
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
input_ids = encoding['input_ids'].to(self.device)
|
| 164 |
+
mask = encoding['attention_mask'].to(self.device)
|
| 165 |
+
pixel_values = audio_tensor.to(self.device)
|
| 166 |
+
ling_input = ling_features.to(self.device)
|
| 167 |
+
|
| 168 |
+
outputs = self.model(input_ids, mask, pixel_values, ling_input)
|
| 169 |
+
probs = F.softmax(outputs, dim=1)
|
| 170 |
+
pred_idx = torch.argmax(probs, dim=1).item()
|
| 171 |
+
confidence = probs[0][pred_idx].item()
|
| 172 |
+
|
| 173 |
+
label_map = {0: 'Control', 1: 'AD'}
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"model_version": "v3_multimodal",
|
| 177 |
+
"filename": filename if filename else "audio_upload",
|
| 178 |
+
"predicted_label": label_map[pred_idx],
|
| 179 |
+
"confidence": round(confidence, 4),
|
| 180 |
+
"modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
|
| 181 |
+
"generated_transcript": processed_text if not is_cha_provided else None
|
| 182 |
+
}
|