File size: 38,106 Bytes
a9c2e29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b9d1e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
# Standard library imports
import os
import math
import glob
import json
import pickle
import random
import sys
from typing import AnyStr, List, Any, Dict, Optional, Union

# Third-party library imports
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import pydicom
import sklearn
import sklearn.metrics
import transformers

# Local module imports
import utils


class EchoPrime:
    """
    EchoPrime is an echocardiography AI model that encodes cardiac ultrasound
    studies (DICOM or MP4) into embeddings, classifies echocardiographic views,
    generates structured clinical reports, and predicts quantitative cardiac
    metrics via multi-instance learning (MIL) over a candidate study database.

    Attributes:
        base_dir (str): Absolute path to the EchoPrime project root directory.
        echo_encoder (torchvision.models.video.MViT): Frozen MViT-v2-S video
            encoder producing 512-dimensional embeddings per video clip.
        view_classifier (torchvision.models.ConvNeXt): Frozen ConvNeXt-Base
            image classifier predicting one of 11 echocardiographic views from
            the first frame of each clip.
        frames_to_take (int): Number of frames sampled from each video clip (32).
        frame_stride (int): Temporal stride applied when sampling frames (2).
        video_size (int): Spatial resolution (height and width) videos are
            resized to before encoding (224 pixels).
        mean (torch.Tensor): Per-channel pixel mean used for normalisation,
            shape (3, 1, 1, 1).
        std (torch.Tensor): Per-channel pixel standard deviation used for
            normalisation, shape (3, 1, 1, 1).
        device (torch.device): Compute device (CUDA if available, else CPU).
        lang (str): ISO 639-1 language code controlling report output language.
        MIL_weights (pd.DataFrame): CSV-loaded table of per-section MIL
            attention weights, shape (n_sections, n_views + 1).
        non_empty_sections (pd.Series): Ordered sequence of cardiac section
            names derived from the first column of ``MIL_weights``.
        section_weights (np.ndarray): Numeric weight matrix extracted from
            ``MIL_weights``, shape (n_sections, n_views).
        candidate_studies (List[str]): Ordered list of candidate study
            identifiers used for nearest-neighbour retrieval.
        candidate_embeddings (torch.Tensor): Concatenated embeddings for all
            candidate studies, shape (N_candidates, 512), on ``device``.
        candidate_reports (List[str]): Decoded text reports for each candidate
            study, aligned index-wise with ``candidate_studies``.
        candidate_labels (pd.DataFrame): Ground-truth phenotype labels for each
            candidate study, indexed by study identifier.
        section_to_phenotypes (Dict[str, List[str]]): Mapping from cardiac
            section name to the list of phenotype labels predicted for that
            section.
    """

    def __init__(self, device: Optional[torch.device] = None, lang: str = "en") -> None:
        """
        Initialise EchoPrime by loading model weights, normalisation statistics,
        MIL attention weights, and candidate study data.

        Args:
            device (Optional[torch.device]): Compute device to use. When
                ``None`` (default), CUDA is used if available, otherwise CPU.
            lang (str): ISO 639-1 language code for report generation.
                Supported values include ``'en'`` (default), ``'it'``,
                ``'bs'``, and ``'ru'``.

        Raises:
            FileNotFoundError: If the echo encoder weights file cannot be
                located at the expected path relative to ``base_dir``.
        """
        self.base_dir: str = os.getenv("ECHOPRIME_ROOT_OVERRIDE") or os.path.dirname(os.path.abspath(__file__))

        def get_path(rel_path: str) -> str:
            """
            Resolve a path relative to the EchoPrime project root.

            Args:
                rel_path (str): Relative path from the project root.

            Returns:
                str: Absolute path formed by joining ``base_dir`` and
                    ``rel_path``.
            """
            return os.path.join(self.base_dir, rel_path)

        print(f"[EchoPrime] Initializing... (Root dir: {self.base_dir})")

        # load language specific files
        utils.initialize_language(lang)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[EchoPrime] Using device: {device}")

        # LOAD MODEL WEIGHTS
        weights_path: str = get_path("model_data/weights/echo_prime_encoder.pt")
        if not os.path.exists(weights_path):
            # Fallback: Print the exact path we tried so you can debug
            print(f"[ERROR] Expected weights at: {weights_path}")
            raise FileNotFoundError(f"Could not find model weights. Check the path above.")

        checkpoint: Dict[str, torch.Tensor] = torch.load(weights_path, map_location=device)
        echo_encoder: torchvision.models.video.MViT = torchvision.models.video.mvit_v2_s()
        echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512)
        echo_encoder.load_state_dict(checkpoint)
        echo_encoder.eval()
        echo_encoder.to(device)
        for param in echo_encoder.parameters():
            param.requires_grad = False

        vc_state_dict: Dict[str, torch.Tensor] = torch.load(
            get_path("model_data/weights/view_classifier.pt"), map_location=device
        )
        view_classifier: torchvision.models.ConvNeXt = torchvision.models.convnext_base()
        view_classifier.classifier[-1] = torch.nn.Linear(
            view_classifier.classifier[-1].in_features, 11
        )
        view_classifier.load_state_dict(vc_state_dict)

        view_classifier.to(device)
        view_classifier.eval()
        for param in view_classifier.parameters():
            param.requires_grad = False

        self.echo_encoder: torchvision.models.video.MViT = echo_encoder
        self.view_classifier: torchvision.models.ConvNeXt = view_classifier
        self.frames_to_take: int = 32
        self.frame_stride: int = 2
        self.video_size: int = 224
        self.mean: torch.Tensor = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1)
        self.std: torch.Tensor = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1)
        self.device: torch.device = device
        self.lang: str = lang

        # LOAD ASSETS
        print("[EchoPrime] Loading assets...")
        self.MIL_weights: pd.DataFrame = pd.read_csv(get_path("assets/MIL_weights.csv"))
        self.non_empty_sections: pd.Series = self.MIL_weights["Section"]
        self.section_weights: np.ndarray = self.MIL_weights.iloc[:, 1:].to_numpy()

        self.candidate_studies: List[str] = list(
            pd.read_csv(get_path("model_data/candidates_data/candidate_studies.csv"))["Study"]
        )
        candidate_embeddings_p1: torch.Tensor = torch.load(
            get_path("model_data/candidates_data/candidate_embeddings_p1.pt"), map_location=device
        )
        candidate_embeddings_p2: torch.Tensor = torch.load(
            get_path("model_data/candidates_data/candidate_embeddings_p2.pt"), map_location=device
        )
        self.candidate_embeddings: torch.Tensor = torch.cat(
            (candidate_embeddings_p1, candidate_embeddings_p2), dim=0
        )

        candidate_reports: pd.Series = pd.read_pickle(
            get_path("model_data/candidates_data/candidate_reports.pkl")
        )
        self.candidate_reports: List[str] = [utils.phrase_decode(vec_phr) for vec_phr in candidate_reports]

        self.candidate_labels: pd.DataFrame = pd.read_pickle(
            get_path("model_data/candidates_data/candidate_labels.pkl")
        )
        self.section_to_phenotypes: Dict[str, List[str]] = pd.read_pickle(
            get_path("assets/section_to_phenotypes.pkl")
        )
        print("[EchoPrime] Initialization Complete.")

    def process_dicoms(self, INPUT: str) -> Union[torch.Tensor, List[torch.Tensor]]:
        """
        Scan a directory tree for DICOM video files, decode each file's pixel
        data, apply spatial pre-processing and temporal sampling, and return a
        stacked tensor ready for ``embed_videos``.

        Static 2D images (``pixels.ndim < 3``) and static RGB screenshots
        (shape ``(H, W, 3)``) are automatically detected and skipped with an
        informational message.

        Args:
            INPUT (str): Path to a directory (searched recursively) that
                contains ``.dcm`` files.

        Returns:
            Union[torch.Tensor, List[torch.Tensor]]: A float32 tensor of shape
                ``(N, 3, 16, H, W)`` where *N* is the number of successfully
                processed video DICOMs, *H* = *W* = ``video_size`` (224), and
                the temporal dimension is ``frames_to_take // frame_stride``
                (16). Returns ``torch.empty(0)`` when no valid DICOMs are found
                or all files fail processing.
        """
        print(f"[EchoPrime] Scanning for DICOMs in: {INPUT}")
        dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.dcm", recursive=True)

        if not dicom_paths:
            print(f"[ERROR] No .dcm files found in {INPUT}")
            return torch.empty(0)

        stack_of_videos: List[torch.Tensor] = []
        skipped_count: int = 0

        print(f"Found {len(dicom_paths)} DICOM files. Processing...")

        for idx, dicom_path in tqdm(enumerate(dicom_paths), total=len(dicom_paths), desc="Processing"):
            try:
                dcm: pydicom.dataset.FileDataset = pydicom.dcmread(dicom_path)
                pixels: np.ndarray = dcm.pixel_array

                # --- VERIFICATION PRINT START ---
                # Check for 2D images (Height, Width) -> No time dimension
                if pixels.ndim < 3:
                    # Print only the filename, not the whole path, to keep it clean
                    fname: str = os.path.basename(dicom_path)
                    print(f"  > Skipped {fname}: Static 2D Image (Shape: {pixels.shape})")
                    skipped_count += 1
                    continue

                # Check for RGB static images (Height, Width, 3) -> 3rd dim is color, not time
                if pixels.ndim == 3 and pixels.shape[2] == 3:
                    fname = os.path.basename(dicom_path)
                    print(f"  > Skipped {fname}: Static RGB Screenshot (Shape: {pixels.shape})")
                    skipped_count += 1
                    continue
                # --- VERIFICATION PRINT END ---

                if pixels.ndim == 3:
                    pixels = np.repeat(pixels[..., None], 3, axis=3)

                pixels = utils.mask_outside_ultrasound(dcm.pixel_array)

                x: np.ndarray = np.zeros((len(pixels), 224, 224, 3))
                for i in range(len(x)):
                    x[i] = utils.crop_and_scale(pixels[i])

                x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2])
                x_tensor.sub_(self.mean).div_(self.std)

                if x_tensor.shape[1] < self.frames_to_take:
                    padding: torch.Tensor = torch.zeros(
                        (3, self.frames_to_take - x_tensor.shape[1], self.video_size, self.video_size),
                        dtype=torch.float,
                    )
                    x_tensor = torch.cat((x_tensor, padding), dim=1)

                start: int = 0
                processed_video: torch.Tensor = x_tensor[
                    :, start : (start + self.frames_to_take) : self.frame_stride, :, :
                ]
                stack_of_videos.append(processed_video)

            except Exception as e:
                print(f"Corrupt file {dicom_path}: {e}")
                pass

        if len(stack_of_videos) == 0:
            print("[ERROR] Found DICOMs but failed to process ANY of them.")
            return torch.empty(0)

        stacked: torch.Tensor = torch.stack(stack_of_videos)
        print(f"\n[Summary] Total: {len(dicom_paths)} | Processed: {len(stacked)} | Skipped: {skipped_count}")
        return stacked

    def process_mp4s(self, INPUT: str) -> torch.Tensor:
        """
        Scan a directory tree for MP4 video files, decode each file's frame
        data, apply spatial pre-processing and temporal sampling, and return a
        stacked tensor ready for ``embed_videos``.

        Args:
            INPUT (str): Path to a directory (searched recursively) that
                contains ``.mp4`` files.

        Returns:
            torch.Tensor: A float32 tensor of shape ``(N, 3, 16, H, W)`` where
                *N* is the number of successfully processed MP4 files, *H* =
                *W* = ``video_size`` (224), and the temporal dimension is
                ``frames_to_take // frame_stride`` (16). Corrupt files are
                silently skipped.
        """
        dicom_paths: List[str] = glob.glob(f"{INPUT}/**/*.mp4", recursive=True)
        stack_of_videos: List[torch.Tensor] = []
        for idx, dicom_path in enumerate(dicom_paths):
            try:
                # simple dicom_processing
                pixels_raw: torch.Tensor
                metadata: Dict[str, Any]
                pixels_raw, _, metadata = torchvision.io.read_video(dicom_path)
                fps: float = metadata["video_fps"]
                pixels: np.ndarray = np.array(pixels_raw)

                # model specific preprocessing
                x: np.ndarray = np.zeros((len(pixels), 224, 224, 3))
                for i in range(len(x)):
                    x[i] = utils.crop_and_scale(pixels[i])

                x_tensor: torch.Tensor = torch.as_tensor(x, dtype=torch.float).permute([3, 0, 1, 2])
                # normalize
                x_tensor.sub_(self.mean).div_(self.std)

                ## if not enough frames add padding
                if x_tensor.shape[1] < self.frames_to_take:
                    padding: torch.Tensor = torch.zeros(
                        (
                            3,
                            self.frames_to_take - x_tensor.shape[1],
                            self.video_size,
                            self.video_size,
                        ),
                        dtype=torch.float,
                    )
                    x_tensor = torch.cat((x_tensor, padding), dim=1)

                start: int = 0
                stack_of_videos.append(
                    x_tensor[:, start : (start + self.frames_to_take) : self.frame_stride, :, :]
                )

            except Exception as e:
                print("corrupt file")
                print(str(e))

        stacked: torch.Tensor = torch.stack(stack_of_videos)

        return stacked

    def embed_videos(self, stack_of_videos: torch.Tensor) -> torch.Tensor:
        """
        Pass a stack of pre-processed video clips through the frozen echo
        encoder in batches and return the resulting feature embeddings.

        Videos are forwarded through the encoder in bins of 50 to avoid
        out-of-memory errors on large studies. Gradient computation is
        disabled throughout.

        Args:
            stack_of_videos (torch.Tensor): Float32 tensor of shape
                ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
                ``process_mp4s``.

        Returns:
            torch.Tensor: Float32 feature tensor of shape ``(N, 512)``
                containing one 512-dimensional embedding per input clip.
                Returns ``torch.empty(0)`` if ``stack_of_videos`` contains no
                elements.
        """
        if stack_of_videos.numel() == 0:
            return torch.empty(0)

        bin_size: int = 50
        n_bins: int = math.ceil(stack_of_videos.shape[0] / bin_size)
        stack_of_features_list: List[torch.Tensor] = []
        with torch.no_grad():
            for bin_idx in range(n_bins):
                start_idx: int = bin_idx * bin_size
                end_idx: int = min((bin_idx + 1) * bin_size, stack_of_videos.shape[0])
                bin_videos: torch.Tensor = stack_of_videos[start_idx:end_idx].to(self.device)
                bin_features: torch.Tensor = self.echo_encoder(bin_videos)
                stack_of_features_list.append(bin_features)
            stack_of_features: torch.Tensor = torch.cat(stack_of_features_list, dim=0)
        return stack_of_features

    def get_views(
        self,
        stack_of_videos: torch.Tensor,
        visualize: bool = False,
        return_view_list: bool = False,
    ) -> Union[torch.Tensor, List[str]]:
        """
        Predict the echocardiographic view for each video clip using the frozen
        view classifier applied to the first frame of each clip.

        Args:
            stack_of_videos (torch.Tensor): Float32 tensor of shape
                ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
                ``process_mp4s``.
            visualize (bool): When ``True``, display a grid of first frames
                annotated with their predicted view label using matplotlib and
                OpenCV. Defaults to ``False``.
            return_view_list (bool): When ``True``, return a plain
                ``List[str]`` of human-readable view names instead of the
                one-hot encoded tensor. Defaults to ``False``.

        Returns:
            Union[torch.Tensor, List[str]]:
                - If ``return_view_list`` is ``False`` (default): a
                  ``torch.Tensor`` of shape ``(N, 11)`` containing one-hot
                  view encodings on ``self.device``.
                - If ``return_view_list`` is ``True``: a ``List[str]`` of
                  length *N* with coarse view name strings.
                - ``torch.empty(0)`` when ``stack_of_videos`` contains no
                  elements.
        """
        if stack_of_videos.numel() == 0:
            return torch.empty(0)

        stack_of_first_frames: torch.Tensor = stack_of_videos[:, :, 0, :, :].to(self.device)
        with torch.no_grad():
            out_logits: torch.Tensor = self.view_classifier(stack_of_first_frames)
        out_views: torch.Tensor = torch.argmax(out_logits, dim=1)
        view_list: List[str] = [utils.COARSE_VIEWS[v] for v in out_views]
        stack_of_view_encodings: torch.Tensor = (
            torch.stack([torch.nn.functional.one_hot(out_views, 11)]).squeeze(0).to(self.device)
        )

        if visualize:
            # FIX: Robust row calculation
            cols: int = 12
            rows: int = (len(view_list) + cols - 1) // cols
            print(f"[EchoPrime] Visualizing {len(view_list)} views in grid {rows}x{cols}")

            fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
            axes = axes.flatten()
            for i in range(len(view_list)):
                display_image: np.ndarray = (
                    stack_of_first_frames[i].cpu().permute([1, 2, 0]) * 255
                ).numpy()
                display_image = np.clip(display_image, 0, 255).astype("uint8")
                display_image = np.ascontiguousarray(display_image)
                display_image = cv2.cvtColor(display_image, cv2.COLOR_RGB2BGR)
                cv2.putText(
                    display_image,
                    view_list[i].replace("_", " "),
                    (10, 25),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.7,
                    (0, 220, 255),
                    2,
                )
                axes[i].imshow(display_image)
                axes[i].axis("off")

            for j in range(i + 1, len(axes)):
                axes[j].axis("off")
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            plt.show()

        if return_view_list:
            return view_list

        return stack_of_view_encodings

    @torch.no_grad()
    def encode_study(
        self, stack_of_videos: torch.Tensor, visualize: bool = False
    ) -> torch.Tensor:
        """
        Produce a per-clip study encoding by concatenating visual embeddings
        from the echo encoder with one-hot view encodings from the view
        classifier.

        This is the primary encoding step that aggregates both *what is shown*
        (clip embedding) and *which view it belongs to* (view encoding) into a
        unified representation used downstream by ``generate_report`` and
        ``predict_metrics``.

        Args:
            stack_of_videos (torch.Tensor): Float32 tensor of shape
                ``(N, 3, T, H, W)`` as produced by ``process_dicoms`` or
                ``process_mp4s``.
            visualize (bool): When ``True``, pass through to ``get_views`` to
                render an annotated view grid. Defaults to ``False``.

        Returns:
            torch.Tensor: Float32 tensor of shape ``(N, 523)`` where the first
                512 columns are clip embeddings and the remaining 11 columns are
                one-hot view encodings. Returns ``torch.empty(0)`` when
                ``stack_of_videos`` contains no elements.
        """
        if stack_of_videos.numel() == 0:
            print("[ERROR] Cannot encode empty video stack.")
            return torch.empty(0)

        stack_of_features: torch.Tensor = self.embed_videos(stack_of_videos)
        stack_of_view_encodings: torch.Tensor = self.get_views(stack_of_videos, visualize)
        encoded_study: torch.Tensor = torch.cat((stack_of_features, stack_of_view_encodings), dim=1)
        return encoded_study

    def translate_sections(self, report: str) -> str:
        """
        Translate anatomical section headings in a generated English report into
        the language specified by ``self.lang``.

        Only the section header strings (e.g. ``"Left Ventricle"``) are
        replaced; the body text of each section is left unchanged. If
        ``self.lang`` is not a recognised code, the report is returned
        unmodified.

        Supported language codes: ``'it'`` (Italian), ``'bs'`` (Bosnian),
        ``'ru'`` (Russian).

        Args:
            report (str): Full clinical report text in English as returned by
                ``generate_report``.

        Returns:
            str: Report with anatomical section headings replaced by their
                translated equivalents. Returns the original ``report``
                unchanged when no translation mapping is available for
                ``self.lang``.
        """
        translations: Dict[str, str] = {}

        if self.lang == "it":
            translations = {
                "Left Ventricle": "Ventricolo Sinistro",
                "Resting Segmental Wall Motion Analysis": "Cinetica Segmentaria a Riposo",
                "Right Ventricle": "Ventricolo Destro",
                "Left Atrium": "Atrio Sinistro",
                "Right Atrium": "Atrio Destro",
                "Atrial Septum": "Setto Inter-Atriale",
                "Mitral Valve": "Valvola Mitrale",
                "Aortic Valve": "Valvola Aortica",
                "Tricuspid Valve": "Valvola Tricuspide",
                "Pulmonic Valve": "Valvola Polmonare",
                "Pericardium": "Pericardio",
                "Aorta": "Aorta",
                "IVC": "Vena Cava Inferiore",
                "Pulmonary Artery": "Arteria Polmonare",
                "Pulmonary Veins": "Vene Polmonari",
                "Postoperative Findings": "Esiti Post-Operatori",
            }
        elif self.lang == "bs":
            translations = {
                "Left Ventricle": "Lijeva komora",
                "Resting Segmental Wall Motion Analysis": "Analiza segmentalne pokretljivosti stijenke u mirovanju",
                "Right Ventricle": "Desna komora",
                "Left Atrium": "Lijeva pretkomora",
                "Right Atrium": "Desna pretkomora",
                "Atrial Septum": "Interatrijski septum",
                "Mitral Valve": "Mitralni zalisak",
                "Aortic Valve": "Aortni zalisak",
                "Tricuspid Valve": "Trikuspidalni zalisak",
                "Pulmonic Valve": "Pulmonalni zalisak",
                "Pericardium": "Perikard",
                "Aorta": "Aorta",
                "IVC": "Donja šuplja vena",
                "Pulmonary Artery": "Plućna arterija",
                "Pulmonary Veins": "Plućne vene",
                "Postoperative Findings": "Postoperativni nalazi",
            }

        elif self.lang == "ru":
            translations = {
                "Left Ventricle": "Левый желудочек",
                "Resting Segmental Wall Motion Analysis": "Анализ сегментарной сократимости в покое",
                "Right Ventricle": "Правый желудочек",
                "Left Atrium": "Левое предсердие",
                "Right Atrium": "Правое предсердие",
                "Atrial Septum": "Межпредсердная перегородка",
                "Mitral Valve": "Митральный клапан",
                "Aortic Valve": "Аортальный клапан",
                "Tricuspid Valve": "Трёхстворчатый клапан",
                "Pulmonic Valve": "Клапан лёгочной артерии",
                "Pericardium": "Перикард",
                "Aorta": "Аорта",
                "IVC": "Нижняя полая вена",
                "Pulmonary Artery": "Лёгочная артерия",
                "Pulmonary Veins": "Лёгочные вены",
                "Postoperative Findings": "Послеоперационные изменения",
            }
        """
        elif self.lang=='your_language_code':
            translations = {
                # add your translations here
            }
        """

        for section, t in translations.items():
            report = report.replace(section, t)

        return report

    def generate_report(self, study_embedding: torch.Tensor) -> str:
        """
        Generate a structured multi-section clinical echocardiography report
        by retrieving the most relevant candidate report section for each
        cardiac section using cosine similarity over ``candidate_embeddings``.

        For each cardiac section in ``non_empty_sections`` the method:

        1. Applies MIL attention weights to weight each clip's embedding by its
           relevance to the current section.
        2. Computes a normalised mean section embedding.
        3. Retrieves the highest-scoring candidate report (by cosine similarity)
           that contains non-empty text for the current section, trying up to
           100 candidates before moving on.
        4. Appends the extracted section text to the running report string.

        If ``self.lang`` is not ``'en'``, ``translate_sections`` is called on
        the final report before returning.

        Args:
            study_embedding (torch.Tensor): Float32 tensor of shape
                ``(N, 523)`` as returned by ``encode_study``, where the first
                512 columns are clip embeddings and the last 11 columns are
                one-hot view encodings.

        Returns:
            str: A multi-section clinical report string. Returns the sentinel
                string ``"No data available to generate report."`` when
                ``study_embedding`` is empty.
        """
        if study_embedding.numel() == 0:
            return "No data available to generate report."

        print("[EchoPrime] Generating clinical report...")
        # Move to CPU for processing with numpy weights
        study_embedding = study_embedding.cpu()
        generated_report: str = ""

        for s_dx, sec in enumerate(self.non_empty_sections):
            cur_weights: List[np.ndarray] = [
                self.section_weights[s_dx][torch.where(ten == 1)[0]]
                for ten in study_embedding[:, 512:]
            ]

            if not cur_weights:
                continue

            no_view_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor(
                cur_weights, dtype=torch.float
            ).unsqueeze(1)
            no_view_study_embedding = torch.mean(no_view_study_embedding, dim=0)
            no_view_study_embedding = torch.nn.functional.normalize(no_view_study_embedding, dim=0)

            # --- FIX: Move vector to GPU before comparing with candidate_embeddings ---
            no_view_study_embedding = no_view_study_embedding.to(self.device)

            similarities: torch.Tensor = no_view_study_embedding @ self.candidate_embeddings.T

            extracted_section: str = "Section not found."
            attempts: int = 0
            # Move similarities back to CPU for the loop logic if needed, or keep on GPU
            # (Keeping on GPU is fine for argmax, but we need the index)

            while extracted_section == "Section not found." and attempts < 100:
                max_id: int = torch.argmax(similarities).item()  # .item() gets the number cleanly
                predicted_section: str = self.candidate_reports[max_id]
                extracted_section = utils.extract_section(predicted_section, sec)
                if extracted_section != "Section not found.":
                    generated_report += extracted_section

                # Set the score to -infinity so we don't pick it again
                similarities[max_id] = float("-inf")
                attempts += 1

        if self.lang != "en":
            generated_report = self.translate_sections(generated_report)

        return generated_report

    def predict_metrics(self, study_embedding: torch.Tensor, k: int = 50) -> Dict[str, float]:
        """
        Predict quantitative cardiac phenotype metrics for a study using a
        *k*-nearest-neighbour (kNN) approach over the candidate study embeddings.

        For each cardiac section the method:

        1. Applies MIL attention weights to compute a section-specific study
           embedding via weighted summation over per-clip embeddings.
        2. Retrieves the top-*k* most similar candidate studies by cosine
           similarity.
        3. Averages the ground-truth phenotype label values from those
           candidates, yielding a soft prediction for each phenotype.

        Args:
            study_embedding (torch.Tensor): Float32 tensor of shape
                ``(N, 523)`` as returned by ``encode_study``, where the first
                512 columns are clip embeddings and the last 11 columns are
                one-hot view encodings.
            k (int): Number of nearest candidate studies to retrieve per
                section when averaging label values. Defaults to ``50``.

        Returns:
            Dict[str, float]: Mapping from phenotype name (str) to its
                predicted value (float). Phenotypes for which no candidate
                labels are available evaluate to ``numpy.nan``. Returns an
                empty dict ``{}`` when ``study_embedding`` is empty.
        """
        if study_embedding.numel() == 0:
            return {}

        print("[EchoPrime] Predicting metrics...")
        # Calculate on CPU because weights are numpy/CPU
        per_section_study_embedding: torch.Tensor = torch.zeros(len(self.non_empty_sections), 512)
        study_embedding = study_embedding.cpu()

        for s_dx, sec in enumerate(self.non_empty_sections):
            this_section_weights: List[np.ndarray] = [
                self.section_weights[s_dx][torch.where(view_encoding == 1)[0]]
                for view_encoding in study_embedding[:, 512:]
            ]

            if not this_section_weights:
                continue

            this_section_study_embedding: torch.Tensor = study_embedding[:, :512] * torch.tensor(
                this_section_weights, dtype=torch.float
            ).unsqueeze(1)

            this_section_study_embedding = torch.sum(this_section_study_embedding, dim=0)
            per_section_study_embedding[s_dx] = this_section_study_embedding

        per_section_study_embedding = torch.nn.functional.normalize(per_section_study_embedding)

        # --- FIX: Move matrix to GPU before comparing ---
        per_section_study_embedding = per_section_study_embedding.to(self.device)

        similarities: torch.Tensor = per_section_study_embedding @ self.candidate_embeddings.T

        top_candidate_ids: torch.Tensor = (
            torch.topk(similarities, k=k, dim=1).indices.cpu()
        )  # Move indices back to CPU for list access

        preds: Dict[str, float] = {}
        for s_dx, section in enumerate(self.section_to_phenotypes.keys()):
            for pheno in self.section_to_phenotypes[section]:
                # Calculate mean
                values: List[float] = [
                    self.candidate_labels[pheno][self.candidate_studies[c_ids]]
                    for c_ids in top_candidate_ids[s_dx]
                    if self.candidate_studies[c_ids] in self.candidate_labels[pheno]
                ]

                preds[pheno] = np.nanmean(values) if values else np.nan

        return preds


class EchoPrimeTextEncoder(torch.nn.Module):
    """
    BiomedBERT-based text encoder that projects clinical echocardiography
    report text into the 512-dimensional embedding space shared with
    ``EchoPrime``'s visual encoder.

    The backbone is a ``BiomedNLP-BiomedBERT-base-uncased-abstract`` masked
    language model whose ``[CLS]`` token representation is linearly projected
    from 768 to 512 dimensions. All forward passes are wrapped in
    ``torch.no_grad()``.

    When an input tokenises to more than 512 tokens the encoder randomly
    samples a 512-token window aligned to sentence boundaries (``[SEP]``
    tokens) so that the window always starts and ends at a sentence boundary.

    Attributes:
        device (str): Identifier of the compute device (e.g. ``"cuda"``).
        backbone (transformers.BertForMaskedLM): BiomedBERT backbone model.
        text_projection (torch.nn.Linear): Linear layer mapping the 768-dim
            ``[CLS]`` representation to 512 dimensions.
        tokenizer (transformers.BertTokenizer): Tokenizer paired with the
            BiomedBERT backbone; ``max_length`` is set to 512.
    """

    def __init__(self, device: str = "cuda") -> None:
        """
        Initialise the text encoder by loading BiomedBERT weights and
        tokenizer from the Hugging Face Hub.

        Args:
            device (str): Compute device identifier passed to ``self.to()``.
                Defaults to ``"cuda"``.
        """
        super().__init__()
        self.device: str = device
        config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained(
            "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
        )
        self.backbone: transformers.BertForMaskedLM = transformers.AutoModelForMaskedLM.from_config(config)
        self.text_projection: torch.nn.Linear = torch.nn.Linear(768, 512)
        self.tokenizer: transformers.BertTokenizer = transformers.AutoTokenizer.from_pretrained(
            "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
        )
        self.tokenizer.max_length = 512
        self.to(device)

    def forward(self, report: str) -> torch.Tensor:
        """
        Encode a clinical report string into a 512-dimensional embedding.

        Tokenises ``report`` with padding and truncation to 512 tokens. If the
        tokenised length exceeds 512 tokens, a random 512-token window aligned
        to ``[SEP]`` token positions is sampled before encoding, preserving
        sentence boundaries at both the start and end of the window.

        The ``[CLS]`` token hidden state from the final BiomedBERT layer is
        extracted and projected to 512 dimensions via ``text_projection``.

        Args:
            report (str): Raw clinical echocardiography report text.

        Returns:
            torch.Tensor: Float32 tensor of shape ``(1, 512)`` containing the
                L2-unnormalised text embedding for ``report``.
        """
        text: transformers.BatchEncoding = self.tokenizer(
            report,
            padding="max_length",   # Pad to max_length
            max_length=512,         # Set the maximum length to 512 tokens
            truncation=True,        # Truncate if the input is longer than max_length,
            return_tensors="pt",
        )
        if text["input_ids"].shape[1] > 512:
            # find sep token positions
            sep_positions: List[int] = list(
                torch.where(text["input_ids"].squeeze(0) == 3)[0].numpy()
            )

            # get maximum possible start that's not going to run out of tokens
            max_start: int = sep_positions[-1] - 512
            possible_starts: List[int] = [pos for pos in sep_positions if pos < max_start]
            # add 0 as a possible start
            possible_starts.insert(0, 0)

            start: int = possible_starts[random.randint(0, len(possible_starts) - 1)]

            max_end: int = start + 512
            end: int = start  # initialised to satisfy linters; always overwritten below
            # find the first number less than max_end in sep_position
            for p in reversed(sep_positions):
                if p <= max_end:
                    end = p
                    break
            # finally cut the tokens
            text = transformers.BatchEncoding(
                data={k: v[:, start:end] for (k, v) in text.items()}
            )
        with torch.no_grad():
            text.to(self.device)
            text_emb: torch.Tensor = self.text_projection(
                self.backbone(**text, output_hidden_states=True).hidden_states[-1][:, 0, :]
            )
        return text_emb