Spaces:
Running
Running
| """ | |
| Attribution module for EchoPrime interpretability. | |
| All computation is on CPU β tensors are explicitly moved at entry points. | |
| """ | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # ββ Data classes ββ | |
| class VideoScore: | |
| video_idx: int | |
| view_label: str | |
| view_idx: int | |
| mil_weight: float | |
| cosine_similarity: float | |
| combined_score: float | |
| dicom_path: Optional[str] = None | |
| class PhraseAttribution: | |
| phrase: str | |
| section: str | |
| section_idx: int | |
| video_scores: list = field(default_factory=list) | |
| class SectionAttribution: | |
| section: str | |
| section_idx: int | |
| mil_weights_per_video: np.ndarray | |
| similarities_per_video: np.ndarray | |
| combined_scores_per_video: np.ndarray | |
| ranked_video_indices: np.ndarray | |
| section_embedding: np.ndarray | |
| # ββ Constants ββ | |
| ALL_SECTIONS = [ | |
| "Left Ventricle", "Resting Segmental Wall Motion Analysis", | |
| "Right Ventricle", "Left Atrium", "Right Atrium", "Atrial Septum", | |
| "Mitral Valve", "Aortic Valve", "Tricuspid Valve", "Pulmonic Valve", | |
| "Pericardium", "Aorta", "IVC", "Pulmonary Artery", "Pulmonary Veins", | |
| "Postoperative Findings", | |
| ] | |
| COARSE_VIEWS = [ | |
| 'A2C', 'A3C', 'A4C', 'A5C', 'Apical_Doppler', | |
| 'Doppler_Parasternal_Long', 'Doppler_Parasternal_Short', | |
| 'Parasternal_Long', 'Parasternal_Short', 'SSN', 'Subcostal', | |
| ] | |
| # ββ Report parsing ββ | |
| def parse_report_to_phrases(report_text: str) -> list: | |
| blocks = report_text.split("[SEP]") | |
| results = [] | |
| for block in blocks: | |
| block = block.strip() | |
| if not block: | |
| continue | |
| section_name = None | |
| content = block | |
| for sec in ALL_SECTIONS: | |
| if block.startswith(sec): | |
| section_name = sec | |
| content = block[len(sec):].lstrip(":").strip() | |
| break | |
| if section_name is None: | |
| for sec in ALL_SECTIONS: | |
| if sec.lower() in block.lower()[:len(sec) + 5]: | |
| section_name = sec | |
| idx = block.lower().find(sec.lower()) | |
| content = block[idx + len(sec):].lstrip(":").strip() | |
| break | |
| if section_name is None: | |
| continue | |
| section_idx = ALL_SECTIONS.index(section_name) | |
| for phrase in _split_into_phrases(content): | |
| phrase = phrase.strip() | |
| if phrase and len(phrase) > 2: | |
| results.append(PhraseAttribution(phrase=phrase, section=section_name, section_idx=section_idx)) | |
| return results | |
| def _split_into_phrases(text: str) -> list: | |
| sentences = re.split(r'(?<=[a-z])\.\s+(?=[A-Z])', text) | |
| phrases = [s.strip().rstrip(".") for s in sentences if s.strip()] | |
| if not phrases: | |
| phrases = [p.strip() for p in re.split(r'\s{2,}', text) if p.strip()] | |
| if not phrases and text.strip(): | |
| phrases = [text.strip()] | |
| return phrases | |
| # ββ Per-video attribution ββ | |
| def compute_section_attributions( | |
| study_embedding: torch.Tensor, | |
| candidate_embeddings: torch.Tensor, | |
| section_weights: np.ndarray, | |
| non_empty_sections, | |
| view_labels: list, | |
| dicom_paths=None, | |
| k: int = 50, | |
| ) -> dict: | |
| # Force everything to CPU | |
| study_embedding = study_embedding.cpu() | |
| candidate_embeddings = candidate_embeddings.cpu() | |
| video_embeddings = study_embedding[:, :512] | |
| view_encodings = study_embedding[:, 512:] | |
| n_videos = video_embeddings.shape[0] | |
| results = {} | |
| for s_dx, sec in enumerate(non_empty_sections): | |
| sec = str(sec) | |
| mil_weights = np.zeros(n_videos, dtype=np.float32) | |
| for v_idx in range(n_videos): | |
| view_idx = torch.where(view_encodings[v_idx] == 1)[0] | |
| if len(view_idx) > 0: | |
| mil_weights[v_idx] = section_weights[s_dx][view_idx[0].item()] | |
| mil_weights_t = torch.tensor(mil_weights, dtype=torch.float32) | |
| weighted = video_embeddings * mil_weights_t.unsqueeze(1) | |
| section_embedding = F.normalize(weighted.mean(dim=0), dim=0) | |
| video_emb_norm = F.normalize(video_embeddings, dim=1) | |
| per_video_sims = video_emb_norm @ candidate_embeddings.T | |
| topk_sims = torch.topk(per_video_sims, k=min(k, per_video_sims.shape[1]), dim=1) | |
| avg_topk_sim = topk_sims.values.mean(dim=1).numpy() | |
| combined = mil_weights * avg_topk_sim | |
| ranked = np.argsort(combined)[::-1] | |
| results[sec] = SectionAttribution( | |
| section=sec, section_idx=s_dx, | |
| mil_weights_per_video=mil_weights, | |
| similarities_per_video=avg_topk_sim, | |
| combined_scores_per_video=combined, | |
| ranked_video_indices=ranked.copy(), | |
| section_embedding=section_embedding.numpy(), | |
| ) | |
| return results | |
| def build_phrase_attributions(phrase_list, section_attributions, view_labels, dicom_paths=None, top_k=10): | |
| for pa in phrase_list: | |
| sec = pa.section | |
| if sec not in section_attributions: | |
| continue | |
| sa = section_attributions[sec] | |
| top_indices = sa.ranked_video_indices[:top_k] | |
| pa.video_scores = [ | |
| VideoScore( | |
| video_idx=int(idx), | |
| view_label=view_labels[idx] if idx < len(view_labels) else "Unknown", | |
| view_idx=COARSE_VIEWS.index(view_labels[idx]) if view_labels[idx] in COARSE_VIEWS else -1, | |
| mil_weight=float(sa.mil_weights_per_video[idx]), | |
| cosine_similarity=float(sa.similarities_per_video[idx]), | |
| combined_score=float(sa.combined_scores_per_video[idx]), | |
| dicom_path=dicom_paths[idx] if dicom_paths and idx < len(dicom_paths) else None, | |
| ) | |
| for idx in top_indices | |
| ] | |
| return phrase_list | |
| # ββ Serialization ββ | |
| def phrase_attribution_to_dict(pa): | |
| return { | |
| "phrase": pa.phrase, "section": pa.section, "section_idx": pa.section_idx, | |
| "video_scores": [ | |
| {"video_idx": vs.video_idx, "view_label": vs.view_label, | |
| "mil_weight": round(vs.mil_weight, 4), | |
| "cosine_similarity": round(vs.cosine_similarity, 4), | |
| "combined_score": round(vs.combined_score, 4), | |
| "dicom_path": vs.dicom_path} | |
| for vs in pa.video_scores | |
| ], | |
| } | |
| def section_attribution_to_dict(sa): | |
| return { | |
| "section": sa.section, "section_idx": sa.section_idx, | |
| "top_videos": [ | |
| {"video_idx": int(idx), | |
| "mil_weight": round(float(sa.mil_weights_per_video[idx]), 4), | |
| "cosine_similarity": round(float(sa.similarities_per_video[idx]), 4), | |
| "combined_score": round(float(sa.combined_scores_per_video[idx]), 4)} | |
| for idx in sa.ranked_video_indices[:10] | |
| ], | |
| } |