import json import math import os from pathlib import Path from typing import Any, Dict, Mapping, Optional, Tuple import pandas as pd import torch from analysis.qalign_utils import QAlignVisionOnlyWrapper, flatten_blc_drop_cls from log_config import get_logger logger = get_logger(__name__) _iqa_activations: Dict[str, torch.Tensor] = {} _iqa_activation_grids: Dict[str, Tuple[int, int]] = {} _hook_handle = None _ACTIVATION_HOOK_ATTR = "_xiqa_activation_hook_handle" def _make_hook(name: str, sequence_layout: str = 'blc'): """Forward hook для ARNIQA/MANIQA/LIQE. Поддерживаемые форматы выхода слоя: - (B, C, H, W) -> (B*H*W, C) - (B, L, C) -> (B*L, C) (MANIQA) - (L, B, C) -> (B*L, C) (LIQE/LIQE-MIX) """ def hook(module, inp, out): if hasattr(out, 'last_hidden_state'): out_detached = out.last_hidden_state.detach() # Temporary workaround else: out_detached = out.detach() if out_detached.ndim == 2: _iqa_activation_grids[name] = (1, 1) _iqa_activations[name] = out_detached.reshape(-1, out_detached.shape[-1]) return if out_detached.ndim == 4: tensor_perm = out_detached.permute(0, 2, 3, 1) # (B, H, W, C) _iqa_activation_grids[name] = tuple(tensor_perm.shape[1:3]) _iqa_activations[name] = tensor_perm.flatten(0, -2) # (B*H*W, C) return if out_detached.ndim == 3: if sequence_layout == 'lbc': # LIQE CLIP resblocks usually output (L, B, C). tokens, batch, channels = out_detached.shape # Drop CLS token when sequence length is 1 + square (e.g. 50 = 1 + 49). spatial_wo_cls = int(math.isqrt(max(tokens - 1, 0))) if spatial_wo_cls * spatial_wo_cls == (tokens - 1): out_detached = out_detached[1:, :, :] tokens = tokens - 1 flat_acts = out_detached.permute(1, 0, 2).reshape(batch * tokens, channels) elif sequence_layout == 'blc': # MANIQA Swin layers usually output (B, L, C). batch, tokens, channels = out_detached.shape flat_acts = out_detached.reshape(batch * tokens, channels) elif sequence_layout == 'blc_drop_cls': batch, tokens, channels = out_detached.shape flat_acts = flatten_blc_drop_cls(out_detached) tokens = flat_acts.shape[0] // batch else: raise ValueError( f'Unsupported sequence_layout={sequence_layout!r} for layer {name!r}; ' "expected one of ('blc', 'lbc')" ) spatial = int(math.isqrt(tokens)) if spatial * spatial == tokens: _iqa_activation_grids[name] = (spatial, spatial) else: _iqa_activation_grids[name] = (tokens, 1) _iqa_activations[name] = flat_acts return raise ValueError( f'Unsupported hooked activation ndim={out_detached.ndim} for layer {name!r}; ' 'expected 3D or 4D output' ) return hook SAE_CONFIG_FILENAME = "sae_config.json" def read_sae_config( checkpoint_path: str, config_path: Optional[str] = None, **overrides: Any, ) -> Dict[str, Any]: """Read SAE JSON config next to a checkpoint (with optional field overrides).""" if config_path is None: config_path = os.path.join( os.path.dirname(os.path.abspath(checkpoint_path)), SAE_CONFIG_FILENAME, ) with open(config_path) as f: cfg: Dict[str, Any] = json.load(f) cfg.update(overrides) return cfg def load_sae( checkpoint_path: str, config_path: Optional[str] = None, device: str = 'cpu', dtype: torch.dtype = torch.float32, sae_config: Optional[Dict[str, Any]] = None, **overrides: Any, ) -> torch.nn.Module: """Загружает SAE из чекпоинта, определяя архитектуру из JSON-конфига. Parameters ---------- checkpoint_path : str Путь к директории чекпоинта (output_dir/checkpoint-/) или файлу весов. config_path : str, optional Путь к sae_config.json. По умолчанию ищется в родительской директории чекпоинта (т.е. output_dir/sae_config.json). device : str Устройство для загрузки. dtype : torch.dtype Тип данных для загрузки. sae_config : dict, optional Уже загруженный конфиг. **overrides Переопределяют отдельные поля конфига (например, mp_threshold=0.05). """ import accelerate from training.models.registry import build_sae_from_config if sae_config is None: cfg = read_sae_config(checkpoint_path, config_path=config_path, **overrides) else: cfg = dict(sae_config) cfg.update(overrides) model = build_sae_from_config(cfg) accelerate.load_checkpoint_in_model(model, checkpoint_path) model.to(device, dtype) model.eval() return model def sae_encode_acts( sae: torch.nn.Module, acts: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Return sparse codes and optional activation steps from any SAE variant. Standard ``SAE.get_acts`` returns a single codes tensor; ``MatchingPursuitSAE`` returns ``(codes, activation_steps)``. """ enc_out = sae.get_acts(acts) if isinstance(enc_out, tuple): codes, activation_steps = enc_out return codes, activation_steps return enc_out, None def load_iqa_model( layer_num: int, device: str = 'cuda', iqa_metric: str = 'arniqa-kadid', swin_num: int = 2, score_linear_idx: Optional[int] = None, ): """ Загружает IQA-модель и регистрирует hook на выбранный слой. Поддерживаемые метрики: - arniqa-kadid: hook на iqa.net.encoder[layer_num] - maniqa: hook на iqa.net.swintransformer{swin_num}.layers[layer_num] - liqe / liqe_mix: hook на iqa.net.clip_model.visual.transformer.resblocks[layer_num] - topiq_nr: hook на iqa.net.score_linear[score_linear_idx] - qalign: hook на visual_abstractor Возвращает ---------- iqa : загруженная IQA-модель layer_name : строковый ключ, под которым активации хранятся в _iqa_activations """ import pyiqa global _hook_handle metric_key = iqa_metric.lower() create_kwargs: Dict[str, Any] = { 'as_loss': False, 'device': device, 'loss_reduction': 'none', } if metric_key == 'maniqa': create_kwargs['test_sample'] = 1 iqa = pyiqa.create_metric(metric_key, **create_kwargs) if iqa_metric == 'qalign': iqa = QAlignVisionOnlyWrapper(iqa) iqa.eval() if _hook_handle is not None: _hook_handle.remove() hook_layout = 'blc' if metric_key == 'arniqa-kadid': layer_name = f'arniqa_enc_{layer_num}' target_layer = iqa.net.encoder[layer_num] hook_layout = 'blc' elif metric_key == 'maniqa': layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}' target_layer = getattr(iqa.net, f'swintransformer{swin_num}').layers[layer_num] elif metric_key == 'qalign': layer_name = 'visual_abstractor' target_layer = iqa.visual_abstractor hook_layout = 'blc_drop_cls' elif metric_key == 'topiq_nr': idx = score_linear_idx if score_linear_idx is not None else layer_num layer_name = f'topiq_score_linear_{idx}' target_layer = iqa.net.score_linear[idx] hook_layout = 'blc' elif metric_key in {'liqe', 'liqe_mix'}: num_visual_layers = len(iqa.net.clip_model.visual.transformer.resblocks) for i in range(layer_num + 1, num_visual_layers): iqa.net.clip_model.visual.transformer.resblocks[i] = torch.nn.Identity() layer_name = f'liqe_resblock_{layer_num}' target_layer = iqa.net.clip_model.visual.transformer.resblocks[layer_num] hook_layout = 'lbc' else: raise ValueError( f'Unsupported iqa_metric={iqa_metric!r}; expected one of ' "('arniqa-kadid', 'maniqa', 'qalign', 'liqe', 'liqe_mix', 'topiq_nr')" ) if metric_key == 'qalign': hook_layout = 'blc_drop_cls' _hook_handle = target_layer.register_forward_hook( _make_hook(layer_name, sequence_layout=hook_layout) ) return iqa, layer_name def ensure_iqa_activation_hook( model: torch.nn.Module, layer_num: int, iqa_metric: str = 'arniqa-kadid', swin_num: int = 2, score_linear_idx: Optional[int] = None, ) -> str: """Ensure a forward hook is registered on the target IQA layer.""" metric_key = iqa_metric.lower() hook_layout = 'blc' if metric_key == 'arniqa-kadid': layer_name = f'arniqa_enc_{layer_num}' target_layer = model.net.encoder[layer_num] elif metric_key == 'maniqa': layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}' target_layer = getattr(model.net, f'swintransformer{swin_num}').layers[layer_num] elif metric_key == 'topiq_nr': idx = score_linear_idx if score_linear_idx is not None else layer_num layer_name = f'topiq_score_linear_{idx}' target_layer = model.net.score_linear[idx] elif metric_key in {'liqe', 'liqe_mix'}: layer_name = f'liqe_resblock_{layer_num}' target_layer = model.net.clip_model.visual.transformer.resblocks[layer_num] hook_layout = 'lbc' elif metric_key == 'qalign': layer_name = 'visual_abstractor' target_layer = model.visual_abstractor hook_layout = 'blc_drop_cls' else: raise ValueError(f'Unsupported iqa_metric={iqa_metric!r}') existing = getattr(target_layer, _ACTIVATION_HOOK_ATTR, None) if existing is not None: return layer_name handle = target_layer.register_forward_hook( _make_hook(layer_name, sequence_layout=hook_layout) ) setattr(target_layer, _ACTIVATION_HOOK_ATTR, handle) return layer_name def _save_decoder_weight_norms_cache(cache_dir: str, norms: Mapping[int, float]) -> None: cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet' cache_path.parent.mkdir(parents=True, exist_ok=True) rows = [{'feature_id': int(k), 'norm': float(v)} for k, v in norms.items()] pd.DataFrame(rows, columns=['feature_id', 'norm']).to_parquet(cache_path, index=False) def _load_decoder_weight_norms_cache(cache_dir: str) -> Optional[Dict[int, float]]: cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet' if not cache_path.exists(): return None logger.debug('Reading decoder weight norms %s', cache_path) df = pd.read_parquet(cache_path) logger.debug('Decoder weight norms loaded successfully') if df.empty: return None return {int(row.feature_id): float(row.norm) for row in df.itertuples(index=False)} def extract_decoder_weight_norms(checkpoint_path: str, cache_dir: str) -> Dict[int, float]: cached_norms = _load_decoder_weight_norms_cache(cache_dir) if cached_norms is not None: return cached_norms sae_model = load_sae(checkpoint_path=checkpoint_path, device='cpu') try: decoder_weight = sae_model.decoder.weight except AttributeError: decoder_weight = sae_model.W.mT decoder_weight = decoder_weight.detach().cpu() norms = decoder_weight.norm(dim=0).numpy() norms_dict = {int(i): float(norms[i]) for i in range(int(norms.shape[0]))} _save_decoder_weight_norms_cache(cache_dir, norms_dict) return norms_dict def extract_model_hyperparameters( sae_config: Optional[Dict[str, Any]], checkpoint_path: str, ) -> Dict[str, Any]: """Extract model hyperparameters from runtime/config and SAE config json.""" hyperparams: Dict[str, Any] = { 'iqa_layer': 3, 'iqa_metric': 'arniqa-kadid', } if sae_config is None: checkpoint = Path(checkpoint_path).resolve() config_dir = checkpoint.parent if checkpoint.is_file() else checkpoint config_path = config_dir / SAE_CONFIG_FILENAME if not config_path.exists(): return hyperparams try: with config_path.open('r', encoding='utf-8') as f: sae_config = json.load(f) except Exception as exc: logger.warning('Failed to load SAE config from %s: %s', config_path, exc) return hyperparams if sae_config and isinstance(sae_config, dict): hyperparams['iqa_layer'] = sae_config.get( 'layer_num', sae_config.get('score_linear_idx', hyperparams['iqa_layer']), ) hyperparams['iqa_metric'] = sae_config.get('iqa_metric', hyperparams['iqa_metric']) hyperparams['sae_type'] = sae_config.get('sae_type', 'sae') hyperparams['lambda_param'] = sae_config.get('lambda_param', 'unknown') hyperparams['sae_inner_dim'] = sae_config.get('inner_dim', 'unknown') hyperparams['sae_input_dim'] = sae_config.get('sae_input_dim', 'unknown') if hyperparams['sae_type'] == 'mp_sae': hyperparams['mp_threshold'] = sae_config.get('mp_threshold', 0.1) hyperparams['mp_normalize'] = sae_config.get('mp_normalize', True) else: hyperparams['weight_norm_init'] = sae_config.get('weight_norm_init', 0.3) return hyperparams