File size: 29,595 Bytes
00a0ce5
5109422
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
 
 
00a0ce5
 
 
5109422
00a0ce5
 
 
 
5109422
00a0ce5
 
 
 
 
5109422
00a0ce5
5109422
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
 
00a0ce5
 
 
 
 
 
 
0705c62
 
 
 
 
 
00a0ce5
 
0705c62
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
 
 
 
 
 
 
 
 
 
00a0ce5
0705c62
 
 
00a0ce5
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5109422
00a0ce5
5109422
 
 
00a0ce5
 
 
 
 
0705c62
 
 
 
 
 
 
 
 
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
 
 
 
 
0705c62
00a0ce5
0705c62
00a0ce5
 
 
 
 
 
0705c62
 
 
 
 
00a0ce5
 
0705c62
00a0ce5
 
 
 
 
0705c62
 
 
 
 
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
 
 
 
 
 
 
 
 
00a0ce5
 
 
 
 
 
 
 
0705c62
 
 
 
 
 
00a0ce5
0705c62
00a0ce5
 
0705c62
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
5109422
00a0ce5
5109422
 
 
 
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
 
 
 
 
5109422
00a0ce5
5109422
00a0ce5
5109422
 
 
 
 
 
 
 
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0705c62
00a0ce5
 
 
 
bb1589d
00a0ce5
 
 
 
 
 
 
 
 
 
 
 
 
bb1589d
00a0ce5
0705c62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
"""
Утилиты для кэширования и загрузки активаций SAE.
"""

from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from analysis.models import _iqa_activations
from log_config import get_logger

logger = get_logger(__name__)


def _df_mb(df: pd.DataFrame) -> float:
    """Объём памяти DataFrame в МБ."""
    return df.memory_usage(deep=True).sum() / 1024 ** 2


def _sparse_mb(mat: sp.csr_matrix) -> float:
    """Объём памяти CSR-матрицы (данные + индексы) в МБ."""
    return (mat.data.nbytes + mat.indices.nbytes + mat.indptr.nbytes) / 1024 ** 2


def _cache_paths(base_path: str) -> Tuple[str, str, str]:
    """
    Возвращает пути к трём файлам кэша из базового пути.

    Пример:
        'cache/kadid_acts.feather'
          → ('cache/kadid_acts_meta.feather', 'cache/kadid_acts_codes.npz',
              'cache/kadid_acts_steps.npz')
    """
    p = Path(base_path)
    stem = p.stem.removesuffix('.feather')
    return (
        str(p.parent / f'{stem}_meta.feather'),
        str(p.parent / f'{stem}_codes.npz'),
        str(p.parent / f'{stem}_steps.npz'),
    )


def _pristine_cache_paths(base_path: str) -> Tuple[str, str, str]:
    """Return paths for pristine cache files derived from base cache path."""
    meta_path, codes_path, steps_path = _cache_paths(base_path)
    return (
        meta_path.replace('.feather', '_pristine.feather'),
        codes_path.replace('.npz', '_pristine.npz'),
        steps_path.replace('.npz', '_pristine.npz'),
    )


def load_parquet_cache(cache_path: Optional[str], *, label: str = 'cache') -> Optional[pd.DataFrame]:
    """Load cached parquet table if present."""
    if cache_path is None:
        return None

    cache = Path(cache_path)
    if not cache.exists():
        return None

    logger.debug('[cache] Loading %s from %s', label, cache)
    return pd.read_parquet(cache)


def save_parquet_cache(df: pd.DataFrame, cache_path: Optional[str], *, label: str = 'cache') -> None:
    """Persist a dataframe to parquet cache if path is provided."""
    if cache_path is None:
        return

    cache = Path(cache_path)
    cache.parent.mkdir(parents=True, exist_ok=True)
    df.to_parquet(cache)
    logger.debug('[cache] Saved %s to %s', label, cache)


_PATCH_LABEL_META_KEYS = frozenset({'dist_type', 'dist_group'})


def _patch_labels_to_dist_meta(
    patch_labels: np.ndarray,
    *,
    label_to_dist_type: Dict[int, str],
    label_to_dist_group: Dict[str, str],
) -> Tuple[List[str], List[str]]:
    flat_labels = patch_labels.reshape(-1)
    dist_types = [label_to_dist_type.get(int(label_id), 'background') for label_id in flat_labels]
    dist_groups = [label_to_dist_group.get(dist_type, dist_type) for dist_type in dist_types]
    return dist_types, dist_groups


def _process_dataloader(
    dataloader,
    iqa,
    sae,
    layer_name,
    scaling_factor,
    device,
    patches_per_image,
    patch_grid_shape,
    meta_keys,
    max_batches,
    max_memory_gb,
    add_patch_mask_stats,
    show_progress_bars: bool = True,
    *,
    label_to_dist_type: Optional[Dict[int, str]] = None,
    label_to_dist_group: Optional[Dict[str, str]] = None,
):
    all_sparse_codes = []
    all_meta = []
    all_sparse_steps = []

    n_patches_known = patches_per_image
    patch_grid_known = patch_grid_shape
    image_offset = 0

    for batch_i, batch in enumerate(tqdm(dataloader, desc='Caching activations', disable=not show_progress_bars)):
        if max_batches is not None and batch_i >= max_batches:
            break

        imgs = batch['images'].to(device)
        B = imgs.shape[0]

        with torch.no_grad():
            iqa(imgs)
            acts = _iqa_activations[layer_name].to(device)
            acts = acts * scaling_factor
            enc_out = sae.get_acts(acts)

            if isinstance(enc_out, tuple):
                codes, activation_steps = enc_out
            else:
                codes, activation_steps = enc_out, None

        codes_np = codes.cpu().float().numpy()

        if activation_steps is None:
            steps_np = np.zeros_like(codes_np, dtype=np.int32)
        else:
            steps_np = activation_steps.cpu().numpy().astype(np.int32)

        if n_patches_known is None:
            n_patches_known = codes_np.shape[0] // B
            logger.info('Detected %s patches per image', n_patches_known)

        P = n_patches_known

        use_patch_label_meta = (
            label_to_dist_type is not None
            and label_to_dist_group is not None
        )
        meta = {}
        for k in meta_keys:
            if k in batch:
                if use_patch_label_meta and k in _PATCH_LABEL_META_KEYS:
                    continue
                vals = batch[k]
                meta[k] = [v for v in vals for _ in range(P)]

        meta['patch_idx'] = list(range(P)) * B
        meta['image_idx'] = [image_offset + i for i in range(B) for _ in range(P)]

        if add_patch_mask_stats and 'masks' in batch:
            masks = batch['masks'].to(device=device, dtype=torch.float32)
            if patch_grid_known is not None:
                grid_h, grid_w = patch_grid_known
                if grid_h * grid_w == P:
                    mask_labels = masks.to(dtype=torch.int64)
                    max_label = int(mask_labels.max().item())
                    max_cov = None

                    if max_label <= 0:
                        patch_labels = torch.zeros((B, grid_h, grid_w), device=device, dtype=torch.int64)
                    else:
                        class_coverages = []
                        for label_id in range(1, max_label + 1):
                            label_cov = F.adaptive_avg_pool2d(
                                (mask_labels == label_id).to(dtype=torch.float32),
                                (grid_h, grid_w),
                            )
                            class_coverages.append(label_cov)

                        coverages = torch.cat(class_coverages, dim=1)  # (B, classes, H, W)
                        max_cov, max_idx = coverages.max(dim=1)
                        patch_labels = torch.where(
                            max_cov > 0,
                            max_idx.to(dtype=torch.int64) + 1,
                            torch.zeros_like(max_idx, dtype=torch.int64),
                        )

                    patch_labels_np = patch_labels.reshape(B, P).cpu().numpy().astype(np.int16)
                    patch_is_dist = (patch_labels_np > 0).astype(np.int8)
                    patch_coverage_np = max_cov.reshape(B, P).cpu().numpy() if max_label > 0 else patch_is_dist.astype(np.float32)

                    meta['patch_mask_label'] = patch_labels_np.reshape(-1).tolist()
                    meta['patch_mask_coverage'] = patch_coverage_np.reshape(-1).tolist()
                    meta['patch_is_distorted'] = patch_is_dist.reshape(-1).tolist()

                    if use_patch_label_meta:
                        dist_types, dist_groups = _patch_labels_to_dist_meta(
                            patch_labels_np,
                            label_to_dist_type=label_to_dist_type,
                            label_to_dist_group=label_to_dist_group,
                        )
                        if 'dist_type' in meta_keys:
                            meta['dist_type'] = dist_types
                        if 'dist_group' in meta_keys:
                            meta['dist_group'] = dist_groups

        all_meta.append(pd.DataFrame(meta))
        all_sparse_codes.append(sp.csr_matrix(codes_np))
        all_sparse_steps.append(sp.csr_matrix(steps_np))

        image_offset += B

    meta_df = pd.concat(all_meta, ignore_index=True)
    codes_csr = sp.vstack(all_sparse_codes, format='csr')
    steps_csr = sp.vstack(all_sparse_steps, format='csr')

    return meta_df, codes_csr, steps_csr


def collect_and_cache(
    dataloader: DataLoader,
    iqa: torch.nn.Module,
    sae,
    layer_name: str,
    output_path: str,
    scaling_factor: float = 1.0,
    patches_per_image: Optional[int] = None,
    patch_grid_shape: Optional[Tuple[int, int]] = None,
    meta_keys: Sequence[str] = (
        'dist_type',
        'dist_group',
        'dist_level',
        'mos',
        'distorted_img_path',
        'original_img_path',
        'sample_id',
    ),
    device: str = 'cuda',
    max_batches: Optional[int] = None,
    max_memory_gb: Optional[float] = None,
    add_patch_mask_stats: bool = True,
    pristine_dataloader: Optional[DataLoader] = None,
    show_progress_bars: bool = True,
    label_to_dist_type: Optional[Dict[int, str]] = None,
    label_to_dist_group: Optional[Dict[str, str]] = None,
) -> Tuple[pd.DataFrame, sp.csr_matrix]:

    meta_df, codes_csr, steps_csr = _process_dataloader(
        dataloader=dataloader,
        iqa=iqa,
        sae=sae,
        layer_name=layer_name,
        scaling_factor=scaling_factor,
        device=device,
        patches_per_image=patches_per_image,
        patch_grid_shape=patch_grid_shape,
        meta_keys=meta_keys,
        max_batches=max_batches,
        max_memory_gb=max_memory_gb,
        add_patch_mask_stats=add_patch_mask_stats,
        show_progress_bars=show_progress_bars,
        label_to_dist_type=label_to_dist_type,
        label_to_dist_group=label_to_dist_group,
    )
    sparse_mb = _sparse_mb(codes_csr)
    steps_mb = _sparse_mb(steps_csr)

    logger.info('Activations: shape=%s, %.1f МБ (sparse)', codes_csr.shape, sparse_mb)
    logger.info('Activation steps: shape=%s, %.1f МБ (sparse)', steps_csr.shape, steps_mb)

    meta_path, codes_path, steps_path = _cache_paths(output_path)

    meta_df.to_feather(meta_path)
    sp.save_npz(codes_path, codes_csr)
    sp.save_npz(steps_path, steps_csr)

    logger.info('Saved metadata (%s rows) -> %s', len(meta_df), meta_path)
    logger.info('Saved activations -> %s', codes_path)
    logger.info('Saved activation steps -> %s', steps_path)
    logger.info('  Metadata:    %.1f МБ', _df_mb(meta_df))
    logger.info('  Activations: %.1f МБ', sparse_mb)
    logger.info('  Steps:       %.1f МБ', steps_mb)

    if pristine_dataloader is not None:
        logger.info('Processing pristine dataset...')

        pristine_meta, pristine_codes, pristine_steps = _process_dataloader(
            dataloader=pristine_dataloader,
            iqa=iqa,
            sae=sae,
            layer_name=layer_name,
            scaling_factor=scaling_factor,
            device=device,
            patches_per_image=patches_per_image,
            patch_grid_shape=patch_grid_shape,
            meta_keys=meta_keys,
            max_batches=max_batches,
            max_memory_gb=max_memory_gb,
            add_patch_mask_stats=False,
            show_progress_bars=show_progress_bars,
        )
        pristine_sparse_mb = _sparse_mb(pristine_codes)
        pristine_steps_mb = _sparse_mb(pristine_steps)

        pristine_meta_path = meta_path.replace(".feather", "_pristine.feather")
        pristine_codes_path = codes_path.replace(".npz", "_pristine.npz")
        pristine_steps_path = steps_path.replace(".npz", "_pristine.npz")

        pristine_meta.to_feather(pristine_meta_path)
        sp.save_npz(pristine_codes_path, pristine_codes)
        sp.save_npz(pristine_steps_path, pristine_steps)

        logger.info(
            'Pristine activations: shape=%s, %.1f МБ',
            pristine_codes.shape,
            pristine_sparse_mb,
        )
        logger.info(
            'Pristine steps: shape=%s, %.1f МБ',
            pristine_steps.shape,
            pristine_steps_mb,
        )

        logger.info('Saved pristine metadata (%s rows) -> %s', len(pristine_meta), pristine_meta_path)
        logger.info('Saved pristine activations -> %s', pristine_codes_path)
        logger.info('Saved pristine activation steps -> %s', pristine_steps_path)

    return meta_df, codes_csr


def build_activation_cache(
    *,
    dataset: str,
    cache_path: str,
    checkpoint_path: str,
    dataset_root: str,
    layer_num: int,
    iqa_metric: str,
    swin_num: int,
    device: str,
    batch_size: int,
    num_workers: int,
    crop_size: int,
    scaling_factor: float = 1.0,
    min_distortion_level: Optional[int] = None,
    max_batches: Optional[int] = None,
    max_memory_gb: Optional[float] = None,
    add_patch_mask_stats: bool = True,
    include_pristine: bool = True,
    show_progress_bars: bool = True,
    srground_include_sr_artifact: bool = False,
) -> Dict[str, Any]:
    """Build activation cache end-to-end for KADID/local-KADID datasets."""
    from .datasets import (
        Kadid10kDataset,
        KadidPristineDataset,
        LocalKadidPresavedDataset,
        LocalKadidPristineDataset,
        QGroundDataset,
        SRGroundSmallDataset,
        available_distortions_qground,
        available_distortions_srground,
        distortion_types_mapping_qground,
        distortion_types_mapping_srground,
        kadid_collate_fn,
        kadid_pristine_collate_fn,
        local_kadid_collate_fn,
        local_kadid_pristine_collate_fn,
        qground_collate_fn,
        srground_collate_fn,
    )
    from .models import _iqa_activation_grids, load_iqa_model, load_sae, read_sae_config

    if min_distortion_level is not None and not (1 <= min_distortion_level <= 5):
        raise ValueError('min_distortion_level must be in [1, 5]')

    label_to_dist_type = None
    label_to_dist_group = None

    if dataset == 'local_kadid':
        data = LocalKadidPresavedDataset(root=dataset_root, crop_size=crop_size)
        collate_fn = local_kadid_collate_fn
        meta_keys = [
            'dist_type',
            'dist_group',
            'dist_level',
            'mos',
            'local_dist_type',
            'local_dist_level',
            'mask_shape',
            'mask_coverage',
            'sample_id',
            'distorted_img_path',
            'original_img_path',
        ]
        pristine_data = LocalKadidPristineDataset(root=dataset_root, crop_size=crop_size) if include_pristine else None
        pristine_collate = local_kadid_pristine_collate_fn
    elif dataset in {'kadid10k', 'kadid'}:
        data = Kadid10kDataset(
            root=dataset_root,
            crop_size=crop_size,
            min_distortion_level=min_distortion_level or 1,
        )
        collate_fn = kadid_collate_fn
        meta_keys = [
            'dist_type',
            'dist_group',
            'dist_level',
            'mos',
            'distorted_img_path',
            'original_img_path',
        ]
        pristine_data = KadidPristineDataset(root=dataset_root, crop_size=crop_size) if include_pristine else None
        pristine_collate = kadid_pristine_collate_fn
    elif dataset == 'QGround':
        data = QGroundDataset(
            root=dataset_root,
            split='test',
            crop_size=crop_size,
        )
        collate_fn = qground_collate_fn
        meta_keys = [
            'dist_type',
            'dist_group',
            'dist_level',
            'mos',
            'mask_coverage',
            'qground_ann_id',
            'sample_id',
            'distorted_img_path',
            'original_img_path',
            'image_path',
            'mask_path',
            'split',
        ]
        label_to_dist_type = distortion_types_mapping_qground
        label_to_dist_group = available_distortions_qground
        pristine_data = None
        pristine_collate = None
    elif dataset == 'SRGround':
        data = SRGroundSmallDataset(
            root=dataset_root,
            json_path='srground_train.json',
            crop_size=crop_size,
            include_sr_artifact=srground_include_sr_artifact,
        )
        collate_fn = srground_collate_fn
        meta_keys = [
            'dist_type',
            'dist_group',
            'dist_level',
            'mos',
            'mask_coverage',
            'sample_id',
            'distorted_img_path',
            'image_path',
            'real_distortions_ann_path',
            'sr_artifacts_ann_path',
        ]
        label_to_dist_type = distortion_types_mapping_srground
        label_to_dist_group = available_distortions_srground
        pristine_data = None
        pristine_collate = None
    else:
        raise ValueError(f'Unsupported dataset: {dataset}')

    loader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
    )

    pristine_loader = None
    if pristine_data is not None:
        pristine_loader = DataLoader(
            pristine_data,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=pristine_collate,
        )

    iqa_model, layer_name = load_iqa_model(
        layer_num=layer_num,
        device=device,
        iqa_metric=iqa_metric,
        swin_num=swin_num,
    )
    dtype = torch.float16 if iqa_metric == 'qalign' else torch.float32
    sae_cfg = read_sae_config(checkpoint_path)
    sae_model = load_sae(checkpoint_path, device=device, dtype=dtype, sae_config=sae_cfg)

    with torch.no_grad():
        dummy = torch.rand(1, 3, crop_size, crop_size, device=device).clamp(0, 1)
        iqa_model(dummy)

    if layer_name not in _iqa_activations or layer_name not in _iqa_activation_grids:
        raise RuntimeError(f'Cannot infer activation grid for layer {layer_name}')

    patch_grid_shape = _iqa_activation_grids[layer_name]
    patches_per_image = patch_grid_shape[0] * patch_grid_shape[1]

    Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
    collect_and_cache(
        dataloader=loader,
        iqa=iqa_model,
        sae=sae_model,
        layer_name=layer_name,
        output_path=cache_path,
        scaling_factor=scaling_factor,
        patches_per_image=patches_per_image,
        patch_grid_shape=patch_grid_shape,
        meta_keys=meta_keys,
        device=device,
        max_batches=max_batches,
        max_memory_gb=max_memory_gb,
        add_patch_mask_stats=add_patch_mask_stats,
        pristine_dataloader=pristine_loader,
        show_progress_bars=show_progress_bars,
        label_to_dist_type=label_to_dist_type,
        label_to_dist_group=label_to_dist_group,
    )

    return {
        'layer_name': layer_name,
        'patch_grid_shape': patch_grid_shape,
        'patches_per_image': patches_per_image,
        'sae_config': sae_cfg,
    }


def load_cache(
    path: str,
    return_activation_steps: bool = False,
    min_distortion_level: Optional[int] = None,
    max_distortion_level: Optional[int] = None,
) -> Union[
    Tuple[pd.DataFrame, sp.csr_matrix],
    Tuple[pd.DataFrame, sp.csr_matrix, sp.csr_matrix],
]:
    """Загружает кэш активаций SAE из раздельных файлов.

    Если ``return_activation_steps=True``, дополнительно возвращает CSR-матрицу
    порядка активаций, где значение ``0`` означает отсутствие активации,
    а ``k>0`` соответствует шагу ``k`` в pursuit.
    """
    meta_path, codes_path, steps_path = _cache_paths(path)
    meta = pd.read_feather(meta_path)
    codes = sp.load_npz(codes_path)

    logger.debug(
        'Loaded from %s: %s rows × %s cols',
        meta_path,
        meta.shape[0],
        meta.shape[1],
    )
    logger.debug('Loaded from %s: shape=%s, dtype=%s', codes_path, codes.shape, codes.dtype)
    logger.debug('  Metadata:    %.1f МБ', _df_mb(meta))
    logger.debug('  Activations: %.1f МБ (sparse)', _sparse_mb(codes))

    keep_idx: Optional[np.ndarray] = None
    if min_distortion_level is not None or max_distortion_level is not None:
        if 'dist_level' not in meta.columns:
            raise ValueError('Cannot filter by distortion level: metadata has no "dist_level" column')

        min_level = 1 if min_distortion_level is None else int(min_distortion_level)
        max_level = 5 if max_distortion_level is None else int(max_distortion_level)
        if min_level > max_level:
            raise ValueError(
                f'Invalid distortion-level range: min_distortion_level={min_level} > max_distortion_level={max_level}'
            )
        if 'Ground' not in path:
            keep_mask = (meta['dist_level'] >= min_level) & (meta['dist_level'] <= max_level)
            keep_idx = np.flatnonzero(keep_mask.to_numpy())
        else: 
            keep_mask = (meta['dist_level'] >= -1000) 
            keep_idx = np.flatnonzero(keep_mask.to_numpy())  # Temporary workaround -- fix later
 
    if return_activation_steps:
        if Path(steps_path).exists():
            steps = sp.load_npz(steps_path)
            if steps.shape != codes.shape:
                raise ValueError(
                    f'Steps cache shape mismatch: expected {codes.shape}, got {steps.shape}'
                )
            logger.info('Loaded from %s: shape=%s, dtype=%s', steps_path, steps.shape, steps.dtype)
        else:
            logger.warning('No steps cache found. Using all-zero activation steps.')
            steps = sp.csr_matrix(codes.shape, dtype=np.int32)

        if keep_idx is not None:
            meta = meta.iloc[keep_idx].reset_index(drop=True)
            codes = codes[keep_idx]
            steps = steps[keep_idx]
            logger.info(
                'Applied dist_level filter [%s, %s] -> %s rows kept',
                min_level,
                max_level,
                meta.shape[0],
            )

        logger.info('  Steps:       %.1f МБ (sparse)', _sparse_mb(steps))
        return meta, codes, steps

    if keep_idx is not None:
        meta = meta.iloc[keep_idx].reset_index(drop=True)
        codes = codes[keep_idx]
        logger.info(
            'Applied dist_level filter [%s, %s] -> %s rows kept',
            min_level,
            max_level,
            meta.shape[0],
        )

    return meta, codes


def load_pristine_cache(
    path: str,
    return_activation_steps: bool = False,
) -> Union[
    Tuple[pd.DataFrame, sp.csr_matrix],
    Tuple[pd.DataFrame, sp.csr_matrix, sp.csr_matrix],
]:
    """Load pristine (original-image) activation cache saved by collect_and_cache."""
    meta_path, codes_path, steps_path = _pristine_cache_paths(path)
    meta = pd.read_feather(meta_path)
    codes = sp.load_npz(codes_path)

    logger.info(
        'Loaded pristine from %s: %s rows × %s cols',
        meta_path,
        meta.shape[0],
        meta.shape[1],
    )
    logger.info('Loaded pristine from %s: shape=%s, dtype=%s', codes_path, codes.shape, codes.dtype)
    logger.info('  Metadata:    %.1f МБ', _df_mb(meta))
    logger.info('  Activations: %.1f МБ (sparse)', _sparse_mb(codes))

    if return_activation_steps:
        if Path(steps_path).exists():
            steps = sp.load_npz(steps_path)
            if steps.shape != codes.shape:
                raise ValueError(
                    f'Pristine steps cache shape mismatch: expected {codes.shape}, got {steps.shape}'
                )
            logger.info(
                'Loaded pristine from %s: shape=%s, dtype=%s',
                steps_path,
                steps.shape,
                steps.dtype,
            )
        else:
            logger.warning('No pristine steps cache found. Using all-zero activation steps.')
            steps = sp.csr_matrix(codes.shape, dtype=np.int32)

        logger.info('  Steps:       %.1f МБ (sparse)', _sparse_mb(steps))
        return meta, codes, steps

    return meta, codes


def ensure_cache_ready(
    cache_path: str,
    *,
    force_recache: bool = False,
    build_cache_if_missing: bool = True,
    load_cache_kwargs: Optional[Dict[str, Any]] = None,
    build_cache_fn: Optional[Callable[[], None]] = None,
) -> None:
    """Проверяет доступность кэша и при необходимости собирает его.

    Поведение:
    - пытается загрузить кэш через ``load_cache``;
    - если кэш отсутствует или выставлен ``force_recache=True``, запускает сборку;
    - если сборка отключена, пробрасывает ``FileNotFoundError``.
    """
    needs_rebuild = force_recache
    if not needs_rebuild:
        try:
            load_cache(cache_path, **(load_cache_kwargs or {}))
            return
        except FileNotFoundError:
            needs_rebuild = True

    if not needs_rebuild:
        return

    if not build_cache_if_missing:
        raise FileNotFoundError(
            f'Activation cache not found at {cache_path}, and build is disabled. '
            'Use --build-cache-if-missing or provide existing cache files.'
        )

    if build_cache_fn is None:
        raise ValueError(
            'build_cache_fn must be provided when cache rebuild is required '
            '(missing cache or force_recache=True).'
        )

    logger.debug('[cache] Building activation cache...')
    build_cache_fn()


def zero_codes_outside_activation_steps(
    codes_csr: sp.csr_matrix,
    activation_steps_csr: sp.csr_matrix,
    activation_steps_to_keep: List[int],
) -> sp.csr_matrix:
    """Обнуляет активации, шаг появления которых не входит в allow-list.

    Параметры
    ----------
    codes_csr : CSR-матрица активаций SAE.
    activation_steps_csr : CSR-матрица шагов активаций (0 = не активирован).
    activation_steps_to_keep : список шагов, которые нужно сохранить.

    Возвращает
    ----------
    CSR-матрицу той же формы, где вне указанных шагов значения занулены.
    Если список шагов пуст, возвращается исходная матрица без изменений.
    """
    if not activation_steps_to_keep:
        return codes_csr

    if codes_csr.shape != activation_steps_csr.shape:
        raise ValueError(
            f'Codes/steps shape mismatch: {codes_csr.shape} vs {activation_steps_csr.shape}'
        )

    keep_steps = sorted({int(step) for step in activation_steps_to_keep})
    if any(step <= 0 for step in keep_steps):
        raise ValueError('activation_steps_to_keep must contain only positive integers')

    codes_coo = codes_csr.tocoo(copy=False)
    steps_coo = activation_steps_csr.tocoo(copy=False)

    # Steps matrix stores indices for nonzero entries of codes, so coordinates must match.
    if (
        codes_coo.nnz != steps_coo.nnz
        or not np.array_equal(codes_coo.row, steps_coo.row)
        or not np.array_equal(codes_coo.col, steps_coo.col)
    ):
        raise ValueError('Codes and steps matrices must have the same sparsity pattern. Something weird is going on.')
    else:
        steps_for_codes = steps_coo.data

    keep_mask = np.isin(np.asarray(steps_for_codes), np.asarray(keep_steps, dtype=np.int32))

    filtered = sp.coo_matrix(
        (codes_coo.data[keep_mask], (codes_coo.row[keep_mask], codes_coo.col[keep_mask])),
        shape=codes_csr.shape,
        dtype=codes_csr.dtype,
    )
    return filtered.tocsr()


def ensure_activation_cache(
    dataset: str,
    acts_cache_path: str,
    dataset_root: str,
    min_distortion_level: int,
    params: dict,
    include_pristine_cache: Optional[bool] = None,
) -> None:
    """Build distorted+pristine activation cache if missing."""
    cache_filter_min = int(min_distortion_level) if dataset == 'kadid10k' else None
    if include_pristine_cache is None:
        needs_pristine_cache = dataset in {'kadid10k', 'local_kadid'}
    else:
        needs_pristine_cache = bool(include_pristine_cache)

    try:
        load_cache(
            acts_cache_path,
            return_activation_steps=True,
            min_distortion_level=cache_filter_min,
            max_distortion_level=params.get('KADID_MAX_DISTORTION_LEVEL') if dataset == 'kadid10k' else None,
        )
        if needs_pristine_cache:
            load_pristine_cache(acts_cache_path, return_activation_steps=True)
        return
    except FileNotFoundError:
        pass

    logger.info('[run] Activation cache not found for %s. Building cache...', acts_cache_path)
    build_activation_cache(
        dataset=dataset,
        cache_path=acts_cache_path,
        checkpoint_path=params.get('SAE_CHECKPOINT_PATH'),
        dataset_root=dataset_root,
        layer_num=params.get('LAYER_NUM'),
        iqa_metric=params.get('IQA_METRIC'),
        swin_num=params.get('SWIN_NUM'),
        device=params.get('DEVICE'),
        batch_size=params.get('BATCH_SIZE'),
        num_workers=params.get('NUM_WORKERS'),
        crop_size=params.get('CROP_SIZE'),
        scaling_factor=params.get('SCALING_FACTOR'),
        min_distortion_level=min_distortion_level,
        max_batches=None,
        max_memory_gb=30.0,
        add_patch_mask_stats=True,
        include_pristine=needs_pristine_cache,
        srground_include_sr_artifact=bool(params.get('SRGROUND_INCLUDE_SR_ARTIFACT', False)),
    )
    logger.info('[run] Activation cache build completed.')