adtrack-v2 / models /model_v3 /wrapper.py
cracker0935's picture
add mode to model 3
e824b96
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from typing import Optional
import os
import tempfile
import whisper
import re
import traceback
from models.base import BaseModelWrapper
from .model import MultimodalFusion
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules, parse_segmentation_csv
from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN, SEGMENTATION_ROOT_PATH
def find_segmentation_file(audio_filename: str) -> Optional[bytes]:
"""
Search for segmentation CSV file based on audio filename.
Looks in SEGMENTATION_ROOT_PATH/AD/ and SEGMENTATION_ROOT_PATH/Control/.
Returns CSV content as bytes if found, None otherwise.
"""
if not SEGMENTATION_ROOT_PATH or not os.path.exists(SEGMENTATION_ROOT_PATH):
return None
# Get base filename without extension
base_name = os.path.splitext(os.path.basename(audio_filename))[0]
# Search in both AD and Control folders
for subfolder in ['AD', 'Control']:
csv_path = os.path.join(SEGMENTATION_ROOT_PATH, subfolder, f"{base_name}.csv")
if os.path.exists(csv_path):
print(f" Found segmentation file: {csv_path}")
with open(csv_path, 'rb') as f:
return f.read()
return None
class MultimodalWrapper(BaseModelWrapper):
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.tokenizer = None
self.asr_model = None
self.ling_extractor = LinguisticFeatureExtractor()
self.audio_processor = AudioProcessor()
def load(self):
print("Loading Model V3 components...")
self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
self.model = MultimodalFusion(TEXT_MODEL_NAME)
# Load Weights - Try local first, then Hugging Face
if os.path.exists(LOCAL_WEIGHTS_PATH):
weights_path = LOCAL_WEIGHTS_PATH
print(f"Loading weights from local: {weights_path}")
else:
try:
print(f"Downloading weights from Hugging Face: {HF_REPO_ID}")
weights_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=WEIGHTS_FILENAME
)
except Exception as e:
raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}")
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Load Whisper (Base model as per notebook)
print("Loading Whisper for Audio-Only Inference...")
self.asr_model = whisper.load_model("base")
def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
"""
Handles 4 scenarios with mode-specific visualizations:
1. CHA only: Text + Linguistic features, zero audio
2. CHA + Audio: Text from CHA, timestamps from CHA for slicing
3. Audio + Segmentation CSV: Slice audio using CSV intervals, Whisper ASR
4. Audio only (participant-only): Full audio with Whisper ASR
"""
try:
return self._predict_internal(file_content, filename, audio_content, segmentation_content)
except Exception as e:
print(f"[Model V3 ERROR] {e}")
traceback.print_exc()
raise
def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
# Determine Scenario
is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
has_audio = audio_content is not None and len(audio_content) > 0
has_segmentation = segmentation_content is not None and len(segmentation_content) > 0
processed_text = ""
raw_text_for_segments = ""
ling_features = None
ling_vec = None
audio_tensor = None
intervals = []
tmp_path = None
# --- MODE 3 & 4: AUDIO WITHOUT CHA (generate transcript via Whisper) ---
if not is_cha_provided and has_audio:
# Save audio to temp file for Whisper/Librosa
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(audio_content)
tmp_path = tmp_audio.name
try:
# Try to get segmentation: 1) User provided, 2) Auto-discover from dataset
seg_content = segmentation_content
if not seg_content:
seg_content = find_segmentation_file(filename)
# MODE 3: Audio + Segmentation CSV (8.2 style)
if seg_content:
print("Processing Mode: Audio + Segmentation CSV (8.2 style)")
intervals = parse_segmentation_csv(seg_content)
print(f" Found {len(intervals)} PAR intervals from CSV")
# Slice audio and create spectrogram with intervals
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
# For transcription, we transcribe the full audio and let Whisper handle it
# (In production, you could slice audio first, but Whisper handles full file well)
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
chat_transcript = apply_chat_rules(result)
# MODE 4: Audio only - participant-only audio (8.3 style)
else:
print("Processing Mode: Audio Only - Participant Audio (8.3 style)")
# No slicing needed - assume entire audio is participant
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
# Transcribe full audio
result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
chat_transcript = apply_chat_rules(result)
processed_text = chat_transcript
raw_text_for_segments = chat_transcript
# Extract Features from generated text
stats = self.ling_extractor.get_features(chat_transcript)
pause_count = chat_transcript.count("[PAUSE]")
repetition_count = chat_transcript.count("[/]")
# TTR Calc
clean_t = re.sub(r'\[.*?\]', '', chat_transcript)
clean_t = re.sub(r'[^\w\s]', '', clean_t)
words = clean_t.lower().split()
n = len(words) if len(words) > 0 else 1
ttr = len(set(words)) / n
ling_vec = [
ttr,
stats['filler_count'] / n,
repetition_count / n,
stats['retracing_count'] / n,
stats['error_count'] / n,
pause_count / n
]
ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0)
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
# --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
else:
# Parse Text from CHA
text_str = file_content.decode('utf-8', errors='replace')
par_lines = []
full_text_for_intervals = ""
for line in text_str.splitlines():
if line.startswith('*PAR:'):
content = line[5:].strip()
par_lines.append(content)
full_text_for_intervals += content + " "
raw_text = " ".join(par_lines)
raw_text_for_segments = raw_text
processed_text = self.ling_extractor.clean_for_bert(raw_text)
# Extract Features
feats = self.ling_extractor.get_feature_vector(raw_text)
ling_vec = feats.tolist()
ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)
# --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
if has_audio:
print("Processing Mode: CHA + Audio (Segmentation)")
# Extract intervals from the raw text
found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
intervals = [(int(s), int(e)) for s, e in found_intervals]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(audio_content)
tmp_path = tmp_audio.name
try:
# Pass intervals to slice specific PAR audio
audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
# Generate spectrogram for visualization
spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals)
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
# --- SCENARIO 1: CHA ONLY ---
else:
print("Processing Mode: CHA Only")
audio_tensor = torch.zeros((1, 3, 224, 224))
# --- COMMON INFERENCE STEPS ---
encoding = self.tokenizer.encode_plus(
processed_text,
add_special_tokens=True,
max_length=MAX_LEN,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
input_ids = encoding['input_ids'].to(self.device)
mask = encoding['attention_mask'].to(self.device)
pixel_values = audio_tensor.to(self.device)
ling_input = ling_features.to(self.device)
outputs = self.model(input_ids, mask, pixel_values, ling_input)
probs = F.softmax(outputs, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_idx].item()
prob_ad = probs[0][1].item()
prob_control = probs[0][0].item()
label_map = {0: 'Control', 1: 'AD'}
# --- BUILD VISUALIZATIONS ---
visualizations = {
"probabilities": {
"AD": round(prob_ad, 4),
"Control": round(prob_control, 4)
}
}
# Key Contribution Segments (from text analysis)
# Denser color = more contribution to prediction
raw_key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments)
if raw_key_segments:
# Transform to contribution-based format (remove markers, use normalized score)
max_count = max(seg['marker_count'] for seg in raw_key_segments) if raw_key_segments else 1
key_segments = [
{
"text": seg['text'],
"contribution_score": round(seg['marker_count'] / max(max_count, 1), 2)
}
for seg in raw_key_segments
]
visualizations["key_contribution_segments"] = {
"note": "Denser color indicates higher contribution to prediction",
"segments": key_segments
}
# Audio-specific visualizations removed as per requirements
return {
"model_version": "v3_multimodal",
"filename": filename if filename else "audio_upload",
"predicted_label": label_map[pred_idx],
"confidence": round(confidence, 4),
"modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
"visualizations": visualizations
}