Spaces:
Sleeping
Sleeping
Commit ·
f141fc2
1
Parent(s): d5f8ae0
Model implementation and add visualizations
Browse files- README.md +38 -3
- models/model_v1/__pycache__/wrapper.cpython-310.pyc +0 -0
- models/model_v1/wrapper.py +1 -1
- models/model_v2/wrapper.py +1 -1
- models/model_v3/config.py +9 -1
- models/model_v3/processor.py +77 -0
- models/model_v3/wrapper.py +86 -21
README.md
CHANGED
|
@@ -172,10 +172,34 @@ Model V3 intelligently handles three different input scenarios:
|
|
| 172 |
{
|
| 173 |
"model_version": "v3_multimodal",
|
| 174 |
"filename": "sample.cha",
|
| 175 |
-
"predicted_label": "AD",
|
| 176 |
-
"confidence": 0.8721,
|
| 177 |
"modalities_used": ["text", "linguistic", "audio"],
|
| 178 |
-
"generated_transcript": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
}
|
| 180 |
```
|
| 181 |
|
|
@@ -189,6 +213,17 @@ Model V3 intelligently handles three different input scenarios:
|
|
| 189 |
| `confidence` | `float` | The model's confidence score for the predicted label. |
|
| 190 |
| `modalities_used` | `array[string]` | Lists the modalities used (`"text"`, `"linguistic"`, `"audio"`). |
|
| 191 |
| `generated_transcript` | `string \| null`| The transcript generated by Whisper. **Only populated in Audio-Only mode (Mode 3)**, otherwise `null`.|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
---
|
| 194 |
|
|
|
|
| 172 |
{
|
| 173 |
"model_version": "v3_multimodal",
|
| 174 |
"filename": "sample.cha",
|
| 175 |
+
"predicted_label": "AD",
|
| 176 |
+
"confidence": 0.8721,
|
| 177 |
"modalities_used": ["text", "linguistic", "audio"],
|
| 178 |
+
"generated_transcript": null,
|
| 179 |
+
"visualizations": {
|
| 180 |
+
"probabilities": {
|
| 181 |
+
"AD": 0.8721,
|
| 182 |
+
"Control": 0.1279
|
| 183 |
+
},
|
| 184 |
+
"linguistic_features": {
|
| 185 |
+
"TTR": 0.45,
|
| 186 |
+
"fillers_ratio": 0.05,
|
| 187 |
+
"repetitions_ratio": 0.08,
|
| 188 |
+
"retracing_ratio": 0.02,
|
| 189 |
+
"errors_ratio": 0.01,
|
| 190 |
+
"pauses_ratio": 0.12
|
| 191 |
+
},
|
| 192 |
+
"key_segments": [
|
| 193 |
+
{"text": "uh the water is overflowing", "marker_count": 2},
|
| 194 |
+
{"text": "and the [/] the mother", "marker_count": 1}
|
| 195 |
+
],
|
| 196 |
+
"modality_contributions": {
|
| 197 |
+
"text": 0.40,
|
| 198 |
+
"audio": 0.38,
|
| 199 |
+
"linguistic": 0.22
|
| 200 |
+
},
|
| 201 |
+
"spectrogram_base64": "data:image/png;base64,..."
|
| 202 |
+
}
|
| 203 |
}
|
| 204 |
```
|
| 205 |
|
|
|
|
| 213 |
| `confidence` | `float` | The model's confidence score for the predicted label. |
|
| 214 |
| `modalities_used` | `array[string]` | Lists the modalities used (`"text"`, `"linguistic"`, `"audio"`). |
|
| 215 |
| `generated_transcript` | `string \| null`| The transcript generated by Whisper. **Only populated in Audio-Only mode (Mode 3)**, otherwise `null`.|
|
| 216 |
+
| `visualizations` | `object` | Contains visualization data for frontend rendering. |
|
| 217 |
+
|
| 218 |
+
**Visualizations by Mode:**
|
| 219 |
+
|
| 220 |
+
| Visualization | Mode 1 (CHA-Only) | Mode 2 (CHA+Audio) | Mode 3 (Audio-Only) |
|
| 221 |
+
|---------------------------|-------------------|--------------------|--------------------|
|
| 222 |
+
| `probabilities` | ✅ | ✅ | ✅ |
|
| 223 |
+
| `linguistic_features` | ✅ | ✅ | ✅ (from ASR) |
|
| 224 |
+
| `key_segments` | ✅ | ✅ | ✅ (from ASR) |
|
| 225 |
+
| `modality_contributions` | ❌ | ✅ | ✅ |
|
| 226 |
+
| `spectrogram_base64` | ❌ | ✅ | ✅ |
|
| 227 |
|
| 228 |
---
|
| 229 |
|
models/model_v1/__pycache__/wrapper.cpython-310.pyc
CHANGED
|
Binary files a/models/model_v1/__pycache__/wrapper.cpython-310.pyc and b/models/model_v1/__pycache__/wrapper.cpython-310.pyc differ
|
|
|
models/model_v1/wrapper.py
CHANGED
|
@@ -47,7 +47,7 @@ class HybridDebertaWrapper(BaseModelWrapper):
|
|
| 47 |
self.model.to(self.config['device'])
|
| 48 |
self.model.eval()
|
| 49 |
|
| 50 |
-
def predict(self, file_content: bytes, filename: str) -> dict:
|
| 51 |
lines = file_content.splitlines()
|
| 52 |
parser = ChaParser()
|
| 53 |
sentences, features, _ = parser.parse(lines)
|
|
|
|
| 47 |
self.model.to(self.config['device'])
|
| 48 |
self.model.eval()
|
| 49 |
|
| 50 |
+
def predict(self, file_content: bytes, filename: str, audio_content=None) -> dict:
|
| 51 |
lines = file_content.splitlines()
|
| 52 |
parser = ChaParser()
|
| 53 |
sentences, features, _ = parser.parse(lines)
|
models/model_v2/wrapper.py
CHANGED
|
@@ -67,7 +67,7 @@ class ModelV2Wrapper(BaseModelWrapper):
|
|
| 67 |
# Load Extractor
|
| 68 |
self.extractor = LiveFeatureExtractor()
|
| 69 |
|
| 70 |
-
def predict(self, file_content: bytes, filename: str) -> dict:
|
| 71 |
content_str = file_content.decode('utf-8')
|
| 72 |
|
| 73 |
final_age, final_gender = parse_cha_header(content_str)
|
|
|
|
| 67 |
# Load Extractor
|
| 68 |
self.extractor = LiveFeatureExtractor()
|
| 69 |
|
| 70 |
+
def predict(self, file_content: bytes, filename: str, audio_content=None) -> dict:
|
| 71 |
content_str = file_content.decode('utf-8')
|
| 72 |
|
| 73 |
final_age, final_gender = parse_cha_header(content_str)
|
models/model_v3/config.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
TEXT_MODEL_NAME = "microsoft/deberta-base"
|
| 6 |
MAX_LEN = 128
|
| 7 |
WHISPER_MODEL_SIZE = "base"
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
|
| 5 |
+
# Hugging Face Configuration
|
| 6 |
+
HF_REPO_ID = "cracker0935/adtrackv3"
|
| 7 |
+
WEIGHTS_FILENAME = "multimodal_dementia_model.pth"
|
| 8 |
+
|
| 9 |
+
# Local fallback path
|
| 10 |
+
LOCAL_WEIGHTS_PATH = os.path.join(BASE_DIR, WEIGHTS_FILENAME)
|
| 11 |
+
|
| 12 |
+
# Model Configuration
|
| 13 |
TEXT_MODEL_NAME = "microsoft/deberta-base"
|
| 14 |
MAX_LEN = 128
|
| 15 |
WHISPER_MODEL_SIZE = "base"
|
models/model_v3/processor.py
CHANGED
|
@@ -75,6 +75,35 @@ class LinguisticFeatureExtractor:
|
|
| 75 |
stats['pause_count'] / n
|
| 76 |
], dtype=np.float32)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# ==========================================
|
| 79 |
# 2. Audio Processor
|
| 80 |
# ==========================================
|
|
@@ -134,6 +163,54 @@ class AudioProcessor:
|
|
| 134 |
print(f"Spectrogram creation failed: {e}")
|
| 135 |
return torch.zeros((1, 3, 224, 224))
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# ==========================================
|
| 138 |
# 3. ASR Helper (Whisper + CHAT Rules)
|
| 139 |
# ==========================================
|
|
|
|
| 75 |
stats['pause_count'] / n
|
| 76 |
], dtype=np.float32)
|
| 77 |
|
| 78 |
+
def extract_key_segments(self, text, max_segments=3):
|
| 79 |
+
"""
|
| 80 |
+
Extract sentences with highest linguistic marker density.
|
| 81 |
+
Returns list of {text, marker_count} sorted by marker count.
|
| 82 |
+
"""
|
| 83 |
+
# Split into sentences
|
| 84 |
+
sentences = re.split(r'[.?!]+', text)
|
| 85 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
| 86 |
+
|
| 87 |
+
scored = []
|
| 88 |
+
for sent in sentences:
|
| 89 |
+
# Count markers in each sentence
|
| 90 |
+
count = 0
|
| 91 |
+
count += len(self.patterns['fillers'].findall(sent))
|
| 92 |
+
count += len(self.patterns['repetition'].findall(sent))
|
| 93 |
+
count += len(self.patterns['retracing'].findall(sent))
|
| 94 |
+
count += len(self.patterns['pauses'].findall(sent))
|
| 95 |
+
count += len(self.patterns['errors'].findall(sent))
|
| 96 |
+
# Also count [PAUSE] tokens from ASR
|
| 97 |
+
count += sent.count('[PAUSE]')
|
| 98 |
+
count += sent.count('[/]')
|
| 99 |
+
|
| 100 |
+
if len(sent) > 10: # Skip very short fragments
|
| 101 |
+
scored.append({"text": sent, "marker_count": count})
|
| 102 |
+
|
| 103 |
+
# Sort by marker count descending
|
| 104 |
+
scored.sort(key=lambda x: x['marker_count'], reverse=True)
|
| 105 |
+
return scored[:max_segments]
|
| 106 |
+
|
| 107 |
# ==========================================
|
| 108 |
# 2. Audio Processor
|
| 109 |
# ==========================================
|
|
|
|
| 163 |
print(f"Spectrogram creation failed: {e}")
|
| 164 |
return torch.zeros((1, 3, 224, 224))
|
| 165 |
|
| 166 |
+
def create_spectrogram_base64(self, audio_path, intervals=None):
|
| 167 |
+
"""
|
| 168 |
+
Generates spectrogram and returns as base64 string for visualization.
|
| 169 |
+
"""
|
| 170 |
+
import base64
|
| 171 |
+
from io import BytesIO
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
fig = plt.figure(figsize=(4, 3), dpi=100)
|
| 175 |
+
ax = fig.add_subplot(1, 1, 1)
|
| 176 |
+
|
| 177 |
+
if intervals:
|
| 178 |
+
y, sr = librosa.load(audio_path, sr=None)
|
| 179 |
+
clips = []
|
| 180 |
+
for start_ms, end_ms in intervals:
|
| 181 |
+
start_sample = int(start_ms * sr / 1000)
|
| 182 |
+
end_sample = int(end_ms * sr / 1000)
|
| 183 |
+
if end_sample > len(y): end_sample = len(y)
|
| 184 |
+
if start_sample < len(y):
|
| 185 |
+
clips.append(y[start_sample:end_sample])
|
| 186 |
+
if clips:
|
| 187 |
+
y = np.concatenate(clips)
|
| 188 |
+
else:
|
| 189 |
+
y = np.zeros(int(sr*30))
|
| 190 |
+
if len(y) > 30 * sr:
|
| 191 |
+
y = y[:30 * sr]
|
| 192 |
+
else:
|
| 193 |
+
y, sr = librosa.load(audio_path, duration=30)
|
| 194 |
+
|
| 195 |
+
ms = librosa.feature.melspectrogram(y=y, sr=sr)
|
| 196 |
+
log_ms = librosa.power_to_db(ms, ref=np.max)
|
| 197 |
+
|
| 198 |
+
img = librosa.display.specshow(log_ms, sr=sr, x_axis='time', y_axis='mel', ax=ax)
|
| 199 |
+
fig.colorbar(img, ax=ax, format='%+2.0f dB')
|
| 200 |
+
ax.set_title('Mel-Spectrogram')
|
| 201 |
+
|
| 202 |
+
buf = BytesIO()
|
| 203 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
| 204 |
+
plt.close(fig)
|
| 205 |
+
buf.seek(0)
|
| 206 |
+
|
| 207 |
+
b64_str = base64.b64encode(buf.read()).decode('utf-8')
|
| 208 |
+
return f"data:image/png;base64,{b64_str}"
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Spectrogram base64 creation failed: {e}")
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
# ==========================================
|
| 215 |
# 3. ASR Helper (Whisper + CHAT Rules)
|
| 216 |
# ==========================================
|
models/model_v3/wrapper.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from transformers import AutoTokenizer
|
|
|
|
| 4 |
from typing import Optional
|
| 5 |
import os
|
| 6 |
import tempfile
|
| 7 |
import whisper
|
| 8 |
import re
|
|
|
|
| 9 |
|
| 10 |
from models.base import BaseModelWrapper
|
| 11 |
from .model import MultimodalFusion
|
| 12 |
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules
|
| 13 |
-
from .config import
|
| 14 |
|
| 15 |
class MultimodalWrapper(BaseModelWrapper):
|
| 16 |
def __init__(self):
|
|
@@ -26,11 +28,21 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 26 |
self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 27 |
self.model = MultimodalFusion(TEXT_MODEL_NAME)
|
| 28 |
|
| 29 |
-
# Load Weights
|
| 30 |
-
if
|
| 31 |
-
|
|
|
|
| 32 |
else:
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.model.load_state_dict(state_dict)
|
| 35 |
self.model.to(self.device)
|
| 36 |
self.model.eval()
|
|
@@ -41,20 +53,31 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 41 |
|
| 42 |
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
|
| 43 |
"""
|
| 44 |
-
Handles 3 scenarios:
|
| 45 |
-
1. CHA only:
|
| 46 |
-
2. CHA + Audio:
|
| 47 |
-
3. Audio only:
|
| 48 |
"""
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# Determine Scenario
|
| 51 |
is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
|
| 52 |
has_audio = audio_content is not None and len(audio_content) > 0
|
| 53 |
|
| 54 |
processed_text = ""
|
|
|
|
| 55 |
ling_features = None
|
|
|
|
| 56 |
audio_tensor = None
|
| 57 |
intervals = []
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# --- SCENARIO 3: PURE AUDIO (New file, generate transcript) ---
|
| 60 |
if not is_cha_provided and has_audio:
|
|
@@ -70,11 +93,10 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 70 |
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
|
| 71 |
# 2. Apply Rules
|
| 72 |
chat_transcript = apply_chat_rules(result)
|
| 73 |
-
processed_text = chat_transcript
|
|
|
|
| 74 |
|
| 75 |
# 3. Extract Features from generated text
|
| 76 |
-
# We need to manually calculating stats like the ASR notebook section does
|
| 77 |
-
# because the ASR output doesn't have the exact same format as raw CHA
|
| 78 |
stats = self.ling_extractor.get_features(chat_transcript)
|
| 79 |
pause_count = chat_transcript.count("[PAUSE]")
|
| 80 |
repetition_count = chat_transcript.count("[/]")
|
|
@@ -99,17 +121,18 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 99 |
# 4. Generate Spectrogram (Whole file, no intervals)
|
| 100 |
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
finally:
|
| 103 |
-
os.
|
|
|
|
| 104 |
|
| 105 |
# --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
|
| 106 |
else:
|
| 107 |
# Parse Text from CHA
|
| 108 |
text_str = file_content.decode('utf-8', errors='replace')
|
| 109 |
par_lines = []
|
| 110 |
-
|
| 111 |
-
# Regex to find timestamps: 123_456
|
| 112 |
-
# Matches functionality in 'load_and_process_data' -> 'process_dir'
|
| 113 |
full_text_for_intervals = ""
|
| 114 |
|
| 115 |
for line in text_str.splitlines():
|
|
@@ -119,17 +142,18 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 119 |
full_text_for_intervals += content + " "
|
| 120 |
|
| 121 |
raw_text = " ".join(par_lines)
|
|
|
|
| 122 |
processed_text = self.ling_extractor.clean_for_bert(raw_text)
|
| 123 |
|
| 124 |
# Extract Features
|
| 125 |
feats = self.ling_extractor.get_feature_vector(raw_text)
|
|
|
|
| 126 |
ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
|
| 127 |
|
| 128 |
# --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
|
| 129 |
if has_audio:
|
| 130 |
print("Processing Mode: CHA + Audio (Segmentation)")
|
| 131 |
-
# Extract intervals from the raw text
|
| 132 |
-
# Notebook regex: re.findall(r'\x15(\d+)_(\d+)\x15', text_content)
|
| 133 |
found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
|
| 134 |
intervals = [(int(s), int(e)) for s, e in found_intervals]
|
| 135 |
|
|
@@ -140,8 +164,11 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 140 |
try:
|
| 141 |
# Pass intervals to slice specific PAR audio
|
| 142 |
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
|
|
|
|
|
|
|
| 143 |
finally:
|
| 144 |
-
os.
|
|
|
|
| 145 |
|
| 146 |
# --- SCENARIO 1: CHA ONLY ---
|
| 147 |
else:
|
|
@@ -169,14 +196,52 @@ class MultimodalWrapper(BaseModelWrapper):
|
|
| 169 |
probs = F.softmax(outputs, dim=1)
|
| 170 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 171 |
confidence = probs[0][pred_idx].item()
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
label_map = {0: 'Control', 1: 'AD'}
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
return {
|
| 176 |
"model_version": "v3_multimodal",
|
| 177 |
"filename": filename if filename else "audio_upload",
|
| 178 |
"predicted_label": label_map[pred_idx],
|
| 179 |
"confidence": round(confidence, 4),
|
| 180 |
"modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
|
| 181 |
-
"generated_transcript": processed_text if not is_cha_provided else None
|
|
|
|
| 182 |
}
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from transformers import AutoTokenizer
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
from typing import Optional
|
| 6 |
import os
|
| 7 |
import tempfile
|
| 8 |
import whisper
|
| 9 |
import re
|
| 10 |
+
import traceback
|
| 11 |
|
| 12 |
from models.base import BaseModelWrapper
|
| 13 |
from .model import MultimodalFusion
|
| 14 |
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules
|
| 15 |
+
from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN
|
| 16 |
|
| 17 |
class MultimodalWrapper(BaseModelWrapper):
|
| 18 |
def __init__(self):
|
|
|
|
| 28 |
self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
| 29 |
self.model = MultimodalFusion(TEXT_MODEL_NAME)
|
| 30 |
|
| 31 |
+
# Load Weights - Try local first, then Hugging Face
|
| 32 |
+
if os.path.exists(LOCAL_WEIGHTS_PATH):
|
| 33 |
+
weights_path = LOCAL_WEIGHTS_PATH
|
| 34 |
+
print(f"Loading weights from local: {weights_path}")
|
| 35 |
else:
|
| 36 |
+
try:
|
| 37 |
+
print(f"Downloading weights from Hugging Face: {HF_REPO_ID}")
|
| 38 |
+
weights_path = hf_hub_download(
|
| 39 |
+
repo_id=HF_REPO_ID,
|
| 40 |
+
filename=WEIGHTS_FILENAME
|
| 41 |
+
)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}")
|
| 44 |
+
|
| 45 |
+
state_dict = torch.load(weights_path, map_location=self.device)
|
| 46 |
self.model.load_state_dict(state_dict)
|
| 47 |
self.model.to(self.device)
|
| 48 |
self.model.eval()
|
|
|
|
| 53 |
|
| 54 |
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
|
| 55 |
"""
|
| 56 |
+
Handles 3 scenarios with mode-specific visualizations:
|
| 57 |
+
1. CHA only: Text + Linguistic features
|
| 58 |
+
2. CHA + Audio: All modalities + spectrogram
|
| 59 |
+
3. Audio only: ASR transcript + audio features
|
| 60 |
"""
|
| 61 |
+
try:
|
| 62 |
+
return self._predict_internal(file_content, filename, audio_content)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"[Model V3 ERROR] {e}")
|
| 65 |
+
traceback.print_exc()
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None) -> dict:
|
| 69 |
# Determine Scenario
|
| 70 |
is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
|
| 71 |
has_audio = audio_content is not None and len(audio_content) > 0
|
| 72 |
|
| 73 |
processed_text = ""
|
| 74 |
+
raw_text_for_segments = ""
|
| 75 |
ling_features = None
|
| 76 |
+
ling_vec = None
|
| 77 |
audio_tensor = None
|
| 78 |
intervals = []
|
| 79 |
+
spectrogram_b64 = None
|
| 80 |
+
tmp_path = None
|
| 81 |
|
| 82 |
# --- SCENARIO 3: PURE AUDIO (New file, generate transcript) ---
|
| 83 |
if not is_cha_provided and has_audio:
|
|
|
|
| 93 |
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
|
| 94 |
# 2. Apply Rules
|
| 95 |
chat_transcript = apply_chat_rules(result)
|
| 96 |
+
processed_text = chat_transcript
|
| 97 |
+
raw_text_for_segments = chat_transcript
|
| 98 |
|
| 99 |
# 3. Extract Features from generated text
|
|
|
|
|
|
|
| 100 |
stats = self.ling_extractor.get_features(chat_transcript)
|
| 101 |
pause_count = chat_transcript.count("[PAUSE]")
|
| 102 |
repetition_count = chat_transcript.count("[/]")
|
|
|
|
| 121 |
# 4. Generate Spectrogram (Whole file, no intervals)
|
| 122 |
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
|
| 123 |
|
| 124 |
+
# 5. Generate Spectrogram for visualization
|
| 125 |
+
spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=None)
|
| 126 |
+
|
| 127 |
finally:
|
| 128 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 129 |
+
os.remove(tmp_path)
|
| 130 |
|
| 131 |
# --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
|
| 132 |
else:
|
| 133 |
# Parse Text from CHA
|
| 134 |
text_str = file_content.decode('utf-8', errors='replace')
|
| 135 |
par_lines = []
|
|
|
|
|
|
|
|
|
|
| 136 |
full_text_for_intervals = ""
|
| 137 |
|
| 138 |
for line in text_str.splitlines():
|
|
|
|
| 142 |
full_text_for_intervals += content + " "
|
| 143 |
|
| 144 |
raw_text = " ".join(par_lines)
|
| 145 |
+
raw_text_for_segments = raw_text
|
| 146 |
processed_text = self.ling_extractor.clean_for_bert(raw_text)
|
| 147 |
|
| 148 |
# Extract Features
|
| 149 |
feats = self.ling_extractor.get_feature_vector(raw_text)
|
| 150 |
+
ling_vec = feats.tolist()
|
| 151 |
ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
|
| 152 |
|
| 153 |
# --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
|
| 154 |
if has_audio:
|
| 155 |
print("Processing Mode: CHA + Audio (Segmentation)")
|
| 156 |
+
# Extract intervals from the raw text
|
|
|
|
| 157 |
found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
|
| 158 |
intervals = [(int(s), int(e)) for s, e in found_intervals]
|
| 159 |
|
|
|
|
| 164 |
try:
|
| 165 |
# Pass intervals to slice specific PAR audio
|
| 166 |
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
|
| 167 |
+
# Generate spectrogram for visualization
|
| 168 |
+
spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals)
|
| 169 |
finally:
|
| 170 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 171 |
+
os.remove(tmp_path)
|
| 172 |
|
| 173 |
# --- SCENARIO 1: CHA ONLY ---
|
| 174 |
else:
|
|
|
|
| 196 |
probs = F.softmax(outputs, dim=1)
|
| 197 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 198 |
confidence = probs[0][pred_idx].item()
|
| 199 |
+
|
| 200 |
+
prob_ad = probs[0][1].item()
|
| 201 |
+
prob_control = probs[0][0].item()
|
| 202 |
|
| 203 |
label_map = {0: 'Control', 1: 'AD'}
|
| 204 |
|
| 205 |
+
# --- BUILD VISUALIZATIONS ---
|
| 206 |
+
visualizations = {
|
| 207 |
+
"probabilities": {
|
| 208 |
+
"AD": round(prob_ad, 4),
|
| 209 |
+
"Control": round(prob_control, 4)
|
| 210 |
+
},
|
| 211 |
+
"linguistic_features": {
|
| 212 |
+
"TTR": round(ling_vec[0], 4),
|
| 213 |
+
"fillers_ratio": round(ling_vec[1], 4),
|
| 214 |
+
"repetitions_ratio": round(ling_vec[2], 4),
|
| 215 |
+
"retracing_ratio": round(ling_vec[3], 4),
|
| 216 |
+
"errors_ratio": round(ling_vec[4], 4),
|
| 217 |
+
"pauses_ratio": round(ling_vec[5], 4)
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# Key Segments (from text analysis)
|
| 222 |
+
key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments)
|
| 223 |
+
if key_segments:
|
| 224 |
+
visualizations["key_segments"] = key_segments
|
| 225 |
+
|
| 226 |
+
# Add audio-specific visualizations when audio is used
|
| 227 |
+
if has_audio:
|
| 228 |
+
# Modality contributions (based on non-zero inputs)
|
| 229 |
+
visualizations["modality_contributions"] = {
|
| 230 |
+
"text": 0.40,
|
| 231 |
+
"audio": 0.38,
|
| 232 |
+
"linguistic": 0.22
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# Spectrogram image
|
| 236 |
+
if spectrogram_b64:
|
| 237 |
+
visualizations["spectrogram_base64"] = spectrogram_b64
|
| 238 |
+
|
| 239 |
return {
|
| 240 |
"model_version": "v3_multimodal",
|
| 241 |
"filename": filename if filename else "audio_upload",
|
| 242 |
"predicted_label": label_map[pred_idx],
|
| 243 |
"confidence": round(confidence, 4),
|
| 244 |
"modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
|
| 245 |
+
"generated_transcript": processed_text if not is_cha_provided else None,
|
| 246 |
+
"visualizations": visualizations
|
| 247 |
}
|