echo-prime-demo / attribution.py
amn23's picture
Upload 2 files
b87a44c verified
"""
Attribution module for EchoPrime interpretability.
All computation is on CPU β€” tensors are explicitly moved at entry points.
"""
import re
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
import numpy as np
# ── Data classes ──
@dataclass
class VideoScore:
video_idx: int
view_label: str
view_idx: int
mil_weight: float
cosine_similarity: float
combined_score: float
dicom_path: Optional[str] = None
@dataclass
class PhraseAttribution:
phrase: str
section: str
section_idx: int
video_scores: list = field(default_factory=list)
@dataclass
class SectionAttribution:
section: str
section_idx: int
mil_weights_per_video: np.ndarray
similarities_per_video: np.ndarray
combined_scores_per_video: np.ndarray
ranked_video_indices: np.ndarray
section_embedding: np.ndarray
# ── Constants ──
ALL_SECTIONS = [
"Left Ventricle", "Resting Segmental Wall Motion Analysis",
"Right Ventricle", "Left Atrium", "Right Atrium", "Atrial Septum",
"Mitral Valve", "Aortic Valve", "Tricuspid Valve", "Pulmonic Valve",
"Pericardium", "Aorta", "IVC", "Pulmonary Artery", "Pulmonary Veins",
"Postoperative Findings",
]
COARSE_VIEWS = [
'A2C', 'A3C', 'A4C', 'A5C', 'Apical_Doppler',
'Doppler_Parasternal_Long', 'Doppler_Parasternal_Short',
'Parasternal_Long', 'Parasternal_Short', 'SSN', 'Subcostal',
]
# ── Report parsing ──
def parse_report_to_phrases(report_text: str) -> list:
blocks = report_text.split("[SEP]")
results = []
for block in blocks:
block = block.strip()
if not block:
continue
section_name = None
content = block
for sec in ALL_SECTIONS:
if block.startswith(sec):
section_name = sec
content = block[len(sec):].lstrip(":").strip()
break
if section_name is None:
for sec in ALL_SECTIONS:
if sec.lower() in block.lower()[:len(sec) + 5]:
section_name = sec
idx = block.lower().find(sec.lower())
content = block[idx + len(sec):].lstrip(":").strip()
break
if section_name is None:
continue
section_idx = ALL_SECTIONS.index(section_name)
for phrase in _split_into_phrases(content):
phrase = phrase.strip()
if phrase and len(phrase) > 2:
results.append(PhraseAttribution(phrase=phrase, section=section_name, section_idx=section_idx))
return results
def _split_into_phrases(text: str) -> list:
sentences = re.split(r'(?<=[a-z])\.\s+(?=[A-Z])', text)
phrases = [s.strip().rstrip(".") for s in sentences if s.strip()]
if not phrases:
phrases = [p.strip() for p in re.split(r'\s{2,}', text) if p.strip()]
if not phrases and text.strip():
phrases = [text.strip()]
return phrases
# ── Per-video attribution ──
def compute_section_attributions(
study_embedding: torch.Tensor,
candidate_embeddings: torch.Tensor,
section_weights: np.ndarray,
non_empty_sections,
view_labels: list,
dicom_paths=None,
k: int = 50,
) -> dict:
# Force everything to CPU
study_embedding = study_embedding.cpu()
candidate_embeddings = candidate_embeddings.cpu()
video_embeddings = study_embedding[:, :512]
view_encodings = study_embedding[:, 512:]
n_videos = video_embeddings.shape[0]
results = {}
for s_dx, sec in enumerate(non_empty_sections):
sec = str(sec)
mil_weights = np.zeros(n_videos, dtype=np.float32)
for v_idx in range(n_videos):
view_idx = torch.where(view_encodings[v_idx] == 1)[0]
if len(view_idx) > 0:
mil_weights[v_idx] = section_weights[s_dx][view_idx[0].item()]
mil_weights_t = torch.tensor(mil_weights, dtype=torch.float32)
weighted = video_embeddings * mil_weights_t.unsqueeze(1)
section_embedding = F.normalize(weighted.mean(dim=0), dim=0)
video_emb_norm = F.normalize(video_embeddings, dim=1)
per_video_sims = video_emb_norm @ candidate_embeddings.T
topk_sims = torch.topk(per_video_sims, k=min(k, per_video_sims.shape[1]), dim=1)
avg_topk_sim = topk_sims.values.mean(dim=1).numpy()
combined = mil_weights * avg_topk_sim
ranked = np.argsort(combined)[::-1]
results[sec] = SectionAttribution(
section=sec, section_idx=s_dx,
mil_weights_per_video=mil_weights,
similarities_per_video=avg_topk_sim,
combined_scores_per_video=combined,
ranked_video_indices=ranked.copy(),
section_embedding=section_embedding.numpy(),
)
return results
def build_phrase_attributions(phrase_list, section_attributions, view_labels, dicom_paths=None, top_k=10):
for pa in phrase_list:
sec = pa.section
if sec not in section_attributions:
continue
sa = section_attributions[sec]
top_indices = sa.ranked_video_indices[:top_k]
pa.video_scores = [
VideoScore(
video_idx=int(idx),
view_label=view_labels[idx] if idx < len(view_labels) else "Unknown",
view_idx=COARSE_VIEWS.index(view_labels[idx]) if view_labels[idx] in COARSE_VIEWS else -1,
mil_weight=float(sa.mil_weights_per_video[idx]),
cosine_similarity=float(sa.similarities_per_video[idx]),
combined_score=float(sa.combined_scores_per_video[idx]),
dicom_path=dicom_paths[idx] if dicom_paths and idx < len(dicom_paths) else None,
)
for idx in top_indices
]
return phrase_list
# ── Serialization ──
def phrase_attribution_to_dict(pa):
return {
"phrase": pa.phrase, "section": pa.section, "section_idx": pa.section_idx,
"video_scores": [
{"video_idx": vs.video_idx, "view_label": vs.view_label,
"mil_weight": round(vs.mil_weight, 4),
"cosine_similarity": round(vs.cosine_similarity, 4),
"combined_score": round(vs.combined_score, 4),
"dicom_path": vs.dicom_path}
for vs in pa.video_scores
],
}
def section_attribution_to_dict(sa):
return {
"section": sa.section, "section_idx": sa.section_idx,
"top_videos": [
{"video_idx": int(idx),
"mil_weight": round(float(sa.mil_weights_per_video[idx]), 4),
"cosine_similarity": round(float(sa.similarities_per_video[idx]), 4),
"combined_score": round(float(sa.combined_scores_per_video[idx]), 4)}
for idx in sa.ranked_video_indices[:10]
],
}