dvarfe's picture
sync with github version
0705c62
Raw
History Blame Contribute Delete
13.9 kB
import json
import math
import os
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Tuple
import pandas as pd
import torch
from analysis.qalign_utils import QAlignVisionOnlyWrapper, flatten_blc_drop_cls
from log_config import get_logger
logger = get_logger(__name__)
_iqa_activations: Dict[str, torch.Tensor] = {}
_iqa_activation_grids: Dict[str, Tuple[int, int]] = {}
_hook_handle = None
_ACTIVATION_HOOK_ATTR = "_xiqa_activation_hook_handle"
def _make_hook(name: str, sequence_layout: str = 'blc'):
"""Forward hook для ARNIQA/MANIQA/LIQE.
Поддерживаемые форматы выхода слоя:
- (B, C, H, W) -> (B*H*W, C)
- (B, L, C) -> (B*L, C) (MANIQA)
- (L, B, C) -> (B*L, C) (LIQE/LIQE-MIX)
"""
def hook(module, inp, out):
if hasattr(out, 'last_hidden_state'):
out_detached = out.last_hidden_state.detach() # Temporary workaround
else:
out_detached = out.detach()
if out_detached.ndim == 2:
_iqa_activation_grids[name] = (1, 1)
_iqa_activations[name] = out_detached.reshape(-1, out_detached.shape[-1])
return
if out_detached.ndim == 4:
tensor_perm = out_detached.permute(0, 2, 3, 1) # (B, H, W, C)
_iqa_activation_grids[name] = tuple(tensor_perm.shape[1:3])
_iqa_activations[name] = tensor_perm.flatten(0, -2) # (B*H*W, C)
return
if out_detached.ndim == 3:
if sequence_layout == 'lbc':
# LIQE CLIP resblocks usually output (L, B, C).
tokens, batch, channels = out_detached.shape
# Drop CLS token when sequence length is 1 + square (e.g. 50 = 1 + 49).
spatial_wo_cls = int(math.isqrt(max(tokens - 1, 0)))
if spatial_wo_cls * spatial_wo_cls == (tokens - 1):
out_detached = out_detached[1:, :, :]
tokens = tokens - 1
flat_acts = out_detached.permute(1, 0, 2).reshape(batch * tokens, channels)
elif sequence_layout == 'blc':
# MANIQA Swin layers usually output (B, L, C).
batch, tokens, channels = out_detached.shape
flat_acts = out_detached.reshape(batch * tokens, channels)
elif sequence_layout == 'blc_drop_cls':
batch, tokens, channels = out_detached.shape
flat_acts = flatten_blc_drop_cls(out_detached)
tokens = flat_acts.shape[0] // batch
else:
raise ValueError(
f'Unsupported sequence_layout={sequence_layout!r} for layer {name!r}; '
"expected one of ('blc', 'lbc')"
)
spatial = int(math.isqrt(tokens))
if spatial * spatial == tokens:
_iqa_activation_grids[name] = (spatial, spatial)
else:
_iqa_activation_grids[name] = (tokens, 1)
_iqa_activations[name] = flat_acts
return
raise ValueError(
f'Unsupported hooked activation ndim={out_detached.ndim} for layer {name!r}; '
'expected 3D or 4D output'
)
return hook
SAE_CONFIG_FILENAME = "sae_config.json"
def read_sae_config(
checkpoint_path: str,
config_path: Optional[str] = None,
**overrides: Any,
) -> Dict[str, Any]:
"""Read SAE JSON config next to a checkpoint (with optional field overrides)."""
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.abspath(checkpoint_path)),
SAE_CONFIG_FILENAME,
)
with open(config_path) as f:
cfg: Dict[str, Any] = json.load(f)
cfg.update(overrides)
return cfg
def load_sae(
checkpoint_path: str,
config_path: Optional[str] = None,
device: str = 'cpu',
dtype: torch.dtype = torch.float32,
sae_config: Optional[Dict[str, Any]] = None,
**overrides: Any,
) -> torch.nn.Module:
"""Загружает SAE из чекпоинта, определяя архитектуру из JSON-конфига.
Parameters
----------
checkpoint_path : str
Путь к директории чекпоинта (output_dir/checkpoint-<step>/) или файлу весов.
config_path : str, optional
Путь к sae_config.json. По умолчанию ищется в родительской директории
чекпоинта (т.е. output_dir/sae_config.json).
device : str
Устройство для загрузки.
dtype : torch.dtype
Тип данных для загрузки.
sae_config : dict, optional
Уже загруженный конфиг.
**overrides
Переопределяют отдельные поля конфига (например, mp_threshold=0.05).
"""
import accelerate
from training.models.registry import build_sae_from_config
if sae_config is None:
cfg = read_sae_config(checkpoint_path, config_path=config_path, **overrides)
else:
cfg = dict(sae_config)
cfg.update(overrides)
model = build_sae_from_config(cfg)
accelerate.load_checkpoint_in_model(model, checkpoint_path)
model.to(device, dtype)
model.eval()
return model
def sae_encode_acts(
sae: torch.nn.Module,
acts: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Return sparse codes and optional activation steps from any SAE variant.
Standard ``SAE.get_acts`` returns a single codes tensor; ``MatchingPursuitSAE``
returns ``(codes, activation_steps)``.
"""
enc_out = sae.get_acts(acts)
if isinstance(enc_out, tuple):
codes, activation_steps = enc_out
return codes, activation_steps
return enc_out, None
def load_iqa_model(
layer_num: int,
device: str = 'cuda',
iqa_metric: str = 'arniqa-kadid',
swin_num: int = 2,
score_linear_idx: Optional[int] = None,
):
"""
Загружает IQA-модель и регистрирует hook на выбранный слой.
Поддерживаемые метрики:
- arniqa-kadid: hook на iqa.net.encoder[layer_num]
- maniqa: hook на iqa.net.swintransformer{swin_num}.layers[layer_num]
- liqe / liqe_mix: hook на iqa.net.clip_model.visual.transformer.resblocks[layer_num]
- topiq_nr: hook на iqa.net.score_linear[score_linear_idx]
- qalign: hook на visual_abstractor
Возвращает
----------
iqa : загруженная IQA-модель
layer_name : строковый ключ, под которым активации хранятся в _iqa_activations
"""
import pyiqa
global _hook_handle
metric_key = iqa_metric.lower()
create_kwargs: Dict[str, Any] = {
'as_loss': False,
'device': device,
'loss_reduction': 'none',
}
if metric_key == 'maniqa':
create_kwargs['test_sample'] = 1
iqa = pyiqa.create_metric(metric_key, **create_kwargs)
if iqa_metric == 'qalign':
iqa = QAlignVisionOnlyWrapper(iqa)
iqa.eval()
if _hook_handle is not None:
_hook_handle.remove()
hook_layout = 'blc'
if metric_key == 'arniqa-kadid':
layer_name = f'arniqa_enc_{layer_num}'
target_layer = iqa.net.encoder[layer_num]
hook_layout = 'blc'
elif metric_key == 'maniqa':
layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}'
target_layer = getattr(iqa.net, f'swintransformer{swin_num}').layers[layer_num]
elif metric_key == 'qalign':
layer_name = 'visual_abstractor'
target_layer = iqa.visual_abstractor
hook_layout = 'blc_drop_cls'
elif metric_key == 'topiq_nr':
idx = score_linear_idx if score_linear_idx is not None else layer_num
layer_name = f'topiq_score_linear_{idx}'
target_layer = iqa.net.score_linear[idx]
hook_layout = 'blc'
elif metric_key in {'liqe', 'liqe_mix'}:
num_visual_layers = len(iqa.net.clip_model.visual.transformer.resblocks)
for i in range(layer_num + 1, num_visual_layers):
iqa.net.clip_model.visual.transformer.resblocks[i] = torch.nn.Identity()
layer_name = f'liqe_resblock_{layer_num}'
target_layer = iqa.net.clip_model.visual.transformer.resblocks[layer_num]
hook_layout = 'lbc'
else:
raise ValueError(
f'Unsupported iqa_metric={iqa_metric!r}; expected one of '
"('arniqa-kadid', 'maniqa', 'qalign', 'liqe', 'liqe_mix', 'topiq_nr')"
)
if metric_key == 'qalign':
hook_layout = 'blc_drop_cls'
_hook_handle = target_layer.register_forward_hook(
_make_hook(layer_name, sequence_layout=hook_layout)
)
return iqa, layer_name
def ensure_iqa_activation_hook(
model: torch.nn.Module,
layer_num: int,
iqa_metric: str = 'arniqa-kadid',
swin_num: int = 2,
score_linear_idx: Optional[int] = None,
) -> str:
"""Ensure a forward hook is registered on the target IQA layer."""
metric_key = iqa_metric.lower()
hook_layout = 'blc'
if metric_key == 'arniqa-kadid':
layer_name = f'arniqa_enc_{layer_num}'
target_layer = model.net.encoder[layer_num]
elif metric_key == 'maniqa':
layer_name = f'maniqa_swintransformer{swin_num}_layers_{layer_num}'
target_layer = getattr(model.net, f'swintransformer{swin_num}').layers[layer_num]
elif metric_key == 'topiq_nr':
idx = score_linear_idx if score_linear_idx is not None else layer_num
layer_name = f'topiq_score_linear_{idx}'
target_layer = model.net.score_linear[idx]
elif metric_key in {'liqe', 'liqe_mix'}:
layer_name = f'liqe_resblock_{layer_num}'
target_layer = model.net.clip_model.visual.transformer.resblocks[layer_num]
hook_layout = 'lbc'
elif metric_key == 'qalign':
layer_name = 'visual_abstractor'
target_layer = model.visual_abstractor
hook_layout = 'blc_drop_cls'
else:
raise ValueError(f'Unsupported iqa_metric={iqa_metric!r}')
existing = getattr(target_layer, _ACTIVATION_HOOK_ATTR, None)
if existing is not None:
return layer_name
handle = target_layer.register_forward_hook(
_make_hook(layer_name, sequence_layout=hook_layout)
)
setattr(target_layer, _ACTIVATION_HOOK_ATTR, handle)
return layer_name
def _save_decoder_weight_norms_cache(cache_dir: str, norms: Mapping[int, float]) -> None:
cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet'
cache_path.parent.mkdir(parents=True, exist_ok=True)
rows = [{'feature_id': int(k), 'norm': float(v)} for k, v in norms.items()]
pd.DataFrame(rows, columns=['feature_id', 'norm']).to_parquet(cache_path, index=False)
def _load_decoder_weight_norms_cache(cache_dir: str) -> Optional[Dict[int, float]]:
cache_path = Path(cache_dir) / 'sae_decoder_weight_norms.parquet'
if not cache_path.exists():
return None
logger.debug('Reading decoder weight norms %s', cache_path)
df = pd.read_parquet(cache_path)
logger.debug('Decoder weight norms loaded successfully')
if df.empty:
return None
return {int(row.feature_id): float(row.norm) for row in df.itertuples(index=False)}
def extract_decoder_weight_norms(checkpoint_path: str, cache_dir: str) -> Dict[int, float]:
cached_norms = _load_decoder_weight_norms_cache(cache_dir)
if cached_norms is not None:
return cached_norms
sae_model = load_sae(checkpoint_path=checkpoint_path, device='cpu')
try:
decoder_weight = sae_model.decoder.weight
except AttributeError:
decoder_weight = sae_model.W.mT
decoder_weight = decoder_weight.detach().cpu()
norms = decoder_weight.norm(dim=0).numpy()
norms_dict = {int(i): float(norms[i]) for i in range(int(norms.shape[0]))}
_save_decoder_weight_norms_cache(cache_dir, norms_dict)
return norms_dict
def extract_model_hyperparameters(
sae_config: Optional[Dict[str, Any]],
checkpoint_path: str,
) -> Dict[str, Any]:
"""Extract model hyperparameters from runtime/config and SAE config json."""
hyperparams: Dict[str, Any] = {
'iqa_layer': 3,
'iqa_metric': 'arniqa-kadid',
}
if sae_config is None:
checkpoint = Path(checkpoint_path).resolve()
config_dir = checkpoint.parent if checkpoint.is_file() else checkpoint
config_path = config_dir / SAE_CONFIG_FILENAME
if not config_path.exists():
return hyperparams
try:
with config_path.open('r', encoding='utf-8') as f:
sae_config = json.load(f)
except Exception as exc:
logger.warning('Failed to load SAE config from %s: %s', config_path, exc)
return hyperparams
if sae_config and isinstance(sae_config, dict):
hyperparams['iqa_layer'] = sae_config.get(
'layer_num',
sae_config.get('score_linear_idx', hyperparams['iqa_layer']),
)
hyperparams['iqa_metric'] = sae_config.get('iqa_metric', hyperparams['iqa_metric'])
hyperparams['sae_type'] = sae_config.get('sae_type', 'sae')
hyperparams['lambda_param'] = sae_config.get('lambda_param', 'unknown')
hyperparams['sae_inner_dim'] = sae_config.get('inner_dim', 'unknown')
hyperparams['sae_input_dim'] = sae_config.get('sae_input_dim', 'unknown')
if hyperparams['sae_type'] == 'mp_sae':
hyperparams['mp_threshold'] = sae_config.get('mp_threshold', 0.1)
hyperparams['mp_normalize'] = sae_config.get('mp_normalize', True)
else:
hyperparams['weight_norm_init'] = sae_config.get('weight_norm_init', 0.3)
return hyperparams