dvarfe's picture
move to plotly backend
5109422
Raw
History Blame Contribute Delete
4.51 kB
"""
Общие вспомогательные утилиты пакета analysis.
"""
from typing import List
import numpy as np
import pandas as pd
from analysis.features.feature_indexing import FeatureMatrix
def get_top_images_for_feature(
features: FeatureMatrix,
meta: pd.DataFrame,
feature_id: int,
top_n: int = 10,
aggregation: str = 'mean_acts',
) -> List[int]:
"""
Возвращает индексы изображений, на которых признак feature_id активируется сильнее всего.
Активации патчей внутри одного изображения агрегируются в одно скалярное значение,
после чего изображения сортируются по убыванию.
Параметры
----------
features : CSR activations with global id per column
meta : DataFrame с колонкой 'image_idx' (один патч — одна строка)
feature_id : global SAE feature id
top_n : число возвращаемых изображений
aggregation : 'mean_acts' | 'max' | 'sum'
mean_acts — среднее по патчам с активацией > 0
max — максимальная активация среди патчей
sum — сумма всех активаций
Возвращает
----------
List[int] — image_idx в порядке убывания агрегированной активации
"""
assert aggregation in ('mean_acts', 'max', 'sum'), (
f"aggregation must be 'mean_acts', 'max' or 'sum', got {aggregation!r}"
)
col = features.column_for(feature_id)
feature_acts = np.asarray(features.codes[:, col].todense()).ravel() # (n_patches,)
image_idx_arr = meta['image_idx'].values
unique_images = np.unique(image_idx_arr)
scores = np.empty(len(unique_images), dtype=np.float32)
for i, img_idx in enumerate(unique_images):
mask = image_idx_arr == img_idx
vals = feature_acts[mask]
if aggregation == 'mean_acts':
active = vals[vals > 0]
scores[i] = active.mean() if len(active) > 0 else 0.0
elif aggregation == 'max':
scores[i] = vals.max()
else:
scores[i] = vals.sum()
order = np.argsort(scores)[::-1]
return unique_images[order[:top_n]].tolist()
def get_top_images_for_feature_by_iou(
features: FeatureMatrix,
meta: pd.DataFrame,
feature_id: int,
top_n: int = 10,
dataset: str | None = None,
) -> List[int]:
"""
Возвращает индексы изображений с наибольшим IoU между бинарной картой
активаций признака и маской искажений для каждого изображения.
Требует, чтобы в `meta` были колонки `image_idx` и либо `patch_mask_label`,
либо `patch_is_distorted` (см. iou_utils._load_patch_mask_for_group).
"""
from analysis.metrics import iou_utils
col = features.column_for(feature_id)
feature_acts = np.asarray(features.codes[:, col].todense()).ravel()
image_groups = meta.groupby('image_idx')
scores = [] # list of (image_idx, iou)
for image_idx, group_df in image_groups:
sample_indices = group_df.index.to_numpy()
vals = feature_acts[sample_indices]
# binary activation map for this image (per-patch)
act_binary = (vals > 0).astype(np.uint8)
try:
patch_masks = iou_utils._load_patch_mask_for_group(
group_df,
target_dist_type=None,
dataset=(dataset or ''),
)
except Exception:
# If no patch masks available, IoU is undefined — treat as 0
scores.append((int(image_idx), 0.0))
continue
if patch_masks.shape[0] != act_binary.shape[0]:
# mismatch: skip or treat as 0
scores.append((int(image_idx), 0.0))
continue
try:
iou = float(iou_utils._compute_patch_iou(act_binary, patch_masks))
except Exception:
iou = 0.0
scores.append((int(image_idx), iou))
if not scores:
return []
scores.sort(key=lambda x: x[1], reverse=True)
return [img for img, _ in scores[:top_n]]