Spaces:
Sleeping
Sleeping
Search
sync: update from main β H6 continuous bias, attention detection, auto-calibration, new scripts
78c132b | from __future__ import annotations | |
| from dataclasses import dataclass | |
| import math | |
| import re | |
| from collections.abc import Mapping | |
| from typing import Any | |
| import torch | |
| import torch.nn.functional as F | |
| _EPS = 1e-8 | |
| _LEADING_WHITESPACE_MARKERS = ("Δ ", "β", "Δ", "Δ") | |
| class AnchorSpanMatch: | |
| anchor_text: str | |
| token_start: int | |
| token_end: int | |
| token_count: int | |
| char_start: int | None | |
| char_end: int | None | |
| match_method: str | |
| matched_text: str | |
| def select_representative_layers( | |
| num_hidden_layers: int, | |
| count: int = 4, | |
| ) -> list[int]: | |
| if num_hidden_layers <= 0: | |
| raise ValueError("num_hidden_layers must be positive") | |
| if count <= 0: | |
| raise ValueError("count must be positive") | |
| fractions = (0.25, 0.50, 0.75, 0.96) | |
| raw_layers = [ | |
| max(1, min(num_hidden_layers, int(round(num_hidden_layers * fraction)))) | |
| for fraction in fractions[:count] | |
| ] | |
| layers = sorted(set(raw_layers)) | |
| if len(layers) >= count: | |
| return layers[:count] | |
| fallback = torch.linspace(1, num_hidden_layers, steps=min(num_hidden_layers, count)).round().int().tolist() | |
| for layer in fallback: | |
| layer_id = int(layer) | |
| if layer_id not in layers: | |
| layers.append(layer_id) | |
| return sorted(set(layers)) | |
| def list_model_layers(num_hidden_layers: int) -> list[int]: | |
| if num_hidden_layers <= 0: | |
| raise ValueError("num_hidden_layers must be positive") | |
| return list(range(int(num_hidden_layers))) | |
| def select_tail_probe_layers( | |
| num_hidden_layers: int, | |
| count: int = 10, | |
| ) -> list[int]: | |
| if num_hidden_layers <= 0: | |
| raise ValueError("num_hidden_layers must be positive") | |
| if count <= 0: | |
| raise ValueError("count must be positive") | |
| width = min(int(count), int(num_hidden_layers)) | |
| start = int(num_hidden_layers) - width | |
| return list(range(start, int(num_hidden_layers))) | |
| def build_tail_reference_layers( | |
| probe_layers: list[int], | |
| ) -> dict[str, int]: | |
| if not probe_layers: | |
| raise ValueError("probe_layers must not be empty") | |
| layers = sorted(int(layer) for layer in probe_layers) | |
| mature_index = max(0, len(layers) - 4) | |
| template_prev_index = max(0, len(layers) - 2) | |
| template_curr_index = len(layers) - 1 | |
| mature_layer = layers[mature_index] | |
| return { | |
| "slope_start_layer": layers[0], | |
| "slope_end_layer": mature_layer, | |
| "mature_layer": mature_layer, | |
| "template_prev_layer": layers[template_prev_index], | |
| "template_curr_layer": layers[template_curr_index], | |
| } | |
| def _normalize_text(text: str) -> str: | |
| return " ".join(text.lower().split()) | |
| def _find_unique_substring(text: str, anchor_text: str) -> tuple[int, int] | None: | |
| matches = list(re.finditer(re.escape(anchor_text), text)) | |
| if not matches: | |
| return None | |
| match = matches[0] | |
| return match.start(), match.end() | |
| def _decode_anchor_tokens( | |
| tokenizer: Any, | |
| token_ids: list[int], | |
| ) -> str: | |
| try: | |
| return str(tokenizer.decode(token_ids, skip_special_tokens=False)) | |
| except Exception: | |
| return "" | |
| def decode_token_surfaces( | |
| tokenizer: Any, | |
| token_ids: list[int], | |
| ) -> list[str]: | |
| convert = getattr(tokenizer, "convert_ids_to_tokens", None) | |
| if callable(convert): | |
| try: | |
| tokens = convert(token_ids) | |
| if isinstance(tokens, list) and len(tokens) == len(token_ids): | |
| return [str(token) for token in tokens] | |
| except Exception: | |
| pass | |
| return [_decode_anchor_tokens(tokenizer, [int(token_id)]) for token_id in token_ids] | |
| def decode_token_pieces( | |
| tokenizer: Any, | |
| token_ids: list[int], | |
| ) -> list[str]: | |
| return [_decode_anchor_tokens(tokenizer, [int(token_id)]) for token_id in token_ids] | |
| def token_has_leading_whitespace( | |
| raw_surface: str, | |
| decoded_piece: str, | |
| ) -> bool: | |
| if decoded_piece.startswith(" "): | |
| return True | |
| return raw_surface.startswith(_LEADING_WHITESPACE_MARKERS) | |
| def _search_subsequence( | |
| full_ids: list[int], | |
| sub_ids: list[int], | |
| ) -> list[tuple[int, int]]: | |
| if not sub_ids or len(sub_ids) > len(full_ids): | |
| return [] | |
| matches: list[tuple[int, int]] = [] | |
| width = len(sub_ids) | |
| for start in range(len(full_ids) - width + 1): | |
| if full_ids[start : start + width] == sub_ids: | |
| matches.append((start, start + width - 1)) | |
| return matches | |
| def _match_from_offsets( | |
| *, | |
| text: str, | |
| anchor_text: str, | |
| offsets: list[tuple[int, int]], | |
| ) -> AnchorSpanMatch | None: | |
| char_span = _find_unique_substring(text, anchor_text) | |
| if char_span is None: | |
| return None | |
| char_start, char_end = char_span | |
| active_tokens = [ | |
| idx | |
| for idx, (offset_start, offset_end) in enumerate(offsets) | |
| if offset_end > offset_start and offset_start < char_end and offset_end > char_start | |
| ] | |
| if not active_tokens: | |
| return None | |
| token_start = active_tokens[0] | |
| token_end = active_tokens[-1] | |
| matched_text = text[offsets[token_start][0] : offsets[token_end][1]] | |
| if _normalize_text(anchor_text) not in _normalize_text(matched_text): | |
| return None | |
| return AnchorSpanMatch( | |
| anchor_text=anchor_text, | |
| token_start=token_start, | |
| token_end=token_end, | |
| token_count=token_end - token_start + 1, | |
| char_start=char_start, | |
| char_end=char_end, | |
| match_method="offset_mapping", | |
| matched_text=matched_text, | |
| ) | |
| def _match_from_token_ids( | |
| *, | |
| anchor_text: str, | |
| input_ids: list[int], | |
| tokenizer: Any, | |
| ) -> AnchorSpanMatch | None: | |
| candidate_matches: dict[tuple[int, int], str] = {} | |
| for variant, label in ((anchor_text, "token_ids_exact"), (f" {anchor_text}", "token_ids_prefixed_space")): | |
| try: | |
| encoded = tokenizer(variant, add_special_tokens=False) | |
| except Exception: | |
| continue | |
| phrase_ids: Any | |
| if isinstance(encoded, Mapping): | |
| phrase_ids = encoded.get("input_ids") | |
| elif hasattr(encoded, "input_ids"): | |
| phrase_ids = getattr(encoded, "input_ids") | |
| elif hasattr(encoded, "__getitem__"): | |
| try: | |
| phrase_ids = encoded["input_ids"] | |
| except Exception: | |
| phrase_ids = encoded | |
| else: | |
| phrase_ids = encoded | |
| if phrase_ids is None: | |
| continue | |
| if isinstance(phrase_ids, torch.Tensor): | |
| phrase_seq = [int(token) for token in phrase_ids.reshape(-1).tolist()] | |
| else: | |
| phrase_seq = [int(token) for token in phrase_ids] | |
| for match in _search_subsequence(input_ids, phrase_seq): | |
| candidate_matches[match] = label | |
| if not candidate_matches: | |
| return None | |
| (token_start, token_end), label = next(iter(candidate_matches.items())) | |
| matched_text = _decode_anchor_tokens(tokenizer, input_ids[token_start : token_end + 1]) | |
| return AnchorSpanMatch( | |
| anchor_text=anchor_text, | |
| token_start=token_start, | |
| token_end=token_end, | |
| token_count=token_end - token_start + 1, | |
| char_start=None, | |
| char_end=None, | |
| match_method=label, | |
| matched_text=matched_text, | |
| ) | |
| def _match_from_decoded_pieces( | |
| *, | |
| anchor_text: str, | |
| input_ids: list[int], | |
| tokenizer: Any, | |
| ) -> AnchorSpanMatch | None: | |
| """Fallback: decode each token individually, build charβtoken map, search.""" | |
| pieces = decode_token_pieces(tokenizer, input_ids) | |
| if not pieces: | |
| return None | |
| # Build cumulative decoded string with token boundary tracking | |
| char_to_token: list[int] = [] | |
| decoded_full = "" | |
| for tok_idx, piece in enumerate(pieces): | |
| for _ in piece: | |
| char_to_token.append(tok_idx) | |
| decoded_full += piece | |
| if not decoded_full: | |
| return None | |
| # Search for anchor_text (case-insensitive to handle tokenizer quirks) | |
| needle = anchor_text.lower() | |
| haystack = decoded_full.lower() | |
| matches: list[int] = [] | |
| start_pos = 0 | |
| while True: | |
| pos = haystack.find(needle, start_pos) | |
| if pos == -1: | |
| break | |
| matches.append(pos) | |
| start_pos = pos + 1 | |
| if not matches: | |
| return None | |
| char_start = matches[0] | |
| char_end = char_start + len(needle) | |
| # Map character positions to token indices | |
| if char_start >= len(char_to_token) or char_end - 1 >= len(char_to_token): | |
| return None | |
| token_start = char_to_token[char_start] | |
| token_end = char_to_token[min(char_end - 1, len(char_to_token) - 1)] | |
| matched_text = decoded_full[char_start:char_end] | |
| return AnchorSpanMatch( | |
| anchor_text=anchor_text, | |
| token_start=token_start, | |
| token_end=token_end, | |
| token_count=token_end - token_start + 1, | |
| char_start=None, | |
| char_end=None, | |
| match_method="decoded_pieces", | |
| matched_text=matched_text, | |
| ) | |
| def match_anchor_span( | |
| *, | |
| text: str, | |
| anchor_text: str, | |
| input_ids: list[int], | |
| tokenizer: Any, | |
| offsets: list[tuple[int, int]] | None = None, | |
| ) -> AnchorSpanMatch | None: | |
| if offsets is not None: | |
| match = _match_from_offsets( | |
| text=text, | |
| anchor_text=anchor_text, | |
| offsets=offsets, | |
| ) | |
| if match is not None: | |
| return match | |
| match = _match_from_token_ids( | |
| anchor_text=anchor_text, | |
| input_ids=input_ids, | |
| tokenizer=tokenizer, | |
| ) | |
| if match is not None: | |
| return match | |
| return _match_from_decoded_pieces( | |
| anchor_text=anchor_text, | |
| input_ids=input_ids, | |
| tokenizer=tokenizer, | |
| ) | |
| def detect_anchor_span( | |
| attentions: tuple[torch.Tensor, ...], | |
| probe_layers: list[int] | None = None, | |
| *, | |
| min_width: int = 2, | |
| max_width: int = 8, | |
| skip_special: int = 1, | |
| use_last_n: int = 4, | |
| ) -> AnchorSpanMatch | None: | |
| """Detect anchor span via attention mass from the last token. | |
| For each available mature attention layer, sum attention the last | |
| token pays to each preceding position (averaged across heads). | |
| The contiguous span with highest cumulative score is the anchor. | |
| Args: | |
| attentions: tuple of [batch, heads, seq, seq] from model output. | |
| May be a subset of all layers (e.g. 8 out of 32). | |
| probe_layers: layer indices to try. If the tuple is shorter than | |
| expected, falls back to the last ``use_last_n`` available entries. | |
| min_width: minimum span width in tokens. | |
| max_width: maximum span width in tokens. | |
| skip_special: skip first N tokens (BOS / special). | |
| use_last_n: how many of the *last* available attention entries to | |
| aggregate when probe_layers don't map into the tuple. | |
| Returns: | |
| AnchorSpanMatch with detected span, or None if input is too short. | |
| """ | |
| if not attentions: | |
| return None | |
| seq_len = attentions[0].size(-1) | |
| if seq_len < skip_special + min_width + 1: | |
| return None | |
| n_attn = len(attentions) | |
| # Build list of actual attention indices to use | |
| attn_indices: list[int] = [] | |
| if probe_layers: | |
| for layer_idx in probe_layers: | |
| attn_idx = layer_idx + 1 | |
| if 0 <= attn_idx < n_attn: | |
| attn_indices.append(attn_idx) | |
| # Fallback: use last N available attention entries | |
| if not attn_indices: | |
| start = max(0, n_attn - use_last_n) | |
| attn_indices = list(range(start, n_attn)) | |
| score = torch.zeros(seq_len) | |
| n_layers_used = 0 | |
| for attn_idx in attn_indices: | |
| attn = attentions[attn_idx][0] # [heads, seq, seq] | |
| # Attention from last token to all previous | |
| last_attn = attn[:, -1, :].mean(dim=0).detach().cpu() # [seq] | |
| score += last_attn | |
| n_layers_used += 1 | |
| if n_layers_used == 0: | |
| return None | |
| score /= n_layers_used | |
| # Zero out special tokens and the last token itself | |
| score[:skip_special] = 0.0 | |
| score[-1] = 0.0 | |
| # Sliding window: find the contiguous span with highest total attention | |
| best_score = -1.0 | |
| best_start = skip_special | |
| best_end = skip_special + min_width - 1 | |
| upper_width = min(max_width, seq_len - skip_special - 1) | |
| for width in range(min_width, upper_width + 1): | |
| if width > seq_len: | |
| break | |
| windows = score[skip_special : seq_len - 1].unfold(0, width, 1) | |
| if windows.numel() == 0: | |
| continue | |
| sums = windows.sum(dim=-1) | |
| idx = int(sums.argmax().item()) | |
| val = float(sums[idx].item()) | |
| if val > best_score: | |
| best_score = val | |
| best_start = skip_special + idx | |
| best_end = best_start + width - 1 | |
| return AnchorSpanMatch( | |
| anchor_text="", | |
| token_start=best_start, | |
| token_end=best_end, | |
| token_count=best_end - best_start + 1, | |
| char_start=None, | |
| char_end=None, | |
| match_method="attention_mass", | |
| matched_text="", | |
| ) | |
| def extract_delta_vectors( | |
| hidden_states: torch.Tensor, | |
| token_start: int, | |
| token_end: int, | |
| ) -> torch.Tensor: | |
| if hidden_states.ndim != 2: | |
| raise ValueError("hidden_states must be shaped [seq_len, hidden_dim]") | |
| if token_start < 0 or token_end < token_start or token_end >= hidden_states.size(0): | |
| raise ValueError("invalid token span") | |
| span_hidden = hidden_states[token_start : token_end + 1].to(dtype=torch.float32) | |
| if span_hidden.size(0) < 2: | |
| return span_hidden.new_zeros((0, span_hidden.size(-1))) | |
| return span_hidden[1:] - span_hidden[:-1] | |
| def compute_geometry_metrics( | |
| delta_vectors: torch.Tensor, | |
| ) -> dict[str, float | int | None]: | |
| if delta_vectors.ndim != 2: | |
| raise ValueError("delta_vectors must be shaped [delta_count, hidden_dim]") | |
| delta_count = int(delta_vectors.size(0)) | |
| token_count = delta_count + 1 if delta_count > 0 else 0 | |
| if delta_count == 0: | |
| return { | |
| "token_count": token_count, | |
| "delta_count": delta_count, | |
| "mean_direction_norm": None, | |
| "mean_step_norm": None, | |
| "adjacent_cosine_coherence": None, | |
| "path_tortuosity": None, | |
| "rank1_explained_variance": None, | |
| "curvature_proxy": None, | |
| } | |
| deltas = delta_vectors.to(dtype=torch.float32) | |
| step_norms = deltas.norm(dim=-1) | |
| mean_direction = deltas.mean(dim=0) | |
| metrics: dict[str, float | int | None] = { | |
| "token_count": token_count, | |
| "delta_count": delta_count, | |
| "mean_direction_norm": float(mean_direction.norm().item()), | |
| "mean_step_norm": float(step_norms.mean().item()), | |
| "adjacent_cosine_coherence": None, | |
| "path_tortuosity": None, | |
| "rank1_explained_variance": None, | |
| "curvature_proxy": None, | |
| } | |
| if token_count < 4: | |
| return metrics | |
| adjacent: list[float] = [] | |
| for idx in range(delta_count - 1): | |
| left = deltas[idx] | |
| right = deltas[idx + 1] | |
| left_norm = float(left.norm().item()) | |
| right_norm = float(right.norm().item()) | |
| if left_norm <= _EPS or right_norm <= _EPS: | |
| continue | |
| adjacent.append(float(F.cosine_similarity(left.unsqueeze(0), right.unsqueeze(0), dim=-1).item())) | |
| if adjacent: | |
| coherence = float(sum(adjacent) / len(adjacent)) | |
| metrics["adjacent_cosine_coherence"] = coherence | |
| metrics["curvature_proxy"] = float(1.0 - coherence) | |
| displacement = deltas.sum(dim=0) | |
| displacement_norm = float(displacement.norm().item()) | |
| if displacement_norm > _EPS: | |
| metrics["path_tortuosity"] = float(step_norms.sum().item() / displacement_norm) | |
| singular_values = torch.linalg.svdvals(deltas) | |
| energy = float((singular_values.square().sum()).item()) | |
| if energy > _EPS: | |
| metrics["rank1_explained_variance"] = float((singular_values[0].item() ** 2) / energy) | |
| elif delta_count > 0: | |
| metrics["rank1_explained_variance"] = 1.0 | |
| return metrics | |
| def build_computability_flags(metrics: dict[str, float | int | None]) -> dict[str, bool]: | |
| return { | |
| "span_mean_direction": metrics.get("delta_count") is not None and int(metrics["delta_count"] or 0) > 0, | |
| "mean_direction_norm": metrics.get("mean_direction_norm") is not None, | |
| "mean_step_norm": metrics.get("mean_step_norm") is not None, | |
| "adjacent_cosine_coherence": metrics.get("adjacent_cosine_coherence") is not None, | |
| "path_tortuosity": metrics.get("path_tortuosity") is not None, | |
| "rank1_explained_variance": metrics.get("rank1_explained_variance") is not None, | |
| "curvature_proxy": metrics.get("curvature_proxy") is not None, | |
| } | |
| def compute_mean_direction( | |
| delta_vectors: torch.Tensor, | |
| ) -> torch.Tensor | None: | |
| if delta_vectors.ndim != 2: | |
| raise ValueError("delta_vectors must be shaped [delta_count, hidden_dim]") | |
| if delta_vectors.size(0) == 0: | |
| return None | |
| return delta_vectors.to(dtype=torch.float32).mean(dim=0) | |
| def compute_cross_prompt_stability( | |
| directions: list[torch.Tensor], | |
| ) -> dict[str, float | int | None]: | |
| valid = [] | |
| for direction in directions: | |
| norm = float(direction.norm().item()) | |
| if norm <= _EPS: | |
| continue | |
| valid.append(direction / norm) | |
| if len(valid) < 2: | |
| return { | |
| "pair_count": 0, | |
| "mean_pairwise_cosine": None, | |
| "std_pairwise_cosine": None, | |
| "min_pairwise_cosine": None, | |
| "max_pairwise_cosine": None, | |
| } | |
| values: list[float] = [] | |
| for idx in range(len(valid)): | |
| for jdx in range(idx + 1, len(valid)): | |
| values.append(float(F.cosine_similarity(valid[idx].unsqueeze(0), valid[jdx].unsqueeze(0), dim=-1).item())) | |
| mean_value = sum(values) / len(values) | |
| variance = sum((value - mean_value) ** 2 for value in values) / len(values) | |
| return { | |
| "pair_count": len(values), | |
| "mean_pairwise_cosine": float(mean_value), | |
| "std_pairwise_cosine": float(math.sqrt(max(variance, 0.0))), | |
| "min_pairwise_cosine": float(min(values)), | |
| "max_pairwise_cosine": float(max(values)), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Auto-calibration: k-means clustering for mature/template/flat thresholds | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def auto_calibrate_thresholds( | |
| r1_references: list[float], | |
| delta_templates: list[float], | |
| n_clusters: int = 3, | |
| max_iter: int = 50, | |
| seed: int = 42, | |
| ) -> dict[str, Any]: | |
| """Compute cluster thresholds from observed r1/delta values via k-means. | |
| Returns dict with: | |
| - mature_r1_threshold: float | |
| - template_delta_threshold: float | |
| - cluster_centers: list of (r1_center, delta_center) | |
| - cluster_labels: list of "mature"/"template"/"flat" per sample | |
| - n_samples: int | |
| """ | |
| n = len(r1_references) | |
| if n < 3 or len(delta_templates) != n: | |
| return { | |
| "mature_r1_threshold": 0.65, | |
| "template_delta_threshold": 0.08, | |
| "cluster_centers": [], | |
| "cluster_labels": [], | |
| "n_samples": n, | |
| "method": "fallback_default", | |
| } | |
| import random | |
| rng = random.Random(seed) | |
| # Normalize features to [0, 1] range for balanced clustering | |
| r1_arr = list(r1_references) | |
| dt_arr = list(delta_templates) | |
| r1_min, r1_max = min(r1_arr), max(r1_arr) | |
| dt_min, dt_max = min(dt_arr), max(dt_arr) | |
| r1_range = max(r1_max - r1_min, _EPS) | |
| dt_range = max(dt_max - dt_min, _EPS) | |
| points = [ | |
| ((r - r1_min) / r1_range, (d - dt_min) / dt_range) | |
| for r, d in zip(r1_arr, dt_arr) | |
| ] | |
| # k-means initialization | |
| indices = rng.sample(range(n), min(n_clusters, n)) | |
| centers = [points[i] for i in indices] | |
| for _ in range(max_iter): | |
| # Assign | |
| assignments: list[int] = [] | |
| for p in points: | |
| dists = [(p[0] - c[0]) ** 2 + (p[1] - c[1]) ** 2 for c in centers] | |
| assignments.append(int(min(range(len(dists)), key=lambda x: dists[x]))) | |
| # Update | |
| new_centers: list[tuple[float, float]] = [] | |
| for k in range(n_clusters): | |
| members = [points[i] for i in range(n) if assignments[i] == k] | |
| if not members: | |
| new_centers.append(centers[k]) | |
| else: | |
| new_centers.append(( | |
| sum(m[0] for m in members) / len(members), | |
| sum(m[1] for m in members) / len(members), | |
| )) | |
| if new_centers == centers: | |
| break | |
| centers = new_centers | |
| # Denormalize centers back to original scale | |
| real_centers = [ | |
| (c[0] * r1_range + r1_min, c[1] * dt_range + dt_min) | |
| for c in centers | |
| ] | |
| # Identify clusters: highest r1 center = mature, highest delta center = template, rest = flat | |
| sorted_by_r1 = sorted(range(len(real_centers)), key=lambda i: real_centers[i][0], reverse=True) | |
| mature_idx = sorted_by_r1[0] | |
| remaining = [i for i in range(len(real_centers)) if i != mature_idx] | |
| template_idx = max(remaining, key=lambda i: real_centers[i][1]) | |
| flat_idx = [i for i in remaining if i != template_idx][0] if len(remaining) > 1 else remaining[0] | |
| cluster_map = {mature_idx: "mature", template_idx: "template", flat_idx: "flat"} | |
| labels = [cluster_map.get(assignments[i], "flat") for i in range(n)] | |
| # Compute thresholds as midpoints between cluster centers | |
| mature_r1_center = real_centers[mature_idx][0] | |
| flat_r1_center = real_centers[flat_idx][0] | |
| mature_r1_threshold = (mature_r1_center + flat_r1_center) / 2.0 | |
| template_delta_center = real_centers[template_idx][1] | |
| flat_delta_center = real_centers[flat_idx][1] | |
| template_delta_threshold = (template_delta_center + flat_delta_center) / 2.0 | |
| return { | |
| "mature_r1_threshold": float(mature_r1_threshold), | |
| "template_delta_threshold": float(template_delta_threshold), | |
| "cluster_centers": [ | |
| {"label": cluster_map[i], "r1": real_centers[i][0], "delta": real_centers[i][1]} | |
| for i in range(len(real_centers)) | |
| ], | |
| "cluster_labels": labels, | |
| "n_samples": n, | |
| "method": "kmeans_3", | |
| } | |