# Standard library imports import os import math import glob import json import pickle import random import sys from typing import AnyStr, List, Any, Dict, Optional, Union # Third-party library imports import torch import torchvision import torch.nn.functional as F import numpy as np import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm import cv2 import pydicom import sklearn import sklearn.metrics import transformers # Local module imports import utils class EchoPrime: """ EchoPrime is an echocardiography AI model that encodes cardiac ultrasound studies (DICOM or MP4) into embeddings, classifies echocardiographic views, generates structured clinical reports, and predicts quantitative cardiac metrics via multi-instance learning (MIL) over a candidate study database. Attributes: base_dir (str): Absolute path to the EchoPrime project root directory. echo_encoder (torchvision.models.video.MViT): Frozen MViT-v2-S video encoder producing 512-dimensional embeddings per video clip. view_classifier (torchvision.models.ConvNeXt): Frozen ConvNeXt-Base image classifier predicting one of 11 echocardiographic views from the first frame of each clip. frames_to_take (int): Number of frames sampled from each video clip (32). frame_stride (int): Temporal stride applied when sampling frames (2). video_size (int): Spatial resolution (height and width) videos are resized to before encoding (224 pixels). mean (torch.Tensor): Per-channel pixel mean used for normalisation, shape (3, 1, 1, 1). std (torch.Tensor): Per-channel pixel standard deviation used for normalisation, shape (3, 1, 1, 1). device (torch.device): Compute device (CUDA if available, else CPU). lang (str): ISO 639-1 language code controlling report output language. MIL_weights (pd.DataFrame): CSV-loaded table of per-section MIL attention weights, shape (n_sections, n_views + 1). non_empty_sections (pd.Series): Ordered sequence of cardiac section names derived from the first column of ``MIL_weights``. section_weights (np.ndarray): Numeric weight matrix extracted from ``MIL_weights``, shape (n_sections, n_views). candidate_studies (List[str]): Ordered list of candidate study identifiers used for nearest-neighbour retrieval. candidate_embeddings (torch.Tensor): Concatenated embeddings for all candidate studies, shape (N_candidates, 512), on ``device``. candidate_reports (List[str]): Decoded text reports for each candidate study, aligned index-wise with ``candidate_studies``. candidate_labels (pd.DataFrame): Ground-truth phenotype labels for each candidate study, indexed by study identifier. section_to_phenotypes (Dict[str, List[str]]): Mapping from cardiac section name to the list of phenotype labels predicted for that section. """ def __init__(self, device: Optional[torch.device] = None, lang: str = "en") -> None: """ Initialise EchoPrime by loading model weights, normalisation statistics, MIL attention weights, and candidate study data. Args: device (Optional[torch.device]): Compute device to use. When ``None`` (default), CUDA is used if available, otherwise CPU. lang (str): ISO 639-1 language code for report generation. Supported values include ``'en'`` (default), ``'it'``, ``'bs'``, and ``'ru'``. Raises: FileNotFoundError: If the echo encoder weights file cannot be located at the expected path relative to ``base_dir``. """ self.base_dir: str = os.getenv("ECHOPRIME_ROOT_OVERRIDE") or os.path.dirname(os.path.abspath(__file__)) def get_path(rel_path: str) -> str: """ Resolve a path relative to the EchoPrime project root. Args: rel_path (str): Relative path from the project root. Returns: str: Absolute path formed by joining ``base_dir`` and ``rel_path``. """ return os.path.join(self.base_dir, rel_path) print(f"[EchoPrime] Initializing... (Root dir: {self.base_dir})") # load language specific files utils.initialize_language(lang) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[EchoPrime] Using device: {device}") # LOAD MODEL WEIGHTS weights_path: str = get_path("model_data/weights/echo_prime_encoder.pt") if not os.path.exists(weights_path): # Fallback: Print the exact path we tried so you can debug print(f"[ERROR] Expected weights at: {weights_path}") raise FileNotFoundError(f"Could not find model weights. Check the path above.") checkpoint: Dict[str, torch.Tensor] = torch.load(weights_path, map_location=device) echo_encoder: torchvision.models.video.MViT = torchvision.models.video.mvit_v2_s() echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512) echo_encoder.load_state_dict(checkpoint) echo_encoder.eval() echo_encoder.to(device) for param in echo_encoder.parameters(): param.requires_grad = False vc_state_dict: Dict[str, torch.Tensor] = torch.load( get_path("model_data/weights/view_classifier.pt"), map_location=device ) view_classifier: torchvision.models.ConvNeXt = torchvision.models.convnext_base() view_classifier.classifier[-1] = torch.nn.Linear( view_classifier.classifier[-1].in_features, 11 ) view_classifier.load_state_dict(vc_state_dict) view_classifier.to(device) view_classifier.eval() for param in view_classifier.parameters(): param.requires_grad = False self.echo_encoder: torchvision.models.video.MViT = echo_encoder self.view_classifier: torchvision.models.ConvNeXt = view_classifier self.frames_to_take: int = 32 self.frame_stride: int = 2 self.video_size: int = 224 self.mean: torch.Tensor = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1) self.std: torch.Tensor = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1) self.device: torch.device = device self.lang: str = lang # LOAD ASSETS print("[EchoPrime] Loading assets...") self.MIL_weights: pd.DataFrame = pd.read_csv(get_path("assets/MIL_weights.csv")) self.non_empty_sections: pd.Series = self.MIL_weights["Section"] self.section_weights: np.ndarray = self.MIL_weights.iloc[:, 1:].to_numpy() self.candidate_studies: List[str] = list( pd.read_csv(get_path("model_data/candidates_data/candidate_studies.csv"))["Study"] ) candidate_embeddings_p1: torch.Tensor = torch.load( get_path("model_data/candidates_data/candidate_embeddings_p1.pt"), map_location=device ) candidate_embeddings_p2: torch.Tensor = torch.load( get_path("model_data/candidates_data/candidate_embeddings_p2.pt"), map_location=device ) self.candidate_embeddings: torch.Tensor = torch.cat( (candidate_embeddings_p1, candidate_embeddings_p2), dim=0 ) candidate_reports: pd.Series = pd.read_pickle( get_path("model_data/candidates_data/candidate_reports.pkl") ) self.candidate_reports: List[str] = [utils.phrase_decode(vec_phr) for vec_phr in candidate_reports] self.candidate_labels: pd.DataFrame = pd.read_pickle( get_path("model_data/candidates_data/candidate_labels.pkl") ) self.section_to_phenotypes: Dict[str, List[str]] = pd.read_pickle( get_path("assets/section_to_phenotypes.pkl") ) print("[EchoPrime] Initialization Complete.") def process_dicoms(self, INPUT: str) -> Union[torch.Tensor, List[torch.Tensor]]: """ Scan a directory tree for DICOM video files, decode each file's pixel data, apply spatial pre-processing and temporal sampling, and return a stacked tensor ready for ``embed_videos``. Static 2D images (``pixels.ndim < 3``) and static RGB screenshots (shape ``(H, W, 3)``) are automatically detected and skipped with an informational message. Args: INPUT (str): Path to a directory (searched recursively) that contains ``.dcm`` files. Returns: Union[torch.Tensor, List[torch.Tensor]]: A float32 tensor of shape ``(N, 3, 16, H, W)`` where *N* is the number of successfully processed video DICOMs, *H* = *W* = ``video_size`` (224), and the temporal dimension is ``frames_to_take // frame_stride`` (16). Returns ``torch.empty(0)`` when no valid DICOMs are found or all files fail processing. """ print(f"[EchoPrime] Scanning for DICOMs in: {INPUT}") dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.dcm", recursive=True) if not dicom_paths: print(f"[ERROR] No .dcm files found in {INPUT}") return torch.empty(0) stack_of_videos: List[torch.Tensor] = [] skipped_count: int = 0 print(f"Found {len(dicom_paths)} DICOM files. Processing...") for idx, dicom_path in tqdm(enumerate(dicom_paths), total=len(dicom_paths), desc="Processing"): try: dcm: pydicom.dataset.FileDataset = pydicom.dcmread(dicom_path) pixels: np.ndarray = dcm.pixel_array # --- VERIFICATION PRINT START --- # Check for 2D images (Height, Width) -> No time dimension if pixels.ndim < 3: # Print only the filename, not the whole path, to keep it clean fname: str = os.path.basename(dicom_path) print(f" > Skipped {fname}: Static 2D Image (Shape: {pixels.shape})") skipped_count += 1 continue # Check for RGB static images (Height, Width, 3) -> 3rd dim is color, not time if pixels.ndim == 3 and pixels.shape[2] == 3: fname = os.path.basename(dicom_path) print(f" > Skipped {fname}: Static RGB Screenshot (Shape: {pixels.shape})") skipped_count += 1 continue # --- VERIFICATION PRINT END --- if pixels.ndim == 3: pixels = np.repeat(pixels[..., None], 3, axis=3) pixels = utils.mask_outside_ultrasound(dcm.pixel_array) x: np.ndarray = np.zeros((len(pixels), 224, 224, 3)) for i in range(len(x)): x[i] = utils.crop_and_scale(pixels[i]) x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2]) x_tensor.sub_(self.mean).div_(self.std) if x_tensor.shape[1] < self.frames_to_take: padding: torch.Tensor = torch.zeros( (3, self.frames_to_take - x_tensor.shape[1], self.video_size, self.video_size), dtype=torch.float, ) x_tensor = torch.cat((x_tensor, padding), dim=1) start: int = 0 processed_video: torch.Tensor = x_tensor[ :, start : (start + self.frames_to_take) : self.frame_stride, :, : ] stack_of_videos.append(processed_video) except Exception as e: print(f"Corrupt file {dicom_path}: {e}") pass if len(stack_of_videos) == 0: print("[ERROR] Found DICOMs but failed to process ANY of them.") return torch.empty(0) stacked: torch.Tensor = torch.stack(stack_of_videos) print(f"\n[Summary] Total: {len(dicom_paths)} | Processed: {len(stacked)} | Skipped: {skipped_count}") return stacked def process_mp4s(self, INPUT: str) -> torch.Tensor: """ Scan a directory tree for MP4 video files, decode each file's frame data, apply spatial pre-processing and temporal sampling, and return a stacked tensor ready for ``embed_videos``. Args: INPUT (str): Path to a directory (searched recursively) that contains ``.mp4`` files. Returns: torch.Tensor: A float32 tensor of shape ``(N, 3, 16, H, W)`` where *N* is the number of successfully processed MP4 files, *H* = *W* = ``video_size`` (224), and the temporal dimension is ``frames_to_take // frame_stride`` (16). Corrupt files are silently skipped. """ dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.mp4", recursive=True) stack_of_videos: List[torch.Tensor] = [] for idx, dicom_path in enumerate(dicom_paths): try: # simple dicom_processing pixels_raw: torch.Tensor metadata: Dict[str, Any] pixels_raw, _, metadata = torchvision.io.read_video(dicom_path) fps: float = metadata["video_fps"] pixels: np.ndarray = np.array(pixels_raw) # model specific preprocessing x: np.ndarray = np.zeros((len(pixels), 224, 224, 3)) for i in range(len(x)): x[i] = utils.crop_and_scale(pixels[i]) x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2]) # normalize x_tensor.sub_(self.mean).div_(self.std) ## if not enough frames add padding if x_tensor.shape[1] < self.frames_to_take: padding: torch.Tensor = torch.zeros( ( 3, self.frames_to_take - x_tensor.shape[1], self.video_size, self.video_size, ), dtype=torch.float, ) x_tensor = torch.cat((x_tensor, padding), dim=1) start: int = 0 stack_of_videos.append( x_tensor[:, start : (start + self.frames_to_take) : self.frame_stride, :, :] ) except Exception as e: print("corrupt file") print(str(e)) stacked: torch.Tensor = torch.stack(stack_of_videos) return stacked def embed_videos(self, stack_of_videos: torch.Tensor) -> torch.Tensor: """ Pass a stack of pre-processed video clips through the frozen echo encoder in batches and return the resulting feature embeddings. Videos are forwarded through the encoder in bins of 50 to avoid out-of-memory errors on large studies. Gradient computation is disabled throughout. Args: stack_of_videos (torch.Tensor): Float32 tensor of shape ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or ``process_mp4s``. Returns: torch.Tensor: Float32 feature tensor of shape ``(N, 512)`` containing one 512-dimensional embedding per input clip. Returns ``torch.empty(0)`` if ``stack_of_videos`` contains no elements. """ if stack_of_videos.numel() == 0: return torch.empty(0) bin_size: int = 50 n_bins: int = math.ceil(stack_of_videos.shape[0] / bin_size) stack_of_features_list: List[torch.Tensor] = [] with torch.no_grad(): for bin_idx in range(n_bins): start_idx: int = bin_idx * bin_size end_idx: int = min((bin_idx + 1) * bin_size, stack_of_videos.shape[0]) bin_videos: torch.Tensor = stack_of_videos[start_idx:end_idx].to(self.device) bin_features: torch.Tensor = self.echo_encoder(bin_videos) stack_of_features_list.append(bin_features) stack_of_features: torch.Tensor = torch.cat(stack_of_features_list, dim=0) return stack_of_features def get_views( self, stack_of_videos: torch.Tensor, visualize: bool = False, return_view_list: bool = False, ) -> Union[torch.Tensor, List[str]]: """ Predict the echocardiographic view for each video clip using the frozen view classifier applied to the first frame of each clip. Args: stack_of_videos (torch.Tensor): Float32 tensor of shape ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or ``process_mp4s``. visualize (bool): When ``True``, display a grid of first frames annotated with their predicted view label using matplotlib and OpenCV. Defaults to ``False``. return_view_list (bool): When ``True``, return a plain ``List[str]`` of human-readable view names instead of the one-hot encoded tensor. Defaults to ``False``. Returns: Union[torch.Tensor, List[str]]: - If ``return_view_list`` is ``False`` (default): a ``torch.Tensor`` of shape ``(N, 11)`` containing one-hot view encodings on ``self.device``. - If ``return_view_list`` is ``True``: a ``List[str]`` of length *N* with coarse view name strings. - ``torch.empty(0)`` when ``stack_of_videos`` contains no elements. """ if stack_of_videos.numel() == 0: return torch.empty(0) stack_of_first_frames: torch.Tensor = stack_of_videos[:, :, 0, :, :].to(self.device) with torch.no_grad(): out_logits: torch.Tensor = self.view_classifier(stack_of_first_frames) out_views: torch.Tensor = torch.argmax(out_logits, dim=1) view_list: List[str] = [utils.COARSE_VIEWS[v] for v in out_views] stack_of_view_encodings: torch.Tensor = ( torch.stack([torch.nn.functional.one_hot(out_views, 11)]).squeeze(0).to(self.device) ) if visualize: # FIX: Robust row calculation cols: int = 12 rows: int = (len(view_list) + cols - 1) // cols print(f"[EchoPrime] Visualizing {len(view_list)} views in grid {rows}x{cols}") fig, axes = plt.subplots(rows, cols, figsize=(cols, rows)) axes = axes.flatten() for i in range(len(view_list)): display_image: np.ndarray = ( stack_of_first_frames[i].cpu().permute([1, 2, 0]) * 255 ).numpy() display_image = np.clip(display_image, 0, 255).astype("uint8") display_image = np.ascontiguousarray(display_image) display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR) cv2.putText( display_image, view_list[i].replace("_", " "), (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 220, 255), 2, ) axes[i].imshow(display_image) axes[i].axis("off") for j in range(i + 1, len(axes)): axes[j].axis("off") plt.subplots_adjust(wspace=0.05, hspace=0.05) plt.show() if return_view_list: return view_list return stack_of_view_encodings @torch.no_grad() def encode_study( self, stack_of_videos: torch.Tensor, visualize: bool = False ) -> torch.Tensor: """ Produce a per-clip study encoding by concatenating visual embeddings from the echo encoder with one-hot view encodings from the view classifier. This is the primary encoding step that aggregates both *what is shown* (clip embedding) and *which view it belongs to* (view encoding) into a unified representation used downstream by ``generate_report`` and ``predict_metrics``. Args: stack_of_videos (torch.Tensor): Float32 tensor of shape ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or ``process_mp4s``. visualize (bool): When ``True``, pass through to ``get_views`` to render an annotated view grid. Defaults to ``False``. Returns: torch.Tensor: Float32 tensor of shape ``(N, 523)`` where the first 512 columns are clip embeddings and the remaining 11 columns are one-hot view encodings. Returns ``torch.empty(0)`` when ``stack_of_videos`` contains no elements. """ if stack_of_videos.numel() == 0: print("[ERROR] Cannot encode empty video stack.") return torch.empty(0) stack_of_features: torch.Tensor = self.embed_videos(stack_of_videos) stack_of_view_encodings: torch.Tensor = self.get_views(stack_of_videos, visualize) encoded_study: torch.Tensor = torch.cat((stack_of_features, stack_of_view_encodings), dim=1) return encoded_study def translate_sections(self, report: str) -> str: """ Translate anatomical section headings in a generated English report into the language specified by ``self.lang``. Only the section header strings (e.g. ``"Left Ventricle"``) are replaced; the body text of each section is left unchanged. If ``self.lang`` is not a recognised code, the report is returned unmodified. Supported language codes: ``'it'`` (Italian), ``'bs'`` (Bosnian), ``'ru'`` (Russian). Args: report (str): Full clinical report text in English as returned by ``generate_report``. Returns: str: Report with anatomical section headings replaced by their translated equivalents. Returns the original ``report`` unchanged when no translation mapping is available for ``self.lang``. """ translations: Dict[str, str] = {} if self.lang == "it": translations = { "Left Ventricle": "Ventricolo Sinistro", "Resting Segmental Wall Motion Analysis": "Cinetica Segmentaria a Riposo", "Right Ventricle": "Ventricolo Destro", "Left Atrium": "Atrio Sinistro", "Right Atrium": "Atrio Destro", "Atrial Septum": "Setto Inter-Atriale", "Mitral Valve": "Valvola Mitrale", "Aortic Valve": "Valvola Aortica", "Tricuspid Valve": "Valvola Tricuspide", "Pulmonic Valve": "Valvola Polmonare", "Pericardium": "Pericardio", "Aorta": "Aorta", "IVC": "Vena Cava Inferiore", "Pulmonary Artery": "Arteria Polmonare", "Pulmonary Veins": "Vene Polmonari", "Postoperative Findings": "Esiti Post-Operatori", } elif self.lang == "bs": translations = { "Left Ventricle": "Lijeva komora", "Resting Segmental Wall Motion Analysis": "Analiza segmentalne pokretljivosti stijenke u mirovanju", "Right Ventricle": "Desna komora", "Left Atrium": "Lijeva pretkomora", "Right Atrium": "Desna pretkomora", "Atrial Septum": "Interatrijski septum", "Mitral Valve": "Mitralni zalisak", "Aortic Valve": "Aortni zalisak", "Tricuspid Valve": "Trikuspidalni zalisak", "Pulmonic Valve": "Pulmonalni zalisak", "Pericardium": "Perikard", "Aorta": "Aorta", "IVC": "Donja šuplja vena", "Pulmonary Artery": "Plućna arterija", "Pulmonary Veins": "Plućne vene", "Postoperative Findings": "Postoperativni nalazi", } elif self.lang == "ru": translations = { "Left Ventricle": "Левый желудочек", "Resting Segmental Wall Motion Analysis": "Анализ сегментарной сократимости в покое", "Right Ventricle": "Правый желудочек", "Left Atrium": "Левое предсердие", "Right Atrium": "Правое предсердие", "Atrial Septum": "Межпредсердная перегородка", "Mitral Valve": "Митральный клапан", "Aortic Valve": "Аортальный клапан", "Tricuspid Valve": "Трёхстворчатый клапан", "Pulmonic Valve": "Клапан лёгочной артерии", "Pericardium": "Перикард", "Aorta": "Аорта", "IVC": "Нижняя полая вена", "Pulmonary Artery": "Лёгочная артерия", "Pulmonary Veins": "Лёгочные вены", "Postoperative Findings": "Послеоперационные изменения", } """ elif self.lang=='your_language_code': translations = { # add your translations here } """ for section, t in translations.items(): report = report.replace(section, t) return report def generate_report(self, study_embedding: torch.Tensor) -> str: """ Generate a structured multi-section clinical echocardiography report by retrieving the most relevant candidate report section for each cardiac section using cosine similarity over ``candidate_embeddings``. For each cardiac section in ``non_empty_sections`` the method: 1. Applies MIL attention weights to weight each clip's embedding by its relevance to the current section. 2. Computes a normalised mean section embedding. 3. Retrieves the highest-scoring candidate report (by cosine similarity) that contains non-empty text for the current section, trying up to 100 candidates before moving on. 4. Appends the extracted section text to the running report string. If ``self.lang`` is not ``'en'``, ``translate_sections`` is called on the final report before returning. Args: study_embedding (torch.Tensor): Float32 tensor of shape ``(N, 523)`` as returned by ``encode_study``, where the first 512 columns are clip embeddings and the last 11 columns are one-hot view encodings. Returns: str: A multi-section clinical report string. Returns the sentinel string ``"No data available to generate report."`` when ``study_embedding`` is empty. """ if study_embedding.numel() == 0: return "No data available to generate report." print("[EchoPrime] Generating clinical report...") # Move to CPU for processing with numpy weights study_embedding = study_embedding.cpu() generated_report: str = "" for s_dx, sec in enumerate(self.non_empty_sections): cur_weights: List[np.ndarray] = [ self.section_weights[s_dx][torch.where(ten == 1)[0]] for ten in study_embedding[:, 512:] ] if not cur_weights: continue no_view_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor( cur_weights, dtype=torch.float ).unsqueeze(1) no_view_study_embedding = torch.mean(no_view_study_embedding, dim=0) no_view_study_embedding = torch.nn.functional.normalize(no_view_study_embedding, dim=0) # --- FIX: Move vector to GPU before comparing with candidate_embeddings --- no_view_study_embedding = no_view_study_embedding.to(self.device) similarities: torch.Tensor = no_view_study_embedding @ self.candidate_embeddings.T extracted_section: str = "Section not found." attempts: int = 0 # Move similarities back to CPU for the loop logic if needed, or keep on GPU # (Keeping on GPU is fine for argmax, but we need the index) while extracted_section == "Section not found." and attempts < 100: max_id: int = torch.argmax(similarities).item() # .item() gets the number cleanly predicted_section: str = self.candidate_reports[max_id] extracted_section = utils.extract_section(predicted_section, sec) if extracted_section != "Section not found.": generated_report += extracted_section # Set the score to -infinity so we don't pick it again similarities[max_id] = float("-inf") attempts += 1 if self.lang != "en": generated_report = self.translate_sections(generated_report) return generated_report def predict_metrics(self, study_embedding: torch.Tensor, k: int = 50) -> Dict[str, float]: """ Predict quantitative cardiac phenotype metrics for a study using a *k*-nearest-neighbour (kNN) approach over the candidate study embeddings. For each cardiac section the method: 1. Applies MIL attention weights to compute a section-specific study embedding via weighted summation over per-clip embeddings. 2. Retrieves the top-*k* most similar candidate studies by cosine similarity. 3. Averages the ground-truth phenotype label values from those candidates, yielding a soft prediction for each phenotype. Args: study_embedding (torch.Tensor): Float32 tensor of shape ``(N, 523)`` as returned by ``encode_study``, where the first 512 columns are clip embeddings and the last 11 columns are one-hot view encodings. k (int): Number of nearest candidate studies to retrieve per section when averaging label values. Defaults to ``50``. Returns: Dict[str, float]: Mapping from phenotype name (str) to its predicted value (float). Phenotypes for which no candidate labels are available evaluate to ``numpy.nan``. Returns an empty dict ``{}`` when ``study_embedding`` is empty. """ if study_embedding.numel() == 0: return {} print("[EchoPrime] Predicting metrics...") # Calculate on CPU because weights are numpy/CPU per_section_study_embedding: torch.Tensor = torch.zeros(len(self.non_empty_sections), 512) study_embedding = study_embedding.cpu() for s_dx, sec in enumerate(self.non_empty_sections): this_section_weights: List[np.ndarray] = [ self.section_weights[s_dx][torch.where(view_encoding == 1)[0]] for view_encoding in study_embedding[:, 512:] ] if not this_section_weights: continue this_section_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor( this_section_weights, dtype=torch.float ).unsqueeze(1) this_section_study_embedding = torch.sum(this_section_study_embedding, dim=0) per_section_study_embedding[s_dx] = this_section_study_embedding per_section_study_embedding = torch.nn.functional.normalize(per_section_study_embedding) # --- FIX: Move matrix to GPU before comparing --- per_section_study_embedding = per_section_study_embedding.to(self.device) similarities: torch.Tensor = per_section_study_embedding @ self.candidate_embeddings.T top_candidate_ids: torch.Tensor = ( torch.topk(similarities, k=k, dim=1).indices.cpu() ) # Move indices back to CPU for list access preds: Dict[str, float] = {} for s_dx, section in enumerate(self.section_to_phenotypes.keys()): for pheno in self.section_to_phenotypes[section]: # Calculate mean values: List[float] = [ self.candidate_labels[pheno][self.candidate_studies[c_ids]] for c_ids in top_candidate_ids[s_dx] if self.candidate_studies[c_ids] in self.candidate_labels[pheno] ] preds[pheno] = np.nanmean(values) if values else np.nan return preds class EchoPrimeTextEncoder(torch.nn.Module): """ BiomedBERT-based text encoder that projects clinical echocardiography report text into the 512-dimensional embedding space shared with ``EchoPrime``'s visual encoder. The backbone is a ``BiomedNLP-BiomedBERT-base-uncased-abstract`` masked language model whose ``[CLS]`` token representation is linearly projected from 768 to 512 dimensions. All forward passes are wrapped in ``torch.no_grad()``. When an input tokenises to more than 512 tokens the encoder randomly samples a 512-token window aligned to sentence boundaries (``[SEP]`` tokens) so that the window always starts and ends at a sentence boundary. Attributes: device (str): Identifier of the compute device (e.g. ``"cuda"``). backbone (transformers.BertForMaskedLM): BiomedBERT backbone model. text_projection (torch.nn.Linear): Linear layer mapping the 768-dim ``[CLS]`` representation to 512 dimensions. tokenizer (transformers.BertTokenizer): Tokenizer paired with the BiomedBERT backbone; ``max_length`` is set to 512. """ def __init__(self, device: str = "cuda") -> None: """ Initialise the text encoder by loading BiomedBERT weights and tokenizer from the Hugging Face Hub. Args: device (str): Compute device identifier passed to ``self.to()``. Defaults to ``"cuda"``. """ super().__init__() self.device: str = device config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained( "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" ) self.backbone: transformers.BertForMaskedLM = transformers.AutoModelForMaskedLM.from_config(config) self.text_projection: torch.nn.Linear = torch.nn.Linear(768, 512) self.tokenizer: transformers.BertTokenizer = transformers.AutoTokenizer.from_pretrained( "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" ) self.tokenizer.max_length = 512 self.to(device) def forward(self, report: str) -> torch.Tensor: """ Encode a clinical report string into a 512-dimensional embedding. Tokenises ``report`` with padding and truncation to 512 tokens. If the tokenised length exceeds 512 tokens, a random 512-token window aligned to ``[SEP]`` token positions is sampled before encoding, preserving sentence boundaries at both the start and end of the window. The ``[CLS]`` token hidden state from the final BiomedBERT layer is extracted and projected to 512 dimensions via ``text_projection``. Args: report (str): Raw clinical echocardiography report text. Returns: torch.Tensor: Float32 tensor of shape ``(1, 512)`` containing the L2-unnormalised text embedding for ``report``. """ text: transformers.BatchEncoding = self.tokenizer( report, padding="max_length", # Pad to max_length max_length=512, # Set the maximum length to 512 tokens truncation=True, # Truncate if the input is longer than max_length, return_tensors="pt", ) if text["input_ids"].shape[1] > 512: # find sep token positions sep_positions: List[int] = list( torch.where(text["input_ids"].squeeze(0) == 3)[0].numpy() ) # get maximum possible start that's not going to run out of tokens max_start: int = sep_positions[-1] - 512 possible_starts: List[int] = [pos for pos in sep_positions if pos < max_start] # add 0 as a possible start possible_starts.insert(0, 0) start: int = possible_starts[random.randint(0, len(possible_starts) - 1)] max_end: int = start + 512 end: int = start # initialised to satisfy linters; always overwritten below # find the first number less than max_end in sep_position for p in reversed(sep_positions): if p <= max_end: end = p break # finally cut the tokens text = transformers.BatchEncoding( data={k: v[:, start:end] for (k, v) in text.items()} ) with torch.no_grad(): text.to(self.device) text_emb: torch.Tensor = self.text_projection( self.backbone(**text, output_hidden_states=True).hidden_states[-1][:, 0, :] ) return text_emb