Spaces:
Sleeping
Sleeping
File size: 12,475 Bytes
d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 e824b96 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 e824b96 d5f8ae0 e824b96 d5f8ae0 f141fc2 e824b96 f141fc2 e824b96 d5f8ae0 e824b96 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 e824b96 d5f8ae0 e824b96 f141fc2 d5f8ae0 e824b96 d5f8ae0 f141fc2 d5f8ae0 e824b96 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 d5f8ae0 f141fc2 4fcfcbc f141fc2 4fcfcbc f141fc2 d5f8ae0 f141fc2 d5f8ae0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from typing import Optional
import os
import tempfile
import whisper
import re
import traceback
from models.base import BaseModelWrapper
from .model import MultimodalFusion
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules, parse_segmentation_csv
from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN, SEGMENTATION_ROOT_PATH
def find_segmentation_file(audio_filename: str) -> Optional[bytes]:
"""
Search for segmentation CSV file based on audio filename.
Looks in SEGMENTATION_ROOT_PATH/AD/ and SEGMENTATION_ROOT_PATH/Control/.
Returns CSV content as bytes if found, None otherwise.
"""
if not SEGMENTATION_ROOT_PATH or not os.path.exists(SEGMENTATION_ROOT_PATH):
return None
# Get base filename without extension
base_name = os.path.splitext(os.path.basename(audio_filename))[0]
# Search in both AD and Control folders
for subfolder in ['AD', 'Control']:
csv_path = os.path.join(SEGMENTATION_ROOT_PATH, subfolder, f"{base_name}.csv")
if os.path.exists(csv_path):
print(f" Found segmentation file: {csv_path}")
with open(csv_path, 'rb') as f:
return f.read()
return None
class MultimodalWrapper(BaseModelWrapper):
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.tokenizer = None
self.asr_model = None
self.ling_extractor = LinguisticFeatureExtractor()
self.audio_processor = AudioProcessor()
def load(self):
print("Loading Model V3 components...")
self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
self.model = MultimodalFusion(TEXT_MODEL_NAME)
# Load Weights - Try local first, then Hugging Face
if os.path.exists(LOCAL_WEIGHTS_PATH):
weights_path = LOCAL_WEIGHTS_PATH
print(f"Loading weights from local: {weights_path}")
else:
try:
print(f"Downloading weights from Hugging Face: {HF_REPO_ID}")
weights_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=WEIGHTS_FILENAME
)
except Exception as e:
raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}")
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Load Whisper (Base model as per notebook)
print("Loading Whisper for Audio-Only Inference...")
self.asr_model = whisper.load_model("base")
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
"""
Handles 4 scenarios with mode-specific visualizations:
1. CHA only: Text + Linguistic features, zero audio
2. CHA + Audio: Text from CHA, timestamps from CHA for slicing
3. Audio + Segmentation CSV: Slice audio using CSV intervals, Whisper ASR
4. Audio only (participant-only): Full audio with Whisper ASR
"""
try:
return self._predict_internal(file_content, filename, audio_content, segmentation_content)
except Exception as e:
print(f"[Model V3 ERROR] {e}")
traceback.print_exc()
raise
def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
# Determine Scenario
is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
has_audio = audio_content is not None and len(audio_content) > 0
has_segmentation = segmentation_content is not None and len(segmentation_content) > 0
processed_text = ""
raw_text_for_segments = ""
ling_features = None
ling_vec = None
audio_tensor = None
intervals = []
tmp_path = None
# --- MODE 3 & 4: AUDIO WITHOUT CHA (generate transcript via Whisper) ---
if not is_cha_provided and has_audio:
# Save audio to temp file for Whisper/Librosa
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(audio_content)
tmp_path = tmp_audio.name
try:
# Try to get segmentation: 1) User provided, 2) Auto-discover from dataset
seg_content = segmentation_content
if not seg_content:
seg_content = find_segmentation_file(filename)
# MODE 3: Audio + Segmentation CSV (8.2 style)
if seg_content:
print("Processing Mode: Audio + Segmentation CSV (8.2 style)")
intervals = parse_segmentation_csv(seg_content)
print(f" Found {len(intervals)} PAR intervals from CSV")
# Slice audio and create spectrogram with intervals
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
# For transcription, we transcribe the full audio and let Whisper handle it
# (In production, you could slice audio first, but Whisper handles full file well)
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
chat_transcript = apply_chat_rules(result)
# MODE 4: Audio only - participant-only audio (8.3 style)
else:
print("Processing Mode: Audio Only - Participant Audio (8.3 style)")
# No slicing needed - assume entire audio is participant
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
# Transcribe full audio
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
chat_transcript = apply_chat_rules(result)
processed_text = chat_transcript
raw_text_for_segments = chat_transcript
# Extract Features from generated text
stats = self.ling_extractor.get_features(chat_transcript)
pause_count = chat_transcript.count("[PAUSE]")
repetition_count = chat_transcript.count("[/]")
# TTR Calc
clean_t = re.sub(r'\[.*?\]', '', chat_transcript)
clean_t = re.sub(r'[^\w\s]', '', clean_t)
words = clean_t.lower().split()
n = len(words) if len(words) > 0 else 1
ttr = len(set(words)) / n
ling_vec = [
ttr,
stats['filler_count'] / n,
repetition_count / n,
stats['retracing_count'] / n,
stats['error_count'] / n,
pause_count / n
]
ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0)
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
# --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
else:
# Parse Text from CHA
text_str = file_content.decode('utf-8', errors='replace')
par_lines = []
full_text_for_intervals = ""
for line in text_str.splitlines():
if line.startswith('*PAR:'):
content = line[5:].strip()
par_lines.append(content)
full_text_for_intervals += content + " "
raw_text = " ".join(par_lines)
raw_text_for_segments = raw_text
processed_text = self.ling_extractor.clean_for_bert(raw_text)
# Extract Features
feats = self.ling_extractor.get_feature_vector(raw_text)
ling_vec = feats.tolist()
ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
# --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
if has_audio:
print("Processing Mode: CHA + Audio (Segmentation)")
# Extract intervals from the raw text
found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
intervals = [(int(s), int(e)) for s, e in found_intervals]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(audio_content)
tmp_path = tmp_audio.name
try:
# Pass intervals to slice specific PAR audio
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
# Generate spectrogram for visualization
spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals)
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
# --- SCENARIO 1: CHA ONLY ---
else:
print("Processing Mode: CHA Only")
audio_tensor = torch.zeros((1, 3, 224, 224))
# --- COMMON INFERENCE STEPS ---
encoding = self.tokenizer.encode_plus(
processed_text,
add_special_tokens=True,
max_length=MAX_LEN,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
input_ids = encoding['input_ids'].to(self.device)
mask = encoding['attention_mask'].to(self.device)
pixel_values = audio_tensor.to(self.device)
ling_input = ling_features.to(self.device)
outputs = self.model(input_ids, mask, pixel_values, ling_input)
probs = F.softmax(outputs, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_idx].item()
prob_ad = probs[0][1].item()
prob_control = probs[0][0].item()
label_map = {0: 'Control', 1: 'AD'}
# --- BUILD VISUALIZATIONS ---
visualizations = {
"probabilities": {
"AD": round(prob_ad, 4),
"Control": round(prob_control, 4)
}
}
# Key Contribution Segments (from text analysis)
# Denser color = more contribution to prediction
raw_key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments)
if raw_key_segments:
# Transform to contribution-based format (remove markers, use normalized score)
max_count = max(seg['marker_count'] for seg in raw_key_segments) if raw_key_segments else 1
key_segments = [
{
"text": seg['text'],
"contribution_score": round(seg['marker_count'] / max(max_count, 1), 2)
}
for seg in raw_key_segments
]
visualizations["key_contribution_segments"] = {
"note": "Denser color indicates higher contribution to prediction",
"segments": key_segments
}
# Audio-specific visualizations removed as per requirements
return {
"model_version": "v3_multimodal",
"filename": filename if filename else "audio_upload",
"predicted_label": label_map[pred_idx],
"confidence": round(confidence, 4),
"modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
"visualizations": visualizations
} |