Spaces:
Running
Running
| import json | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Mapping, Optional, Tuple | |
| import pandas as pd | |
| import torch | |
| from analysis.qalign_utils import QAlignVisionOnlyWrapper, flatten_blc_drop_cls | |
| from log_config import get_logger | |
| logger = get_logger(__name__) | |
| _iqa_activations: Dict[str, torch.Tensor] = {} | |
| _iqa_activation_grids: Dict[str, Tuple[int, int]] = {} | |
| _hook_handle = None | |
| _ACTIVATION_HOOK_ATTR = "_xiqa_activation_hook_handle" | |
| def _make_hook(name: str, sequence_layout: str = 'blc'): | |
| """Forward hook для ARNIQA/MANIQA/LIQE. | |
| Поддерживаемые форматы выхода слоя: | |
| - (B, C, H, W) -> (B*H*W, C) | |
| - (B, L, C) -> (B*L, C) (MANIQA) | |
| - (L, B, C) -> (B*L, C) (LIQE/LIQE-MIX) | |
| """ | |
| def hook(module, inp, out): | |
| if hasattr(out, 'last_hidden_state'): | |
| out_detached = out.last_hidden_state.detach() # Temporary workaround | |
| else: | |
| out_detached = out.detach() | |
| if out_detached.ndim == 2: | |
| _iqa_activation_grids[name] = (1, 1) | |
| _iqa_activations[name] = out_detached.reshape(-1, out_detached.shape[-1]) | |
| return | |
| if out_detached.ndim == 4: | |
| tensor_perm = out_detached.permute(0, 2, 3, 1) # (B, H, W, C) | |
| _iqa_activation_grids[name] = tuple(tensor_perm.shape[1:3]) | |
| _iqa_activations[name] = tensor_perm.flatten(0, -2) # (B*H*W, C) | |
| return | |
| if out_detached.ndim == 3: | |
| if sequence_layout == 'lbc': | |
| # LIQE CLIP resblocks usually output (L, B, C). | |
| tokens, batch, channels = out_detached.shape | |
| # Drop CLS token when sequence length is 1 + square (e.g. 50 = 1 + 49). | |
| spatial_wo_cls = int(math.isqrt(max(tokens - 1, 0))) | |
| if spatial_wo_cls * spatial_wo_cls == (tokens - 1): | |
| out_detached = out_detached[1:, :, :] | |
| tokens = tokens - 1 | |
| flat_acts = out_detached.permute(1, 0, 2).reshape(batch * tokens, channels) | |
| elif sequence_layout == 'blc': | |
| # MANIQA Swin layers usually output (B, L, C). | |
| batch, tokens, channels = out_detached.shape | |
| flat_acts = out_detached.reshape(batch * tokens, channels) | |
| elif sequence_layout == 'blc_drop_cls': | |
| batch, tokens, channels = out_detached.shape | |
| flat_acts = flatten_blc_drop_cls(out_detached) | |
| tokens = flat_acts.shape[0] // batch | |
| else: | |
| raise ValueError( | |
| f'Unsupported sequence_layout={sequence_layout!r} for layer {name!r}; ' | |
| "expected one of ('blc', 'lbc')" | |
| ) | |
| spatial = int(math.isqrt(tokens)) | |
| if spatial * spatial == tokens: | |
| _iqa_activation_grids[name] = (spatial, spatial) | |
| else: | |
| _iqa_activation_grids[name] = (tokens, 1) | |
| _iqa_activations[name] = flat_acts | |
| return | |
| raise ValueError( | |
| f'Unsupported hooked activation ndim={out_detached.ndim} for layer {name!r}; ' | |
| 'expected 3D or 4D output' | |
| ) | |
| return hook | |
| SAE_CONFIG_FILENAME = "sae_config.json" | |
| def read_sae_config( | |
| checkpoint_path: str, | |
| config_path: Optional[str] = None, | |
| **overrides: Any, | |
| ) -> Dict[str, Any]: | |
| """Read SAE JSON config next to a checkpoint (with optional field overrides).""" | |
| if config_path is None: | |
| config_path = os.path.join( | |
| os.path.dirname(os.path.abspath(checkpoint_path)), | |
| SAE_CONFIG_FILENAME, | |
| ) | |
| with open(config_path) as f: | |
| cfg: Dict[str, Any] = json.load(f) | |
| cfg.update(overrides) | |
| return cfg | |
| def load_sae( | |
| checkpoint_path: str, | |
| config_path: Optional[str] = None, | |
| device: str = 'cpu', | |
| dtype: torch.dtype = torch.float32, | |
| sae_config: Optional[Dict[str, Any]] = None, | |
| **overrides: Any, | |
| ) -> torch.nn.Module: | |
| """Загружает SAE из чекпоинта, определяя архитектуру из JSON-конфига. | |
| Parameters | |
| ---------- | |
| checkpoint_path : str | |
| Путь к директории чекпоинта (output_dir/checkpoint-<step>/) или файлу весов. | |
| config_path : str, optional | |
| Путь к sae_config.json. По умолчанию ищется в родительской директории | |
| чекпоинта (т.е. output_dir/sae_config.json). | |
| device : str | |
| Устройство для загрузки. | |
| dtype : torch.dtype | |
| Тип данных для загрузки. | |
| sae_config : dict, optional | |
| Уже загруженный конфиг. | |
| **overrides | |
| Переопределяют отдельные поля конфига (например, mp_threshold=0.05). | |
| """ | |
| import accelerate | |
| from training.models.registry import build_sae_from_config | |
| if sae_config is None: | |
| cfg = read_sae_config(checkpoint_path, config_path=config_path, **overrides) | |
| else: | |
| cfg = dict(sae_config) | |
| cfg.update(overrides) | |
| model = build_sae_from_config(cfg) | |
| accelerate.load_checkpoint_in_model(model, checkpoint_path) | |
| model.to(device, dtype) | |
| model.eval() | |
| return model | |
| def sae_encode_acts( | |
| sae: torch.nn.Module, | |
| acts: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| """Return sparse codes and optional activation steps from any SAE variant. | |
| Standard ``SAE.get_acts`` returns a single codes tensor; ``MatchingPursuitSAE`` | |
| returns ``(codes, activation_steps)``. | |
| """ | |
| enc_out = sae.get_acts(acts) | |
| if isinstance(enc_out, tuple): | |
| codes, activation_steps = enc_out | |
| return codes, activation_steps | |
| return enc_out, None | |
| def load_iqa_model( | |
| layer_num: int, | |
| device: str = 'cuda', | |
| iqa_metric: str = 'arniqa-kadid', | |
| swin_num: int = 2, | |
| score_linear_idx: Optional[int] = None, | |
| ): | |
| """ | |
| Загружает IQA-модель и регистрирует hook на выбранный слой. | |
| Поддерживаемые метрики: | |
| - arniqa-kadid: hook на iqa.net.encoder[layer_num] | |
| - maniqa: hook на iqa.net.swintransformer{swin_num}.layers[layer_num] | |
| - liqe / liqe_mix: hook на iqa.net.clip_model.visual.transformer.resblocks[layer_num] | |
| - topiq_nr: hook на iqa.net.score_linear[score_linear_idx] | |
| - qalign: hook на visual_abstractor | |
| Возвращает | |
| ---------- | |
| iqa : загруженная IQA-модель | |
| layer_name : строковый ключ, под которым активации хранятся в _iqa_activations | |
| """ | |
| import pyiqa | |
| global _hook_handle | |
| metric_key = iqa_metric.lower() | |
| create_kwargs: Dict[str, Any] = { | |
| 'as_loss': False, | |
| 'device': device, | |
| 'loss_reduction': 'none', | |
| } | |
| if metric_key == 'maniqa': | |
| create_kwargs['test_sample'] = 1 | |
| iqa = pyiqa.create_metric(metric_key, **create_kwargs) | |
| if iqa_metric == 'qalign': | |
| iqa = QAlignVisionOnlyWrapper(iqa) | |
| iqa.eval() | |
| if _hook_handle is not None: | |
| _hook_handle.remove() | |
| hook_layout = 'blc' | |
| if metric_key == 'arniqa-kadid': | |
| layer_name = f'arniqa_enc_{layer_num}' | |
| target_layer = iqa.net.encoder[layer_num] | |
| hook_layout = 'blc' | |
| elif metric_key == 'maniqa': | |
| layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}' | |
| target_layer = getattr(iqa.net, f'swintransformer{swin_num}').layers[layer_num] | |
| elif metric_key == 'qalign': | |
| layer_name = 'visual_abstractor' | |
| target_layer = iqa.visual_abstractor | |
| hook_layout = 'blc_drop_cls' | |
| elif metric_key == 'topiq_nr': | |
| idx = score_linear_idx if score_linear_idx is not None else layer_num | |
| layer_name = f'topiq_score_linear_{idx}' | |
| target_layer = iqa.net.score_linear[idx] | |
| hook_layout = 'blc' | |
| elif metric_key in {'liqe', 'liqe_mix'}: | |
| num_visual_layers = len(iqa.net.clip_model.visual.transformer.resblocks) | |
| for i in range(layer_num + 1, num_visual_layers): | |
| iqa.net.clip_model.visual.transformer.resblocks[i] = torch.nn.Identity() | |
| layer_name = f'liqe_resblock_{layer_num}' | |
| target_layer = iqa.net.clip_model.visual.transformer.resblocks[layer_num] | |
| hook_layout = 'lbc' | |
| else: | |
| raise ValueError( | |
| f'Unsupported iqa_metric={iqa_metric!r}; expected one of ' | |
| "('arniqa-kadid', 'maniqa', 'qalign', 'liqe', 'liqe_mix', 'topiq_nr')" | |
| ) | |
| if metric_key == 'qalign': | |
| hook_layout = 'blc_drop_cls' | |
| _hook_handle = target_layer.register_forward_hook( | |
| _make_hook(layer_name, sequence_layout=hook_layout) | |
| ) | |
| return iqa, layer_name | |
| def ensure_iqa_activation_hook( | |
| model: torch.nn.Module, | |
| layer_num: int, | |
| iqa_metric: str = 'arniqa-kadid', | |
| swin_num: int = 2, | |
| score_linear_idx: Optional[int] = None, | |
| ) -> str: | |
| """Ensure a forward hook is registered on the target IQA layer.""" | |
| metric_key = iqa_metric.lower() | |
| hook_layout = 'blc' | |
| if metric_key == 'arniqa-kadid': | |
| layer_name = f'arniqa_enc_{layer_num}' | |
| target_layer = model.net.encoder[layer_num] | |
| elif metric_key == 'maniqa': | |
| layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}' | |
| target_layer = getattr(model.net, f'swintransformer{swin_num}').layers[layer_num] | |
| elif metric_key == 'topiq_nr': | |
| idx = score_linear_idx if score_linear_idx is not None else layer_num | |
| layer_name = f'topiq_score_linear_{idx}' | |
| target_layer = model.net.score_linear[idx] | |
| elif metric_key in {'liqe', 'liqe_mix'}: | |
| layer_name = f'liqe_resblock_{layer_num}' | |
| target_layer = model.net.clip_model.visual.transformer.resblocks[layer_num] | |
| hook_layout = 'lbc' | |
| elif metric_key == 'qalign': | |
| layer_name = 'visual_abstractor' | |
| target_layer = model.visual_abstractor | |
| hook_layout = 'blc_drop_cls' | |
| else: | |
| raise ValueError(f'Unsupported iqa_metric={iqa_metric!r}') | |
| existing = getattr(target_layer, _ACTIVATION_HOOK_ATTR, None) | |
| if existing is not None: | |
| return layer_name | |
| handle = target_layer.register_forward_hook( | |
| _make_hook(layer_name, sequence_layout=hook_layout) | |
| ) | |
| setattr(target_layer, _ACTIVATION_HOOK_ATTR, handle) | |
| return layer_name | |
| def _save_decoder_weight_norms_cache(cache_dir: str, norms: Mapping[int, float]) -> None: | |
| cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet' | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| rows = [{'feature_id': int(k), 'norm': float(v)} for k, v in norms.items()] | |
| pd.DataFrame(rows, columns=['feature_id', 'norm']).to_parquet(cache_path, index=False) | |
| def _load_decoder_weight_norms_cache(cache_dir: str) -> Optional[Dict[int, float]]: | |
| cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet' | |
| if not cache_path.exists(): | |
| return None | |
| logger.debug('Reading decoder weight norms %s', cache_path) | |
| df = pd.read_parquet(cache_path) | |
| logger.debug('Decoder weight norms loaded successfully') | |
| if df.empty: | |
| return None | |
| return {int(row.feature_id): float(row.norm) for row in df.itertuples(index=False)} | |
| def extract_decoder_weight_norms(checkpoint_path: str, cache_dir: str) -> Dict[int, float]: | |
| cached_norms = _load_decoder_weight_norms_cache(cache_dir) | |
| if cached_norms is not None: | |
| return cached_norms | |
| sae_model = load_sae(checkpoint_path=checkpoint_path, device='cpu') | |
| try: | |
| decoder_weight = sae_model.decoder.weight | |
| except AttributeError: | |
| decoder_weight = sae_model.W.mT | |
| decoder_weight = decoder_weight.detach().cpu() | |
| norms = decoder_weight.norm(dim=0).numpy() | |
| norms_dict = {int(i): float(norms[i]) for i in range(int(norms.shape[0]))} | |
| _save_decoder_weight_norms_cache(cache_dir, norms_dict) | |
| return norms_dict | |
| def extract_model_hyperparameters( | |
| sae_config: Optional[Dict[str, Any]], | |
| checkpoint_path: str, | |
| ) -> Dict[str, Any]: | |
| """Extract model hyperparameters from runtime/config and SAE config json.""" | |
| hyperparams: Dict[str, Any] = { | |
| 'iqa_layer': 3, | |
| 'iqa_metric': 'arniqa-kadid', | |
| } | |
| if sae_config is None: | |
| checkpoint = Path(checkpoint_path).resolve() | |
| config_dir = checkpoint.parent if checkpoint.is_file() else checkpoint | |
| config_path = config_dir / SAE_CONFIG_FILENAME | |
| if not config_path.exists(): | |
| return hyperparams | |
| try: | |
| with config_path.open('r', encoding='utf-8') as f: | |
| sae_config = json.load(f) | |
| except Exception as exc: | |
| logger.warning('Failed to load SAE config from %s: %s', config_path, exc) | |
| return hyperparams | |
| if sae_config and isinstance(sae_config, dict): | |
| hyperparams['iqa_layer'] = sae_config.get( | |
| 'layer_num', | |
| sae_config.get('score_linear_idx', hyperparams['iqa_layer']), | |
| ) | |
| hyperparams['iqa_metric'] = sae_config.get('iqa_metric', hyperparams['iqa_metric']) | |
| hyperparams['sae_type'] = sae_config.get('sae_type', 'sae') | |
| hyperparams['lambda_param'] = sae_config.get('lambda_param', 'unknown') | |
| hyperparams['sae_inner_dim'] = sae_config.get('inner_dim', 'unknown') | |
| hyperparams['sae_input_dim'] = sae_config.get('sae_input_dim', 'unknown') | |
| if hyperparams['sae_type'] == 'mp_sae': | |
| hyperparams['mp_threshold'] = sae_config.get('mp_threshold', 0.1) | |
| hyperparams['mp_normalize'] = sae_config.get('mp_normalize', True) | |
| else: | |
| hyperparams['weight_norm_init'] = sae_config.get('weight_norm_init', 0.3) | |
| return hyperparams | |