Spaces:
Running
Running
| """ | |
| Утилиты для кэширования и загрузки активаций SAE. | |
| """ | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import pandas as pd | |
| import scipy.sparse as sp | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from analysis.models import _iqa_activations | |
| from log_config import get_logger | |
| logger = get_logger(__name__) | |
| def _df_mb(df: pd.DataFrame) -> float: | |
| """Объём памяти DataFrame в МБ.""" | |
| return df.memory_usage(deep=True).sum() / 1024 ** 2 | |
| def _sparse_mb(mat: sp.csr_matrix) -> float: | |
| """Объём памяти CSR-матрицы (данные + индексы) в МБ.""" | |
| return (mat.data.nbytes + mat.indices.nbytes + mat.indptr.nbytes) / 1024 ** 2 | |
| def _cache_paths(base_path: str) -> Tuple[str, str, str]: | |
| """ | |
| Возвращает пути к трём файлам кэша из базового пути. | |
| Пример: | |
| 'cache/kadid_acts.feather' | |
| → ('cache/kadid_acts_meta.feather', 'cache/kadid_acts_codes.npz', | |
| 'cache/kadid_acts_steps.npz') | |
| """ | |
| p = Path(base_path) | |
| stem = p.stem.removesuffix('.feather') | |
| return ( | |
| str(p.parent / f'{stem}_meta.feather'), | |
| str(p.parent / f'{stem}_codes.npz'), | |
| str(p.parent / f'{stem}_steps.npz'), | |
| ) | |
| def _pristine_cache_paths(base_path: str) -> Tuple[str, str, str]: | |
| """Return paths for pristine cache files derived from base cache path.""" | |
| meta_path, codes_path, steps_path = _cache_paths(base_path) | |
| return ( | |
| meta_path.replace('.feather', '_pristine.feather'), | |
| codes_path.replace('.npz', '_pristine.npz'), | |
| steps_path.replace('.npz', '_pristine.npz'), | |
| ) | |
| def load_parquet_cache(cache_path: Optional[str], *, label: str = 'cache') -> Optional[pd.DataFrame]: | |
| """Load cached parquet table if present.""" | |
| if cache_path is None: | |
| return None | |
| cache = Path(cache_path) | |
| if not cache.exists(): | |
| return None | |
| logger.debug('[cache] Loading %s from %s', label, cache) | |
| return pd.read_parquet(cache) | |
| def save_parquet_cache(df: pd.DataFrame, cache_path: Optional[str], *, label: str = 'cache') -> None: | |
| """Persist a dataframe to parquet cache if path is provided.""" | |
| if cache_path is None: | |
| return | |
| cache = Path(cache_path) | |
| cache.parent.mkdir(parents=True, exist_ok=True) | |
| df.to_parquet(cache) | |
| logger.debug('[cache] Saved %s to %s', label, cache) | |
| _PATCH_LABEL_META_KEYS = frozenset({'dist_type', 'dist_group'}) | |
| def _patch_labels_to_dist_meta( | |
| patch_labels: np.ndarray, | |
| *, | |
| label_to_dist_type: Dict[int, str], | |
| label_to_dist_group: Dict[str, str], | |
| ) -> Tuple[List[str], List[str]]: | |
| flat_labels = patch_labels.reshape(-1) | |
| dist_types = [label_to_dist_type.get(int(label_id), 'background') for label_id in flat_labels] | |
| dist_groups = [label_to_dist_group.get(dist_type, dist_type) for dist_type in dist_types] | |
| return dist_types, dist_groups | |
| def _process_dataloader( | |
| dataloader, | |
| iqa, | |
| sae, | |
| layer_name, | |
| scaling_factor, | |
| device, | |
| patches_per_image, | |
| patch_grid_shape, | |
| meta_keys, | |
| max_batches, | |
| max_memory_gb, | |
| add_patch_mask_stats, | |
| show_progress_bars: bool = True, | |
| *, | |
| label_to_dist_type: Optional[Dict[int, str]] = None, | |
| label_to_dist_group: Optional[Dict[str, str]] = None, | |
| ): | |
| all_sparse_codes = [] | |
| all_meta = [] | |
| all_sparse_steps = [] | |
| n_patches_known = patches_per_image | |
| patch_grid_known = patch_grid_shape | |
| image_offset = 0 | |
| for batch_i, batch in enumerate(tqdm(dataloader, desc='Caching activations', disable=not show_progress_bars)): | |
| if max_batches is not None and batch_i >= max_batches: | |
| break | |
| imgs = batch['images'].to(device) | |
| B = imgs.shape[0] | |
| with torch.no_grad(): | |
| iqa(imgs) | |
| acts = _iqa_activations[layer_name].to(device) | |
| acts = acts * scaling_factor | |
| enc_out = sae.get_acts(acts) | |
| if isinstance(enc_out, tuple): | |
| codes, activation_steps = enc_out | |
| else: | |
| codes, activation_steps = enc_out, None | |
| codes_np = codes.cpu().float().numpy() | |
| if activation_steps is None: | |
| steps_np = np.zeros_like(codes_np, dtype=np.int32) | |
| else: | |
| steps_np = activation_steps.cpu().numpy().astype(np.int32) | |
| if n_patches_known is None: | |
| n_patches_known = codes_np.shape[0] // B | |
| logger.info('Detected %s patches per image', n_patches_known) | |
| P = n_patches_known | |
| use_patch_label_meta = ( | |
| label_to_dist_type is not None | |
| and label_to_dist_group is not None | |
| ) | |
| meta = {} | |
| for k in meta_keys: | |
| if k in batch: | |
| if use_patch_label_meta and k in _PATCH_LABEL_META_KEYS: | |
| continue | |
| vals = batch[k] | |
| meta[k] = [v for v in vals for _ in range(P)] | |
| meta['patch_idx'] = list(range(P)) * B | |
| meta['image_idx'] = [image_offset + i for i in range(B) for _ in range(P)] | |
| if add_patch_mask_stats and 'masks' in batch: | |
| masks = batch['masks'].to(device=device, dtype=torch.float32) | |
| if patch_grid_known is not None: | |
| grid_h, grid_w = patch_grid_known | |
| if grid_h * grid_w == P: | |
| mask_labels = masks.to(dtype=torch.int64) | |
| max_label = int(mask_labels.max().item()) | |
| max_cov = None | |
| if max_label <= 0: | |
| patch_labels = torch.zeros((B, grid_h, grid_w), device=device, dtype=torch.int64) | |
| else: | |
| class_coverages = [] | |
| for label_id in range(1, max_label + 1): | |
| label_cov = F.adaptive_avg_pool2d( | |
| (mask_labels == label_id).to(dtype=torch.float32), | |
| (grid_h, grid_w), | |
| ) | |
| class_coverages.append(label_cov) | |
| coverages = torch.cat(class_coverages, dim=1) # (B, classes, H, W) | |
| max_cov, max_idx = coverages.max(dim=1) | |
| patch_labels = torch.where( | |
| max_cov > 0, | |
| max_idx.to(dtype=torch.int64) + 1, | |
| torch.zeros_like(max_idx, dtype=torch.int64), | |
| ) | |
| patch_labels_np = patch_labels.reshape(B, P).cpu().numpy().astype(np.int16) | |
| patch_is_dist = (patch_labels_np > 0).astype(np.int8) | |
| patch_coverage_np = max_cov.reshape(B, P).cpu().numpy() if max_label > 0 else patch_is_dist.astype(np.float32) | |
| meta['patch_mask_label'] = patch_labels_np.reshape(-1).tolist() | |
| meta['patch_mask_coverage'] = patch_coverage_np.reshape(-1).tolist() | |
| meta['patch_is_distorted'] = patch_is_dist.reshape(-1).tolist() | |
| if use_patch_label_meta: | |
| dist_types, dist_groups = _patch_labels_to_dist_meta( | |
| patch_labels_np, | |
| label_to_dist_type=label_to_dist_type, | |
| label_to_dist_group=label_to_dist_group, | |
| ) | |
| if 'dist_type' in meta_keys: | |
| meta['dist_type'] = dist_types | |
| if 'dist_group' in meta_keys: | |
| meta['dist_group'] = dist_groups | |
| all_meta.append(pd.DataFrame(meta)) | |
| all_sparse_codes.append(sp.csr_matrix(codes_np)) | |
| all_sparse_steps.append(sp.csr_matrix(steps_np)) | |
| image_offset += B | |
| meta_df = pd.concat(all_meta, ignore_index=True) | |
| codes_csr = sp.vstack(all_sparse_codes, format='csr') | |
| steps_csr = sp.vstack(all_sparse_steps, format='csr') | |
| return meta_df, codes_csr, steps_csr | |
| def collect_and_cache( | |
| dataloader: DataLoader, | |
| iqa: torch.nn.Module, | |
| sae, | |
| layer_name: str, | |
| output_path: str, | |
| scaling_factor: float = 1.0, | |
| patches_per_image: Optional[int] = None, | |
| patch_grid_shape: Optional[Tuple[int, int]] = None, | |
| meta_keys: Sequence[str] = ( | |
| 'dist_type', | |
| 'dist_group', | |
| 'dist_level', | |
| 'mos', | |
| 'distorted_img_path', | |
| 'original_img_path', | |
| 'sample_id', | |
| ), | |
| device: str = 'cuda', | |
| max_batches: Optional[int] = None, | |
| max_memory_gb: Optional[float] = None, | |
| add_patch_mask_stats: bool = True, | |
| pristine_dataloader: Optional[DataLoader] = None, | |
| show_progress_bars: bool = True, | |
| label_to_dist_type: Optional[Dict[int, str]] = None, | |
| label_to_dist_group: Optional[Dict[str, str]] = None, | |
| ) -> Tuple[pd.DataFrame, sp.csr_matrix]: | |
| meta_df, codes_csr, steps_csr = _process_dataloader( | |
| dataloader=dataloader, | |
| iqa=iqa, | |
| sae=sae, | |
| layer_name=layer_name, | |
| scaling_factor=scaling_factor, | |
| device=device, | |
| patches_per_image=patches_per_image, | |
| patch_grid_shape=patch_grid_shape, | |
| meta_keys=meta_keys, | |
| max_batches=max_batches, | |
| max_memory_gb=max_memory_gb, | |
| add_patch_mask_stats=add_patch_mask_stats, | |
| show_progress_bars=show_progress_bars, | |
| label_to_dist_type=label_to_dist_type, | |
| label_to_dist_group=label_to_dist_group, | |
| ) | |
| sparse_mb = _sparse_mb(codes_csr) | |
| steps_mb = _sparse_mb(steps_csr) | |
| logger.info('Activations: shape=%s, %.1f МБ (sparse)', codes_csr.shape, sparse_mb) | |
| logger.info('Activation steps: shape=%s, %.1f МБ (sparse)', steps_csr.shape, steps_mb) | |
| meta_path, codes_path, steps_path = _cache_paths(output_path) | |
| meta_df.to_feather(meta_path) | |
| sp.save_npz(codes_path, codes_csr) | |
| sp.save_npz(steps_path, steps_csr) | |
| logger.info('Saved metadata (%s rows) -> %s', len(meta_df), meta_path) | |
| logger.info('Saved activations -> %s', codes_path) | |
| logger.info('Saved activation steps -> %s', steps_path) | |
| logger.info(' Metadata: %.1f МБ', _df_mb(meta_df)) | |
| logger.info(' Activations: %.1f МБ', sparse_mb) | |
| logger.info(' Steps: %.1f МБ', steps_mb) | |
| if pristine_dataloader is not None: | |
| logger.info('Processing pristine dataset...') | |
| pristine_meta, pristine_codes, pristine_steps = _process_dataloader( | |
| dataloader=pristine_dataloader, | |
| iqa=iqa, | |
| sae=sae, | |
| layer_name=layer_name, | |
| scaling_factor=scaling_factor, | |
| device=device, | |
| patches_per_image=patches_per_image, | |
| patch_grid_shape=patch_grid_shape, | |
| meta_keys=meta_keys, | |
| max_batches=max_batches, | |
| max_memory_gb=max_memory_gb, | |
| add_patch_mask_stats=False, | |
| show_progress_bars=show_progress_bars, | |
| ) | |
| pristine_sparse_mb = _sparse_mb(pristine_codes) | |
| pristine_steps_mb = _sparse_mb(pristine_steps) | |
| pristine_meta_path = meta_path.replace(".feather", "_pristine.feather") | |
| pristine_codes_path = codes_path.replace(".npz", "_pristine.npz") | |
| pristine_steps_path = steps_path.replace(".npz", "_pristine.npz") | |
| pristine_meta.to_feather(pristine_meta_path) | |
| sp.save_npz(pristine_codes_path, pristine_codes) | |
| sp.save_npz(pristine_steps_path, pristine_steps) | |
| logger.info( | |
| 'Pristine activations: shape=%s, %.1f МБ', | |
| pristine_codes.shape, | |
| pristine_sparse_mb, | |
| ) | |
| logger.info( | |
| 'Pristine steps: shape=%s, %.1f МБ', | |
| pristine_steps.shape, | |
| pristine_steps_mb, | |
| ) | |
| logger.info('Saved pristine metadata (%s rows) -> %s', len(pristine_meta), pristine_meta_path) | |
| logger.info('Saved pristine activations -> %s', pristine_codes_path) | |
| logger.info('Saved pristine activation steps -> %s', pristine_steps_path) | |
| return meta_df, codes_csr | |
| def build_activation_cache( | |
| *, | |
| dataset: str, | |
| cache_path: str, | |
| checkpoint_path: str, | |
| dataset_root: str, | |
| layer_num: int, | |
| iqa_metric: str, | |
| swin_num: int, | |
| device: str, | |
| batch_size: int, | |
| num_workers: int, | |
| crop_size: int, | |
| scaling_factor: float = 1.0, | |
| min_distortion_level: Optional[int] = None, | |
| max_batches: Optional[int] = None, | |
| max_memory_gb: Optional[float] = None, | |
| add_patch_mask_stats: bool = True, | |
| include_pristine: bool = True, | |
| show_progress_bars: bool = True, | |
| srground_include_sr_artifact: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Build activation cache end-to-end for KADID/local-KADID datasets.""" | |
| from .datasets import ( | |
| Kadid10kDataset, | |
| KadidPristineDataset, | |
| LocalKadidPresavedDataset, | |
| LocalKadidPristineDataset, | |
| QGroundDataset, | |
| SRGroundSmallDataset, | |
| available_distortions_qground, | |
| available_distortions_srground, | |
| distortion_types_mapping_qground, | |
| distortion_types_mapping_srground, | |
| kadid_collate_fn, | |
| kadid_pristine_collate_fn, | |
| local_kadid_collate_fn, | |
| local_kadid_pristine_collate_fn, | |
| qground_collate_fn, | |
| srground_collate_fn, | |
| ) | |
| from .models import _iqa_activation_grids, load_iqa_model, load_sae, read_sae_config | |
| if min_distortion_level is not None and not (1 <= min_distortion_level <= 5): | |
| raise ValueError('min_distortion_level must be in [1, 5]') | |
| label_to_dist_type = None | |
| label_to_dist_group = None | |
| if dataset == 'local_kadid': | |
| data = LocalKadidPresavedDataset(root=dataset_root, crop_size=crop_size) | |
| collate_fn = local_kadid_collate_fn | |
| meta_keys = [ | |
| 'dist_type', | |
| 'dist_group', | |
| 'dist_level', | |
| 'mos', | |
| 'local_dist_type', | |
| 'local_dist_level', | |
| 'mask_shape', | |
| 'mask_coverage', | |
| 'sample_id', | |
| 'distorted_img_path', | |
| 'original_img_path', | |
| ] | |
| pristine_data = LocalKadidPristineDataset(root=dataset_root, crop_size=crop_size) if include_pristine else None | |
| pristine_collate = local_kadid_pristine_collate_fn | |
| elif dataset in {'kadid10k', 'kadid'}: | |
| data = Kadid10kDataset( | |
| root=dataset_root, | |
| crop_size=crop_size, | |
| min_distortion_level=min_distortion_level or 1, | |
| ) | |
| collate_fn = kadid_collate_fn | |
| meta_keys = [ | |
| 'dist_type', | |
| 'dist_group', | |
| 'dist_level', | |
| 'mos', | |
| 'distorted_img_path', | |
| 'original_img_path', | |
| ] | |
| pristine_data = KadidPristineDataset(root=dataset_root, crop_size=crop_size) if include_pristine else None | |
| pristine_collate = kadid_pristine_collate_fn | |
| elif dataset == 'QGround': | |
| data = QGroundDataset( | |
| root=dataset_root, | |
| split='test', | |
| crop_size=crop_size, | |
| ) | |
| collate_fn = qground_collate_fn | |
| meta_keys = [ | |
| 'dist_type', | |
| 'dist_group', | |
| 'dist_level', | |
| 'mos', | |
| 'mask_coverage', | |
| 'qground_ann_id', | |
| 'sample_id', | |
| 'distorted_img_path', | |
| 'original_img_path', | |
| 'image_path', | |
| 'mask_path', | |
| 'split', | |
| ] | |
| label_to_dist_type = distortion_types_mapping_qground | |
| label_to_dist_group = available_distortions_qground | |
| pristine_data = None | |
| pristine_collate = None | |
| elif dataset == 'SRGround': | |
| data = SRGroundSmallDataset( | |
| root=dataset_root, | |
| json_path='srground_train.json', | |
| crop_size=crop_size, | |
| include_sr_artifact=srground_include_sr_artifact, | |
| ) | |
| collate_fn = srground_collate_fn | |
| meta_keys = [ | |
| 'dist_type', | |
| 'dist_group', | |
| 'dist_level', | |
| 'mos', | |
| 'mask_coverage', | |
| 'sample_id', | |
| 'distorted_img_path', | |
| 'image_path', | |
| 'real_distortions_ann_path', | |
| 'sr_artifacts_ann_path', | |
| ] | |
| label_to_dist_type = distortion_types_mapping_srground | |
| label_to_dist_group = available_distortions_srground | |
| pristine_data = None | |
| pristine_collate = None | |
| else: | |
| raise ValueError(f'Unsupported dataset: {dataset}') | |
| loader = DataLoader( | |
| data, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, | |
| ) | |
| pristine_loader = None | |
| if pristine_data is not None: | |
| pristine_loader = DataLoader( | |
| pristine_data, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=pristine_collate, | |
| ) | |
| iqa_model, layer_name = load_iqa_model( | |
| layer_num=layer_num, | |
| device=device, | |
| iqa_metric=iqa_metric, | |
| swin_num=swin_num, | |
| ) | |
| dtype = torch.float16 if iqa_metric == 'qalign' else torch.float32 | |
| sae_cfg = read_sae_config(checkpoint_path) | |
| sae_model = load_sae(checkpoint_path, device=device, dtype=dtype, sae_config=sae_cfg) | |
| with torch.no_grad(): | |
| dummy = torch.rand(1, 3, crop_size, crop_size, device=device).clamp(0, 1) | |
| iqa_model(dummy) | |
| if layer_name not in _iqa_activations or layer_name not in _iqa_activation_grids: | |
| raise RuntimeError(f'Cannot infer activation grid for layer {layer_name}') | |
| patch_grid_shape = _iqa_activation_grids[layer_name] | |
| patches_per_image = patch_grid_shape[0] * patch_grid_shape[1] | |
| Path(cache_path).parent.mkdir(parents=True, exist_ok=True) | |
| collect_and_cache( | |
| dataloader=loader, | |
| iqa=iqa_model, | |
| sae=sae_model, | |
| layer_name=layer_name, | |
| output_path=cache_path, | |
| scaling_factor=scaling_factor, | |
| patches_per_image=patches_per_image, | |
| patch_grid_shape=patch_grid_shape, | |
| meta_keys=meta_keys, | |
| device=device, | |
| max_batches=max_batches, | |
| max_memory_gb=max_memory_gb, | |
| add_patch_mask_stats=add_patch_mask_stats, | |
| pristine_dataloader=pristine_loader, | |
| show_progress_bars=show_progress_bars, | |
| label_to_dist_type=label_to_dist_type, | |
| label_to_dist_group=label_to_dist_group, | |
| ) | |
| return { | |
| 'layer_name': layer_name, | |
| 'patch_grid_shape': patch_grid_shape, | |
| 'patches_per_image': patches_per_image, | |
| 'sae_config': sae_cfg, | |
| } | |
| def load_cache( | |
| path: str, | |
| return_activation_steps: bool = False, | |
| min_distortion_level: Optional[int] = None, | |
| max_distortion_level: Optional[int] = None, | |
| ) -> Union[ | |
| Tuple[pd.DataFrame, sp.csr_matrix], | |
| Tuple[pd.DataFrame, sp.csr_matrix, sp.csr_matrix], | |
| ]: | |
| """Загружает кэш активаций SAE из раздельных файлов. | |
| Если ``return_activation_steps=True``, дополнительно возвращает CSR-матрицу | |
| порядка активаций, где значение ``0`` означает отсутствие активации, | |
| а ``k>0`` соответствует шагу ``k`` в pursuit. | |
| """ | |
| meta_path, codes_path, steps_path = _cache_paths(path) | |
| meta = pd.read_feather(meta_path) | |
| codes = sp.load_npz(codes_path) | |
| logger.debug( | |
| 'Loaded from %s: %s rows × %s cols', | |
| meta_path, | |
| meta.shape[0], | |
| meta.shape[1], | |
| ) | |
| logger.debug('Loaded from %s: shape=%s, dtype=%s', codes_path, codes.shape, codes.dtype) | |
| logger.debug(' Metadata: %.1f МБ', _df_mb(meta)) | |
| logger.debug(' Activations: %.1f МБ (sparse)', _sparse_mb(codes)) | |
| keep_idx: Optional[np.ndarray] = None | |
| if min_distortion_level is not None or max_distortion_level is not None: | |
| if 'dist_level' not in meta.columns: | |
| raise ValueError('Cannot filter by distortion level: metadata has no "dist_level" column') | |
| min_level = 1 if min_distortion_level is None else int(min_distortion_level) | |
| max_level = 5 if max_distortion_level is None else int(max_distortion_level) | |
| if min_level > max_level: | |
| raise ValueError( | |
| f'Invalid distortion-level range: min_distortion_level={min_level} > max_distortion_level={max_level}' | |
| ) | |
| if 'Ground' not in path: | |
| keep_mask = (meta['dist_level'] >= min_level) & (meta['dist_level'] <= max_level) | |
| keep_idx = np.flatnonzero(keep_mask.to_numpy()) | |
| else: | |
| keep_mask = (meta['dist_level'] >= -1000) | |
| keep_idx = np.flatnonzero(keep_mask.to_numpy()) # Temporary workaround -- fix later | |
| if return_activation_steps: | |
| if Path(steps_path).exists(): | |
| steps = sp.load_npz(steps_path) | |
| if steps.shape != codes.shape: | |
| raise ValueError( | |
| f'Steps cache shape mismatch: expected {codes.shape}, got {steps.shape}' | |
| ) | |
| logger.info('Loaded from %s: shape=%s, dtype=%s', steps_path, steps.shape, steps.dtype) | |
| else: | |
| logger.warning('No steps cache found. Using all-zero activation steps.') | |
| steps = sp.csr_matrix(codes.shape, dtype=np.int32) | |
| if keep_idx is not None: | |
| meta = meta.iloc[keep_idx].reset_index(drop=True) | |
| codes = codes[keep_idx] | |
| steps = steps[keep_idx] | |
| logger.info( | |
| 'Applied dist_level filter [%s, %s] -> %s rows kept', | |
| min_level, | |
| max_level, | |
| meta.shape[0], | |
| ) | |
| logger.info(' Steps: %.1f МБ (sparse)', _sparse_mb(steps)) | |
| return meta, codes, steps | |
| if keep_idx is not None: | |
| meta = meta.iloc[keep_idx].reset_index(drop=True) | |
| codes = codes[keep_idx] | |
| logger.info( | |
| 'Applied dist_level filter [%s, %s] -> %s rows kept', | |
| min_level, | |
| max_level, | |
| meta.shape[0], | |
| ) | |
| return meta, codes | |
| def load_pristine_cache( | |
| path: str, | |
| return_activation_steps: bool = False, | |
| ) -> Union[ | |
| Tuple[pd.DataFrame, sp.csr_matrix], | |
| Tuple[pd.DataFrame, sp.csr_matrix, sp.csr_matrix], | |
| ]: | |
| """Load pristine (original-image) activation cache saved by collect_and_cache.""" | |
| meta_path, codes_path, steps_path = _pristine_cache_paths(path) | |
| meta = pd.read_feather(meta_path) | |
| codes = sp.load_npz(codes_path) | |
| logger.info( | |
| 'Loaded pristine from %s: %s rows × %s cols', | |
| meta_path, | |
| meta.shape[0], | |
| meta.shape[1], | |
| ) | |
| logger.info('Loaded pristine from %s: shape=%s, dtype=%s', codes_path, codes.shape, codes.dtype) | |
| logger.info(' Metadata: %.1f МБ', _df_mb(meta)) | |
| logger.info(' Activations: %.1f МБ (sparse)', _sparse_mb(codes)) | |
| if return_activation_steps: | |
| if Path(steps_path).exists(): | |
| steps = sp.load_npz(steps_path) | |
| if steps.shape != codes.shape: | |
| raise ValueError( | |
| f'Pristine steps cache shape mismatch: expected {codes.shape}, got {steps.shape}' | |
| ) | |
| logger.info( | |
| 'Loaded pristine from %s: shape=%s, dtype=%s', | |
| steps_path, | |
| steps.shape, | |
| steps.dtype, | |
| ) | |
| else: | |
| logger.warning('No pristine steps cache found. Using all-zero activation steps.') | |
| steps = sp.csr_matrix(codes.shape, dtype=np.int32) | |
| logger.info(' Steps: %.1f МБ (sparse)', _sparse_mb(steps)) | |
| return meta, codes, steps | |
| return meta, codes | |
| def ensure_cache_ready( | |
| cache_path: str, | |
| *, | |
| force_recache: bool = False, | |
| build_cache_if_missing: bool = True, | |
| load_cache_kwargs: Optional[Dict[str, Any]] = None, | |
| build_cache_fn: Optional[Callable[[], None]] = None, | |
| ) -> None: | |
| """Проверяет доступность кэша и при необходимости собирает его. | |
| Поведение: | |
| - пытается загрузить кэш через ``load_cache``; | |
| - если кэш отсутствует или выставлен ``force_recache=True``, запускает сборку; | |
| - если сборка отключена, пробрасывает ``FileNotFoundError``. | |
| """ | |
| needs_rebuild = force_recache | |
| if not needs_rebuild: | |
| try: | |
| load_cache(cache_path, **(load_cache_kwargs or {})) | |
| return | |
| except FileNotFoundError: | |
| needs_rebuild = True | |
| if not needs_rebuild: | |
| return | |
| if not build_cache_if_missing: | |
| raise FileNotFoundError( | |
| f'Activation cache not found at {cache_path}, and build is disabled. ' | |
| 'Use --build-cache-if-missing or provide existing cache files.' | |
| ) | |
| if build_cache_fn is None: | |
| raise ValueError( | |
| 'build_cache_fn must be provided when cache rebuild is required ' | |
| '(missing cache or force_recache=True).' | |
| ) | |
| logger.debug('[cache] Building activation cache...') | |
| build_cache_fn() | |
| def zero_codes_outside_activation_steps( | |
| codes_csr: sp.csr_matrix, | |
| activation_steps_csr: sp.csr_matrix, | |
| activation_steps_to_keep: List[int], | |
| ) -> sp.csr_matrix: | |
| """Обнуляет активации, шаг появления которых не входит в allow-list. | |
| Параметры | |
| ---------- | |
| codes_csr : CSR-матрица активаций SAE. | |
| activation_steps_csr : CSR-матрица шагов активаций (0 = не активирован). | |
| activation_steps_to_keep : список шагов, которые нужно сохранить. | |
| Возвращает | |
| ---------- | |
| CSR-матрицу той же формы, где вне указанных шагов значения занулены. | |
| Если список шагов пуст, возвращается исходная матрица без изменений. | |
| """ | |
| if not activation_steps_to_keep: | |
| return codes_csr | |
| if codes_csr.shape != activation_steps_csr.shape: | |
| raise ValueError( | |
| f'Codes/steps shape mismatch: {codes_csr.shape} vs {activation_steps_csr.shape}' | |
| ) | |
| keep_steps = sorted({int(step) for step in activation_steps_to_keep}) | |
| if any(step <= 0 for step in keep_steps): | |
| raise ValueError('activation_steps_to_keep must contain only positive integers') | |
| codes_coo = codes_csr.tocoo(copy=False) | |
| steps_coo = activation_steps_csr.tocoo(copy=False) | |
| # Steps matrix stores indices for nonzero entries of codes, so coordinates must match. | |
| if ( | |
| codes_coo.nnz != steps_coo.nnz | |
| or not np.array_equal(codes_coo.row, steps_coo.row) | |
| or not np.array_equal(codes_coo.col, steps_coo.col) | |
| ): | |
| raise ValueError('Codes and steps matrices must have the same sparsity pattern. Something weird is going on.') | |
| else: | |
| steps_for_codes = steps_coo.data | |
| keep_mask = np.isin(np.asarray(steps_for_codes), np.asarray(keep_steps, dtype=np.int32)) | |
| filtered = sp.coo_matrix( | |
| (codes_coo.data[keep_mask], (codes_coo.row[keep_mask], codes_coo.col[keep_mask])), | |
| shape=codes_csr.shape, | |
| dtype=codes_csr.dtype, | |
| ) | |
| return filtered.tocsr() | |
| def ensure_activation_cache( | |
| dataset: str, | |
| acts_cache_path: str, | |
| dataset_root: str, | |
| min_distortion_level: int, | |
| params: dict, | |
| include_pristine_cache: Optional[bool] = None, | |
| ) -> None: | |
| """Build distorted+pristine activation cache if missing.""" | |
| cache_filter_min = int(min_distortion_level) if dataset == 'kadid10k' else None | |
| if include_pristine_cache is None: | |
| needs_pristine_cache = dataset in {'kadid10k', 'local_kadid'} | |
| else: | |
| needs_pristine_cache = bool(include_pristine_cache) | |
| try: | |
| load_cache( | |
| acts_cache_path, | |
| return_activation_steps=True, | |
| min_distortion_level=cache_filter_min, | |
| max_distortion_level=params.get('KADID_MAX_DISTORTION_LEVEL') if dataset == 'kadid10k' else None, | |
| ) | |
| if needs_pristine_cache: | |
| load_pristine_cache(acts_cache_path, return_activation_steps=True) | |
| return | |
| except FileNotFoundError: | |
| pass | |
| logger.info('[run] Activation cache not found for %s. Building cache...', acts_cache_path) | |
| build_activation_cache( | |
| dataset=dataset, | |
| cache_path=acts_cache_path, | |
| checkpoint_path=params.get('SAE_CHECKPOINT_PATH'), | |
| dataset_root=dataset_root, | |
| layer_num=params.get('LAYER_NUM'), | |
| iqa_metric=params.get('IQA_METRIC'), | |
| swin_num=params.get('SWIN_NUM'), | |
| device=params.get('DEVICE'), | |
| batch_size=params.get('BATCH_SIZE'), | |
| num_workers=params.get('NUM_WORKERS'), | |
| crop_size=params.get('CROP_SIZE'), | |
| scaling_factor=params.get('SCALING_FACTOR'), | |
| min_distortion_level=min_distortion_level, | |
| max_batches=None, | |
| max_memory_gb=30.0, | |
| add_patch_mask_stats=True, | |
| include_pristine=needs_pristine_cache, | |
| srground_include_sr_artifact=bool(params.get('SRGROUND_INCLUDE_SR_ARTIFACT', False)), | |
| ) | |
| logger.info('[run] Activation cache build completed.') | |