Spaces:
Running
Running
| # Standard library imports | |
| import os | |
| import math | |
| import glob | |
| import json | |
| import pickle | |
| import random | |
| import sys | |
| from typing import AnyStr, List, Any, Dict, Optional, Union | |
| # Third-party library imports | |
| import torch | |
| import torchvision | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import cv2 | |
| import pydicom | |
| import sklearn | |
| import sklearn.metrics | |
| import transformers | |
| # Local module imports | |
| import utils | |
| class EchoPrime: | |
| """ | |
| EchoPrime is an echocardiography AI model that encodes cardiac ultrasound | |
| studies (DICOM or MP4) into embeddings, classifies echocardiographic views, | |
| generates structured clinical reports, and predicts quantitative cardiac | |
| metrics via multi-instance learning (MIL) over a candidate study database. | |
| Attributes: | |
| base_dir (str): Absolute path to the EchoPrime project root directory. | |
| echo_encoder (torchvision.models.video.MViT): Frozen MViT-v2-S video | |
| encoder producing 512-dimensional embeddings per video clip. | |
| view_classifier (torchvision.models.ConvNeXt): Frozen ConvNeXt-Base | |
| image classifier predicting one of 11 echocardiographic views from | |
| the first frame of each clip. | |
| frames_to_take (int): Number of frames sampled from each video clip (32). | |
| frame_stride (int): Temporal stride applied when sampling frames (2). | |
| video_size (int): Spatial resolution (height and width) videos are | |
| resized to before encoding (224 pixels). | |
| mean (torch.Tensor): Per-channel pixel mean used for normalisation, | |
| shape (3, 1, 1, 1). | |
| std (torch.Tensor): Per-channel pixel standard deviation used for | |
| normalisation, shape (3, 1, 1, 1). | |
| device (torch.device): Compute device (CUDA if available, else CPU). | |
| lang (str): ISO 639-1 language code controlling report output language. | |
| MIL_weights (pd.DataFrame): CSV-loaded table of per-section MIL | |
| attention weights, shape (n_sections, n_views + 1). | |
| non_empty_sections (pd.Series): Ordered sequence of cardiac section | |
| names derived from the first column of ``MIL_weights``. | |
| section_weights (np.ndarray): Numeric weight matrix extracted from | |
| ``MIL_weights``, shape (n_sections, n_views). | |
| candidate_studies (List[str]): Ordered list of candidate study | |
| identifiers used for nearest-neighbour retrieval. | |
| candidate_embeddings (torch.Tensor): Concatenated embeddings for all | |
| candidate studies, shape (N_candidates, 512), on ``device``. | |
| candidate_reports (List[str]): Decoded text reports for each candidate | |
| study, aligned index-wise with ``candidate_studies``. | |
| candidate_labels (pd.DataFrame): Ground-truth phenotype labels for each | |
| candidate study, indexed by study identifier. | |
| section_to_phenotypes (Dict[str, List[str]]): Mapping from cardiac | |
| section name to the list of phenotype labels predicted for that | |
| section. | |
| """ | |
| def __init__(self, device: Optional[torch.device] = None, lang: str = "en") -> None: | |
| """ | |
| Initialise EchoPrime by loading model weights, normalisation statistics, | |
| MIL attention weights, and candidate study data. | |
| Args: | |
| device (Optional[torch.device]): Compute device to use. When | |
| ``None`` (default), CUDA is used if available, otherwise CPU. | |
| lang (str): ISO 639-1 language code for report generation. | |
| Supported values include ``'en'`` (default), ``'it'``, | |
| ``'bs'``, and ``'ru'``. | |
| Raises: | |
| FileNotFoundError: If the echo encoder weights file cannot be | |
| located at the expected path relative to ``base_dir``. | |
| """ | |
| self.base_dir: str = os.getenv("ECHOPRIME_ROOT_OVERRIDE") or os.path.dirname(os.path.abspath(__file__)) | |
| def get_path(rel_path: str) -> str: | |
| """ | |
| Resolve a path relative to the EchoPrime project root. | |
| Args: | |
| rel_path (str): Relative path from the project root. | |
| Returns: | |
| str: Absolute path formed by joining ``base_dir`` and | |
| ``rel_path``. | |
| """ | |
| return os.path.join(self.base_dir, rel_path) | |
| print(f"[EchoPrime] Initializing... (Root dir: {self.base_dir})") | |
| # load language specific files | |
| utils.initialize_language(lang) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[EchoPrime] Using device: {device}") | |
| # LOAD MODEL WEIGHTS | |
| weights_path: str = get_path("model_data/weights/echo_prime_encoder.pt") | |
| if not os.path.exists(weights_path): | |
| # Fallback: Print the exact path we tried so you can debug | |
| print(f"[ERROR] Expected weights at: {weights_path}") | |
| raise FileNotFoundError(f"Could not find model weights. Check the path above.") | |
| checkpoint: Dict[str, torch.Tensor] = torch.load(weights_path, map_location=device) | |
| echo_encoder: torchvision.models.video.MViT = torchvision.models.video.mvit_v2_s() | |
| echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512) | |
| echo_encoder.load_state_dict(checkpoint) | |
| echo_encoder.eval() | |
| echo_encoder.to(device) | |
| for param in echo_encoder.parameters(): | |
| param.requires_grad = False | |
| vc_state_dict: Dict[str, torch.Tensor] = torch.load( | |
| get_path("model_data/weights/view_classifier.pt"), map_location=device | |
| ) | |
| view_classifier: torchvision.models.ConvNeXt = torchvision.models.convnext_base() | |
| view_classifier.classifier[-1] = torch.nn.Linear( | |
| view_classifier.classifier[-1].in_features, 11 | |
| ) | |
| view_classifier.load_state_dict(vc_state_dict) | |
| view_classifier.to(device) | |
| view_classifier.eval() | |
| for param in view_classifier.parameters(): | |
| param.requires_grad = False | |
| self.echo_encoder: torchvision.models.video.MViT = echo_encoder | |
| self.view_classifier: torchvision.models.ConvNeXt = view_classifier | |
| self.frames_to_take: int = 32 | |
| self.frame_stride: int = 2 | |
| self.video_size: int = 224 | |
| self.mean: torch.Tensor = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1) | |
| self.std: torch.Tensor = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1) | |
| self.device: torch.device = device | |
| self.lang: str = lang | |
| # LOAD ASSETS | |
| print("[EchoPrime] Loading assets...") | |
| self.MIL_weights: pd.DataFrame = pd.read_csv(get_path("assets/MIL_weights.csv")) | |
| self.non_empty_sections: pd.Series = self.MIL_weights["Section"] | |
| self.section_weights: np.ndarray = self.MIL_weights.iloc[:, 1:].to_numpy() | |
| self.candidate_studies: List[str] = list( | |
| pd.read_csv(get_path("model_data/candidates_data/candidate_studies.csv"))["Study"] | |
| ) | |
| candidate_embeddings_p1: torch.Tensor = torch.load( | |
| get_path("model_data/candidates_data/candidate_embeddings_p1.pt"), map_location=device | |
| ) | |
| candidate_embeddings_p2: torch.Tensor = torch.load( | |
| get_path("model_data/candidates_data/candidate_embeddings_p2.pt"), map_location=device | |
| ) | |
| self.candidate_embeddings: torch.Tensor = torch.cat( | |
| (candidate_embeddings_p1, candidate_embeddings_p2), dim=0 | |
| ) | |
| candidate_reports: pd.Series = pd.read_pickle( | |
| get_path("model_data/candidates_data/candidate_reports.pkl") | |
| ) | |
| self.candidate_reports: List[str] = [utils.phrase_decode(vec_phr) for vec_phr in candidate_reports] | |
| self.candidate_labels: pd.DataFrame = pd.read_pickle( | |
| get_path("model_data/candidates_data/candidate_labels.pkl") | |
| ) | |
| self.section_to_phenotypes: Dict[str, List[str]] = pd.read_pickle( | |
| get_path("assets/section_to_phenotypes.pkl") | |
| ) | |
| print("[EchoPrime] Initialization Complete.") | |
| def process_dicoms(self, INPUT: str) -> Union[torch.Tensor, List[torch.Tensor]]: | |
| """ | |
| Scan a directory tree for DICOM video files, decode each file's pixel | |
| data, apply spatial pre-processing and temporal sampling, and return a | |
| stacked tensor ready for ``embed_videos``. | |
| Static 2D images (``pixels.ndim < 3``) and static RGB screenshots | |
| (shape ``(H, W, 3)``) are automatically detected and skipped with an | |
| informational message. | |
| Args: | |
| INPUT (str): Path to a directory (searched recursively) that | |
| contains ``.dcm`` files. | |
| Returns: | |
| Union[torch.Tensor, List[torch.Tensor]]: A float32 tensor of shape | |
| ``(N, 3, 16, H, W)`` where *N* is the number of successfully | |
| processed video DICOMs, *H* = *W* = ``video_size`` (224), and | |
| the temporal dimension is ``frames_to_take // frame_stride`` | |
| (16). Returns ``torch.empty(0)`` when no valid DICOMs are found | |
| or all files fail processing. | |
| """ | |
| print(f"[EchoPrime] Scanning for DICOMs in: {INPUT}") | |
| dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.dcm", recursive=True) | |
| if not dicom_paths: | |
| print(f"[ERROR] No .dcm files found in {INPUT}") | |
| return torch.empty(0) | |
| stack_of_videos: List[torch.Tensor] = [] | |
| skipped_count: int = 0 | |
| print(f"Found {len(dicom_paths)} DICOM files. Processing...") | |
| for idx, dicom_path in tqdm(enumerate(dicom_paths), total=len(dicom_paths), desc="Processing"): | |
| try: | |
| dcm: pydicom.dataset.FileDataset = pydicom.dcmread(dicom_path) | |
| pixels: np.ndarray = dcm.pixel_array | |
| # --- VERIFICATION PRINT START --- | |
| # Check for 2D images (Height, Width) -> No time dimension | |
| if pixels.ndim < 3: | |
| # Print only the filename, not the whole path, to keep it clean | |
| fname: str = os.path.basename(dicom_path) | |
| print(f" > Skipped {fname}: Static 2D Image (Shape: {pixels.shape})") | |
| skipped_count += 1 | |
| continue | |
| # Check for RGB static images (Height, Width, 3) -> 3rd dim is color, not time | |
| if pixels.ndim == 3 and pixels.shape[2] == 3: | |
| fname = os.path.basename(dicom_path) | |
| print(f" > Skipped {fname}: Static RGB Screenshot (Shape: {pixels.shape})") | |
| skipped_count += 1 | |
| continue | |
| # --- VERIFICATION PRINT END --- | |
| if pixels.ndim == 3: | |
| pixels = np.repeat(pixels[..., None], 3, axis=3) | |
| pixels = utils.mask_outside_ultrasound(dcm.pixel_array) | |
| x: np.ndarray = np.zeros((len(pixels), 224, 224, 3)) | |
| for i in range(len(x)): | |
| x[i] = utils.crop_and_scale(pixels[i]) | |
| x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2]) | |
| x_tensor.sub_(self.mean).div_(self.std) | |
| if x_tensor.shape[1] < self.frames_to_take: | |
| padding: torch.Tensor = torch.zeros( | |
| (3, self.frames_to_take - x_tensor.shape[1], self.video_size, self.video_size), | |
| dtype=torch.float, | |
| ) | |
| x_tensor = torch.cat((x_tensor, padding), dim=1) | |
| start: int = 0 | |
| processed_video: torch.Tensor = x_tensor[ | |
| :, start : (start + self.frames_to_take) : self.frame_stride, :, : | |
| ] | |
| stack_of_videos.append(processed_video) | |
| except Exception as e: | |
| print(f"Corrupt file {dicom_path}: {e}") | |
| pass | |
| if len(stack_of_videos) == 0: | |
| print("[ERROR] Found DICOMs but failed to process ANY of them.") | |
| return torch.empty(0) | |
| stacked: torch.Tensor = torch.stack(stack_of_videos) | |
| print(f"\n[Summary] Total: {len(dicom_paths)} | Processed: {len(stacked)} | Skipped: {skipped_count}") | |
| return stacked | |
| def process_mp4s(self, INPUT: str) -> torch.Tensor: | |
| """ | |
| Scan a directory tree for MP4 video files, decode each file's frame | |
| data, apply spatial pre-processing and temporal sampling, and return a | |
| stacked tensor ready for ``embed_videos``. | |
| Args: | |
| INPUT (str): Path to a directory (searched recursively) that | |
| contains ``.mp4`` files. | |
| Returns: | |
| torch.Tensor: A float32 tensor of shape ``(N, 3, 16, H, W)`` where | |
| *N* is the number of successfully processed MP4 files, *H* = | |
| *W* = ``video_size`` (224), and the temporal dimension is | |
| ``frames_to_take // frame_stride`` (16). Corrupt files are | |
| silently skipped. | |
| """ | |
| dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.mp4", recursive=True) | |
| stack_of_videos: List[torch.Tensor] = [] | |
| for idx, dicom_path in enumerate(dicom_paths): | |
| try: | |
| # simple dicom_processing | |
| pixels_raw: torch.Tensor | |
| metadata: Dict[str, Any] | |
| pixels_raw, _, metadata = torchvision.io.read_video(dicom_path) | |
| fps: float = metadata["video_fps"] | |
| pixels: np.ndarray = np.array(pixels_raw) | |
| # model specific preprocessing | |
| x: np.ndarray = np.zeros((len(pixels), 224, 224, 3)) | |
| for i in range(len(x)): | |
| x[i] = utils.crop_and_scale(pixels[i]) | |
| x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2]) | |
| # normalize | |
| x_tensor.sub_(self.mean).div_(self.std) | |
| ## if not enough frames add padding | |
| if x_tensor.shape[1] < self.frames_to_take: | |
| padding: torch.Tensor = torch.zeros( | |
| ( | |
| 3, | |
| self.frames_to_take - x_tensor.shape[1], | |
| self.video_size, | |
| self.video_size, | |
| ), | |
| dtype=torch.float, | |
| ) | |
| x_tensor = torch.cat((x_tensor, padding), dim=1) | |
| start: int = 0 | |
| stack_of_videos.append( | |
| x_tensor[:, start : (start + self.frames_to_take) : self.frame_stride, :, :] | |
| ) | |
| except Exception as e: | |
| print("corrupt file") | |
| print(str(e)) | |
| stacked: torch.Tensor = torch.stack(stack_of_videos) | |
| return stacked | |
| def embed_videos(self, stack_of_videos: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pass a stack of pre-processed video clips through the frozen echo | |
| encoder in batches and return the resulting feature embeddings. | |
| Videos are forwarded through the encoder in bins of 50 to avoid | |
| out-of-memory errors on large studies. Gradient computation is | |
| disabled throughout. | |
| Args: | |
| stack_of_videos (torch.Tensor): Float32 tensor of shape | |
| ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or | |
| ``process_mp4s``. | |
| Returns: | |
| torch.Tensor: Float32 feature tensor of shape ``(N, 512)`` | |
| containing one 512-dimensional embedding per input clip. | |
| Returns ``torch.empty(0)`` if ``stack_of_videos`` contains no | |
| elements. | |
| """ | |
| if stack_of_videos.numel() == 0: | |
| return torch.empty(0) | |
| bin_size: int = 50 | |
| n_bins: int = math.ceil(stack_of_videos.shape[0] / bin_size) | |
| stack_of_features_list: List[torch.Tensor] = [] | |
| with torch.no_grad(): | |
| for bin_idx in range(n_bins): | |
| start_idx: int = bin_idx * bin_size | |
| end_idx: int = min((bin_idx + 1) * bin_size, stack_of_videos.shape[0]) | |
| bin_videos: torch.Tensor = stack_of_videos[start_idx:end_idx].to(self.device) | |
| bin_features: torch.Tensor = self.echo_encoder(bin_videos) | |
| stack_of_features_list.append(bin_features) | |
| stack_of_features: torch.Tensor = torch.cat(stack_of_features_list, dim=0) | |
| return stack_of_features | |
| def get_views( | |
| self, | |
| stack_of_videos: torch.Tensor, | |
| visualize: bool = False, | |
| return_view_list: bool = False, | |
| ) -> Union[torch.Tensor, List[str]]: | |
| """ | |
| Predict the echocardiographic view for each video clip using the frozen | |
| view classifier applied to the first frame of each clip. | |
| Args: | |
| stack_of_videos (torch.Tensor): Float32 tensor of shape | |
| ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or | |
| ``process_mp4s``. | |
| visualize (bool): When ``True``, display a grid of first frames | |
| annotated with their predicted view label using matplotlib and | |
| OpenCV. Defaults to ``False``. | |
| return_view_list (bool): When ``True``, return a plain | |
| ``List[str]`` of human-readable view names instead of the | |
| one-hot encoded tensor. Defaults to ``False``. | |
| Returns: | |
| Union[torch.Tensor, List[str]]: | |
| - If ``return_view_list`` is ``False`` (default): a | |
| ``torch.Tensor`` of shape ``(N, 11)`` containing one-hot | |
| view encodings on ``self.device``. | |
| - If ``return_view_list`` is ``True``: a ``List[str]`` of | |
| length *N* with coarse view name strings. | |
| - ``torch.empty(0)`` when ``stack_of_videos`` contains no | |
| elements. | |
| """ | |
| if stack_of_videos.numel() == 0: | |
| return torch.empty(0) | |
| stack_of_first_frames: torch.Tensor = stack_of_videos[:, :, 0, :, :].to(self.device) | |
| with torch.no_grad(): | |
| out_logits: torch.Tensor = self.view_classifier(stack_of_first_frames) | |
| out_views: torch.Tensor = torch.argmax(out_logits, dim=1) | |
| view_list: List[str] = [utils.COARSE_VIEWS[v] for v in out_views] | |
| stack_of_view_encodings: torch.Tensor = ( | |
| torch.stack([torch.nn.functional.one_hot(out_views, 11)]).squeeze(0).to(self.device) | |
| ) | |
| if visualize: | |
| # FIX: Robust row calculation | |
| cols: int = 12 | |
| rows: int = (len(view_list) + cols - 1) // cols | |
| print(f"[EchoPrime] Visualizing {len(view_list)} views in grid {rows}x{cols}") | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols, rows)) | |
| axes = axes.flatten() | |
| for i in range(len(view_list)): | |
| display_image: np.ndarray = ( | |
| stack_of_first_frames[i].cpu().permute([1, 2, 0]) * 255 | |
| ).numpy() | |
| display_image = np.clip(display_image, 0, 255).astype("uint8") | |
| display_image = np.ascontiguousarray(display_image) | |
| display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR) | |
| cv2.putText( | |
| display_image, | |
| view_list[i].replace("_", " "), | |
| (10, 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (0, 220, 255), | |
| 2, | |
| ) | |
| axes[i].imshow(display_image) | |
| axes[i].axis("off") | |
| for j in range(i + 1, len(axes)): | |
| axes[j].axis("off") | |
| plt.subplots_adjust(wspace=0.05, hspace=0.05) | |
| plt.show() | |
| if return_view_list: | |
| return view_list | |
| return stack_of_view_encodings | |
| def encode_study( | |
| self, stack_of_videos: torch.Tensor, visualize: bool = False | |
| ) -> torch.Tensor: | |
| """ | |
| Produce a per-clip study encoding by concatenating visual embeddings | |
| from the echo encoder with one-hot view encodings from the view | |
| classifier. | |
| This is the primary encoding step that aggregates both *what is shown* | |
| (clip embedding) and *which view it belongs to* (view encoding) into a | |
| unified representation used downstream by ``generate_report`` and | |
| ``predict_metrics``. | |
| Args: | |
| stack_of_videos (torch.Tensor): Float32 tensor of shape | |
| ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or | |
| ``process_mp4s``. | |
| visualize (bool): When ``True``, pass through to ``get_views`` to | |
| render an annotated view grid. Defaults to ``False``. | |
| Returns: | |
| torch.Tensor: Float32 tensor of shape ``(N, 523)`` where the first | |
| 512 columns are clip embeddings and the remaining 11 columns are | |
| one-hot view encodings. Returns ``torch.empty(0)`` when | |
| ``stack_of_videos`` contains no elements. | |
| """ | |
| if stack_of_videos.numel() == 0: | |
| print("[ERROR] Cannot encode empty video stack.") | |
| return torch.empty(0) | |
| stack_of_features: torch.Tensor = self.embed_videos(stack_of_videos) | |
| stack_of_view_encodings: torch.Tensor = self.get_views(stack_of_videos, visualize) | |
| encoded_study: torch.Tensor = torch.cat((stack_of_features, stack_of_view_encodings), dim=1) | |
| return encoded_study | |
| def translate_sections(self, report: str) -> str: | |
| """ | |
| Translate anatomical section headings in a generated English report into | |
| the language specified by ``self.lang``. | |
| Only the section header strings (e.g. ``"Left Ventricle"``) are | |
| replaced; the body text of each section is left unchanged. If | |
| ``self.lang`` is not a recognised code, the report is returned | |
| unmodified. | |
| Supported language codes: ``'it'`` (Italian), ``'bs'`` (Bosnian), | |
| ``'ru'`` (Russian). | |
| Args: | |
| report (str): Full clinical report text in English as returned by | |
| ``generate_report``. | |
| Returns: | |
| str: Report with anatomical section headings replaced by their | |
| translated equivalents. Returns the original ``report`` | |
| unchanged when no translation mapping is available for | |
| ``self.lang``. | |
| """ | |
| translations: Dict[str, str] = {} | |
| if self.lang == "it": | |
| translations = { | |
| "Left Ventricle": "Ventricolo Sinistro", | |
| "Resting Segmental Wall Motion Analysis": "Cinetica Segmentaria a Riposo", | |
| "Right Ventricle": "Ventricolo Destro", | |
| "Left Atrium": "Atrio Sinistro", | |
| "Right Atrium": "Atrio Destro", | |
| "Atrial Septum": "Setto Inter-Atriale", | |
| "Mitral Valve": "Valvola Mitrale", | |
| "Aortic Valve": "Valvola Aortica", | |
| "Tricuspid Valve": "Valvola Tricuspide", | |
| "Pulmonic Valve": "Valvola Polmonare", | |
| "Pericardium": "Pericardio", | |
| "Aorta": "Aorta", | |
| "IVC": "Vena Cava Inferiore", | |
| "Pulmonary Artery": "Arteria Polmonare", | |
| "Pulmonary Veins": "Vene Polmonari", | |
| "Postoperative Findings": "Esiti Post-Operatori", | |
| } | |
| elif self.lang == "bs": | |
| translations = { | |
| "Left Ventricle": "Lijeva komora", | |
| "Resting Segmental Wall Motion Analysis": "Analiza segmentalne pokretljivosti stijenke u mirovanju", | |
| "Right Ventricle": "Desna komora", | |
| "Left Atrium": "Lijeva pretkomora", | |
| "Right Atrium": "Desna pretkomora", | |
| "Atrial Septum": "Interatrijski septum", | |
| "Mitral Valve": "Mitralni zalisak", | |
| "Aortic Valve": "Aortni zalisak", | |
| "Tricuspid Valve": "Trikuspidalni zalisak", | |
| "Pulmonic Valve": "Pulmonalni zalisak", | |
| "Pericardium": "Perikard", | |
| "Aorta": "Aorta", | |
| "IVC": "Donja šuplja vena", | |
| "Pulmonary Artery": "Plućna arterija", | |
| "Pulmonary Veins": "Plućne vene", | |
| "Postoperative Findings": "Postoperativni nalazi", | |
| } | |
| elif self.lang == "ru": | |
| translations = { | |
| "Left Ventricle": "Левый желудочек", | |
| "Resting Segmental Wall Motion Analysis": "Анализ сегментарной сократимости в покое", | |
| "Right Ventricle": "Правый желудочек", | |
| "Left Atrium": "Левое предсердие", | |
| "Right Atrium": "Правое предсердие", | |
| "Atrial Septum": "Межпредсердная перегородка", | |
| "Mitral Valve": "Митральный клапан", | |
| "Aortic Valve": "Аортальный клапан", | |
| "Tricuspid Valve": "Трёхстворчатый клапан", | |
| "Pulmonic Valve": "Клапан лёгочной артерии", | |
| "Pericardium": "Перикард", | |
| "Aorta": "Аорта", | |
| "IVC": "Нижняя полая вена", | |
| "Pulmonary Artery": "Лёгочная артерия", | |
| "Pulmonary Veins": "Лёгочные вены", | |
| "Postoperative Findings": "Послеоперационные изменения", | |
| } | |
| """ | |
| elif self.lang=='your_language_code': | |
| translations = { | |
| # add your translations here | |
| } | |
| """ | |
| for section, t in translations.items(): | |
| report = report.replace(section, t) | |
| return report | |
| def generate_report(self, study_embedding: torch.Tensor) -> str: | |
| """ | |
| Generate a structured multi-section clinical echocardiography report | |
| by retrieving the most relevant candidate report section for each | |
| cardiac section using cosine similarity over ``candidate_embeddings``. | |
| For each cardiac section in ``non_empty_sections`` the method: | |
| 1. Applies MIL attention weights to weight each clip's embedding by its | |
| relevance to the current section. | |
| 2. Computes a normalised mean section embedding. | |
| 3. Retrieves the highest-scoring candidate report (by cosine similarity) | |
| that contains non-empty text for the current section, trying up to | |
| 100 candidates before moving on. | |
| 4. Appends the extracted section text to the running report string. | |
| If ``self.lang`` is not ``'en'``, ``translate_sections`` is called on | |
| the final report before returning. | |
| Args: | |
| study_embedding (torch.Tensor): Float32 tensor of shape | |
| ``(N, 523)`` as returned by ``encode_study``, where the first | |
| 512 columns are clip embeddings and the last 11 columns are | |
| one-hot view encodings. | |
| Returns: | |
| str: A multi-section clinical report string. Returns the sentinel | |
| string ``"No data available to generate report."`` when | |
| ``study_embedding`` is empty. | |
| """ | |
| if study_embedding.numel() == 0: | |
| return "No data available to generate report." | |
| print("[EchoPrime] Generating clinical report...") | |
| # Move to CPU for processing with numpy weights | |
| study_embedding = study_embedding.cpu() | |
| generated_report: str = "" | |
| for s_dx, sec in enumerate(self.non_empty_sections): | |
| cur_weights: List[np.ndarray] = [ | |
| self.section_weights[s_dx][torch.where(ten == 1)[0]] | |
| for ten in study_embedding[:, 512:] | |
| ] | |
| if not cur_weights: | |
| continue | |
| no_view_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor( | |
| cur_weights, dtype=torch.float | |
| ).unsqueeze(1) | |
| no_view_study_embedding = torch.mean(no_view_study_embedding, dim=0) | |
| no_view_study_embedding = torch.nn.functional.normalize(no_view_study_embedding, dim=0) | |
| # --- FIX: Move vector to GPU before comparing with candidate_embeddings --- | |
| no_view_study_embedding = no_view_study_embedding.to(self.device) | |
| similarities: torch.Tensor = no_view_study_embedding @ self.candidate_embeddings.T | |
| extracted_section: str = "Section not found." | |
| attempts: int = 0 | |
| # Move similarities back to CPU for the loop logic if needed, or keep on GPU | |
| # (Keeping on GPU is fine for argmax, but we need the index) | |
| while extracted_section == "Section not found." and attempts < 100: | |
| max_id: int = torch.argmax(similarities).item() # .item() gets the number cleanly | |
| predicted_section: str = self.candidate_reports[max_id] | |
| extracted_section = utils.extract_section(predicted_section, sec) | |
| if extracted_section != "Section not found.": | |
| generated_report += extracted_section | |
| # Set the score to -infinity so we don't pick it again | |
| similarities[max_id] = float("-inf") | |
| attempts += 1 | |
| if self.lang != "en": | |
| generated_report = self.translate_sections(generated_report) | |
| return generated_report | |
| def predict_metrics(self, study_embedding: torch.Tensor, k: int = 50) -> Dict[str, float]: | |
| """ | |
| Predict quantitative cardiac phenotype metrics for a study using a | |
| *k*-nearest-neighbour (kNN) approach over the candidate study embeddings. | |
| For each cardiac section the method: | |
| 1. Applies MIL attention weights to compute a section-specific study | |
| embedding via weighted summation over per-clip embeddings. | |
| 2. Retrieves the top-*k* most similar candidate studies by cosine | |
| similarity. | |
| 3. Averages the ground-truth phenotype label values from those | |
| candidates, yielding a soft prediction for each phenotype. | |
| Args: | |
| study_embedding (torch.Tensor): Float32 tensor of shape | |
| ``(N, 523)`` as returned by ``encode_study``, where the first | |
| 512 columns are clip embeddings and the last 11 columns are | |
| one-hot view encodings. | |
| k (int): Number of nearest candidate studies to retrieve per | |
| section when averaging label values. Defaults to ``50``. | |
| Returns: | |
| Dict[str, float]: Mapping from phenotype name (str) to its | |
| predicted value (float). Phenotypes for which no candidate | |
| labels are available evaluate to ``numpy.nan``. Returns an | |
| empty dict ``{}`` when ``study_embedding`` is empty. | |
| """ | |
| if study_embedding.numel() == 0: | |
| return {} | |
| print("[EchoPrime] Predicting metrics...") | |
| # Calculate on CPU because weights are numpy/CPU | |
| per_section_study_embedding: torch.Tensor = torch.zeros(len(self.non_empty_sections), 512) | |
| study_embedding = study_embedding.cpu() | |
| for s_dx, sec in enumerate(self.non_empty_sections): | |
| this_section_weights: List[np.ndarray] = [ | |
| self.section_weights[s_dx][torch.where(view_encoding == 1)[0]] | |
| for view_encoding in study_embedding[:, 512:] | |
| ] | |
| if not this_section_weights: | |
| continue | |
| this_section_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor( | |
| this_section_weights, dtype=torch.float | |
| ).unsqueeze(1) | |
| this_section_study_embedding = torch.sum(this_section_study_embedding, dim=0) | |
| per_section_study_embedding[s_dx] = this_section_study_embedding | |
| per_section_study_embedding = torch.nn.functional.normalize(per_section_study_embedding) | |
| # --- FIX: Move matrix to GPU before comparing --- | |
| per_section_study_embedding = per_section_study_embedding.to(self.device) | |
| similarities: torch.Tensor = per_section_study_embedding @ self.candidate_embeddings.T | |
| top_candidate_ids: torch.Tensor = ( | |
| torch.topk(similarities, k=k, dim=1).indices.cpu() | |
| ) # Move indices back to CPU for list access | |
| preds: Dict[str, float] = {} | |
| for s_dx, section in enumerate(self.section_to_phenotypes.keys()): | |
| for pheno in self.section_to_phenotypes[section]: | |
| # Calculate mean | |
| values: List[float] = [ | |
| self.candidate_labels[pheno][self.candidate_studies[c_ids]] | |
| for c_ids in top_candidate_ids[s_dx] | |
| if self.candidate_studies[c_ids] in self.candidate_labels[pheno] | |
| ] | |
| preds[pheno] = np.nanmean(values) if values else np.nan | |
| return preds | |
| class EchoPrimeTextEncoder(torch.nn.Module): | |
| """ | |
| BiomedBERT-based text encoder that projects clinical echocardiography | |
| report text into the 512-dimensional embedding space shared with | |
| ``EchoPrime``'s visual encoder. | |
| The backbone is a ``BiomedNLP-BiomedBERT-base-uncased-abstract`` masked | |
| language model whose ``[CLS]`` token representation is linearly projected | |
| from 768 to 512 dimensions. All forward passes are wrapped in | |
| ``torch.no_grad()``. | |
| When an input tokenises to more than 512 tokens the encoder randomly | |
| samples a 512-token window aligned to sentence boundaries (``[SEP]`` | |
| tokens) so that the window always starts and ends at a sentence boundary. | |
| Attributes: | |
| device (str): Identifier of the compute device (e.g. ``"cuda"``). | |
| backbone (transformers.BertForMaskedLM): BiomedBERT backbone model. | |
| text_projection (torch.nn.Linear): Linear layer mapping the 768-dim | |
| ``[CLS]`` representation to 512 dimensions. | |
| tokenizer (transformers.BertTokenizer): Tokenizer paired with the | |
| BiomedBERT backbone; ``max_length`` is set to 512. | |
| """ | |
| def __init__(self, device: str = "cuda") -> None: | |
| """ | |
| Initialise the text encoder by loading BiomedBERT weights and | |
| tokenizer from the Hugging Face Hub. | |
| Args: | |
| device (str): Compute device identifier passed to ``self.to()``. | |
| Defaults to ``"cuda"``. | |
| """ | |
| super().__init__() | |
| self.device: str = device | |
| config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained( | |
| "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" | |
| ) | |
| self.backbone: transformers.BertForMaskedLM = transformers.AutoModelForMaskedLM.from_config(config) | |
| self.text_projection: torch.nn.Linear = torch.nn.Linear(768, 512) | |
| self.tokenizer: transformers.BertTokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" | |
| ) | |
| self.tokenizer.max_length = 512 | |
| self.to(device) | |
| def forward(self, report: str) -> torch.Tensor: | |
| """ | |
| Encode a clinical report string into a 512-dimensional embedding. | |
| Tokenises ``report`` with padding and truncation to 512 tokens. If the | |
| tokenised length exceeds 512 tokens, a random 512-token window aligned | |
| to ``[SEP]`` token positions is sampled before encoding, preserving | |
| sentence boundaries at both the start and end of the window. | |
| The ``[CLS]`` token hidden state from the final BiomedBERT layer is | |
| extracted and projected to 512 dimensions via ``text_projection``. | |
| Args: | |
| report (str): Raw clinical echocardiography report text. | |
| Returns: | |
| torch.Tensor: Float32 tensor of shape ``(1, 512)`` containing the | |
| L2-unnormalised text embedding for ``report``. | |
| """ | |
| text: transformers.BatchEncoding = self.tokenizer( | |
| report, | |
| padding="max_length", # Pad to max_length | |
| max_length=512, # Set the maximum length to 512 tokens | |
| truncation=True, # Truncate if the input is longer than max_length, | |
| return_tensors="pt", | |
| ) | |
| if text["input_ids"].shape[1] > 512: | |
| # find sep token positions | |
| sep_positions: List[int] = list( | |
| torch.where(text["input_ids"].squeeze(0) == 3)[0].numpy() | |
| ) | |
| # get maximum possible start that's not going to run out of tokens | |
| max_start: int = sep_positions[-1] - 512 | |
| possible_starts: List[int] = [pos for pos in sep_positions if pos < max_start] | |
| # add 0 as a possible start | |
| possible_starts.insert(0, 0) | |
| start: int = possible_starts[random.randint(0, len(possible_starts) - 1)] | |
| max_end: int = start + 512 | |
| end: int = start # initialised to satisfy linters; always overwritten below | |
| # find the first number less than max_end in sep_position | |
| for p in reversed(sep_positions): | |
| if p <= max_end: | |
| end = p | |
| break | |
| # finally cut the tokens | |
| text = transformers.BatchEncoding( | |
| data={k: v[:, start:end] for (k, v) in text.items()} | |
| ) | |
| with torch.no_grad(): | |
| text.to(self.device) | |
| text_emb: torch.Tensor = self.text_projection( | |
| self.backbone(**text, output_hidden_states=True).hidden_states[-1][:, 0, :] | |
| ) | |
| return text_emb |