ICAExplorer / server /store.py
sida's picture
Deploy ICA explorer app
34d520a
from __future__ import annotations
import json
import sqlite3
from pathlib import Path
from typing import Any, Iterable
SCHEMA = """
CREATE TABLE IF NOT EXISTS annotations (
model_name TEXT NOT NULL,
layer TEXT NOT NULL,
component INTEGER NOT NULL,
label TEXT,
confidence TEXT,
positive_label TEXT,
positive_confidence TEXT,
negative_label TEXT,
negative_confidence TEXT,
positive_interpretation_types_json TEXT,
negative_interpretation_types_json TEXT,
interpretation_types_json TEXT,
summary TEXT,
notes TEXT,
include_as_case_study INTEGER NOT NULL DEFAULT 0,
updated_at TEXT NOT NULL,
PRIMARY KEY (model_name, layer, component)
);
CREATE TABLE IF NOT EXISTS chosen_random_components (
selection_name TEXT NOT NULL,
model_name TEXT NOT NULL,
selection_index INTEGER NOT NULL,
layer TEXT NOT NULL,
component INTEGER NOT NULL,
component_id TEXT,
source_json TEXT,
seed INTEGER,
requested_n INTEGER,
inventory_size INTEGER,
selected_size INTEGER,
fit_converged INTEGER,
fit_iterations INTEGER,
fit_final_lim REAL,
fit_final_lim_p95 REAL,
fit_seed INTEGER,
inserted_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (selection_name, model_name, selection_index)
);
CREATE INDEX IF NOT EXISTS idx_chosen_random_components_component
ON chosen_random_components(model_name, layer, component);
"""
def connect(db_path: Path) -> sqlite3.Connection:
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn
def init_db(conn: sqlite3.Connection) -> None:
conn.executescript(SCHEMA)
_ensure_annotation_columns(conn)
conn.commit()
def validate_db(conn: sqlite3.Connection, model_name: str | None = None) -> None:
tables = {str(row[0]) for row in conn.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall()}
missing = {"components", "examples", "annotations"} - tables
if missing:
raise RuntimeError(f"SQLite DB is missing required table(s): {', '.join(sorted(missing))}")
if model_name:
count = conn.execute("SELECT COUNT(*) FROM components WHERE model_name = ?", (model_name,)).fetchone()[0]
if int(count) == 0:
raise RuntimeError(f"SQLite DB has no components for model_name={model_name!r}")
def list_models(conn: sqlite3.Connection) -> list[str]:
rows = conn.execute(
"""
SELECT model_name FROM components
UNION
SELECT model_name FROM annotations
ORDER BY model_name
"""
).fetchall()
return [str(row[0]) for row in rows]
def list_layers(conn: sqlite3.Connection, model_name: str) -> list[str]:
rows = conn.execute(
"SELECT DISTINCT layer FROM components WHERE model_name = ?",
(model_name,),
).fetchall()
return sorted([str(row[0]) for row in rows], key=layer_sort_key)
def list_components(conn: sqlite3.Connection, model_name: str, layer: str | None = None, component: int | None = None, search: str | None = None) -> list[dict[str, Any]]:
erf_join, erf_select = _erf_join_sql(conn)
rows = conn.execute(
f"""
SELECT c.layer,
c.component,
c.excess_kurtosis,
{erf_select}
a.positive_label,
a.positive_confidence,
a.negative_label,
a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.summary,
a.notes,
a.include_as_case_study
FROM components c
LEFT JOIN annotations a
ON a.model_name = c.model_name
AND a.layer = c.layer
AND a.component = c.component
{erf_join}
WHERE c.model_name = ?
AND (? IS NULL OR c.layer = ?)
AND (? IS NULL OR c.component = ?)
""",
(model_name, layer, layer, component, component),
).fetchall()
query = str(search or "").strip().lower()
out = []
for row in rows:
item = {
"layer": str(row["layer"]),
"component": int(row["component"]),
"excess_kurtosis": float(row["excess_kurtosis"]) if row["excess_kurtosis"] is not None else None,
"effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None,
**_annotation_row(row),
}
text = " ".join(str(item.get(key, "")) for key in ("component", "positive_label", "negative_label", "summary", "notes")).lower()
if query and query not in text:
continue
out.append(item)
return sorted(out, key=lambda row: (layer_sort_key(row["layer"]), row["component"]))
def list_component_stats(conn: sqlite3.Connection, model_name: str) -> list[dict[str, Any]]:
erf_join, erf_select = _erf_join_sql(conn)
rows = conn.execute(
f"""
SELECT c.layer,
c.component,
c.excess_kurtosis,
{erf_select}
a.positive_label,
a.positive_confidence,
a.negative_label,
a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.notes
FROM components c
LEFT JOIN annotations a
ON a.model_name = c.model_name
AND a.layer = c.layer
AND a.component = c.component
{erf_join}
WHERE c.model_name = ?
""",
(model_name,),
).fetchall()
return sorted(
[
{
"layer": str(row["layer"]),
"effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None,
**_annotation_row(row),
}
for row in rows
],
key=lambda row: (layer_sort_key(row["layer"]), row["component"]),
)
def list_annotated_components(conn: sqlite3.Connection, model_name: str, layer: str) -> list[dict[str, Any]]:
erf_join, erf_select = _erf_join_sql(conn)
rows = conn.execute(
f"""
SELECT a.component,
{erf_select}
a.positive_label, a.positive_confidence,
a.negative_label, a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.notes,
c.excess_kurtosis
FROM annotations a
LEFT JOIN components c
ON c.model_name = a.model_name
AND c.layer = a.layer
AND c.component = a.component
{erf_join}
WHERE a.model_name = ? AND a.layer = ?
AND (
(a.positive_label IS NOT NULL AND TRIM(a.positive_label) != '')
OR (a.negative_label IS NOT NULL AND TRIM(a.negative_label) != '')
)
ORDER BY a.component
""",
(model_name, layer),
).fetchall()
out = []
for row in rows:
item = _annotation_row(row)
item["effective_context_mean"] = float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None
out.append(item)
return out
def list_component_metadata(conn: sqlite3.Connection, model_name: str, layer: str, components: list[int]) -> list[dict[str, Any]]:
if not components:
return []
erf_join, erf_select = _erf_join_sql(conn)
unique_components = sorted({int(component) for component in components})
placeholders = ",".join("?" for _ in unique_components)
rows = conn.execute(
f"""
SELECT c.component,
c.excess_kurtosis,
{erf_select}
a.positive_label,
a.positive_confidence,
a.negative_label,
a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.summary,
a.notes,
a.include_as_case_study
FROM components c
LEFT JOIN annotations a
ON a.model_name = c.model_name
AND a.layer = c.layer
AND a.component = c.component
{erf_join}
WHERE c.model_name = ? AND c.layer = ? AND c.component IN ({placeholders})
ORDER BY c.component
""",
(model_name, layer, *unique_components),
).fetchall()
out = []
for row in rows:
item = _annotation_row(row)
item["effective_context_mean"] = float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None
out.append(item)
return out
def list_component_neighbors(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> list[dict[str, Any]]:
if not _table_exists(conn, "component_neighbors"):
return []
neighbor_columns = _table_columns(conn, "component_neighbors")
neighbor_sign_select = "n.neighbor_sign" if "neighbor_sign" in neighbor_columns else "CASE WHEN n.signed_cosine >= 0 THEN 1 ELSE -1 END"
erf_join, erf_select = _component_neighbor_erf_join_sql(conn)
rows = conn.execute(
f"""
SELECT n.direction,
n.neighbor_layer,
n.neighbor_component AS component,
n.abs_cosine,
n.signed_cosine,
{neighbor_sign_select} AS neighbor_sign,
n.source_n_components,
n.neighbor_n_components,
{erf_select}
a.positive_label,
a.positive_confidence,
a.negative_label,
a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.summary,
a.notes,
a.include_as_case_study,
c.excess_kurtosis
FROM component_neighbors n
LEFT JOIN annotations a
ON a.model_name = n.model_name
AND a.layer = n.neighbor_layer
AND a.component = n.neighbor_component
LEFT JOIN components c
ON c.model_name = n.model_name
AND c.layer = n.neighbor_layer
AND c.component = n.neighbor_component
{erf_join}
WHERE n.model_name = ?
AND n.layer = ?
AND n.component = ?
ORDER BY CASE n.direction WHEN 'prev' THEN 0 WHEN 'next' THEN 1 ELSE 2 END
""",
(model_name, layer, int(component)),
).fetchall()
out = []
for row in rows:
item = _annotation_row(row)
item.update(
{
"direction": str(row["direction"]),
"neighbor_layer": str(row["neighbor_layer"]),
"neighbor_component": int(row["component"]),
"abs_cosine": float(row["abs_cosine"]),
"signed_cosine": float(row["signed_cosine"]),
"neighbor_sign": int(row["neighbor_sign"]),
"source_n_components": int(row["source_n_components"]),
"neighbor_n_components": int(row["neighbor_n_components"]),
"effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None,
}
)
out.append(item)
return out
def list_chosen_random_components(conn: sqlite3.Connection, model_name: str | None = None, selection_name: str | None = None) -> list[dict[str, Any]]:
if not _table_exists(conn, "chosen_random_components"):
return []
erf_join, erf_select = _erf_join_sql(conn)
rows = conn.execute(
f"""
SELECT r.selection_name,
r.model_name,
r.selection_index,
r.layer,
r.component,
r.component_id,
r.source_json,
r.seed,
r.requested_n,
r.inventory_size,
r.selected_size,
r.fit_converged,
r.fit_iterations,
r.fit_final_lim,
r.fit_final_lim_p95,
r.fit_seed,
c.excess_kurtosis,
{erf_select}
a.positive_label,
a.positive_confidence,
a.negative_label,
a.negative_confidence,
a.positive_interpretation_types_json,
a.negative_interpretation_types_json,
a.summary,
a.notes,
a.include_as_case_study
FROM chosen_random_components r
LEFT JOIN components c
ON c.model_name = r.model_name
AND c.layer = r.layer
AND c.component = r.component
LEFT JOIN annotations a
ON a.model_name = r.model_name
AND a.layer = r.layer
AND a.component = r.component
{erf_join}
WHERE (? IS NULL OR r.model_name = ?)
AND (? IS NULL OR r.selection_name = ?)
ORDER BY r.model_name, r.selection_name, r.selection_index
""",
(model_name, model_name, selection_name, selection_name),
).fetchall()
out = []
for row in rows:
item = _annotation_row(row)
item.update(
{
"selection_name": str(row["selection_name"]),
"model_name": str(row["model_name"]),
"selection_index": int(row["selection_index"]),
"layer": str(row["layer"]),
"component": int(row["component"]),
"component_id": str(row["component_id"] or ""),
"source_json": str(row["source_json"] or ""),
"seed": int(row["seed"]) if row["seed"] is not None else None,
"requested_n": int(row["requested_n"]) if row["requested_n"] is not None else None,
"inventory_size": int(row["inventory_size"]) if row["inventory_size"] is not None else None,
"selected_size": int(row["selected_size"]) if row["selected_size"] is not None else None,
"fit_converged": bool(row["fit_converged"]) if row["fit_converged"] is not None else None,
"fit_iterations": int(row["fit_iterations"]) if row["fit_iterations"] is not None else None,
"fit_final_lim": float(row["fit_final_lim"]) if row["fit_final_lim"] is not None else None,
"fit_final_lim_p95": float(row["fit_final_lim_p95"]) if row["fit_final_lim_p95"] is not None else None,
"fit_seed": int(row["fit_seed"]) if row["fit_seed"] is not None else None,
"effective_context_mean": float(row["effective_context_mean"]) if row["effective_context_mean"] is not None else None,
}
)
out.append(item)
return out
def list_component_examples(conn: sqlite3.Connection, model_name: str, layer: str | None = None, component: int | None = None, *, region: str | None = None, limit: int | None = None) -> dict[tuple[str, int], list[dict[str, Any]]]:
rows = conn.execute(
"""
SELECT layer,
component,
region,
rank,
row_index,
doc_id,
token_id,
token,
source_score,
direction_cosine,
position,
context_to_target,
context,
context_score_max_abs
FROM examples
WHERE model_name = ?
AND (? IS NULL OR layer = ?)
AND (? IS NULL OR component = ?)
AND (? IS NULL OR region = ?)
AND (? IS NULL OR rank <= ?)
ORDER BY layer, component, region, rank
""",
(model_name, layer, layer, component, component, region, region, limit, limit),
).fetchall()
examples: dict[tuple[str, int], list[dict[str, Any]]] = {}
for row in rows:
key = (str(row["layer"]), int(row["component"]))
examples.setdefault(key, []).append(_example_row(row))
return examples
def list_component_example_details(conn: sqlite3.Connection, model_name: str, *, layer: str, component: int) -> list[dict[str, Any]]:
rows = conn.execute(
"""
SELECT id,
region,
rank,
row_index,
doc_id,
token_id,
token,
source_score,
direction_cosine,
position,
context_to_target,
context,
context_score_max_abs
FROM examples
WHERE model_name = ? AND layer = ? AND component = ?
ORDER BY region, rank
""",
(model_name, layer, component),
).fetchall()
examples = []
by_id: dict[int, dict[str, Any]] = {}
for row in rows:
example = _example_row(row)
example["context_token_scores"] = []
examples.append(example)
by_id[int(row["id"])] = example
for ids in _chunks(list(by_id), 500):
if not ids:
continue
placeholders = ",".join("?" for _ in ids)
token_rows = conn.execute(
f"""
SELECT example_id, seq, token_position, token, source_score, direction_cosine, is_target
FROM context_tokens
WHERE example_id IN ({placeholders})
ORDER BY example_id, seq
""",
tuple(ids),
).fetchall()
for token_row in token_rows:
example = by_id.get(int(token_row["example_id"]))
if example is None:
continue
example["context_token_scores"].append(
{
"position": int(token_row["token_position"]) if token_row["token_position"] is not None else None,
"token": str(token_row["token"] or ""),
"source_score": float(token_row["source_score"]) if token_row["source_score"] is not None else None,
"direction_cosine": float(token_row["direction_cosine"]) if token_row["direction_cosine"] is not None else None,
"is_target": bool(token_row["is_target"]),
}
)
return examples
def get_component_row(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> dict[str, Any] | None:
rows = list_components(conn, model_name, layer=layer, component=component)
return rows[0] if rows else None
def get_annotation(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> dict[str, Any] | None:
row = conn.execute(
"""
SELECT * FROM annotations
WHERE model_name = ? AND layer = ? AND component = ?
""",
(model_name, layer, component),
).fetchone()
return dict(row) if row else None
def update_annotation(conn: sqlite3.Connection, *, model_name: str, layer: str, component: int, positive_label: str, positive_confidence: str, positive_interpretation_types: list[str], negative_label: str, negative_confidence: str, negative_interpretation_types: list[str], summary: str, notes: str, include_as_case_study: bool) -> None:
conn.execute(
"""
INSERT INTO annotations(
model_name, layer, component,
positive_label, positive_confidence, positive_interpretation_types_json,
negative_label, negative_confidence, negative_interpretation_types_json,
summary, notes, include_as_case_study, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(model_name, layer, component) DO UPDATE SET
positive_label = excluded.positive_label,
positive_confidence = excluded.positive_confidence,
positive_interpretation_types_json = excluded.positive_interpretation_types_json,
negative_label = excluded.negative_label,
negative_confidence = excluded.negative_confidence,
negative_interpretation_types_json = excluded.negative_interpretation_types_json,
summary = excluded.summary,
notes = excluded.notes,
include_as_case_study = excluded.include_as_case_study,
updated_at = datetime('now')
""",
(
model_name,
layer,
int(component),
positive_label,
_normalize_confidence(positive_confidence),
json.dumps([str(item) for item in positive_interpretation_types if str(item)]),
negative_label,
_normalize_confidence(negative_confidence),
json.dumps([str(item) for item in negative_interpretation_types if str(item)]),
summary,
notes,
1 if include_as_case_study else 0,
),
)
conn.commit()
def infer_default_annotation_sign(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> int:
row = conn.execute(
"""
SELECT source_score
FROM examples
WHERE model_name = ? AND layer = ? AND component = ?
AND region IN ('top_abs', 'top_abs_sample_500', 'top_abs_sample_5000')
ORDER BY ABS(source_score) DESC, rank ASC
LIMIT 1
""",
(model_name, layer, component),
).fetchone()
if row is None or row["source_score"] is None:
return 1
return 1 if float(row["source_score"]) >= 0 else -1
def get_examples_by_region(conn: sqlite3.Connection, model_name: str, layer: str, component: int) -> tuple[list[str], dict[str, list[dict[str, Any]]]]:
examples = list_component_example_details(conn, model_name, layer=layer, component=component)
by_region: dict[str, list[dict[str, Any]]] = {}
for example in examples:
by_region.setdefault(str(example["region"] or "examples"), []).append(example)
regions = sorted(by_region, key=_example_band_sort_key)
return regions, by_region
def pick_default_region(regions: list[str], examples_by_region: dict[str, list[dict[str, Any]]]) -> str:
for candidate in ("top_abs", "top_abs_sample_500", "top_abs_sample_5000"):
if candidate in examples_by_region:
return candidate
return regions[0] if regions else ""
def search_components(conn: sqlite3.Connection, *, model_name: str | None, query: str, confidence: str, annotation_type: str, include_examples: bool, limit: int) -> list[dict[str, Any]]:
models = [model_name] if model_name else list_models(conn)
results = []
query_l = query.strip().lower()
for model in models:
for item in list_components(conn, model):
labels = [item.get("positive_label", ""), item.get("negative_label", ""), item.get("summary", ""), item.get("notes", "")]
haystack = " ".join(str(x) for x in labels).lower()
if query_l and query_l not in haystack:
continue
if confidence and confidence not in {item.get("positive_confidence"), item.get("negative_confidence")}:
continue
if annotation_type and annotation_type not in set(item.get("positive_types", []) + item.get("negative_types", [])):
continue
row = {"model_name": model, **item}
if include_examples:
ex = list_component_examples(conn, model, layer=item["layer"], component=item["component"], limit=3).get((item["layer"], item["component"]), [])
row["examples"] = ex
results.append(row)
if len(results) >= limit:
return results
return results
def layer_sort_key(layer: str) -> tuple[int, int | str]:
if layer == "embedding":
return (0, 0)
if layer.startswith("layer_"):
return (1, int(layer.removeprefix("layer_")))
return (2, layer)
def _erf_join_sql(conn: sqlite3.Connection) -> tuple[str, str]:
table = _first_existing_table(conn, ["effective_receptive_fields", "effective_context_lengths"])
if table is None:
return "", "NULL AS effective_context_mean,"
columns = _table_columns(conn, table)
value_column = _first_existing_name(columns, ["mean_erf", "erf_mean", "mean_length", "effective_receptive_field_mean", "effective_context_mean"])
if value_column is None:
return "", "NULL AS effective_context_mean,"
quoted_table = _quote_identifier(table)
quoted_value = _quote_identifier(value_column)
return (
f"""
LEFT JOIN {quoted_table} erf
ON erf.model_name = c.model_name
AND erf.layer = c.layer
AND erf.component = c.component
""",
f"erf.{quoted_value} AS effective_context_mean,",
)
def _component_neighbor_erf_join_sql(conn: sqlite3.Connection) -> tuple[str, str]:
table = _first_existing_table(conn, ["effective_receptive_fields", "effective_context_lengths"])
if table is None:
return "", "NULL AS effective_context_mean,"
columns = _table_columns(conn, table)
value_column = _first_existing_name(columns, ["mean_erf", "erf_mean", "mean_length", "effective_receptive_field_mean", "effective_context_mean"])
if value_column is None:
return "", "NULL AS effective_context_mean,"
quoted_table = _quote_identifier(table)
quoted_value = _quote_identifier(value_column)
return (
f"""
LEFT JOIN {quoted_table} erf
ON erf.model_name = n.model_name
AND erf.layer = n.neighbor_layer
AND erf.component = n.neighbor_component
""",
f"erf.{quoted_value} AS effective_context_mean,",
)
def _first_existing_table(conn: sqlite3.Connection, names: list[str]) -> str | None:
for name in names:
if _table_exists(conn, name):
return name
return None
def _table_columns(conn: sqlite3.Connection, table_name: str) -> set[str]:
return {str(row["name"]) for row in conn.execute(f"PRAGMA table_info({_quote_identifier(table_name)})").fetchall()}
def _first_existing_name(names: set[str], candidates: list[str]) -> str | None:
for candidate in candidates:
if candidate in names:
return candidate
return None
def _quote_identifier(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
def _table_exists(conn: sqlite3.Connection, table_name: str) -> bool:
return conn.execute("SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone() is not None
def _ensure_annotation_columns(conn: sqlite3.Connection) -> None:
columns = {str(row["name"]) for row in conn.execute("PRAGMA table_info(annotations)").fetchall()}
required = {
"label": "TEXT",
"confidence": "TEXT",
"positive_label": "TEXT",
"positive_confidence": "TEXT",
"negative_label": "TEXT",
"negative_confidence": "TEXT",
"positive_interpretation_types_json": "TEXT",
"negative_interpretation_types_json": "TEXT",
"interpretation_types_json": "TEXT",
"summary": "TEXT",
"notes": "TEXT",
"include_as_case_study": "INTEGER NOT NULL DEFAULT 0",
"updated_at": "TEXT",
}
for column, ddl in required.items():
if column not in columns:
conn.execute(f"ALTER TABLE annotations ADD COLUMN {column} {ddl}")
def _annotation_row(row: sqlite3.Row | dict[str, Any]) -> dict[str, Any]:
keys = row.keys() if hasattr(row, "keys") else row
def get(key: str, default: Any = None) -> Any:
return row[key] if key in keys else default
positive_label = str(get("positive_label") or "")
negative_label = str(get("negative_label") or "")
auto_annotated = "auto-annotation" in str(get("notes") or "").lower()
return {
"component": int(get("component")),
"positive_label": positive_label,
"positive_confidence": _normalize_confidence(get("positive_confidence")),
"negative_label": negative_label,
"negative_confidence": _normalize_confidence(get("negative_confidence")),
"positive_types": _json_list(get("positive_interpretation_types_json")),
"negative_types": _json_list(get("negative_interpretation_types_json")),
"summary": str(get("summary") or ""),
"notes": str(get("notes") or ""),
"include_as_case_study": bool(get("include_as_case_study") or 0),
"excess_kurtosis": float(get("excess_kurtosis")) if get("excess_kurtosis") is not None else None,
"auto_annotated": auto_annotated,
}
def _example_row(row: sqlite3.Row) -> dict[str, Any]:
return {
"region": str(row["region"] or ""),
"rank": int(row["rank"]),
"row_index": int(row["row_index"]) if "row_index" in row.keys() and row["row_index"] is not None else None,
"doc_id": int(row["doc_id"]) if "doc_id" in row.keys() and row["doc_id"] is not None else None,
"token_id": int(row["token_id"]) if "token_id" in row.keys() and row["token_id"] is not None else None,
"token": str(row["token"] or ""),
"source_score": float(row["source_score"]) if row["source_score"] is not None else None,
"direction_cosine": float(row["direction_cosine"]) if row["direction_cosine"] is not None else None,
"position": int(row["position"]) if row["position"] is not None else None,
"context_to_target": str(row["context_to_target"] or ""),
"context": str(row["context"] or ""),
"context_score_max_abs": float(row["context_score_max_abs"]) if row["context_score_max_abs"] is not None else None,
}
def _json_list(value: Any) -> list[str]:
try:
parsed = json.loads(value or "[]")
except json.JSONDecodeError:
return []
return [str(item) for item in parsed] if isinstance(parsed, list) else []
def _normalize_confidence(value: Any) -> str:
text = str(value or "unclear").strip().lower()
return text if text in {"high", "medium", "low", "unclear"} else "unclear"
def _chunks(items: list[int], size: int) -> Iterable[list[int]]:
for index in range(0, len(items), size):
yield items[index : index + size]
def _example_band_sort_key(region: str) -> tuple[int, str]:
name = str(region or "").lower().replace("-", "_")
compact = name.replace("_", "")
if name == "top_abs" or compact == "topabs":
return (0, name)
if "sample" in name:
return (1, name)
if "near_zero" in name or "nearzero" in compact:
return (2, name)
if "opposite" in name:
return (3, name)
return (4, name)