""" Attribution module for EchoPrime interpretability. All computation is on CPU — tensors are explicitly moved at entry points. """ import re from dataclasses import dataclass, field from typing import Optional import torch import torch.nn.functional as F import numpy as np # ── Data classes ── @dataclass class VideoScore: video_idx: int view_label: str view_idx: int mil_weight: float cosine_similarity: float combined_score: float dicom_path: Optional[str] = None @dataclass class PhraseAttribution: phrase: str section: str section_idx: int video_scores: list = field(default_factory=list) @dataclass class SectionAttribution: section: str section_idx: int mil_weights_per_video: np.ndarray similarities_per_video: np.ndarray combined_scores_per_video: np.ndarray ranked_video_indices: np.ndarray section_embedding: np.ndarray # ── Constants ── ALL_SECTIONS = [ "Left Ventricle", "Resting Segmental Wall Motion Analysis", "Right Ventricle", "Left Atrium", "Right Atrium", "Atrial Septum", "Mitral Valve", "Aortic Valve", "Tricuspid Valve", "Pulmonic Valve", "Pericardium", "Aorta", "IVC", "Pulmonary Artery", "Pulmonary Veins", "Postoperative Findings", ] COARSE_VIEWS = [ 'A2C', 'A3C', 'A4C', 'A5C', 'Apical_Doppler', 'Doppler_Parasternal_Long', 'Doppler_Parasternal_Short', 'Parasternal_Long', 'Parasternal_Short', 'SSN', 'Subcostal', ] # ── Report parsing ── def parse_report_to_phrases(report_text: str) -> list: blocks = report_text.split("[SEP]") results = [] for block in blocks: block = block.strip() if not block: continue section_name = None content = block for sec in ALL_SECTIONS: if block.startswith(sec): section_name = sec content = block[len(sec):].lstrip(":").strip() break if section_name is None: for sec in ALL_SECTIONS: if sec.lower() in block.lower()[:len(sec) + 5]: section_name = sec idx = block.lower().find(sec.lower()) content = block[idx + len(sec):].lstrip(":").strip() break if section_name is None: continue section_idx = ALL_SECTIONS.index(section_name) for phrase in _split_into_phrases(content): phrase = phrase.strip() if phrase and len(phrase) > 2: results.append(PhraseAttribution(phrase=phrase, section=section_name, section_idx=section_idx)) return results def _split_into_phrases(text: str) -> list: sentences = re.split(r'(?<=[a-z])\.\s+(?=[A-Z])', text) phrases = [s.strip().rstrip(".") for s in sentences if s.strip()] if not phrases: phrases = [p.strip() for p in re.split(r'\s{2,}', text) if p.strip()] if not phrases and text.strip(): phrases = [text.strip()] return phrases # ── Per-video attribution ── def compute_section_attributions( study_embedding: torch.Tensor, candidate_embeddings: torch.Tensor, section_weights: np.ndarray, non_empty_sections, view_labels: list, dicom_paths=None, k: int = 50, ) -> dict: # Force everything to CPU study_embedding = study_embedding.cpu() candidate_embeddings = candidate_embeddings.cpu() video_embeddings = study_embedding[:, :512] view_encodings = study_embedding[:, 512:] n_videos = video_embeddings.shape[0] results = {} for s_dx, sec in enumerate(non_empty_sections): sec = str(sec) mil_weights = np.zeros(n_videos, dtype=np.float32) for v_idx in range(n_videos): view_idx = torch.where(view_encodings[v_idx] == 1)[0] if len(view_idx) > 0: mil_weights[v_idx] = section_weights[s_dx][view_idx[0].item()] mil_weights_t = torch.tensor(mil_weights, dtype=torch.float32) weighted = video_embeddings * mil_weights_t.unsqueeze(1) section_embedding = F.normalize(weighted.mean(dim=0), dim=0) video_emb_norm = F.normalize(video_embeddings, dim=1) per_video_sims = video_emb_norm @ candidate_embeddings.T topk_sims = torch.topk(per_video_sims, k=min(k, per_video_sims.shape[1]), dim=1) avg_topk_sim = topk_sims.values.mean(dim=1).numpy() combined = mil_weights * avg_topk_sim ranked = np.argsort(combined)[::-1] results[sec] = SectionAttribution( section=sec, section_idx=s_dx, mil_weights_per_video=mil_weights, similarities_per_video=avg_topk_sim, combined_scores_per_video=combined, ranked_video_indices=ranked.copy(), section_embedding=section_embedding.numpy(), ) return results def build_phrase_attributions(phrase_list, section_attributions, view_labels, dicom_paths=None, top_k=10): for pa in phrase_list: sec = pa.section if sec not in section_attributions: continue sa = section_attributions[sec] top_indices = sa.ranked_video_indices[:top_k] pa.video_scores = [ VideoScore( video_idx=int(idx), view_label=view_labels[idx] if idx < len(view_labels) else "Unknown", view_idx=COARSE_VIEWS.index(view_labels[idx]) if view_labels[idx] in COARSE_VIEWS else -1, mil_weight=float(sa.mil_weights_per_video[idx]), cosine_similarity=float(sa.similarities_per_video[idx]), combined_score=float(sa.combined_scores_per_video[idx]), dicom_path=dicom_paths[idx] if dicom_paths and idx < len(dicom_paths) else None, ) for idx in top_indices ] return phrase_list # ── Serialization ── def phrase_attribution_to_dict(pa): return { "phrase": pa.phrase, "section": pa.section, "section_idx": pa.section_idx, "video_scores": [ {"video_idx": vs.video_idx, "view_label": vs.view_label, "mil_weight": round(vs.mil_weight, 4), "cosine_similarity": round(vs.cosine_similarity, 4), "combined_score": round(vs.combined_score, 4), "dicom_path": vs.dicom_path} for vs in pa.video_scores ], } def section_attribution_to_dict(sa): return { "section": sa.section, "section_idx": sa.section_idx, "top_videos": [ {"video_idx": int(idx), "mil_weight": round(float(sa.mil_weights_per_video[idx]), 4), "cosine_similarity": round(float(sa.similarities_per_video[idx]), 4), "combined_score": round(float(sa.combined_scores_per_video[idx]), 4)} for idx in sa.ranked_video_indices[:10] ], }