Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import sqlite3 | |
| from pathlib import Path | |
| from typing import Any, Iterable | |
| SCHEMA = """ | |
| CREATE TABLE IF NOT EXISTS annotations ( | |
| model_name TEXT NOT NULL, | |
| layer TEXT NOT NULL, | |
| component INTEGER NOT NULL, | |
| label TEXT, | |
| confidence TEXT, | |
| positive_label TEXT, | |
| positive_confidence TEXT, | |
| negative_label TEXT, | |
| negative_confidence TEXT, | |
| positive_interpretation_types_json TEXT, | |
| negative_interpretation_types_json TEXT, | |
| interpretation_types_json TEXT, | |
| summary TEXT, | |
| notes TEXT, | |
| include_as_case_study INTEGER NOT NULL DEFAULT 0, | |
| updated_at TEXT NOT NULL, | |
| PRIMARY KEY (model_name, layer, component) | |
| ); | |
| CREATE TABLE IF NOT EXISTS chosen_random_components ( | |
| selection_name TEXT NOT NULL, | |
| model_name TEXT NOT NULL, | |
| selection_index INTEGER NOT NULL, | |
| layer TEXT NOT NULL, | |
| component INTEGER NOT NULL, | |
| component_id TEXT, | |
| source_json TEXT, | |
| seed INTEGER, | |
| requested_n INTEGER, | |
| inventory_size INTEGER, | |
| selected_size INTEGER, | |
| fit_converged INTEGER, | |
| fit_iterations INTEGER, | |
| fit_final_lim REAL, | |
| fit_final_lim_p95 REAL, | |
| fit_seed INTEGER, | |
| inserted_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
| PRIMARY KEY (selection_name, model_name, selection_index) | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_chosen_random_components_component | |
| ON chosen_random_components(model_name, layer, component); | |
| """ | |
| def connect(db_path: Path) -> sqlite3.Connection: | |
| db_path.parent.mkdir(parents=True, exist_ok=True) | |
| conn = sqlite3.connect(str(db_path), check_same_thread=False) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| return conn | |
| def init_db(conn: sqlite3.Connection) -> None: | |
| conn.executescript(SCHEMA) | |
| _ensure_annotation_columns(conn) | |
| conn.commit() | |
| def validate_db(conn: sqlite3.Connection, model_name: str | None = None) -> None: | |
| tables = {str(row[0]) for row in conn.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall()} | |
| missing = {"components", "examples", "annotations"} - tables | |
| if missing: | |
| raise RuntimeError(f"SQLite DB is missing required table(s): {', '.join(sorted(missing))}") | |
| if model_name: | |
| count = conn.execute("SELECT COUNT(*) FROM components WHERE model_name = ?", (model_name,)).fetchone()[0] | |
| if int(count) == 0: | |
| raise RuntimeError(f"SQLite DB has no components for model_name={model_name!r}") | |
| def list_models(conn: sqlite3.Connection) -> list[str]: | |
| rows = conn.execute( | |
| """ | |
| SELECT model_name FROM components | |
| UNION | |
| SELECT model_name FROM annotations | |
| ORDER BY model_name | |
| """ | |
| ).fetchall() | |
| return [str(row[0]) for row in rows] | |
| def list_layers(conn: sqlite3.Connection, model_name: str) -> list[str]: | |
| rows = conn.execute( | |
| "SELECT DISTINCT layer FROM components WHERE model_name = ?", | |
| (model_name,), | |
| ).fetchall() | |
| return sorted([str(row[0]) for row in rows], key=layer_sort_key) | |
| def list_components(conn: sqlite3.Connection, model_name: str, layer: str | None = None, component: int | None = None, search: str | None = None) -> list[dict[str, Any]]: | |
| erf_join, erf_select = _erf_join_sql(conn) | |
| rows = conn.execute( | |
| f""" | |
| SELECT c.layer, | |
| c.component, | |
| c.excess_kurtosis, | |
| {erf_select} | |
| a.positive_label, | |
| a.positive_confidence, | |
| a.negative_label, | |
| a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.summary, | |
| a.notes, | |
| a.include_as_case_study | |
| FROM components c | |
| LEFT JOIN annotations a | |
| ON a.model_name = c.model_name | |
| AND a.layer = c.layer | |
| AND a.component = c.component | |
| {erf_join} | |
| WHERE c.model_name = ? | |
| AND (? IS NULL OR c.layer = ?) | |
| AND (? IS NULL OR c.component = ?) | |
| """, | |
| (model_name, layer, layer, component, component), | |
| ).fetchall() | |
| query = str(search or "").strip().lower() | |
| out = [] | |
| for row in rows: | |
| item = { | |
| "layer": str(row["layer"]), | |
| "component": int(row["component"]), | |
| "excess_kurtosis": float(row["excess_kurtosis"]) if row["excess_kurtosis"] is not None else None, | |
| "effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None, | |
| **_annotation_row(row), | |
| } | |
| text = " ".join(str(item.get(key, "")) for key in ("component", "positive_label", "negative_label", "summary", "notes")).lower() | |
| if query and query not in text: | |
| continue | |
| out.append(item) | |
| return sorted(out, key=lambda row: (layer_sort_key(row["layer"]), row["component"])) | |
| def list_component_stats(conn: sqlite3.Connection, model_name: str) -> list[dict[str, Any]]: | |
| erf_join, erf_select = _erf_join_sql(conn) | |
| rows = conn.execute( | |
| f""" | |
| SELECT c.layer, | |
| c.component, | |
| c.excess_kurtosis, | |
| {erf_select} | |
| a.positive_label, | |
| a.positive_confidence, | |
| a.negative_label, | |
| a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.notes | |
| FROM components c | |
| LEFT JOIN annotations a | |
| ON a.model_name = c.model_name | |
| AND a.layer = c.layer | |
| AND a.component = c.component | |
| {erf_join} | |
| WHERE c.model_name = ? | |
| """, | |
| (model_name,), | |
| ).fetchall() | |
| return sorted( | |
| [ | |
| { | |
| "layer": str(row["layer"]), | |
| "effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None, | |
| **_annotation_row(row), | |
| } | |
| for row in rows | |
| ], | |
| key=lambda row: (layer_sort_key(row["layer"]), row["component"]), | |
| ) | |
| def list_annotated_components(conn: sqlite3.Connection, model_name: str, layer: str) -> list[dict[str, Any]]: | |
| erf_join, erf_select = _erf_join_sql(conn) | |
| rows = conn.execute( | |
| f""" | |
| SELECT a.component, | |
| {erf_select} | |
| a.positive_label, a.positive_confidence, | |
| a.negative_label, a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.notes, | |
| c.excess_kurtosis | |
| FROM annotations a | |
| LEFT JOIN components c | |
| ON c.model_name = a.model_name | |
| AND c.layer = a.layer | |
| AND c.component = a.component | |
| {erf_join} | |
| WHERE a.model_name = ? AND a.layer = ? | |
| AND ( | |
| (a.positive_label IS NOT NULL AND TRIM(a.positive_label) != '') | |
| OR (a.negative_label IS NOT NULL AND TRIM(a.negative_label) != '') | |
| ) | |
| ORDER BY a.component | |
| """, | |
| (model_name, layer), | |
| ).fetchall() | |
| out = [] | |
| for row in rows: | |
| item = _annotation_row(row) | |
| item["effective_context_mean"] = float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None | |
| out.append(item) | |
| return out | |
| def list_component_metadata(conn: sqlite3.Connection, model_name: str, layer: str, components: list[int]) -> list[dict[str, Any]]: | |
| if not components: | |
| return [] | |
| erf_join, erf_select = _erf_join_sql(conn) | |
| unique_components = sorted({int(component) for component in components}) | |
| placeholders = ",".join("?" for _ in unique_components) | |
| rows = conn.execute( | |
| f""" | |
| SELECT c.component, | |
| c.excess_kurtosis, | |
| {erf_select} | |
| a.positive_label, | |
| a.positive_confidence, | |
| a.negative_label, | |
| a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.summary, | |
| a.notes, | |
| a.include_as_case_study | |
| FROM components c | |
| LEFT JOIN annotations a | |
| ON a.model_name = c.model_name | |
| AND a.layer = c.layer | |
| AND a.component = c.component | |
| {erf_join} | |
| WHERE c.model_name = ? AND c.layer = ? AND c.component IN ({placeholders}) | |
| ORDER BY c.component | |
| """, | |
| (model_name, layer, *unique_components), | |
| ).fetchall() | |
| out = [] | |
| for row in rows: | |
| item = _annotation_row(row) | |
| item["effective_context_mean"] = float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None | |
| out.append(item) | |
| return out | |
| def list_component_neighbors(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> list[dict[str, Any]]: | |
| if not _table_exists(conn, "component_neighbors"): | |
| return [] | |
| neighbor_columns = _table_columns(conn, "component_neighbors") | |
| neighbor_sign_select = "n.neighbor_sign" if "neighbor_sign" in neighbor_columns else "CASE WHEN n.signed_cosine >= 0 THEN 1 ELSE -1 END" | |
| erf_join, erf_select = _component_neighbor_erf_join_sql(conn) | |
| rows = conn.execute( | |
| f""" | |
| SELECT n.direction, | |
| n.neighbor_layer, | |
| n.neighbor_component AS component, | |
| n.abs_cosine, | |
| n.signed_cosine, | |
| {neighbor_sign_select} AS neighbor_sign, | |
| n.source_n_components, | |
| n.neighbor_n_components, | |
| {erf_select} | |
| a.positive_label, | |
| a.positive_confidence, | |
| a.negative_label, | |
| a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.summary, | |
| a.notes, | |
| a.include_as_case_study, | |
| c.excess_kurtosis | |
| FROM component_neighbors n | |
| LEFT JOIN annotations a | |
| ON a.model_name = n.model_name | |
| AND a.layer = n.neighbor_layer | |
| AND a.component = n.neighbor_component | |
| LEFT JOIN components c | |
| ON c.model_name = n.model_name | |
| AND c.layer = n.neighbor_layer | |
| AND c.component = n.neighbor_component | |
| {erf_join} | |
| WHERE n.model_name = ? | |
| AND n.layer = ? | |
| AND n.component = ? | |
| ORDER BY CASE n.direction WHEN 'prev' THEN 0 WHEN 'next' THEN 1 ELSE 2 END | |
| """, | |
| (model_name, layer, int(component)), | |
| ).fetchall() | |
| out = [] | |
| for row in rows: | |
| item = _annotation_row(row) | |
| item.update( | |
| { | |
| "direction": str(row["direction"]), | |
| "neighbor_layer": str(row["neighbor_layer"]), | |
| "neighbor_component": int(row["component"]), | |
| "abs_cosine": float(row["abs_cosine"]), | |
| "signed_cosine": float(row["signed_cosine"]), | |
| "neighbor_sign": int(row["neighbor_sign"]), | |
| "source_n_components": int(row["source_n_components"]), | |
| "neighbor_n_components": int(row["neighbor_n_components"]), | |
| "effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None, | |
| } | |
| ) | |
| out.append(item) | |
| return out | |
| def list_chosen_random_components(conn: sqlite3.Connection, model_name: str | None = None, selection_name: str | None = None) -> list[dict[str, Any]]: | |
| if not _table_exists(conn, "chosen_random_components"): | |
| return [] | |
| erf_join, erf_select = _erf_join_sql(conn) | |
| rows = conn.execute( | |
| f""" | |
| SELECT r.selection_name, | |
| r.model_name, | |
| r.selection_index, | |
| r.layer, | |
| r.component, | |
| r.component_id, | |
| r.source_json, | |
| r.seed, | |
| r.requested_n, | |
| r.inventory_size, | |
| r.selected_size, | |
| r.fit_converged, | |
| r.fit_iterations, | |
| r.fit_final_lim, | |
| r.fit_final_lim_p95, | |
| r.fit_seed, | |
| c.excess_kurtosis, | |
| {erf_select} | |
| a.positive_label, | |
| a.positive_confidence, | |
| a.negative_label, | |
| a.negative_confidence, | |
| a.positive_interpretation_types_json, | |
| a.negative_interpretation_types_json, | |
| a.summary, | |
| a.notes, | |
| a.include_as_case_study | |
| FROM chosen_random_components r | |
| LEFT JOIN components c | |
| ON c.model_name = r.model_name | |
| AND c.layer = r.layer | |
| AND c.component = r.component | |
| LEFT JOIN annotations a | |
| ON a.model_name = r.model_name | |
| AND a.layer = r.layer | |
| AND a.component = r.component | |
| {erf_join} | |
| WHERE (? IS NULL OR r.model_name = ?) | |
| AND (? IS NULL OR r.selection_name = ?) | |
| ORDER BY r.model_name, r.selection_name, r.selection_index | |
| """, | |
| (model_name, model_name, selection_name, selection_name), | |
| ).fetchall() | |
| out = [] | |
| for row in rows: | |
| item = _annotation_row(row) | |
| item.update( | |
| { | |
| "selection_name": str(row["selection_name"]), | |
| "model_name": str(row["model_name"]), | |
| "selection_index": int(row["selection_index"]), | |
| "layer": str(row["layer"]), | |
| "component": int(row["component"]), | |
| "component_id": str(row["component_id"] or ""), | |
| "source_json": str(row["source_json"] or ""), | |
| "seed": int(row["seed"]) if row["seed"] is not None else None, | |
| "requested_n": int(row["requested_n"]) if row["requested_n"] is not None else None, | |
| "inventory_size": int(row["inventory_size"]) if row["inventory_size"] is not None else None, | |
| "selected_size": int(row["selected_size"]) if row["selected_size"] is not None else None, | |
| "fit_converged": bool(row["fit_converged"]) if row["fit_converged"] is not None else None, | |
| "fit_iterations": int(row["fit_iterations"]) if row["fit_iterations"] is not None else None, | |
| "fit_final_lim": float(row["fit_final_lim"]) if row["fit_final_lim"] is not None else None, | |
| "fit_final_lim_p95": float(row["fit_final_lim_p95"]) if row["fit_final_lim_p95"] is not None else None, | |
| "fit_seed": int(row["fit_seed"]) if row["fit_seed"] is not None else None, | |
| "effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None, | |
| } | |
| ) | |
| out.append(item) | |
| return out | |
| def list_component_examples(conn: sqlite3.Connection, model_name: str, layer: str | None = None, component: int | None = None, *, region: str | None = None, limit: int | None = None) -> dict[tuple[str, int], list[dict[str, Any]]]: | |
| rows = conn.execute( | |
| """ | |
| SELECT layer, | |
| component, | |
| region, | |
| rank, | |
| row_index, | |
| doc_id, | |
| token_id, | |
| token, | |
| source_score, | |
| direction_cosine, | |
| position, | |
| context_to_target, | |
| context, | |
| context_score_max_abs | |
| FROM examples | |
| WHERE model_name = ? | |
| AND (? IS NULL OR layer = ?) | |
| AND (? IS NULL OR component = ?) | |
| AND (? IS NULL OR region = ?) | |
| AND (? IS NULL OR rank <= ?) | |
| ORDER BY layer, component, region, rank | |
| """, | |
| (model_name, layer, layer, component, component, region, region, limit, limit), | |
| ).fetchall() | |
| examples: dict[tuple[str, int], list[dict[str, Any]]] = {} | |
| for row in rows: | |
| key = (str(row["layer"]), int(row["component"])) | |
| examples.setdefault(key, []).append(_example_row(row)) | |
| return examples | |
| def list_component_example_details(conn: sqlite3.Connection, model_name: str, *, layer: str, component: int) -> list[dict[str, Any]]: | |
| rows = conn.execute( | |
| """ | |
| SELECT id, | |
| region, | |
| rank, | |
| row_index, | |
| doc_id, | |
| token_id, | |
| token, | |
| source_score, | |
| direction_cosine, | |
| position, | |
| context_to_target, | |
| context, | |
| context_score_max_abs | |
| FROM examples | |
| WHERE model_name = ? AND layer = ? AND component = ? | |
| ORDER BY region, rank | |
| """, | |
| (model_name, layer, component), | |
| ).fetchall() | |
| examples = [] | |
| by_id: dict[int, dict[str, Any]] = {} | |
| for row in rows: | |
| example = _example_row(row) | |
| example["context_token_scores"] = [] | |
| examples.append(example) | |
| by_id[int(row["id"])] = example | |
| for ids in _chunks(list(by_id), 500): | |
| if not ids: | |
| continue | |
| placeholders = ",".join("?" for _ in ids) | |
| token_rows = conn.execute( | |
| f""" | |
| SELECT example_id, seq, token_position, token, source_score, direction_cosine, is_target | |
| FROM context_tokens | |
| WHERE example_id IN ({placeholders}) | |
| ORDER BY example_id, seq | |
| """, | |
| tuple(ids), | |
| ).fetchall() | |
| for token_row in token_rows: | |
| example = by_id.get(int(token_row["example_id"])) | |
| if example is None: | |
| continue | |
| example["context_token_scores"].append( | |
| { | |
| "position": int(token_row["token_position"]) if token_row["token_position"] is not None else None, | |
| "token": str(token_row["token"] or ""), | |
| "source_score": float(token_row["source_score"]) if token_row["source_score"] is not None else None, | |
| "direction_cosine": float(token_row["direction_cosine"]) if token_row["direction_cosine"] is not None else None, | |
| "is_target": bool(token_row["is_target"]), | |
| } | |
| ) | |
| return examples | |
| def get_component_row(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> dict[str, Any] | None: | |
| rows = list_components(conn, model_name, layer=layer, component=component) | |
| return rows[0] if rows else None | |
| def get_annotation(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> dict[str, Any] | None: | |
| row = conn.execute( | |
| """ | |
| SELECT * FROM annotations | |
| WHERE model_name = ? AND layer = ? AND component = ? | |
| """, | |
| (model_name, layer, component), | |
| ).fetchone() | |
| return dict(row) if row else None | |
| def update_annotation(conn: sqlite3.Connection, *, model_name: str, layer: str, component: int, positive_label: str, positive_confidence: str, positive_interpretation_types: list[str], negative_label: str, negative_confidence: str, negative_interpretation_types: list[str], summary: str, notes: str, include_as_case_study: bool) -> None: | |
| conn.execute( | |
| """ | |
| INSERT INTO annotations( | |
| model_name, layer, component, | |
| positive_label, positive_confidence, positive_interpretation_types_json, | |
| negative_label, negative_confidence, negative_interpretation_types_json, | |
| summary, notes, include_as_case_study, updated_at | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) | |
| ON CONFLICT(model_name, layer, component) DO UPDATE SET | |
| positive_label = excluded.positive_label, | |
| positive_confidence = excluded.positive_confidence, | |
| positive_interpretation_types_json = excluded.positive_interpretation_types_json, | |
| negative_label = excluded.negative_label, | |
| negative_confidence = excluded.negative_confidence, | |
| negative_interpretation_types_json = excluded.negative_interpretation_types_json, | |
| summary = excluded.summary, | |
| notes = excluded.notes, | |
| include_as_case_study = excluded.include_as_case_study, | |
| updated_at = datetime('now') | |
| """, | |
| ( | |
| model_name, | |
| layer, | |
| int(component), | |
| positive_label, | |
| _normalize_confidence(positive_confidence), | |
| json.dumps([str(item) for item in positive_interpretation_types if str(item)]), | |
| negative_label, | |
| _normalize_confidence(negative_confidence), | |
| json.dumps([str(item) for item in negative_interpretation_types if str(item)]), | |
| summary, | |
| notes, | |
| 1 if include_as_case_study else 0, | |
| ), | |
| ) | |
| conn.commit() | |
| def infer_default_annotation_sign(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> int: | |
| row = conn.execute( | |
| """ | |
| SELECT source_score | |
| FROM examples | |
| WHERE model_name = ? AND layer = ? AND component = ? | |
| AND region IN ('top_abs', 'top_abs_sample_500', 'top_abs_sample_5000') | |
| ORDER BY ABS(source_score) DESC, rank ASC | |
| LIMIT 1 | |
| """, | |
| (model_name, layer, component), | |
| ).fetchone() | |
| if row is None or row["source_score"] is None: | |
| return 1 | |
| return 1 if float(row["source_score"]) >= 0 else -1 | |
| def get_examples_by_region(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> tuple[list[str], dict[str, list[dict[str, Any]]]]: | |
| examples = list_component_example_details(conn, model_name, layer=layer, component=component) | |
| by_region: dict[str, list[dict[str, Any]]] = {} | |
| for example in examples: | |
| by_region.setdefault(str(example["region"] or "examples"), []).append(example) | |
| regions = sorted(by_region, key=_example_band_sort_key) | |
| return regions, by_region | |
| def pick_default_region(regions: list[str], examples_by_region: dict[str, list[dict[str, Any]]]) -> str: | |
| for candidate in ("top_abs", "top_abs_sample_500", "top_abs_sample_5000"): | |
| if candidate in examples_by_region: | |
| return candidate | |
| return regions[0] if regions else "" | |
| def search_components(conn: sqlite3.Connection, *, model_name: str | None, query: str, confidence: str, annotation_type: str, include_examples: bool, limit: int) -> list[dict[str, Any]]: | |
| models = [model_name] if model_name else list_models(conn) | |
| results = [] | |
| query_l = query.strip().lower() | |
| for model in models: | |
| for item in list_components(conn, model): | |
| labels = [item.get("positive_label", ""), item.get("negative_label", ""), item.get("summary", ""), item.get("notes", "")] | |
| haystack = " ".join(str(x) for x in labels).lower() | |
| if query_l and query_l not in haystack: | |
| continue | |
| if confidence and confidence not in {item.get("positive_confidence"), item.get("negative_confidence")}: | |
| continue | |
| if annotation_type and annotation_type not in set(item.get("positive_types", []) + item.get("negative_types", [])): | |
| continue | |
| row = {"model_name": model, **item} | |
| if include_examples: | |
| ex = list_component_examples(conn, model, layer=item["layer"], component=item["component"], limit=3).get((item["layer"], item["component"]), []) | |
| row["examples"] = ex | |
| results.append(row) | |
| if len(results) >= limit: | |
| return results | |
| return results | |
| def layer_sort_key(layer: str) -> tuple[int, int | str]: | |
| if layer == "embedding": | |
| return (0, 0) | |
| if layer.startswith("layer_"): | |
| return (1, int(layer.removeprefix("layer_"))) | |
| return (2, layer) | |
| def _erf_join_sql(conn: sqlite3.Connection) -> tuple[str, str]: | |
| table = _first_existing_table(conn, ["effective_receptive_fields", "effective_context_lengths"]) | |
| if table is None: | |
| return "", "NULL AS effective_context_mean," | |
| columns = _table_columns(conn, table) | |
| value_column = _first_existing_name(columns, ["mean_erf", "erf_mean", "mean_length", "effective_receptive_field_mean", "effective_context_mean"]) | |
| if value_column is None: | |
| return "", "NULL AS effective_context_mean," | |
| quoted_table = _quote_identifier(table) | |
| quoted_value = _quote_identifier(value_column) | |
| return ( | |
| f""" | |
| LEFT JOIN {quoted_table} erf | |
| ON erf.model_name = c.model_name | |
| AND erf.layer = c.layer | |
| AND erf.component = c.component | |
| """, | |
| f"erf.{quoted_value} AS effective_context_mean,", | |
| ) | |
| def _component_neighbor_erf_join_sql(conn: sqlite3.Connection) -> tuple[str, str]: | |
| table = _first_existing_table(conn, ["effective_receptive_fields", "effective_context_lengths"]) | |
| if table is None: | |
| return "", "NULL AS effective_context_mean," | |
| columns = _table_columns(conn, table) | |
| value_column = _first_existing_name(columns, ["mean_erf", "erf_mean", "mean_length", "effective_receptive_field_mean", "effective_context_mean"]) | |
| if value_column is None: | |
| return "", "NULL AS effective_context_mean," | |
| quoted_table = _quote_identifier(table) | |
| quoted_value = _quote_identifier(value_column) | |
| return ( | |
| f""" | |
| LEFT JOIN {quoted_table} erf | |
| ON erf.model_name = n.model_name | |
| AND erf.layer = n.neighbor_layer | |
| AND erf.component = n.neighbor_component | |
| """, | |
| f"erf.{quoted_value} AS effective_context_mean,", | |
| ) | |
| def _first_existing_table(conn: sqlite3.Connection, names: list[str]) -> str | None: | |
| for name in names: | |
| if _table_exists(conn, name): | |
| return name | |
| return None | |
| def _table_columns(conn: sqlite3.Connection, table_name: str) -> set[str]: | |
| return {str(row["name"]) for row in conn.execute(f"PRAGMA table_info({_quote_identifier(table_name)})").fetchall()} | |
| def _first_existing_name(names: set[str], candidates: list[str]) -> str | None: | |
| for candidate in candidates: | |
| if candidate in names: | |
| return candidate | |
| return None | |
| def _quote_identifier(name: str) -> str: | |
| return '"' + str(name).replace('"', '""') + '"' | |
| def _table_exists(conn: sqlite3.Connection, table_name: str) -> bool: | |
| return conn.execute("SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone() is not None | |
| def _ensure_annotation_columns(conn: sqlite3.Connection) -> None: | |
| columns = {str(row["name"]) for row in conn.execute("PRAGMA table_info(annotations)").fetchall()} | |
| required = { | |
| "label": "TEXT", | |
| "confidence": "TEXT", | |
| "positive_label": "TEXT", | |
| "positive_confidence": "TEXT", | |
| "negative_label": "TEXT", | |
| "negative_confidence": "TEXT", | |
| "positive_interpretation_types_json": "TEXT", | |
| "negative_interpretation_types_json": "TEXT", | |
| "interpretation_types_json": "TEXT", | |
| "summary": "TEXT", | |
| "notes": "TEXT", | |
| "include_as_case_study": "INTEGER NOT NULL DEFAULT 0", | |
| "updated_at": "TEXT", | |
| } | |
| for column, ddl in required.items(): | |
| if column not in columns: | |
| conn.execute(f"ALTER TABLE annotations ADD COLUMN {column} {ddl}") | |
| def _annotation_row(row: sqlite3.Row | dict[str, Any]) -> dict[str, Any]: | |
| keys = row.keys() if hasattr(row, "keys") else row | |
| def get(key: str, default: Any = None) -> Any: | |
| return row[key] if key in keys else default | |
| positive_label = str(get("positive_label") or "") | |
| negative_label = str(get("negative_label") or "") | |
| auto_annotated = "auto-annotation" in str(get("notes") or "").lower() | |
| return { | |
| "component": int(get("component")), | |
| "positive_label": positive_label, | |
| "positive_confidence": _normalize_confidence(get("positive_confidence")), | |
| "negative_label": negative_label, | |
| "negative_confidence": _normalize_confidence(get("negative_confidence")), | |
| "positive_types": _json_list(get("positive_interpretation_types_json")), | |
| "negative_types": _json_list(get("negative_interpretation_types_json")), | |
| "summary": str(get("summary") or ""), | |
| "notes": str(get("notes") or ""), | |
| "include_as_case_study": bool(get("include_as_case_study") or 0), | |
| "excess_kurtosis": float(get("excess_kurtosis")) if get("excess_kurtosis") is not None else None, | |
| "auto_annotated": auto_annotated, | |
| } | |
| def _example_row(row: sqlite3.Row) -> dict[str, Any]: | |
| return { | |
| "region": str(row["region"] or ""), | |
| "rank": int(row["rank"]), | |
| "row_index": int(row["row_index"]) if "row_index" in row.keys() and row["row_index"] is not None else None, | |
| "doc_id": int(row["doc_id"]) if "doc_id" in row.keys() and row["doc_id"] is not None else None, | |
| "token_id": int(row["token_id"]) if "token_id" in row.keys() and row["token_id"] is not None else None, | |
| "token": str(row["token"] or ""), | |
| "source_score": float(row["source_score"]) if row["source_score"] is not None else None, | |
| "direction_cosine": float(row["direction_cosine"]) if row["direction_cosine"] is not None else None, | |
| "position": int(row["position"]) if row["position"] is not None else None, | |
| "context_to_target": str(row["context_to_target"] or ""), | |
| "context": str(row["context"] or ""), | |
| "context_score_max_abs": float(row["context_score_max_abs"]) if row["context_score_max_abs"] is not None else None, | |
| } | |
| def _json_list(value: Any) -> list[str]: | |
| try: | |
| parsed = json.loads(value or "[]") | |
| except json.JSONDecodeError: | |
| return [] | |
| return [str(item) for item in parsed] if isinstance(parsed, list) else [] | |
| def _normalize_confidence(value: Any) -> str: | |
| text = str(value or "unclear").strip().lower() | |
| return text if text in {"high", "medium", "low", "unclear"} else "unclear" | |
| def _chunks(items: list[int], size: int) -> Iterable[list[int]]: | |
| for index in range(0, len(items), size): | |
| yield items[index : index + size] | |
| def _example_band_sort_key(region: str) -> tuple[int, str]: | |
| name = str(region or "").lower().replace("-", "_") | |
| compact = name.replace("_", "") | |
| if name == "top_abs" or compact == "topabs": | |
| return (0, name) | |
| if "sample" in name: | |
| return (1, name) | |
| if "near_zero" in name or "nearzero" in compact: | |
| return (2, name) | |
| if "opposite" in name: | |
| return (3, name) | |
| return (4, name) | |