cracker0935 commited on
Commit
d5f8ae0
ยท
1 Parent(s): 4b061f6

base modelv3

Browse files
README.md CHANGED
@@ -9,21 +9,24 @@ pinned: false
9
 
10
  # Alzheimer's Detection Backend API
11
 
12
- This repository contains a FastAPI-based backend for detecting Alzheimer's disease from linguistic data (transcribed speech files in `.cha` format). It supports multiple underlying machine learning models, each providing unique insights and output formats.
13
 
14
  ## ๐Ÿš€ Features
15
 
16
  - **FastAPI Framework**: High-performance, easy-to-use API.
17
  - **Support for .cha Files**: Specialized parsing for CHAT format transcripts.
 
18
  - **Multiple AI Models**:
19
- - `hybrid_deberta`: A DeBERTa-based hybrid model focusing on semantic understanding.
20
- - `model_v2`: An explainable model incorporating linguistic features (TTR, fillers, pauses, etc.) along with deep learning.
21
- - **CORS Support**: Configured to allow requests from the frontend application (`https://adtrack.onrender.com`).
 
22
 
23
  ## ๐Ÿ› ๏ธ Prerequisites
24
 
25
  - **Python 3.8+**
26
  - **pip** package manager
 
27
 
28
  ## ๐Ÿ“ฅ Installation
29
 
@@ -61,6 +64,8 @@ uvicorn main:app --reload
61
 
62
  The server will start at `http://127.0.0.1:8000`.
63
 
 
 
64
  ## ๐Ÿ“– API Documentation
65
 
66
  ### 1. Health Check
@@ -72,7 +77,7 @@ Checks if the API is active and lists loaded models.
72
  ```json
73
  {
74
  "status": "active",
75
- "loaded_models": ["hybrid_deberta", "model_v2"]
76
  }
77
  ```
78
 
@@ -86,7 +91,7 @@ Returns a list of all available model keys that can be used for prediction.
86
  **Response:**
87
  ```json
88
  {
89
- "models": ["hybrid_deberta", "model_v2"]
90
  }
91
  ```
92
 
@@ -95,58 +100,174 @@ Returns a list of all available model keys that can be used for prediction.
95
  ### 3. Predict / Analyze
96
  **Endpoint:** `POST /predict`
97
 
98
- Uploads a `.cha` file and processes it using the specified model.
 
 
99
 
100
  **Parameters:**
101
- - `file`: The `.cha` transcript file (Form Data).
102
- - `model_name`: The key of the model to use (e.g., `hybrid_deberta` or `model_v2`) (Form Data).
103
 
104
- #### **Model Response Formats**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- Each model returns results in a simplified structure tailored to its architecture.
 
 
 
 
 
 
 
 
 
107
 
108
- #### **A. `hybrid_deberta` Output**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  Focuses on sequence classification and attention scores for sentences.
110
 
111
  ```json
112
  {
113
  "filename": "sample.cha",
114
- "prediction": "DEMENTIA", // "DEMENTIA" or "HEALTHY CONTROL"
115
- "confidence": 0.85, // Probability score (0.0 - 1.0)
116
- "is_dementia": true, // Boolean flag
117
- "attention_map": [ // Token/Sentence level attention weights
118
  {
119
  "sentence": "I saw the cookie jar.",
120
  "attention_score": 0.92
121
- },
122
- ...
123
  ],
124
  "model_used": "hybrid_deberta"
125
  }
126
  ```
127
 
128
- #### **B. `model_v2` Output**
129
  Provides a rich set of metadata and linguistic features for explainability.
130
 
131
  ```json
132
  {
133
  "filename": "sample.cha",
134
- "prediction": "Dementia", // "Dementia" or "Control"
135
- "probability_dementia": 0.78, // Probability score (0.0 - 1.0)
136
  "metadata": {
137
  "age": 72,
138
  "gender": "Female",
139
  "sentence_count": 15
140
  },
141
- "linguistic_features": { // Extracted linguistic metrics
142
- "TTR": 0.45, // Type-Token Ratio
143
  "fillers_ratio": 0.05,
144
  "repetitions_ratio": 0.02,
145
  "retracing_ratio": 0.01,
146
  "incomplete_ratio": 0.03,
147
  "pauses_ratio": 0.12
148
  },
149
- "key_segments": [ // Top sentences contributing to the decision
150
  {
151
  "text": "Um... checking the... the overflowing water.",
152
  "importance": 0.88
@@ -156,14 +277,21 @@ Provides a rich set of metadata and linguistic features for explainability.
156
  }
157
  ```
158
 
 
 
159
  ## ๐Ÿ“‚ Project Structure
160
 
161
  ```
162
  .
163
- โ”œโ”€โ”€ main.py # Entry point, API routes, and CORS config
164
- โ”œโ”€โ”€ models/ # Model definitions and wrappers
165
- โ”‚ โ”œโ”€โ”€ base.py # Base class for model wrappers
166
- โ”‚ โ”œโ”€โ”€ model_v1/ # Logic for 'hybrid_deberta'
167
- โ”‚ โ””โ”€โ”€ model_v2/ # Logic for 'model_v2' (Explainable + Linguistic features)
168
- โ””โ”€โ”€ requirements.txt # Project dependencies
 
 
 
 
 
169
  ```
 
9
 
10
  # Alzheimer's Detection Backend API
11
 
12
+ This repository contains a FastAPI-based backend for detecting Alzheimer's disease from linguistic data. It supports multiple machine learning models, including text-only analysis from `.cha` transcripts and a **multimodal model (V3)** that can process both text and audio.
13
 
14
  ## ๐Ÿš€ Features
15
 
16
  - **FastAPI Framework**: High-performance, easy-to-use API.
17
  - **Support for .cha Files**: Specialized parsing for CHAT format transcripts.
18
+ - **Multimodal Audio Support (V3)**: Process raw audio files with Automatic Speech Recognition (ASR).
19
  - **Multiple AI Models**:
20
+ - `Model V1`: A DeBERTa-based hybrid model focusing on semantic understanding.
21
+ - `Model V2`: An explainable model with rich linguistic features (TTR, fillers, pauses, etc.).
22
+ - `Model V3 (Multimodal)`: A multimodal fusion model combining text, audio spectrograms, and linguistic features.
23
+ - **CORS Support**: Configured to allow requests from frontend applications.
24
 
25
  ## ๐Ÿ› ๏ธ Prerequisites
26
 
27
  - **Python 3.8+**
28
  - **pip** package manager
29
+ - **FFmpeg**: Required for audio processing in Model V3.
30
 
31
  ## ๐Ÿ“ฅ Installation
32
 
 
64
 
65
  The server will start at `http://127.0.0.1:8000`.
66
 
67
+ ---
68
+
69
  ## ๐Ÿ“– API Documentation
70
 
71
  ### 1. Health Check
 
77
  ```json
78
  {
79
  "status": "active",
80
+ "loaded_models": ["Model V1", "Model V2", "Model V3 (Multimodal)"]
81
  }
82
  ```
83
 
 
91
  **Response:**
92
  ```json
93
  {
94
+ "models": ["Model V1", "Model V2", "Model V3 (Multimodal)"]
95
  }
96
  ```
97
 
 
100
  ### 3. Predict / Analyze
101
  **Endpoint:** `POST /predict`
102
 
103
+ Uploads files and processes them using the specified model.
104
+
105
+ **Request Type:** `multipart/form-data`
106
 
107
  **Parameters:**
 
 
108
 
109
+ | Parameter | Type | Required | Description |
110
+ |---------------|----------------|----------|-----------------------------------------------------------------------------|
111
+ | `model_name` | `string` | **Yes** | The key of the model (e.g., `Model V1`, `Model V2`, `Model V3 (Multimodal)`)|
112
+ | `file` | `file (.cha)` | Depends | The CHAT format transcript file. |
113
+ | `audio_file` | `file (audio)` | No | An audio file (e.g., `.wav`, `.mp3`). **Only for Model V3.** |
114
+
115
+ #### **Input Validation Rules**
116
+
117
+ | Model | `file` (.cha) | `audio_file` | Notes |
118
+ |-------------------------|---------------|--------------|-------------------------------------------------------|
119
+ | `Model V1` | **Required** | Ignored | Text-only model. |
120
+ | `Model V2` | **Required** | Ignored | Text-only model. |
121
+ | `Model V3 (Multimodal)` | Optional | Optional | At least one file must be provided. Supports 3 modes. |
122
+
123
+ ---
124
+
125
+ ## ๐Ÿง  Model V3 (Multimodal) - Deep Dive
126
+
127
+ Model V3 is a **multimodal fusion model** that combines three branches of information for its predictions:
128
+
129
+ 1. **Text Branch**: Uses a DeBERTa transformer with an LSTM layer to encode textual semantics.
130
+ 2. **Audio Branch**: Uses a Vision Transformer (ViT) trained on spectrograms derived from the audio.
131
+ 3. **Linguistic Branch**: A simple feedforward network processing extracted linguistic features (TTR, filler ratio, pause ratio, etc.).
132
+
133
+ ### Processing Modes
134
+
135
+ Model V3 intelligently handles three different input scenarios:
136
+
137
+ #### **Mode 1: CHA File Only**
138
+ - **Input:** A `.cha` transcript file.
139
+ - **Process:**
140
+ 1. Parses `*PAR:` (participant) lines from the CHA file.
141
+ 2. Cleans the text for the DeBERTa model.
142
+ 3. Extracts a 6-dimensional linguistic feature vector (TTR, fillers, repetitions, retracing, errors, pauses).
143
+ 4. **Audio branch receives a zero-tensor** (no audio input).
144
+ - **Use Case:** When you have a pre-existing transcript and no audio.
145
+
146
+ #### **Mode 2: CHA File + Audio (Segmented)**
147
+ - **Input:** A `.cha` transcript file AND an audio file.
148
+ - **Process:**
149
+ 1. Parses the CHA file for text and linguistic features (same as Mode 1).
150
+ 2. Extracts timestamps (e.g., `15123_456`) from the CHA file.
151
+ 3. Uses these timestamps to **slice the corresponding audio segments** from the full audio file.
152
+ 4. Concatenates the slices and generates a spectrogram.
153
+ 5. Passes the spectrogram to the ViT-based audio branch.
154
+ - **Use Case:** For maximum accuracy when you have a professionally transcribed CHA file that is time-aligned with its source audio.
155
+
156
+ #### **Mode 3: Audio Only (ASR)**
157
+ - **Input:** An audio file only (no `.cha` file).
158
+ - **Process:**
159
+ 1. Uses OpenAI's **Whisper** model to transcribe the audio.
160
+ 2. Applies CHAT-like formatting rules to the transcript:
161
+ - Detects pauses and inserts `[PAUSE]` tokens.
162
+ - Detects word repetitions and inserts `[/]` markers.
163
+ 3. Extracts linguistic features from the generated transcript.
164
+ 4. Generates a spectrogram from the **full audio file** (up to 30s).
165
+ - **Use Case:** For real-world inference when you only have raw audio (e.g., a voice recording).
166
+
167
+ ---
168
+
169
+ ### Model V3 Response Format
170
+
171
+ ```json
172
+ {
173
+ "model_version": "v3_multimodal",
174
+ "filename": "sample.cha",
175
+ "predicted_label": "AD", // "AD" (Alzheimer's Disease) or "Control"
176
+ "confidence": 0.8721, // Probability score (0.0 - 1.0)
177
+ "modalities_used": ["text", "linguistic", "audio"],
178
+ "generated_transcript": null // Populated only in Audio-Only mode (Mode 3)
179
+ }
180
+ ```
181
+
182
+ **Response Fields:**
183
 
184
+ | Field | Type | Description |
185
+ |------------------------|-----------------|------------------------------------------------------------------------------------------------------|
186
+ | `model_version` | `string` | Always `"v3_multimodal"` for this model. |
187
+ | `filename` | `string` | Name of the uploaded file, or `"audio_only_upload"` if no CHA file was provided. |
188
+ | `predicted_label` | `string` | The classification result: `"AD"` or `"Control"`. |
189
+ | `confidence` | `float` | The model's confidence score for the predicted label. |
190
+ | `modalities_used` | `array[string]` | Lists the modalities used (`"text"`, `"linguistic"`, `"audio"`). |
191
+ | `generated_transcript` | `string \| null`| The transcript generated by Whisper. **Only populated in Audio-Only mode (Mode 3)**, otherwise `null`.|
192
+
193
+ ---
194
 
195
+ ## Example API Requests (cURL)
196
+
197
+ ### Model V1 / V2 (CHA File Only)
198
+ ```bash
199
+ curl -X POST "http://127.0.0.1:8000/predict" \
200
+ -F "model_name=Model V1" \
201
+ -F "file=@/path/to/your/transcript.cha"
202
+ ```
203
+
204
+ ### Model V3: CHA Only
205
+ ```bash
206
+ curl -X POST "http://127.0.0.1:8000/predict" \
207
+ -F "model_name=Model V3 (Multimodal)" \
208
+ -F "file=@/path/to/your/transcript.cha"
209
+ ```
210
+
211
+ ### Model V3: CHA + Audio
212
+ ```bash
213
+ curl -X POST "http://127.0.0.1:8000/predict" \
214
+ -F "model_name=Model V3 (Multimodal)" \
215
+ -F "file=@/path/to/your/transcript.cha" \
216
+ -F "audio_file=@/path/to/your/audio.wav"
217
+ ```
218
+
219
+ ### Model V3: Audio Only (ASR)
220
+ ```bash
221
+ curl -X POST "http://127.0.0.1:8000/predict" \
222
+ -F "model_name=Model V3 (Multimodal)" \
223
+ -F "audio_file=@/path/to/your/audio.wav"
224
+ ```
225
+
226
+ ---
227
+
228
+ ## Older Model Response Formats
229
+
230
+ ### **A. `Model V1` Output**
231
  Focuses on sequence classification and attention scores for sentences.
232
 
233
  ```json
234
  {
235
  "filename": "sample.cha",
236
+ "prediction": "DEMENTIA",
237
+ "confidence": 0.85,
238
+ "is_dementia": true,
239
+ "attention_map": [
240
  {
241
  "sentence": "I saw the cookie jar.",
242
  "attention_score": 0.92
243
+ }
 
244
  ],
245
  "model_used": "hybrid_deberta"
246
  }
247
  ```
248
 
249
+ ### **B. `Model V2` Output**
250
  Provides a rich set of metadata and linguistic features for explainability.
251
 
252
  ```json
253
  {
254
  "filename": "sample.cha",
255
+ "prediction": "Dementia",
256
+ "probability_dementia": 0.78,
257
  "metadata": {
258
  "age": 72,
259
  "gender": "Female",
260
  "sentence_count": 15
261
  },
262
+ "linguistic_features": {
263
+ "TTR": 0.45,
264
  "fillers_ratio": 0.05,
265
  "repetitions_ratio": 0.02,
266
  "retracing_ratio": 0.01,
267
  "incomplete_ratio": 0.03,
268
  "pauses_ratio": 0.12
269
  },
270
+ "key_segments": [
271
  {
272
  "text": "Um... checking the... the overflowing water.",
273
  "importance": 0.88
 
277
  }
278
  ```
279
 
280
+ ---
281
+
282
  ## ๐Ÿ“‚ Project Structure
283
 
284
  ```
285
  .
286
+ โ”œโ”€โ”€ main.py # Entry point, API routes, and CORS config
287
+ โ”œโ”€โ”€ models/ # Model definitions and wrappers
288
+ โ”‚ โ”œโ”€โ”€ base.py # Base class for model wrappers
289
+ โ”‚ โ”œโ”€โ”€ model_v1/ # Logic for 'Model V1' (DeBERTa Hybrid)
290
+ โ”‚ โ”œโ”€โ”€ model_v2/ # Logic for 'Model V2' (Explainable + Linguistic)
291
+ โ”‚ โ””โ”€โ”€ model_v3/ # Logic for 'Model V3 (Multimodal)'
292
+ โ”‚ โ”œโ”€โ”€ config.py # Model configuration (weights path, model names)
293
+ โ”‚ โ”œโ”€โ”€ model.py # Neural network architecture (TextBranch, AudioBranch, etc.)
294
+ โ”‚ โ”œโ”€โ”€ processor.py # Preprocessing (Linguistic features, Spectrograms, ASR)
295
+ โ”‚ โ””โ”€โ”€ wrapper.py # The main wrapper class integrating all components
296
+ โ””โ”€โ”€ requirements.txt # Project dependencies
297
  ```
main.py CHANGED
@@ -1,15 +1,17 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from contextlib import asynccontextmanager
4
- from typing import Dict
5
 
6
  from models.base import BaseModelWrapper
7
  from models.model_v1.wrapper import HybridDebertaWrapper
8
  from models.model_v2.wrapper import ModelV2Wrapper
 
9
 
10
  AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = {
11
  "Model V1": HybridDebertaWrapper(),
12
- "Model V2":ModelV2Wrapper()
 
13
  }
14
 
15
  @asynccontextmanager
@@ -41,24 +43,54 @@ def list_models():
41
 
42
  @app.post("/predict")
43
  async def predict(
44
- file: UploadFile = File(...),
45
- model_name: str = Form(...)
 
46
  ):
47
  if model_name not in AVAILABLE_MODELS:
48
  raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
 
 
 
 
 
 
49
 
50
- if not file.filename.endswith('.cha'):
51
- raise HTTPException(status_code=400, detail="Only .cha files are supported")
52
-
53
- contents = await file.read()
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
- result = AVAILABLE_MODELS[model_name].predict(contents, file.filename)
 
 
 
 
 
 
 
 
57
  return result
58
  except ValueError as e:
59
  raise HTTPException(status_code=400, detail=str(e))
60
  except Exception as e:
61
- raise HTTPException(status_code=500, detail=str(e))
 
62
 
63
  @app.get("/health")
64
  def health_check():
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from contextlib import asynccontextmanager
4
+ from typing import Dict, Optional
5
 
6
  from models.base import BaseModelWrapper
7
  from models.model_v1.wrapper import HybridDebertaWrapper
8
  from models.model_v2.wrapper import ModelV2Wrapper
9
+ from models.model_v3.wrapper import MultimodalWrapper
10
 
11
  AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = {
12
  "Model V1": HybridDebertaWrapper(),
13
+ "Model V2": ModelV2Wrapper(),
14
+ "Model V3 (Multimodal)": MultimodalWrapper()
15
  }
16
 
17
  @asynccontextmanager
 
43
 
44
  @app.post("/predict")
45
  async def predict(
46
+ model_name: str = Form(...),
47
+ file: Optional[UploadFile] = File(None), # Changed to Optional
48
+ audio_file: Optional[UploadFile] = File(None) # Added Audio Input
49
  ):
50
  if model_name not in AVAILABLE_MODELS:
51
  raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
52
+
53
+ # --- Validation Logic ---
54
+ # Models V1 and V2 REQUIRE a .cha file
55
+ if model_name in ["Model V1", "Model V2"]:
56
+ if not file or not file.filename.endswith('.cha'):
57
+ raise HTTPException(status_code=400, detail=f"{model_name} requires a .cha file.")
58
 
59
+ # Model V3 requires AT LEAST one file
60
+ if not file and not audio_file:
61
+ raise HTTPException(status_code=400, detail="Please provide a .cha file, an audio file, or both.")
62
+
63
+ # --- Read Files ---
64
+ text_content = b""
65
+ filename = "audio_only_upload" # Default if no .cha
66
 
67
+ if file:
68
+ if not file.filename.endswith('.cha'):
69
+ raise HTTPException(status_code=400, detail="Text file must be a .cha file.")
70
+ text_content = await file.read()
71
+ filename = file.filename
72
+
73
+ audio_content = None
74
+ if audio_file:
75
+ audio_content = await audio_file.read()
76
+
77
+ # --- Prediction ---
78
  try:
79
+ # We pass audio_content to all models.
80
+ # Base.py now supports it, and V1/V2 wrappers (if not updated) might need a dummy arg
81
+ # or we rely on Python's flexible args if they inherit correctly.
82
+ # Ideally, update V1/V2 predict signatures to accept **kwargs or audio_content=None too.
83
+ result = AVAILABLE_MODELS[model_name].predict(
84
+ file_content=text_content,
85
+ filename=filename,
86
+ audio_content=audio_content
87
+ )
88
  return result
89
  except ValueError as e:
90
  raise HTTPException(status_code=400, detail=str(e))
91
  except Exception as e:
92
+ print(f"Prediction Error: {e}") # Log internal errors
93
+ raise HTTPException(status_code=500, detail="Internal Server Error")
94
 
95
  @app.get("/health")
96
  def health_check():
models/__pycache__/base.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/base.cpython-310.pyc and b/models/__pycache__/base.cpython-310.pyc differ
 
models/base.py CHANGED
@@ -1,4 +1,5 @@
1
  from abc import ABC, abstractmethod
 
2
 
3
  class BaseModelWrapper(ABC):
4
  @abstractmethod
@@ -6,5 +7,5 @@ class BaseModelWrapper(ABC):
6
  pass
7
 
8
  @abstractmethod
9
- def predict(self, file_content: bytes, filename: str) -> dict:
10
  pass
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Optional
3
 
4
  class BaseModelWrapper(ABC):
5
  @abstractmethod
 
7
  pass
8
 
9
  @abstractmethod
10
+ def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
11
  pass
models/model_v1/__pycache__/wrapper.cpython-310.pyc CHANGED
Binary files a/models/model_v1/__pycache__/wrapper.cpython-310.pyc and b/models/model_v1/__pycache__/wrapper.cpython-310.pyc differ
 
models/model_v3/config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "D:/Work/7th Sem/adtrack-v2/models/model_v3/multimodal_dementia_model.pth")
5
+ TEXT_MODEL_NAME = "microsoft/deberta-base"
6
+ MAX_LEN = 128
7
+ WHISPER_MODEL_SIZE = "base"
models/model_v3/model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ from transformers import AutoModel
5
+
6
+ class TextBranch(nn.Module):
7
+ def __init__(self, model_name):
8
+ super().__init__()
9
+ self.bert = AutoModel.from_pretrained(model_name)
10
+ self.lstm = nn.LSTM(768, 128, batch_first=True, bidirectional=True)
11
+ self.fc = nn.Linear(256, 64)
12
+
13
+ def forward(self, input_ids, attention_mask):
14
+ out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
15
+ _, (h_n, _) = self.lstm(out.last_hidden_state)
16
+ context = torch.cat((h_n[-2], h_n[-1]), dim=1)
17
+ return self.fc(context)
18
+
19
+ class AudioBranch(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+ # Pretrained ViT
23
+ self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
24
+ self.fc = nn.Linear(768, 64) # Project ViT dim to 64
25
+
26
+ def forward(self, pixel_values):
27
+ features = self.vit(pixel_values)
28
+ return self.fc(features)
29
+
30
+ class LinguisticBranch(nn.Module):
31
+ def __init__(self, input_dim=6):
32
+ super().__init__()
33
+ self.net = nn.Sequential(
34
+ nn.Linear(input_dim, 32),
35
+ nn.ReLU(),
36
+ nn.Linear(32, 16)
37
+ )
38
+ def forward(self, x):
39
+ return self.net(x)
40
+
41
+ class MultimodalFusion(nn.Module):
42
+ def __init__(self, text_model_name='microsoft/deberta-base'):
43
+ super().__init__()
44
+ self.text_branch = TextBranch(text_model_name)
45
+ self.audio_branch = AudioBranch()
46
+ self.ling_branch = LinguisticBranch(input_dim=6)
47
+
48
+ # Fusion: 64 (Text) + 64 (Audio) + 16 (Ling) = 144
49
+ self.classifier = nn.Sequential(
50
+ nn.Linear(64 + 64 + 16, 64),
51
+ nn.ReLU(),
52
+ nn.Dropout(0.5),
53
+ nn.Linear(64, 2)
54
+ )
55
+
56
+ def forward(self, input_ids, attention_mask, pixel_values, ling_features):
57
+ text_emb = self.text_branch(input_ids, attention_mask)
58
+ audio_emb = self.audio_branch(pixel_values)
59
+ ling_emb = self.ling_branch(ling_features)
60
+
61
+ # Concat
62
+ combined = torch.cat((text_emb, audio_emb, ling_emb), dim=1)
63
+ return self.classifier(combined)
models/model_v3/processor.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import librosa
6
+ import librosa.display
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ import whisper
12
+
13
+ # Force non-interactive backend for server environments
14
+ matplotlib.use('Agg')
15
+
16
+ # ==========================================
17
+ # 1. Linguistic Feature Extractor
18
+ # ==========================================
19
+ class LinguisticFeatureExtractor:
20
+ def __init__(self):
21
+ self.patterns = {
22
+ 'fillers': re.compile(r'&-([a-z]+)', re.IGNORECASE),
23
+ 'repetition': re.compile(r'\[/+\]'),
24
+ 'retracing': re.compile(r'\[//\]'),
25
+ 'incomplete': re.compile(r'\+[\./]+'),
26
+ 'errors': re.compile(r'\[\*.*?\]'),
27
+ 'pauses': re.compile(r'\(\.+\)')
28
+ }
29
+
30
+ def clean_for_bert(self, raw_text):
31
+ text = re.sub(r'^\*PAR:\s+', '', raw_text)
32
+ text = re.sub(r'\x15\d+_\d+\x15', '', text)
33
+ text = re.sub(r'<|>', '', text)
34
+ text = re.sub(r'\[.*?\]', '', text)
35
+ text = re.sub(r'\(\.+\)', '[PAUSE]', text)
36
+ text = text.replace('_', ' ')
37
+ text = re.sub(r'\s+', ' ', text).strip()
38
+ if text.endswith('[PAUSE]'):
39
+ text = text[:-7].strip()
40
+ return text
41
+
42
+ def get_features(self, raw_text):
43
+ stats = {
44
+ 'filler_count': len(self.patterns['fillers'].findall(raw_text)),
45
+ 'repetition_count': len(self.patterns['repetition'].findall(raw_text)),
46
+ 'retracing_count': len(self.patterns['retracing'].findall(raw_text)),
47
+ 'incomplete_count': len(self.patterns['incomplete'].findall(raw_text)),
48
+ 'error_count': len(self.patterns['errors'].findall(raw_text)),
49
+ 'pause_count': len(self.patterns['pauses'].findall(raw_text))
50
+ }
51
+ clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
52
+ clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
53
+ clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
54
+ words = clean_for_stats.lower().split()
55
+ stats['word_count'] = len(words)
56
+ return stats
57
+
58
+ def get_feature_vector(self, raw_text):
59
+ stats = self.get_features(raw_text)
60
+ n = stats['word_count'] if stats['word_count'] > 0 else 1
61
+
62
+ # Calculate TTR (Type-Token Ratio)
63
+ clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
64
+ clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
65
+ clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
66
+ words = clean_for_stats.lower().split()
67
+ ttr = (len(set(words)) / n) if n > 0 else 0.0
68
+
69
+ return np.array([
70
+ ttr,
71
+ stats['filler_count'] / n,
72
+ stats['repetition_count'] / n,
73
+ stats['retracing_count'] / n,
74
+ stats['error_count'] / n,
75
+ stats['pause_count'] / n
76
+ ], dtype=np.float32)
77
+
78
+ # ==========================================
79
+ # 2. Audio Processor
80
+ # ==========================================
81
+ class AudioProcessor:
82
+ def __init__(self):
83
+ self.vit_transforms = transforms.Compose([
84
+ transforms.Resize((224, 224)),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
87
+ ])
88
+
89
+ def create_spectrogram_tensor(self, audio_path, intervals=None):
90
+ """
91
+ Generates spectrogram image and transforms it to Tensor.
92
+ """
93
+ try:
94
+ fig = plt.figure(figsize=(2.24, 2.24), dpi=100)
95
+ ax = fig.add_subplot(1, 1, 1)
96
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
97
+
98
+ if intervals:
99
+ # Load full audio then slice based on timestamps
100
+ y, sr = librosa.load(audio_path, sr=None)
101
+ clips = []
102
+ for start_ms, end_ms in intervals:
103
+ start_sample = int(start_ms * sr / 1000)
104
+ end_sample = int(end_ms * sr / 1000)
105
+ if end_sample > len(y): end_sample = len(y)
106
+ if start_sample < len(y):
107
+ clips.append(y[start_sample:end_sample])
108
+ if clips:
109
+ y = np.concatenate(clips)
110
+ else:
111
+ y = np.zeros(int(sr*30))
112
+
113
+ # Limit to 30s
114
+ if len(y) > 30 * sr:
115
+ y = y[:30 * sr]
116
+ else:
117
+ y, sr = librosa.load(audio_path, duration=30)
118
+
119
+ ms = librosa.feature.melspectrogram(y=y, sr=sr)
120
+ log_ms = librosa.power_to_db(ms, ref=np.max)
121
+ librosa.display.specshow(log_ms, sr=sr, ax=ax)
122
+
123
+ # Save to buffer instead of file
124
+ from io import BytesIO
125
+ buf = BytesIO()
126
+ fig.savefig(buf, format='png')
127
+ plt.close(fig)
128
+ buf.seek(0)
129
+
130
+ image = Image.open(buf).convert('RGB')
131
+ return self.vit_transforms(image).unsqueeze(0)
132
+
133
+ except Exception as e:
134
+ print(f"Spectrogram creation failed: {e}")
135
+ return torch.zeros((1, 3, 224, 224))
136
+
137
+ # ==========================================
138
+ # 3. ASR Helper (Whisper + CHAT Rules)
139
+ # ==========================================
140
+ def apply_chat_rules(transcription_result):
141
+ """
142
+ Converts Whisper result into CHAT-like format AND inserts [PAUSE] tokens.
143
+ """
144
+ formatted_text = []
145
+ segments = transcription_result.get('segments', [])
146
+ last_end = 0
147
+
148
+ for seg in segments:
149
+ gap = seg['start'] - last_end
150
+ # Insert [PAUSE] token + CHAT marker
151
+ if gap > 0.8:
152
+ formatted_text.append("[PAUSE] (..)")
153
+ elif gap > 0.3:
154
+ formatted_text.append("[PAUSE] (.)")
155
+
156
+ text = seg['text'].strip()
157
+
158
+ # Repetitions (Basic Detection)
159
+ words = text.split()
160
+ processed_words = []
161
+ for i, w in enumerate(words):
162
+ clean_w = re.sub(r'[^a-zA-Z]', '', w.lower())
163
+ if i > 0:
164
+ prev_clean = re.sub(r'[^a-zA-Z]', '', words[i-1].lower())
165
+ if clean_w == prev_clean and clean_w:
166
+ processed_words[-1] = f"{words[i-1]} [/]"
167
+ processed_words.append(w)
168
+
169
+ formatted_text.append(" ".join(processed_words))
170
+ last_end = seg['end']
171
+
172
+ return " ".join(formatted_text)
models/model_v3/wrapper.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer
4
+ from typing import Optional
5
+ import os
6
+ import tempfile
7
+ import whisper
8
+ import re
9
+
10
+ from models.base import BaseModelWrapper
11
+ from .model import MultimodalFusion
12
+ from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules
13
+ from .config import WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN
14
+
15
+ class MultimodalWrapper(BaseModelWrapper):
16
+ def __init__(self):
17
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ self.model = None
19
+ self.tokenizer = None
20
+ self.asr_model = None
21
+ self.ling_extractor = LinguisticFeatureExtractor()
22
+ self.audio_processor = AudioProcessor()
23
+
24
+ def load(self):
25
+ print("Loading Model V3 components...")
26
+ self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
27
+ self.model = MultimodalFusion(TEXT_MODEL_NAME)
28
+
29
+ # Load Weights
30
+ if torch.cuda.is_available():
31
+ state_dict = torch.load(WEIGHTS_PATH)
32
+ else:
33
+ state_dict = torch.load(WEIGHTS_PATH, map_location=torch.device('cpu'))
34
+ self.model.load_state_dict(state_dict)
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ # Load Whisper (Base model as per notebook)
39
+ print("Loading Whisper for Audio-Only Inference...")
40
+ self.asr_model = whisper.load_model("base")
41
+
42
+ def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
43
+ """
44
+ Handles 3 scenarios:
45
+ 1. CHA only: file_content is CHA.
46
+ 2. CHA + Audio: file_content is CHA, audio_content is Audio.
47
+ 3. Audio only: file_content is likely empty/dummy, audio_content is Audio.
48
+ """
49
+
50
+ # Determine Scenario
51
+ is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
52
+ has_audio = audio_content is not None and len(audio_content) > 0
53
+
54
+ processed_text = ""
55
+ ling_features = None
56
+ audio_tensor = None
57
+ intervals = []
58
+
59
+ # --- SCENARIO 3: PURE AUDIO (New file, generate transcript) ---
60
+ if not is_cha_provided and has_audio:
61
+ print("Processing Mode: Audio Only (ASR)")
62
+
63
+ # Save audio to temp file for Whisper/Librosa
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
65
+ tmp_audio.write(audio_content)
66
+ tmp_path = tmp_audio.name
67
+
68
+ try:
69
+ # 1. Transcribe
70
+ result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
71
+ # 2. Apply Rules
72
+ chat_transcript = apply_chat_rules(result)
73
+ processed_text = chat_transcript # No BERT cleaning needed on Whisper output usually, or minimal
74
+
75
+ # 3. Extract Features from generated text
76
+ # We need to manually calculating stats like the ASR notebook section does
77
+ # because the ASR output doesn't have the exact same format as raw CHA
78
+ stats = self.ling_extractor.get_features(chat_transcript)
79
+ pause_count = chat_transcript.count("[PAUSE]")
80
+ repetition_count = chat_transcript.count("[/]")
81
+
82
+ # TTR Calc
83
+ clean_t = re.sub(r'\[.*?\]', '', chat_transcript)
84
+ clean_t = re.sub(r'[^\w\s]', '', clean_t)
85
+ words = clean_t.lower().split()
86
+ n = len(words) if len(words) > 0 else 1
87
+ ttr = len(set(words)) / n
88
+
89
+ ling_vec = [
90
+ ttr,
91
+ stats['filler_count'] / n,
92
+ repetition_count / n,
93
+ stats['retracing_count'] / n,
94
+ stats['error_count'] / n,
95
+ pause_count / n
96
+ ]
97
+ ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0)
98
+
99
+ # 4. Generate Spectrogram (Whole file, no intervals)
100
+ audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
101
+
102
+ finally:
103
+ os.remove(tmp_path)
104
+
105
+ # --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
106
+ else:
107
+ # Parse Text from CHA
108
+ text_str = file_content.decode('utf-8', errors='replace')
109
+ par_lines = []
110
+
111
+ # Regex to find timestamps: 123_456
112
+ # Matches functionality in 'load_and_process_data' -> 'process_dir'
113
+ full_text_for_intervals = ""
114
+
115
+ for line in text_str.splitlines():
116
+ if line.startswith('*PAR:'):
117
+ content = line[5:].strip()
118
+ par_lines.append(content)
119
+ full_text_for_intervals += content + " "
120
+
121
+ raw_text = " ".join(par_lines)
122
+ processed_text = self.ling_extractor.clean_for_bert(raw_text)
123
+
124
+ # Extract Features
125
+ feats = self.ling_extractor.get_feature_vector(raw_text)
126
+ ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
127
+
128
+ # --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
129
+ if has_audio:
130
+ print("Processing Mode: CHA + Audio (Segmentation)")
131
+ # Extract intervals from the raw text (containing the bullets)
132
+ # Notebook regex: re.findall(r'\x15(\d+)_(\d+)\x15', text_content)
133
+ found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
134
+ intervals = [(int(s), int(e)) for s, e in found_intervals]
135
+
136
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
137
+ tmp_audio.write(audio_content)
138
+ tmp_path = tmp_audio.name
139
+
140
+ try:
141
+ # Pass intervals to slice specific PAR audio
142
+ audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
143
+ finally:
144
+ os.remove(tmp_path)
145
+
146
+ # --- SCENARIO 1: CHA ONLY ---
147
+ else:
148
+ print("Processing Mode: CHA Only")
149
+ audio_tensor = torch.zeros((1, 3, 224, 224))
150
+
151
+ # --- COMMON INFERENCE STEPS ---
152
+ encoding = self.tokenizer.encode_plus(
153
+ processed_text,
154
+ add_special_tokens=True,
155
+ max_length=MAX_LEN,
156
+ padding='max_length',
157
+ truncation=True,
158
+ return_attention_mask=True,
159
+ return_tensors='pt'
160
+ )
161
+
162
+ with torch.no_grad():
163
+ input_ids = encoding['input_ids'].to(self.device)
164
+ mask = encoding['attention_mask'].to(self.device)
165
+ pixel_values = audio_tensor.to(self.device)
166
+ ling_input = ling_features.to(self.device)
167
+
168
+ outputs = self.model(input_ids, mask, pixel_values, ling_input)
169
+ probs = F.softmax(outputs, dim=1)
170
+ pred_idx = torch.argmax(probs, dim=1).item()
171
+ confidence = probs[0][pred_idx].item()
172
+
173
+ label_map = {0: 'Control', 1: 'AD'}
174
+
175
+ return {
176
+ "model_version": "v3_multimodal",
177
+ "filename": filename if filename else "audio_upload",
178
+ "predicted_label": label_map[pred_idx],
179
+ "confidence": round(confidence, 4),
180
+ "modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
181
+ "generated_transcript": processed_text if not is_cha_provided else None
182
+ }