Spaces:
Running
Running
| from __future__ import annotations | |
| from collections.abc import Collection | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| def list_ica_layer_keys(ica_dir: Path) -> list[str]: | |
| if not ica_dir.is_dir(): | |
| return [] | |
| layers = [path.name[: -len("_fastica.pt")] for path in ica_dir.glob("*_fastica.pt")] | |
| return sorted(layers, key=_layer_sort_key) | |
| def fastica_artifact_path(ica_dir: Path, layer: str) -> Path: | |
| return ica_dir / f"{layer}_fastica.pt" | |
| def interpret_text_probe( | |
| *, | |
| model: torch.nn.Module, | |
| tokenizer: Any, | |
| text: str, | |
| layer: str, | |
| ica_artifact_path: Path, | |
| top_k: int, | |
| highlight_components: Collection[int] | None, | |
| max_length: int, | |
| raw_gpt2_block_index: int | None = None, | |
| ) -> dict[str, Any]: | |
| num_transformer_layers = int(model.config.num_hidden_layers) | |
| hidden_index = _layer_key_to_hidden_index(layer, num_transformer_layers=num_transformer_layers) | |
| show_predictions = layer == f"layer_{num_transformer_layers - 1:02d}" | |
| full_ids = tokenizer.encode(text, add_special_tokens=True) | |
| truncated = len(full_ids) > max_length | |
| encoded = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) | |
| input_ids = encoded["input_ids"].to(next(model.parameters()).device) | |
| attention_mask = encoded.get("attention_mask") | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(input_ids.device) | |
| predictions: list[dict[str, Any]] | None = None | |
| with torch.inference_mode(): | |
| if layer == "embedding": | |
| hidden = model.get_input_embeddings()(input_ids)[0] | |
| elif raw_gpt2_block_index is not None: | |
| captured: dict[str, torch.Tensor] = {} | |
| blocks = getattr(getattr(model, "transformer", None), "h", None) | |
| if blocks is None: | |
| raise RuntimeError("raw_gpt2_block_index requires a GPT-2 style model.transformer.h module list.") | |
| def hook(_module: Any, _inputs: Any, output: Any) -> None: | |
| captured["hidden"] = output[0].detach() if isinstance(output, tuple) else output.detach() | |
| handle = blocks[int(raw_gpt2_block_index)].register_forward_hook(hook) | |
| try: | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| use_cache=False, | |
| ) | |
| finally: | |
| handle.remove() | |
| captured_hidden = captured.get("hidden") | |
| if captured_hidden is None: | |
| raise RuntimeError("GPT-2 raw block hook did not capture hidden states.") | |
| hidden = captured_hidden[0] | |
| if show_predictions: | |
| pred_ids = torch.argmax(outputs.logits[0], dim=-1).detach().cpu().tolist() | |
| predictions = [ | |
| { | |
| "token_id": int(pred_id), | |
| "token": tokenizer.convert_ids_to_tokens(int(pred_id)), | |
| "token_text": _decode_token_text(tokenizer, int(pred_id)), | |
| } | |
| for pred_id in pred_ids | |
| ] | |
| else: | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| use_cache=False, | |
| ) | |
| hidden_states = outputs.hidden_states | |
| if hidden_states is None: | |
| raise RuntimeError("Model did not return hidden states.") | |
| hidden = hidden_states[hidden_index][0] | |
| if show_predictions: | |
| pred_ids = torch.argmax(outputs.logits[0], dim=-1).detach().cpu().tolist() | |
| predictions = [ | |
| { | |
| "token_id": int(pred_id), | |
| "token": tokenizer.convert_ids_to_tokens(int(pred_id)), | |
| "token_text": _decode_token_text(tokenizer, int(pred_id)), | |
| } | |
| for pred_id in pred_ids | |
| ] | |
| artifact = _load_fastica_artifact(ica_artifact_path) | |
| scores = _all_source_scores(hidden, **artifact) | |
| idx, vals = _topk_components_per_token(scores, top_k=top_k) | |
| forced = sorted({int(component) for component in (highlight_components or [])}) | |
| out_of_range = [component for component in forced if component < 0 or component >= int(scores.shape[1])] | |
| if out_of_range: | |
| raise ValueError(f"highlight component out of range: {out_of_range[0]}") | |
| ids = input_ids[0].detach().cpu().tolist() | |
| idx_cpu = idx.detach().cpu().tolist() | |
| vals_cpu = vals.detach().cpu().tolist() | |
| forced_scores = scores[:, forced].detach().cpu().tolist() if forced else [] | |
| tokens = [] | |
| for pos, token_id in enumerate(ids): | |
| top = [ | |
| {"component": int(idx_cpu[pos][j]), "score": float(vals_cpu[pos][j])} | |
| for j in range(len(idx_cpu[pos])) | |
| ] | |
| seen = {item["component"] for item in top} | |
| for j, component in enumerate(forced): | |
| if component not in seen: | |
| top.append({"component": component, "score": float(forced_scores[pos][j]), "highlighted": True}) | |
| token_item = { | |
| "position": pos, | |
| "token_id": int(token_id), | |
| "token": tokenizer.convert_ids_to_tokens(int(token_id)), | |
| "token_text": _decode_token_text(tokenizer, int(token_id)), | |
| "top": top, | |
| } | |
| if predictions is not None: | |
| token_item["prediction"] = predictions[pos] | |
| tokens.append(token_item) | |
| return { | |
| "layer": layer, | |
| "top_k": int(top_k), | |
| "max_length": int(max_length), | |
| "tokens": tokens, | |
| "seq_len": len(tokens), | |
| "truncated": truncated, | |
| "n_components": int(scores.shape[1]), | |
| "predictions_available": predictions is not None, | |
| } | |
| def _load_fastica_artifact(path: Path) -> dict[str, torch.Tensor | float]: | |
| blob = torch.load(path, map_location="cpu") | |
| tensors = blob["tensors"] | |
| meta = blob.get("metadata") or {} | |
| mean = tensors["mean"].to(torch.float32) | |
| if mean.dim() == 1: | |
| mean = mean.unsqueeze(0) | |
| return { | |
| "mean": mean, | |
| "components": tensors["components"].to(torch.float32), | |
| "norm_eps": float(meta.get("norm_eps", 1e-12)), | |
| } | |
| def _all_source_scores( | |
| activations: torch.Tensor, | |
| *, | |
| mean: torch.Tensor, | |
| components: torch.Tensor, | |
| norm_eps: float, | |
| ) -> torch.Tensor: | |
| x = activations.to(dtype=torch.float32) | |
| normalized = x / torch.linalg.vector_norm(x, dim=1, keepdim=True).clamp_min(norm_eps) | |
| return (normalized - mean.to(x.device)) @ components.to(x.device).T | |
| def _topk_components_per_token(scores: torch.Tensor, *, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| k = min(int(top_k), int(scores.shape[1])) | |
| if k <= 0: | |
| raise ValueError("top_k must be positive.") | |
| idx = torch.topk(scores.abs(), k, dim=1).indices | |
| row_idx = torch.arange(scores.shape[0], device=scores.device, dtype=torch.long).unsqueeze(1).expand_as(idx) | |
| return idx, scores[row_idx, idx] | |
| def _layer_key_to_hidden_index(layer: str, *, num_transformer_layers: int) -> int: | |
| if layer == "embedding": | |
| return 0 | |
| if not layer.startswith("layer_"): | |
| raise ValueError(f"Unknown layer key: {layer!r}") | |
| idx = int(layer.split("_", maxsplit=1)[1]) | |
| if idx < 0 or idx >= num_transformer_layers: | |
| raise ValueError(f"Layer index out of range: {layer!r}") | |
| return idx + 1 | |
| def _layer_sort_key(layer: str) -> tuple[int, int | str]: | |
| if layer == "embedding": | |
| return (0, 0) | |
| if layer.startswith("layer_"): | |
| return (1, int(layer.removeprefix("layer_"))) | |
| return (2, layer) | |
| def _decode_token_text(tokenizer: Any, token_id: int) -> str: | |
| try: | |
| decoded = str(tokenizer.decode([token_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)) | |
| except Exception: | |
| decoded = "" | |
| raw_token = str(tokenizer.convert_ids_to_tokens(token_id)) | |
| if "\ufffd" in decoded: | |
| byte_text = _decode_byte_level_token(raw_token) | |
| if byte_text is not None: | |
| return byte_text | |
| return decoded or raw_token | |
| def _decode_byte_level_token(token: str) -> str | None: | |
| byte_decoder = _gpt2_byte_decoder() | |
| try: | |
| raw = bytes(byte_decoder[ch] for ch in token) | |
| except KeyError: | |
| return None | |
| try: | |
| return raw.decode("utf-8") | |
| except UnicodeDecodeError: | |
| return "".join(f"\\x{byte:02X}" for byte in raw) | |
| def _gpt2_byte_decoder() -> dict[str, int]: | |
| bs = ( | |
| list(range(ord("!"), ord("~") + 1)) | |
| + list(range(ord("¡"), ord("¬") + 1)) | |
| + list(range(ord("®"), ord("ÿ") + 1)) | |
| ) | |
| cs = bs[:] | |
| n = 0 | |
| for byte in range(256): | |
| if byte not in bs: | |
| bs.append(byte) | |
| cs.append(256 + n) | |
| n += 1 | |
| return {chr(char): byte for byte, char in zip(bs, cs, strict=True)} | |