echo-prime-demo / model.py
amn23's picture
Update model.py
a9c2e29 verified
# Standard library imports
import os
import math
import glob
import json
import pickle
import random
import sys
from typing import AnyStr, List, Any, Dict, Optional, Union
# Third-party library imports
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import pydicom
import sklearn
import sklearn.metrics
import transformers
# Local module imports
import utils
class EchoPrime:
"""
EchoPrime is an echocardiography AI model that encodes cardiac ultrasound
studies (DICOM or MP4) into embeddings, classifies echocardiographic views,
generates structured clinical reports, and predicts quantitative cardiac
metrics via multi-instance learning (MIL) over a candidate study database.
Attributes:
base_dir (str): Absolute path to the EchoPrime project root directory.
echo_encoder (torchvision.models.video.MViT): Frozen MViT-v2-S video
encoder producing 512-dimensional embeddings per video clip.
view_classifier (torchvision.models.ConvNeXt): Frozen ConvNeXt-Base
image classifier predicting one of 11 echocardiographic views from
the first frame of each clip.
frames_to_take (int): Number of frames sampled from each video clip (32).
frame_stride (int): Temporal stride applied when sampling frames (2).
video_size (int): Spatial resolution (height and width) videos are
resized to before encoding (224 pixels).
mean (torch.Tensor): Per-channel pixel mean used for normalisation,
shape (3, 1, 1, 1).
std (torch.Tensor): Per-channel pixel standard deviation used for
normalisation, shape (3, 1, 1, 1).
device (torch.device): Compute device (CUDA if available, else CPU).
lang (str): ISO 639-1 language code controlling report output language.
MIL_weights (pd.DataFrame): CSV-loaded table of per-section MIL
attention weights, shape (n_sections, n_views + 1).
non_empty_sections (pd.Series): Ordered sequence of cardiac section
names derived from the first column of ``MIL_weights``.
section_weights (np.ndarray): Numeric weight matrix extracted from
``MIL_weights``, shape (n_sections, n_views).
candidate_studies (List[str]): Ordered list of candidate study
identifiers used for nearest-neighbour retrieval.
candidate_embeddings (torch.Tensor): Concatenated embeddings for all
candidate studies, shape (N_candidates, 512), on ``device``.
candidate_reports (List[str]): Decoded text reports for each candidate
study, aligned index-wise with ``candidate_studies``.
candidate_labels (pd.DataFrame): Ground-truth phenotype labels for each
candidate study, indexed by study identifier.
section_to_phenotypes (Dict[str, List[str]]): Mapping from cardiac
section name to the list of phenotype labels predicted for that
section.
"""
def __init__(self, device: Optional[torch.device] = None, lang: str = "en") -> None:
"""
Initialise EchoPrime by loading model weights, normalisation statistics,
MIL attention weights, and candidate study data.
Args:
device (Optional[torch.device]): Compute device to use. When
``None`` (default), CUDA is used if available, otherwise CPU.
lang (str): ISO 639-1 language code for report generation.
Supported values include ``'en'`` (default), ``'it'``,
``'bs'``, and ``'ru'``.
Raises:
FileNotFoundError: If the echo encoder weights file cannot be
located at the expected path relative to ``base_dir``.
"""
self.base_dir: str = os.getenv("ECHOPRIME_ROOT_OVERRIDE") or os.path.dirname(os.path.abspath(__file__))
def get_path(rel_path: str) -> str:
"""
Resolve a path relative to the EchoPrime project root.
Args:
rel_path (str): Relative path from the project root.
Returns:
str: Absolute path formed by joining ``base_dir`` and
``rel_path``.
"""
return os.path.join(self.base_dir, rel_path)
print(f"[EchoPrime] Initializing... (Root dir: {self.base_dir})")
# load language specific files
utils.initialize_language(lang)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[EchoPrime] Using device: {device}")
# LOAD MODEL WEIGHTS
weights_path: str = get_path("model_data/weights/echo_prime_encoder.pt")
if not os.path.exists(weights_path):
# Fallback: Print the exact path we tried so you can debug
print(f"[ERROR] Expected weights at: {weights_path}")
raise FileNotFoundError(f"Could not find model weights. Check the path above.")
checkpoint: Dict[str, torch.Tensor] = torch.load(weights_path, map_location=device)
echo_encoder: torchvision.models.video.MViT = torchvision.models.video.mvit_v2_s()
echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512)
echo_encoder.load_state_dict(checkpoint)
echo_encoder.eval()
echo_encoder.to(device)
for param in echo_encoder.parameters():
param.requires_grad = False
vc_state_dict: Dict[str, torch.Tensor] = torch.load(
get_path("model_data/weights/view_classifier.pt"), map_location=device
)
view_classifier: torchvision.models.ConvNeXt = torchvision.models.convnext_base()
view_classifier.classifier[-1] = torch.nn.Linear(
view_classifier.classifier[-1].in_features, 11
)
view_classifier.load_state_dict(vc_state_dict)
view_classifier.to(device)
view_classifier.eval()
for param in view_classifier.parameters():
param.requires_grad = False
self.echo_encoder: torchvision.models.video.MViT = echo_encoder
self.view_classifier: torchvision.models.ConvNeXt = view_classifier
self.frames_to_take: int = 32
self.frame_stride: int = 2
self.video_size: int = 224
self.mean: torch.Tensor = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1)
self.std: torch.Tensor = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1)
self.device: torch.device = device
self.lang: str = lang
# LOAD ASSETS
print("[EchoPrime] Loading assets...")
self.MIL_weights: pd.DataFrame = pd.read_csv(get_path("assets/MIL_weights.csv"))
self.non_empty_sections: pd.Series = self.MIL_weights["Section"]
self.section_weights: np.ndarray = self.MIL_weights.iloc[:, 1:].to_numpy()
self.candidate_studies: List[str] = list(
pd.read_csv(get_path("model_data/candidates_data/candidate_studies.csv"))["Study"]
)
candidate_embeddings_p1: torch.Tensor = torch.load(
get_path("model_data/candidates_data/candidate_embeddings_p1.pt"), map_location=device
)
candidate_embeddings_p2: torch.Tensor = torch.load(
get_path("model_data/candidates_data/candidate_embeddings_p2.pt"), map_location=device
)
self.candidate_embeddings: torch.Tensor = torch.cat(
(candidate_embeddings_p1, candidate_embeddings_p2), dim=0
)
candidate_reports: pd.Series = pd.read_pickle(
get_path("model_data/candidates_data/candidate_reports.pkl")
)
self.candidate_reports: List[str] = [utils.phrase_decode(vec_phr) for vec_phr in candidate_reports]
self.candidate_labels: pd.DataFrame = pd.read_pickle(
get_path("model_data/candidates_data/candidate_labels.pkl")
)
self.section_to_phenotypes: Dict[str, List[str]] = pd.read_pickle(
get_path("assets/section_to_phenotypes.pkl")
)
print("[EchoPrime] Initialization Complete.")
def process_dicoms(self, INPUT: str) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Scan a directory tree for DICOM video files, decode each file's pixel
data, apply spatial pre-processing and temporal sampling, and return a
stacked tensor ready for ``embed_videos``.
Static 2D images (``pixels.ndim < 3``) and static RGB screenshots
(shape ``(H, W, 3)``) are automatically detected and skipped with an
informational message.
Args:
INPUT (str): Path to a directory (searched recursively) that
contains ``.dcm`` files.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: A float32 tensor of shape
``(N, 3, 16, H, W)`` where *N* is the number of successfully
processed video DICOMs, *H* = *W* = ``video_size`` (224), and
the temporal dimension is ``frames_to_take // frame_stride``
(16). Returns ``torch.empty(0)`` when no valid DICOMs are found
or all files fail processing.
"""
print(f"[EchoPrime] Scanning for DICOMs in: {INPUT}")
dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.dcm", recursive=True)
if not dicom_paths:
print(f"[ERROR] No .dcm files found in {INPUT}")
return torch.empty(0)
stack_of_videos: List[torch.Tensor] = []
skipped_count: int = 0
print(f"Found {len(dicom_paths)} DICOM files. Processing...")
for idx, dicom_path in tqdm(enumerate(dicom_paths), total=len(dicom_paths), desc="Processing"):
try:
dcm: pydicom.dataset.FileDataset = pydicom.dcmread(dicom_path)
pixels: np.ndarray = dcm.pixel_array
# --- VERIFICATION PRINT START ---
# Check for 2D images (Height, Width) -> No time dimension
if pixels.ndim < 3:
# Print only the filename, not the whole path, to keep it clean
fname: str = os.path.basename(dicom_path)
print(f" > Skipped {fname}: Static 2D Image (Shape: {pixels.shape})")
skipped_count += 1
continue
# Check for RGB static images (Height, Width, 3) -> 3rd dim is color, not time
if pixels.ndim == 3 and pixels.shape[2] == 3:
fname = os.path.basename(dicom_path)
print(f" > Skipped {fname}: Static RGB Screenshot (Shape: {pixels.shape})")
skipped_count += 1
continue
# --- VERIFICATION PRINT END ---
if pixels.ndim == 3:
pixels = np.repeat(pixels[..., None], 3, axis=3)
pixels = utils.mask_outside_ultrasound(dcm.pixel_array)
x: np.ndarray = np.zeros((len(pixels), 224, 224, 3))
for i in range(len(x)):
x[i] = utils.crop_and_scale(pixels[i])
x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2])
x_tensor.sub_(self.mean).div_(self.std)
if x_tensor.shape[1] < self.frames_to_take:
padding: torch.Tensor = torch.zeros(
(3, self.frames_to_take - x_tensor.shape[1], self.video_size, self.video_size),
dtype=torch.float,
)
x_tensor = torch.cat((x_tensor, padding), dim=1)
start: int = 0
processed_video: torch.Tensor = x_tensor[
:, start : (start + self.frames_to_take) : self.frame_stride, :, :
]
stack_of_videos.append(processed_video)
except Exception as e:
print(f"Corrupt file {dicom_path}: {e}")
pass
if len(stack_of_videos) == 0:
print("[ERROR] Found DICOMs but failed to process ANY of them.")
return torch.empty(0)
stacked: torch.Tensor = torch.stack(stack_of_videos)
print(f"\n[Summary] Total: {len(dicom_paths)} | Processed: {len(stacked)} | Skipped: {skipped_count}")
return stacked
def process_mp4s(self, INPUT: str) -> torch.Tensor:
"""
Scan a directory tree for MP4 video files, decode each file's frame
data, apply spatial pre-processing and temporal sampling, and return a
stacked tensor ready for ``embed_videos``.
Args:
INPUT (str): Path to a directory (searched recursively) that
contains ``.mp4`` files.
Returns:
torch.Tensor: A float32 tensor of shape ``(N, 3, 16, H, W)`` where
*N* is the number of successfully processed MP4 files, *H* =
*W* = ``video_size`` (224), and the temporal dimension is
``frames_to_take // frame_stride`` (16). Corrupt files are
silently skipped.
"""
dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.mp4", recursive=True)
stack_of_videos: List[torch.Tensor] = []
for idx, dicom_path in enumerate(dicom_paths):
try:
# simple dicom_processing
pixels_raw: torch.Tensor
metadata: Dict[str, Any]
pixels_raw, _, metadata = torchvision.io.read_video(dicom_path)
fps: float = metadata["video_fps"]
pixels: np.ndarray = np.array(pixels_raw)
# model specific preprocessing
x: np.ndarray = np.zeros((len(pixels), 224, 224, 3))
for i in range(len(x)):
x[i] = utils.crop_and_scale(pixels[i])
x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2])
# normalize
x_tensor.sub_(self.mean).div_(self.std)
## if not enough frames add padding
if x_tensor.shape[1] < self.frames_to_take:
padding: torch.Tensor = torch.zeros(
(
3,
self.frames_to_take - x_tensor.shape[1],
self.video_size,
self.video_size,
),
dtype=torch.float,
)
x_tensor = torch.cat((x_tensor, padding), dim=1)
start: int = 0
stack_of_videos.append(
x_tensor[:, start : (start + self.frames_to_take) : self.frame_stride, :, :]
)
except Exception as e:
print("corrupt file")
print(str(e))
stacked: torch.Tensor = torch.stack(stack_of_videos)
return stacked
def embed_videos(self, stack_of_videos: torch.Tensor) -> torch.Tensor:
"""
Pass a stack of pre-processed video clips through the frozen echo
encoder in batches and return the resulting feature embeddings.
Videos are forwarded through the encoder in bins of 50 to avoid
out-of-memory errors on large studies. Gradient computation is
disabled throughout.
Args:
stack_of_videos (torch.Tensor): Float32 tensor of shape
``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
``process_mp4s``.
Returns:
torch.Tensor: Float32 feature tensor of shape ``(N, 512)``
containing one 512-dimensional embedding per input clip.
Returns ``torch.empty(0)`` if ``stack_of_videos`` contains no
elements.
"""
if stack_of_videos.numel() == 0:
return torch.empty(0)
bin_size: int = 50
n_bins: int = math.ceil(stack_of_videos.shape[0] / bin_size)
stack_of_features_list: List[torch.Tensor] = []
with torch.no_grad():
for bin_idx in range(n_bins):
start_idx: int = bin_idx * bin_size
end_idx: int = min((bin_idx + 1) * bin_size, stack_of_videos.shape[0])
bin_videos: torch.Tensor = stack_of_videos[start_idx:end_idx].to(self.device)
bin_features: torch.Tensor = self.echo_encoder(bin_videos)
stack_of_features_list.append(bin_features)
stack_of_features: torch.Tensor = torch.cat(stack_of_features_list, dim=0)
return stack_of_features
def get_views(
self,
stack_of_videos: torch.Tensor,
visualize: bool = False,
return_view_list: bool = False,
) -> Union[torch.Tensor, List[str]]:
"""
Predict the echocardiographic view for each video clip using the frozen
view classifier applied to the first frame of each clip.
Args:
stack_of_videos (torch.Tensor): Float32 tensor of shape
``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
``process_mp4s``.
visualize (bool): When ``True``, display a grid of first frames
annotated with their predicted view label using matplotlib and
OpenCV. Defaults to ``False``.
return_view_list (bool): When ``True``, return a plain
``List[str]`` of human-readable view names instead of the
one-hot encoded tensor. Defaults to ``False``.
Returns:
Union[torch.Tensor, List[str]]:
- If ``return_view_list`` is ``False`` (default): a
``torch.Tensor`` of shape ``(N, 11)`` containing one-hot
view encodings on ``self.device``.
- If ``return_view_list`` is ``True``: a ``List[str]`` of
length *N* with coarse view name strings.
- ``torch.empty(0)`` when ``stack_of_videos`` contains no
elements.
"""
if stack_of_videos.numel() == 0:
return torch.empty(0)
stack_of_first_frames: torch.Tensor = stack_of_videos[:, :, 0, :, :].to(self.device)
with torch.no_grad():
out_logits: torch.Tensor = self.view_classifier(stack_of_first_frames)
out_views: torch.Tensor = torch.argmax(out_logits, dim=1)
view_list: List[str] = [utils.COARSE_VIEWS[v] for v in out_views]
stack_of_view_encodings: torch.Tensor = (
torch.stack([torch.nn.functional.one_hot(out_views, 11)]).squeeze(0).to(self.device)
)
if visualize:
# FIX: Robust row calculation
cols: int = 12
rows: int = (len(view_list) + cols - 1) // cols
print(f"[EchoPrime] Visualizing {len(view_list)} views in grid {rows}x{cols}")
fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
axes = axes.flatten()
for i in range(len(view_list)):
display_image: np.ndarray = (
stack_of_first_frames[i].cpu().permute([1, 2, 0]) * 255
).numpy()
display_image = np.clip(display_image, 0, 255).astype("uint8")
display_image = np.ascontiguousarray(display_image)
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
cv2.putText(
display_image,
view_list[i].replace("_", " "),
(10, 25),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 220, 255),
2,
)
axes[i].imshow(display_image)
axes[i].axis("off")
for j in range(i + 1, len(axes)):
axes[j].axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()
if return_view_list:
return view_list
return stack_of_view_encodings
@torch.no_grad()
def encode_study(
self, stack_of_videos: torch.Tensor, visualize: bool = False
) -> torch.Tensor:
"""
Produce a per-clip study encoding by concatenating visual embeddings
from the echo encoder with one-hot view encodings from the view
classifier.
This is the primary encoding step that aggregates both *what is shown*
(clip embedding) and *which view it belongs to* (view encoding) into a
unified representation used downstream by ``generate_report`` and
``predict_metrics``.
Args:
stack_of_videos (torch.Tensor): Float32 tensor of shape
``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
``process_mp4s``.
visualize (bool): When ``True``, pass through to ``get_views`` to
render an annotated view grid. Defaults to ``False``.
Returns:
torch.Tensor: Float32 tensor of shape ``(N, 523)`` where the first
512 columns are clip embeddings and the remaining 11 columns are
one-hot view encodings. Returns ``torch.empty(0)`` when
``stack_of_videos`` contains no elements.
"""
if stack_of_videos.numel() == 0:
print("[ERROR] Cannot encode empty video stack.")
return torch.empty(0)
stack_of_features: torch.Tensor = self.embed_videos(stack_of_videos)
stack_of_view_encodings: torch.Tensor = self.get_views(stack_of_videos, visualize)
encoded_study: torch.Tensor = torch.cat((stack_of_features, stack_of_view_encodings), dim=1)
return encoded_study
def translate_sections(self, report: str) -> str:
"""
Translate anatomical section headings in a generated English report into
the language specified by ``self.lang``.
Only the section header strings (e.g. ``"Left Ventricle"``) are
replaced; the body text of each section is left unchanged. If
``self.lang`` is not a recognised code, the report is returned
unmodified.
Supported language codes: ``'it'`` (Italian), ``'bs'`` (Bosnian),
``'ru'`` (Russian).
Args:
report (str): Full clinical report text in English as returned by
``generate_report``.
Returns:
str: Report with anatomical section headings replaced by their
translated equivalents. Returns the original ``report``
unchanged when no translation mapping is available for
``self.lang``.
"""
translations: Dict[str, str] = {}
if self.lang == "it":
translations = {
"Left Ventricle": "Ventricolo Sinistro",
"Resting Segmental Wall Motion Analysis": "Cinetica Segmentaria a Riposo",
"Right Ventricle": "Ventricolo Destro",
"Left Atrium": "Atrio Sinistro",
"Right Atrium": "Atrio Destro",
"Atrial Septum": "Setto Inter-Atriale",
"Mitral Valve": "Valvola Mitrale",
"Aortic Valve": "Valvola Aortica",
"Tricuspid Valve": "Valvola Tricuspide",
"Pulmonic Valve": "Valvola Polmonare",
"Pericardium": "Pericardio",
"Aorta": "Aorta",
"IVC": "Vena Cava Inferiore",
"Pulmonary Artery": "Arteria Polmonare",
"Pulmonary Veins": "Vene Polmonari",
"Postoperative Findings": "Esiti Post-Operatori",
}
elif self.lang == "bs":
translations = {
"Left Ventricle": "Lijeva komora",
"Resting Segmental Wall Motion Analysis": "Analiza segmentalne pokretljivosti stijenke u mirovanju",
"Right Ventricle": "Desna komora",
"Left Atrium": "Lijeva pretkomora",
"Right Atrium": "Desna pretkomora",
"Atrial Septum": "Interatrijski septum",
"Mitral Valve": "Mitralni zalisak",
"Aortic Valve": "Aortni zalisak",
"Tricuspid Valve": "Trikuspidalni zalisak",
"Pulmonic Valve": "Pulmonalni zalisak",
"Pericardium": "Perikard",
"Aorta": "Aorta",
"IVC": "Donja šuplja vena",
"Pulmonary Artery": "Plućna arterija",
"Pulmonary Veins": "Plućne vene",
"Postoperative Findings": "Postoperativni nalazi",
}
elif self.lang == "ru":
translations = {
"Left Ventricle": "Левый желудочек",
"Resting Segmental Wall Motion Analysis": "Анализ сегментарной сократимости в покое",
"Right Ventricle": "Правый желудочек",
"Left Atrium": "Левое предсердие",
"Right Atrium": "Правое предсердие",
"Atrial Septum": "Межпредсердная перегородка",
"Mitral Valve": "Митральный клапан",
"Aortic Valve": "Аортальный клапан",
"Tricuspid Valve": "Трёхстворчатый клапан",
"Pulmonic Valve": "Клапан лёгочной артерии",
"Pericardium": "Перикард",
"Aorta": "Аорта",
"IVC": "Нижняя полая вена",
"Pulmonary Artery": "Лёгочная артерия",
"Pulmonary Veins": "Лёгочные вены",
"Postoperative Findings": "Послеоперационные изменения",
}
"""
elif self.lang=='your_language_code':
translations = {
# add your translations here
}
"""
for section, t in translations.items():
report = report.replace(section, t)
return report
def generate_report(self, study_embedding: torch.Tensor) -> str:
"""
Generate a structured multi-section clinical echocardiography report
by retrieving the most relevant candidate report section for each
cardiac section using cosine similarity over ``candidate_embeddings``.
For each cardiac section in ``non_empty_sections`` the method:
1. Applies MIL attention weights to weight each clip's embedding by its
relevance to the current section.
2. Computes a normalised mean section embedding.
3. Retrieves the highest-scoring candidate report (by cosine similarity)
that contains non-empty text for the current section, trying up to
100 candidates before moving on.
4. Appends the extracted section text to the running report string.
If ``self.lang`` is not ``'en'``, ``translate_sections`` is called on
the final report before returning.
Args:
study_embedding (torch.Tensor): Float32 tensor of shape
``(N, 523)`` as returned by ``encode_study``, where the first
512 columns are clip embeddings and the last 11 columns are
one-hot view encodings.
Returns:
str: A multi-section clinical report string. Returns the sentinel
string ``"No data available to generate report."`` when
``study_embedding`` is empty.
"""
if study_embedding.numel() == 0:
return "No data available to generate report."
print("[EchoPrime] Generating clinical report...")
# Move to CPU for processing with numpy weights
study_embedding = study_embedding.cpu()
generated_report: str = ""
for s_dx, sec in enumerate(self.non_empty_sections):
cur_weights: List[np.ndarray] = [
self.section_weights[s_dx][torch.where(ten == 1)[0]]
for ten in study_embedding[:, 512:]
]
if not cur_weights:
continue
no_view_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor(
cur_weights, dtype=torch.float
).unsqueeze(1)
no_view_study_embedding = torch.mean(no_view_study_embedding, dim=0)
no_view_study_embedding = torch.nn.functional.normalize(no_view_study_embedding, dim=0)
# --- FIX: Move vector to GPU before comparing with candidate_embeddings ---
no_view_study_embedding = no_view_study_embedding.to(self.device)
similarities: torch.Tensor = no_view_study_embedding @ self.candidate_embeddings.T
extracted_section: str = "Section not found."
attempts: int = 0
# Move similarities back to CPU for the loop logic if needed, or keep on GPU
# (Keeping on GPU is fine for argmax, but we need the index)
while extracted_section == "Section not found." and attempts < 100:
max_id: int = torch.argmax(similarities).item() # .item() gets the number cleanly
predicted_section: str = self.candidate_reports[max_id]
extracted_section = utils.extract_section(predicted_section, sec)
if extracted_section != "Section not found.":
generated_report += extracted_section
# Set the score to -infinity so we don't pick it again
similarities[max_id] = float("-inf")
attempts += 1
if self.lang != "en":
generated_report = self.translate_sections(generated_report)
return generated_report
def predict_metrics(self, study_embedding: torch.Tensor, k: int = 50) -> Dict[str, float]:
"""
Predict quantitative cardiac phenotype metrics for a study using a
*k*-nearest-neighbour (kNN) approach over the candidate study embeddings.
For each cardiac section the method:
1. Applies MIL attention weights to compute a section-specific study
embedding via weighted summation over per-clip embeddings.
2. Retrieves the top-*k* most similar candidate studies by cosine
similarity.
3. Averages the ground-truth phenotype label values from those
candidates, yielding a soft prediction for each phenotype.
Args:
study_embedding (torch.Tensor): Float32 tensor of shape
``(N, 523)`` as returned by ``encode_study``, where the first
512 columns are clip embeddings and the last 11 columns are
one-hot view encodings.
k (int): Number of nearest candidate studies to retrieve per
section when averaging label values. Defaults to ``50``.
Returns:
Dict[str, float]: Mapping from phenotype name (str) to its
predicted value (float). Phenotypes for which no candidate
labels are available evaluate to ``numpy.nan``. Returns an
empty dict ``{}`` when ``study_embedding`` is empty.
"""
if study_embedding.numel() == 0:
return {}
print("[EchoPrime] Predicting metrics...")
# Calculate on CPU because weights are numpy/CPU
per_section_study_embedding: torch.Tensor = torch.zeros(len(self.non_empty_sections), 512)
study_embedding = study_embedding.cpu()
for s_dx, sec in enumerate(self.non_empty_sections):
this_section_weights: List[np.ndarray] = [
self.section_weights[s_dx][torch.where(view_encoding == 1)[0]]
for view_encoding in study_embedding[:, 512:]
]
if not this_section_weights:
continue
this_section_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor(
this_section_weights, dtype=torch.float
).unsqueeze(1)
this_section_study_embedding = torch.sum(this_section_study_embedding, dim=0)
per_section_study_embedding[s_dx] = this_section_study_embedding
per_section_study_embedding = torch.nn.functional.normalize(per_section_study_embedding)
# --- FIX: Move matrix to GPU before comparing ---
per_section_study_embedding = per_section_study_embedding.to(self.device)
similarities: torch.Tensor = per_section_study_embedding @ self.candidate_embeddings.T
top_candidate_ids: torch.Tensor = (
torch.topk(similarities, k=k, dim=1).indices.cpu()
) # Move indices back to CPU for list access
preds: Dict[str, float] = {}
for s_dx, section in enumerate(self.section_to_phenotypes.keys()):
for pheno in self.section_to_phenotypes[section]:
# Calculate mean
values: List[float] = [
self.candidate_labels[pheno][self.candidate_studies[c_ids]]
for c_ids in top_candidate_ids[s_dx]
if self.candidate_studies[c_ids] in self.candidate_labels[pheno]
]
preds[pheno] = np.nanmean(values) if values else np.nan
return preds
class EchoPrimeTextEncoder(torch.nn.Module):
"""
BiomedBERT-based text encoder that projects clinical echocardiography
report text into the 512-dimensional embedding space shared with
``EchoPrime``'s visual encoder.
The backbone is a ``BiomedNLP-BiomedBERT-base-uncased-abstract`` masked
language model whose ``[CLS]`` token representation is linearly projected
from 768 to 512 dimensions. All forward passes are wrapped in
``torch.no_grad()``.
When an input tokenises to more than 512 tokens the encoder randomly
samples a 512-token window aligned to sentence boundaries (``[SEP]``
tokens) so that the window always starts and ends at a sentence boundary.
Attributes:
device (str): Identifier of the compute device (e.g. ``"cuda"``).
backbone (transformers.BertForMaskedLM): BiomedBERT backbone model.
text_projection (torch.nn.Linear): Linear layer mapping the 768-dim
``[CLS]`` representation to 512 dimensions.
tokenizer (transformers.BertTokenizer): Tokenizer paired with the
BiomedBERT backbone; ``max_length`` is set to 512.
"""
def __init__(self, device: str = "cuda") -> None:
"""
Initialise the text encoder by loading BiomedBERT weights and
tokenizer from the Hugging Face Hub.
Args:
device (str): Compute device identifier passed to ``self.to()``.
Defaults to ``"cuda"``.
"""
super().__init__()
self.device: str = device
config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained(
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
)
self.backbone: transformers.BertForMaskedLM = transformers.AutoModelForMaskedLM.from_config(config)
self.text_projection: torch.nn.Linear = torch.nn.Linear(768, 512)
self.tokenizer: transformers.BertTokenizer = transformers.AutoTokenizer.from_pretrained(
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
)
self.tokenizer.max_length = 512
self.to(device)
def forward(self, report: str) -> torch.Tensor:
"""
Encode a clinical report string into a 512-dimensional embedding.
Tokenises ``report`` with padding and truncation to 512 tokens. If the
tokenised length exceeds 512 tokens, a random 512-token window aligned
to ``[SEP]`` token positions is sampled before encoding, preserving
sentence boundaries at both the start and end of the window.
The ``[CLS]`` token hidden state from the final BiomedBERT layer is
extracted and projected to 512 dimensions via ``text_projection``.
Args:
report (str): Raw clinical echocardiography report text.
Returns:
torch.Tensor: Float32 tensor of shape ``(1, 512)`` containing the
L2-unnormalised text embedding for ``report``.
"""
text: transformers.BatchEncoding = self.tokenizer(
report,
padding="max_length", # Pad to max_length
max_length=512, # Set the maximum length to 512 tokens
truncation=True, # Truncate if the input is longer than max_length,
return_tensors="pt",
)
if text["input_ids"].shape[1] > 512:
# find sep token positions
sep_positions: List[int] = list(
torch.where(text["input_ids"].squeeze(0) == 3)[0].numpy()
)
# get maximum possible start that's not going to run out of tokens
max_start: int = sep_positions[-1] - 512
possible_starts: List[int] = [pos for pos in sep_positions if pos < max_start]
# add 0 as a possible start
possible_starts.insert(0, 0)
start: int = possible_starts[random.randint(0, len(possible_starts) - 1)]
max_end: int = start + 512
end: int = start # initialised to satisfy linters; always overwritten below
# find the first number less than max_end in sep_position
for p in reversed(sep_positions):
if p <= max_end:
end = p
break
# finally cut the tokens
text = transformers.BatchEncoding(
data={k: v[:, start:end] for (k, v) in text.items()}
)
with torch.no_grad():
text.to(self.device)
text_emb: torch.Tensor = self.text_projection(
self.backbone(**text, output_hidden_states=True).hidden_states[-1][:, 0, :]
)
return text_emb