diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..e3547b7c994f25f7043895495003e1c136c5ffff
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,10 @@
+.git
+.venv
+.pytest_cache
+__pycache__
+*.pyc
+*.pyo
+*.pyd
+outputs
+tests
+cot_anc.egg-info
diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..81b48eb2a2dfadbc477417d63ed25755c0d0cbe2
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,13 @@
+MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+DEVICE_PREFERENCE=auto
+DTYPE_PREFERENCE=auto
+ATTN_IMPLEMENTATION=eager
+TRUST_REMOTE_CODE=true
+LOW_CPU_MEM_USAGE=true
+MAX_TRACE_TOKENS=256
+MAX_SENTENCES=16
+TAKE_LOG=true
+PRELOAD_MODEL=true
+REQUIRE_AUTH=true
+MAX_QUEUED_JOBS=8
+MAX_ACTIVE_JOBS_PER_USER=2
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9e2cf02e23713e8c4e41182df287739c27083ff0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,11 @@
+/.venv
+/.pytest_cache
+/__pycache__
+/data/app.db
+.DS_Store
+*.pyc
+*.pyo
+*.pyd
+cot_anc.egg-info/
+outputs/
+uv.lock
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..c4a0c8f9234795d6d4ad54151b1c46149caddb83
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,45 @@
+FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PYTHONDONTWRITEBYTECODE=1
+ENV PYTHONUNBUFFERED=1
+ENV UV_LINK_MODE=copy
+ENV HOME=/home/user
+ENV PATH=/home/user/.local/bin:$PATH
+ENV HF_HOME=/home/user/.cache/huggingface
+ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
+ENV API_HOST=0.0.0.0
+ENV API_PORT=7860
+ENV DEVICE_PREFERENCE=auto
+ENV DTYPE_PREFERENCE=auto
+ENV ATTN_IMPLEMENTATION=eager
+ENV LOW_CPU_MEM_USAGE=true
+ENV TRUST_REMOTE_CODE=true
+ENV PRELOAD_MODEL=true
+ENV REQUIRE_AUTH=true
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ git \
+ python3 \
+ python3-pip \
+ python3-venv \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN useradd -m -u 1000 user
+USER user
+WORKDIR $HOME/app
+
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+
+COPY --chown=user pyproject.toml uv.lock README.md .env.example ./
+COPY --chown=user app ./app
+COPY --chown=user notebooks ./notebooks
+COPY --chown=user tests ./tests
+
+RUN uv sync --frozen
+
+EXPOSE 7860
+
+CMD ["uv", "run", "python", "-m", "app.cli.run_api"]
diff --git a/README.md b/README.md
index a67b9f869e759e6a2d88942ddb32892e8a96e719..83ba1c28c3145af95059e8d725922b3cba4717d1 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,85 @@
---
-title: Cot Anc
-emoji: š
-colorFrom: indigo
-colorTo: green
+title: Thought Anchors
+emoji: š§
+colorFrom: blue
+colorTo: red
sdk: docker
+app_port: 7860
+hf_oauth: true
pinned: false
+models:
+ - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Thought Anchors
+
+Thought Anchors generates visible reasoning traces from open-weight models and
+computes sentence-to-sentence influence with gradient x attention attribution.
+
+Current product shape:
+
+- Hugging Face `Docker Space` first
+- Hugging Face OAuth sign-in
+- web UI + API
+- per-user ephemeral sessions
+- JSON / CSV export
+- adaptive CPU / MPS / CUDA loading
+
+## Quick Start
+
+Install deps:
+
+```bash
+uv sync --extra dev
+```
+
+Run API:
+
+```bash
+uv run python -m app.cli.run_api
+```
+
+Run CLI:
+
+```bash
+uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
+```
+
+Run tests:
+
+```bash
+uv run python -m pytest -q
+```
+
+## Main Endpoints
+
+- `GET /healthz`
+- `GET /api/me`
+- `POST /api/warmup`
+- `POST /api/analyze`
+- `GET /api/sessions`
+- `POST /api/sessions`
+- `GET /api/sessions/{id}`
+- `GET /api/sessions/{id}/result`
+- `GET /api/sessions/{id}/export.json`
+- `GET /api/sessions/{id}/export.csv`
+
+## Docs
+
+- [Hugging Face deployment](./docs/deploy-huggingface.md)
+- [Runtime and model support](./docs/runtime.md)
+- [API and product behavior](./docs/api.md)
+- [Notebook usage](./docs/notebook.md)
+
+## Notebook
+
+Colab / Kaggle smoke-test notebook:
+
+- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
+
+## Key Constraints
+
+- Attribution needs `attn_implementation="eager"`.
+- Model must expose supported decoder layers and attention modules.
+- Long traces stay capped because analysis uses full backward pass.
+- Space disk is ephemeral; export results you want to keep.
diff --git a/app/__init__.py b/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4814867a127241148d25839faacb615d6a89c1ba
--- /dev/null
+++ b/app/__init__.py
@@ -0,0 +1 @@
+"""Application package for chain-of-thought attribution analysis."""
diff --git a/app/analysis/__init__.py b/app/analysis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0277200e78b87d2e8d336dbb5fd5b171a3114bdb
--- /dev/null
+++ b/app/analysis/__init__.py
@@ -0,0 +1 @@
+"""Analysis utilities for attribution patching."""
diff --git a/app/analysis/hooks.py b/app/analysis/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0580c700238bf4a239c64fc6aeb0ed805f7187d
--- /dev/null
+++ b/app/analysis/hooks.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any
+
+import torch
+
+from app.core.model_support import get_decoder_layers
+
+_ATTENTION_STORE: dict[int, torch.Tensor] = {}
+
+
+def clear_stored_attentions() -> None:
+ _ATTENTION_STORE.clear()
+
+
+def get_stored_attentions() -> dict[int, torch.Tensor]:
+ return dict(_ATTENTION_STORE)
+
+
+def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
+ if isinstance(output, torch.Tensor):
+ return output if output.dim() == 4 else None
+
+ if isinstance(output, dict):
+ for value in output.values():
+ if isinstance(value, torch.Tensor) and value.dim() == 4:
+ return value
+
+ if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
+ for item in output:
+ if isinstance(item, torch.Tensor) and item.dim() == 4:
+ return item
+
+ return None
+
+
+def _get_attention_impl(model: Any) -> str | None:
+ config = getattr(model, "config", None)
+ if config is None:
+ return None
+ return getattr(config, "_attn_implementation", None) or getattr(
+ config,
+ "attn_implementation",
+ None,
+ )
+
+
+def make_attn_hook(layer_idx: int):
+ def hook(_module: Any, _inputs: Any, output: Any) -> None:
+ attn = _extract_attention_tensor(output)
+ if attn is None:
+ return
+ if attn.dim() != 4:
+ raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
+ attn.retain_grad()
+ _ATTENTION_STORE[layer_idx] = attn
+
+ return hook
+
+
+def register_hooks(model: Any) -> list[Any]:
+ clear_stored_attentions()
+ layers, _layer_path, attention_attr = get_decoder_layers(model)
+
+ handles: list[Any] = []
+ for layer_idx, layer in enumerate(layers):
+ self_attn = getattr(layer, attention_attr, None)
+ if self_attn is None:
+ raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
+ handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
+ return handles
+
+
+def remove_hooks(handles: list[Any]) -> None:
+ for handle in handles:
+ handle.remove()
+ clear_stored_attentions()
diff --git a/app/analysis/sentence_split.py b/app/analysis/sentence_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..c34f234eda78e1eb728f3d3a65c747224865023c
--- /dev/null
+++ b/app/analysis/sentence_split.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+import re
+from dataclasses import dataclass
+
+THINK_TAG_RE = re.compile(r"?think>", re.IGNORECASE)
+
+
+@dataclass(slots=True)
+class SentenceSpan:
+ text: str
+ start_char: int
+ end_char: int
+
+
+def normalize_trace_text(raw_trace_text: str) -> str:
+ return THINK_TAG_RE.sub("", raw_trace_text)
+
+
+def _non_whitespace_token_count(text: str) -> int:
+ return len(re.findall(r"\S+", text))
+
+
+def split_sentences(
+ text: str,
+ *,
+ min_token_like_units: int = 2,
+) -> list[SentenceSpan]:
+ if not text:
+ return []
+
+ raw_spans: list[tuple[int, int]] = []
+ start = 0
+ index = 0
+ text_length = len(text)
+
+ while index < text_length:
+ if text[index : index + 2] == "\n\n":
+ end = index + 2
+ raw_spans.append((start, end))
+ start = end
+ index = end
+ continue
+
+ if text[index] in ".!?":
+ end = index + 1
+ while end < text_length and text[end] in "\"')]}":
+ end += 1
+ while end < text_length and text[end].isspace() and text[end : end + 2] != "\n\n":
+ end += 1
+ raw_spans.append((start, end))
+ start = end
+ index = end
+ continue
+
+ index += 1
+
+ if start < text_length:
+ raw_spans.append((start, text_length))
+
+ merged: list[tuple[int, int]] = []
+ for span_start, span_end in raw_spans:
+ fragment = text[span_start:span_end]
+ if not fragment:
+ continue
+ if merged and _non_whitespace_token_count(fragment) < min_token_like_units:
+ previous_start, _ = merged[-1]
+ merged[-1] = (previous_start, span_end)
+ continue
+ merged.append((span_start, span_end))
+
+ if len(merged) > 1:
+ last_start, last_end = merged[-1]
+ if _non_whitespace_token_count(text[last_start:last_end]) < min_token_like_units:
+ prev_start, _ = merged[-2]
+ merged[-2] = (prev_start, last_end)
+ merged.pop()
+
+ return [
+ SentenceSpan(
+ text=text[span_start:span_end],
+ start_char=span_start,
+ end_char=span_end,
+ )
+ for span_start, span_end in merged
+ if text[span_start:span_end]
+ ]
diff --git a/app/analysis/summaries.py b/app/analysis/summaries.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af4693dea6e613917bd04df19902bf0d8815799
--- /dev/null
+++ b/app/analysis/summaries.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+import numpy as np
+
+from app.core.schemas import TopEdge
+
+
+def compute_outgoing_importance(matrix: np.ndarray) -> list[float]:
+ sentence_count = matrix.shape[0]
+ scores: list[float] = []
+ for source_idx in range(sentence_count):
+ column = matrix[source_idx + 1 :, source_idx]
+ scores.append(float(column.mean()) if column.size else 0.0)
+ return scores
+
+
+def compute_incoming_importance(matrix: np.ndarray) -> list[float]:
+ sentence_count = matrix.shape[0]
+ scores: list[float] = []
+ for target_idx in range(sentence_count):
+ row = matrix[target_idx, :target_idx]
+ scores.append(float(row.mean()) if row.size else 0.0)
+ return scores
+
+
+def compute_top_edges(matrix: np.ndarray, top_k: int = 10) -> list[TopEdge]:
+ sentence_count = matrix.shape[0]
+ candidates: list[TopEdge] = []
+ for target_idx in range(sentence_count):
+ for source_idx in range(target_idx):
+ candidates.append(
+ TopEdge(
+ source_sentence_idx=source_idx,
+ target_sentence_idx=target_idx,
+ score=float(matrix[target_idx, source_idx]),
+ )
+ )
+ candidates.sort(key=lambda edge: edge.score, reverse=True)
+ return candidates[:top_k]
diff --git a/app/analysis/suppression.py b/app/analysis/suppression.py
new file mode 100644
index 0000000000000000000000000000000000000000..43c60e801dcfd41a829d16f2c32e8b4bd33e2545
--- /dev/null
+++ b/app/analysis/suppression.py
@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+import time
+from dataclasses import asdict
+from dataclasses import dataclass
+from typing import Any
+
+import numpy as np
+import torch
+
+from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
+from app.core.model_support import describe_model_support
+from app.core.schemas import ModelCapability, RuntimeMetadata
+
+
+@dataclass(slots=True)
+class AttributionMatrixComputation:
+ matrix: np.ndarray
+ raw_matrix: np.ndarray
+ token_nll: np.ndarray
+ runtime_metadata: RuntimeMetadata
+
+
+def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
+ if logits.ndim != 3 or input_ids.ndim != 2:
+ raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
+ if logits.shape[0] != 1 or input_ids.shape[0] != 1:
+ raise ValueError("Only batch size 1 is supported for the prototype.")
+ if input_ids.shape[1] < 2:
+ raise ValueError("Need at least two tokens to compute next-token loss.")
+
+ shifted_logits = logits[:, :-1, :]
+ shifted_targets = input_ids[:, 1:]
+ log_probs = torch.log_softmax(shifted_logits, dim=-1)
+ gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
+ return -gathered[0]
+
+
+def _current_memory_mb(device: torch.device) -> float | None:
+ if device.type != "cuda":
+ return None
+ return float(torch.cuda.memory_allocated(device) / (1024 * 1024))
+
+
+def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
+ if not take_log:
+ return raw_matrix.copy()
+ presentation = np.zeros_like(raw_matrix)
+ positive = raw_matrix > 0
+ presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
+ return presentation
+
+
+def compute_attribution_matrix(
+ input_ids: torch.Tensor,
+ token_ranges: list[tuple[int, int]],
+ model: Any,
+ take_log: bool = True,
+ max_trace_tokens: int = 0,
+ max_sentences: int = 0,
+) -> AttributionMatrixComputation:
+ device = input_ids.device
+ handles = register_hooks(model)
+ model.zero_grad(set_to_none=True)
+ forward_start = time.perf_counter()
+ memory_before_mb = _current_memory_mb(device)
+
+ try:
+ with torch.enable_grad():
+ outputs = model(
+ input_ids=input_ids,
+ output_attentions=True,
+ return_dict=True,
+ )
+ forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0
+
+ logits = outputs.logits
+ token_nll = compute_self_token_nll(logits, input_ids)
+ loss = token_nll.sum()
+
+ backward_start = time.perf_counter()
+ loss.backward()
+ backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0
+
+ attentions = get_stored_attentions()
+ if not attentions:
+ raise RuntimeError("No attention tensors were captured. Check eager attention mode.")
+
+ matrix_start = time.perf_counter()
+ sentence_count = len(token_ranges)
+ raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)
+
+ ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
+ first_attention = ordered_layers[0]
+ num_layers = len(ordered_layers)
+ num_heads = int(first_attention.shape[1])
+
+ for source_idx, (source_start, source_end) in enumerate(token_ranges):
+ for target_idx, (target_start, target_end) in enumerate(token_ranges):
+ if target_idx <= source_idx:
+ continue
+
+ total = 0.0
+ for attention in ordered_layers:
+ grad = attention.grad
+ if grad is None:
+ raise RuntimeError("Attention gradient was not retained for one or more layers.")
+ total += -(
+ grad[0, :, target_start:target_end, source_start:source_end]
+ * attention[0, :, target_start:target_end, source_start:source_end]
+ ).sum().item()
+
+ denominator = max(1, target_end - target_start)
+ raw_matrix[target_idx, source_idx] = total / denominator
+
+ matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
+ total_analysis_ms = (
+ forward_pass_ms + backward_pass_ms + matrix_computation_ms
+ )
+ presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)
+
+ attention_impl = getattr(model.config, "_attn_implementation", "unknown")
+ capability = describe_model_support(model)
+ runtime_metadata = RuntimeMetadata(
+ forward_pass_ms=forward_pass_ms,
+ backward_pass_ms=backward_pass_ms,
+ matrix_computation_ms=matrix_computation_ms,
+ total_analysis_ms=total_analysis_ms,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ sequence_length_tokens=int(input_ids.shape[1]),
+ sentence_count=sentence_count,
+ device=str(device),
+ dtype=str(first_attention.dtype),
+ attention_impl=str(attention_impl),
+ max_trace_tokens=max_trace_tokens,
+ max_sentences=max_sentences,
+ capability=ModelCapability.model_validate(asdict(capability)),
+ )
+
+ memory_after_mb = _current_memory_mb(device)
+ if memory_before_mb is not None and memory_after_mb is not None:
+ runtime_metadata = runtime_metadata.model_copy(
+ update={
+ "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
+ }
+ )
+
+ return AttributionMatrixComputation(
+ matrix=presentation_matrix,
+ raw_matrix=raw_matrix,
+ token_nll=token_nll.detach().cpu().numpy(),
+ runtime_metadata=runtime_metadata,
+ )
+ finally:
+ for attention in get_stored_attentions().values():
+ attention.grad = None
+ remove_hooks(handles)
+ model.zero_grad(set_to_none=True)
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
diff --git a/app/analysis/token_boundaries.py b/app/analysis/token_boundaries.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfc31a75675744f4dfce0b1b311e335a0ecd5d50
--- /dev/null
+++ b/app/analysis/token_boundaries.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+
+from app.analysis.sentence_split import SentenceSpan
+
+
+@dataclass(slots=True)
+class TokenizedSentenceMapping:
+ input_ids: torch.Tensor
+ token_ranges: list[tuple[int, int]]
+ offsets: list[tuple[int, int]]
+ text: str
+
+
+def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str:
+ if max_tokens <= 0:
+ raise ValueError("max_tokens must be positive.")
+ encoded = tokenizer(
+ text,
+ add_special_tokens=False,
+ return_offsets_mapping=True,
+ )
+ offsets = encoded["offset_mapping"]
+ if len(offsets) <= max_tokens:
+ return text
+ end_char = offsets[max_tokens - 1][1]
+ return text[:end_char]
+
+
+def tokenize_with_sentence_ranges(
+ text: str,
+ sentence_spans: list[SentenceSpan],
+ tokenizer: Any,
+) -> TokenizedSentenceMapping:
+ encoded = tokenizer(
+ text,
+ add_special_tokens=False,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ )
+
+ input_ids = encoded["input_ids"]
+ raw_offsets = encoded["offset_mapping"][0].tolist()
+ offsets = [(int(start), int(end)) for start, end in raw_offsets]
+ token_ranges: list[tuple[int, int]] = []
+
+ for span in sentence_spans:
+ overlapping = [
+ token_index
+ for token_index, (token_start, token_end) in enumerate(offsets)
+ if token_end > span.start_char and token_start < span.end_char
+ ]
+ if not overlapping:
+ raise ValueError(
+ f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens."
+ )
+ token_ranges.append((overlapping[0], overlapping[-1] + 1))
+
+ if token_ranges:
+ if token_ranges[0][0] != 0 or token_ranges[-1][1] != len(offsets):
+ raise ValueError("Sentence token ranges do not cover the full analyzed sequence.")
+ for previous, current in zip(token_ranges, token_ranges[1:]):
+ if previous[1] != current[0]:
+ raise ValueError("Sentence token ranges are not contiguous.")
+
+ return TokenizedSentenceMapping(
+ input_ids=input_ids,
+ token_ranges=token_ranges,
+ offsets=offsets,
+ text=text,
+ )
diff --git a/app/analysis/validation.py b/app/analysis/validation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b47ac29965bb8f4a07e9c31d4bde42a70a10516b
--- /dev/null
+++ b/app/analysis/validation.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+import torch
+from scipy.stats import pearsonr, spearmanr
+
+from app.analysis.suppression import compute_self_token_nll
+from app.core.schemas import TopEdge, ValidationMetadata
+
+
+def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice:
+ start, end = token_range
+ return slice(max(0, start - 1), max(0, end - 1))
+
+
+def build_exact_suppression_mask(
+ *,
+ sequence_length: int,
+ source_range: tuple[int, int],
+ target_range: tuple[int, int],
+ device: torch.device,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ fill_value = torch.finfo(dtype).min
+ mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype)
+ future_positions = torch.triu(
+ torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool),
+ diagonal=1,
+ )
+ mask = mask.masked_fill(future_positions, fill_value)
+ source_start, source_end = source_range
+ target_start, target_end = target_range
+ mask[target_start:target_end, source_start:source_end] = fill_value
+ return mask.unsqueeze(0).unsqueeze(0)
+
+
+def compute_exact_edge_score(
+ *,
+ model: Any,
+ input_ids: torch.Tensor,
+ source_range: tuple[int, int],
+ target_range: tuple[int, int],
+ baseline_token_nll: np.ndarray,
+) -> float:
+ model_dtype = next(model.parameters()).dtype
+ attention_mask = build_exact_suppression_mask(
+ sequence_length=int(input_ids.shape[1]),
+ source_range=source_range,
+ target_range=target_range,
+ device=input_ids.device,
+ dtype=model_dtype,
+ )
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=False,
+ return_dict=True,
+ )
+ suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy()
+ nll_slice = _nll_slice_for_token_range(target_range)
+ return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum())
+
+
+def validate_top_edges(
+ *,
+ model: Any,
+ input_ids: torch.Tensor,
+ token_ranges: list[tuple[int, int]],
+ top_edges: list[TopEdge],
+ baseline_token_nll: np.ndarray,
+ top_k: int,
+) -> ValidationMetadata:
+ if top_k <= 0 or not top_edges:
+ return ValidationMetadata(enabled=False, top_k=0)
+
+ selected_edges = top_edges[:top_k]
+ exact_scores: list[float] = []
+ attributed_scores: list[float] = []
+ compared_edges: list[TopEdge] = []
+
+ try:
+ for edge in selected_edges:
+ exact_score = compute_exact_edge_score(
+ model=model,
+ input_ids=input_ids,
+ source_range=token_ranges[edge.source_sentence_idx],
+ target_range=token_ranges[edge.target_sentence_idx],
+ baseline_token_nll=baseline_token_nll,
+ )
+ exact_scores.append(exact_score)
+ attributed_scores.append(edge.score)
+ compared_edges.append(
+ TopEdge(
+ source_sentence_idx=edge.source_sentence_idx,
+ target_sentence_idx=edge.target_sentence_idx,
+ score=exact_score,
+ )
+ )
+ except Exception as exc: # pragma: no cover - environment/model dependent
+ return ValidationMetadata(
+ enabled=True,
+ top_k=top_k,
+ compared_edges=[],
+ notes=f"Exact suppression validation failed: {exc}",
+ )
+
+ pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
+ spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
+
+ return ValidationMetadata(
+ enabled=True,
+ top_k=top_k,
+ pearson=pearson,
+ spearman=spearman,
+ compared_edges=compared_edges,
+ notes="Exact suppression compares sentence-level NLL deltas for selected edges.",
+ )
diff --git a/app/api/__init__.py b/app/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1abc8addac74151fc1a80fea35ceb9836e9145ce
--- /dev/null
+++ b/app/api/__init__.py
@@ -0,0 +1 @@
+"""API package for GPU-hosted analysis service."""
diff --git a/app/api/auth.py b/app/api/auth.py
new file mode 100644
index 0000000000000000000000000000000000000000..717879f41d752e903626c810d17d334a311dacd6
--- /dev/null
+++ b/app/api/auth.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from fastapi import HTTPException, Request
+from huggingface_hub import parse_huggingface_oauth
+
+from app.core.config import Settings
+
+
+@dataclass(frozen=True, slots=True)
+class UserContext:
+ id: str
+ username: str
+ display_name: str | None
+ avatar_url: str | None
+ authenticated: bool
+
+
+def get_optional_user(request: Request) -> UserContext | None:
+ if "session" not in request.scope:
+ return None
+ try:
+ oauth_info = parse_huggingface_oauth(request)
+ except AssertionError:
+ return None
+ if oauth_info is None:
+ return None
+
+ user_info = oauth_info.user_info
+ username = user_info.preferred_username or user_info.sub or "hf-user"
+ display_name = user_info.name or username
+ return UserContext(
+ id=user_info.sub or username,
+ username=username,
+ display_name=display_name,
+ avatar_url=user_info.picture,
+ authenticated=True,
+ )
+
+
+def require_user(request: Request, settings: Settings) -> UserContext:
+ user = get_optional_user(request)
+ if user is not None:
+ return user
+ if settings.require_auth:
+ raise HTTPException(status_code=401, detail="Sign in with Hugging Face to use this service.")
+ return UserContext(
+ id="anonymous",
+ username="anonymous",
+ display_name="Anonymous",
+ avatar_url=None,
+ authenticated=False,
+ )
diff --git a/app/api/main.py b/app/api/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e94e4946f3b41d44976e5267dbfb4f8737bb3540
--- /dev/null
+++ b/app/api/main.py
@@ -0,0 +1,310 @@
+from __future__ import annotations
+
+import logging
+from contextlib import asynccontextmanager
+from io import StringIO
+from pathlib import Path
+from dataclasses import asdict
+import os
+
+import torch
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import FileResponse
+from fastapi.responses import StreamingResponse
+from fastapi.staticfiles import StaticFiles
+from huggingface_hub import attach_huggingface_oauth
+
+from app.api.auth import get_optional_user, require_user
+from app.core.config import get_settings
+from app.core.runtime import load_model_bundle
+from app.core.runtime_pipeline import compute_attribution_analysis
+from app.core.schemas import (
+ AnalysisRequest,
+ AnalysisResult,
+ CurrentUserResponse,
+ HealthResponse,
+ SessionCreateRequest,
+ SessionResponse,
+ SessionResultResponse,
+ WarmupResponse,
+)
+from app.services.sessions import SessionAccessError, SessionLimitError, SessionService
+from app.storage.repository import SessionRepository
+from app.workers.jobs import build_job_runner
+
+logger = logging.getLogger(__name__)
+FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
+
+
+@asynccontextmanager
+async def lifespan(_app: FastAPI):
+ settings = get_settings()
+ repository = SessionRepository(settings.database_path)
+ jobs = build_job_runner(settings.job_workers)
+ _app.state.repository = repository
+ _app.state.jobs = jobs
+ _app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs)
+ if settings.preload_model:
+ logger.info(
+ "Preloading model '%s' on device preference '%s'.",
+ settings.model_name,
+ settings.device_preference,
+ )
+ load_model_bundle(
+ settings.model_name,
+ device_preference=settings.device_preference,
+ dtype_preference=settings.dtype_preference,
+ attn_implementation=settings.attn_implementation,
+ trust_remote_code=settings.trust_remote_code,
+ low_cpu_mem_usage=settings.low_cpu_mem_usage,
+ )
+ yield
+ jobs.shutdown()
+
+
+app = FastAPI(
+ title="CoT Attribution Analysis API",
+ version="0.1.0",
+ lifespan=lifespan,
+)
+if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"):
+ attach_huggingface_oauth(app)
+app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
+
+
+def get_session_service() -> SessionService:
+ return app.state.session_service
+
+
+def _to_session_response(payload: dict) -> SessionResponse:
+ return SessionResponse(
+ id=payload["id"],
+ status=payload["status"],
+ question=payload["question"],
+ model_name=payload["model_name"],
+ error=payload.get("error"),
+ created_at=payload["created_at"],
+ updated_at=payload["updated_at"],
+ answer=payload.get("answer"),
+ raw_trace_text=payload.get("raw_trace_text"),
+ normalized_trace_text=payload.get("normalized_trace_text"),
+ sentences=payload.get("sentences"),
+ generation_metadata=payload.get("generation_metadata"),
+ )
+
+
+@app.get("/", include_in_schema=False)
+def index() -> FileResponse:
+ return FileResponse(FRONTEND_DIR / "index.html")
+
+
+@app.get("/healthz", response_model=HealthResponse)
+def healthz() -> HealthResponse:
+ settings = get_settings()
+ return HealthResponse(
+ status="ok",
+ model_name=settings.model_name,
+ device_preference=settings.device_preference,
+ dtype_preference=settings.dtype_preference,
+ preload_model=settings.preload_model,
+ cuda_available=torch.cuda.is_available(),
+ mps_available=torch.backends.mps.is_available(),
+ require_auth=settings.require_auth,
+ public_api_enabled=settings.public_api_enabled,
+ max_queued_jobs=settings.max_queued_jobs,
+ max_active_jobs_per_user=settings.max_active_jobs_per_user,
+ )
+
+
+@app.get("/api/me", response_model=CurrentUserResponse)
+def me(request: Request) -> CurrentUserResponse:
+ settings = get_settings()
+ user = get_optional_user(request)
+ return CurrentUserResponse(
+ authenticated=user is not None,
+ auth_required=settings.require_auth,
+ username=user.username if user else None,
+ full_name=user.display_name if user else None,
+ avatar_url=user.avatar_url if user else None,
+ )
+
+
+@app.post("/api/warmup", response_model=WarmupResponse)
+def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse:
+ settings = get_settings()
+ bundle = load_model_bundle(
+ model_name or settings.model_name,
+ device_preference=device_preference or settings.device_preference,
+ dtype_preference=settings.dtype_preference,
+ attn_implementation=settings.attn_implementation,
+ trust_remote_code=settings.trust_remote_code,
+ low_cpu_mem_usage=settings.low_cpu_mem_usage,
+ )
+ return WarmupResponse(
+ status="ready",
+ model_name=bundle.model_name,
+ device=str(bundle.device),
+ dtype=str(bundle.dtype),
+ capability=asdict(bundle.capability),
+ )
+
+
+@app.post("/api/analyze", response_model=AnalysisResult)
+def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult:
+ settings = get_settings()
+ require_user(http_request, settings)
+ try:
+ return compute_attribution_analysis(
+ question=request.question,
+ model_name=request.model_name,
+ take_log=request.take_log,
+ max_sentences=request.max_sentences,
+ max_trace_tokens=request.max_trace_tokens,
+ validate_top_k=request.validate_top_k,
+ max_new_tokens=request.max_new_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ device_preference=request.device_preference,
+ dtype_preference=request.dtype_preference,
+ attn_implementation=request.attn_implementation,
+ trust_remote_code=request.trust_remote_code,
+ low_cpu_mem_usage=request.low_cpu_mem_usage,
+ )
+ except Exception as exc: # pragma: no cover - runtime path
+ logger.exception("Analysis request failed")
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
+
+
+@app.get("/api/sessions", response_model=list[SessionResponse])
+def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ payloads = service.list_sessions(user.id, limit=limit)
+ return [_to_session_response(payload) for payload in payloads]
+
+
+@app.post("/api/sessions", response_model=SessionResponse)
+def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse:
+ settings = get_settings()
+ user = require_user(http_request, settings)
+ service = get_session_service()
+ try:
+ session = service.create_session(
+ AnalysisRequest(
+ question=request.question,
+ model_name=request.model_name,
+ take_log=request.take_log,
+ max_sentences=request.max_sentences,
+ max_trace_tokens=request.max_trace_tokens,
+ validate_top_k=request.validate_top_k,
+ max_new_tokens=request.max_new_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ device_preference=request.device_preference,
+ dtype_preference=request.dtype_preference,
+ attn_implementation=request.attn_implementation,
+ trust_remote_code=request.trust_remote_code,
+ low_cpu_mem_usage=request.low_cpu_mem_usage,
+ ),
+ owner_id=user.id,
+ owner_name=user.display_name,
+ )
+ except SessionLimitError as exc:
+ raise HTTPException(status_code=429, detail=str(exc)) from exc
+ payload = service.get_session_payload(session.id, owner_id=user.id)
+ return _to_session_response(payload)
+
+
+@app.get("/api/sessions/{session_id}", response_model=SessionResponse)
+def get_session(session_id: str, request: Request) -> SessionResponse:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ try:
+ payload = service.get_session_payload(session_id, owner_id=user.id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="Session not found") from exc
+ except SessionAccessError as exc:
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
+ return _to_session_response(payload)
+
+
+@app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse)
+def analyze_session(session_id: str, request: Request) -> SessionResponse:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ try:
+ session = service.start_analysis(session_id, owner_id=user.id)
+ payload = service.get_session_payload(session.id, owner_id=user.id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="Session not found") from exc
+ except SessionAccessError as exc:
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
+ return _to_session_response(payload)
+
+
+@app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse)
+def get_session_result(session_id: str, request: Request) -> SessionResultResponse:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ try:
+ payload = service.get_session_payload(session_id, owner_id=user.id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="Session not found") from exc
+ except SessionAccessError as exc:
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
+
+ session_response = _to_session_response(payload)
+ analysis_payload = payload.get("analysis")
+ return SessionResultResponse(
+ session=session_response,
+ analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None,
+ )
+
+
+@app.get("/api/sessions/{session_id}/export.json")
+def export_session_json(session_id: str, request: Request) -> StreamingResponse:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ try:
+ payload = service.get_session_payload(session_id, owner_id=user.id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="Session not found") from exc
+ except SessionAccessError as exc:
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
+ result = SessionResultResponse(
+ session=_to_session_response(payload),
+ analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None,
+ )
+ return StreamingResponse(
+ iter([result.model_dump_json(indent=2)]),
+ media_type="application/json",
+ headers={"content-disposition": f'attachment; filename="{session_id}.json"'},
+ )
+
+
+@app.get("/api/sessions/{session_id}/export.csv")
+def export_session_csv(session_id: str, request: Request) -> StreamingResponse:
+ settings = get_settings()
+ user = require_user(request, settings)
+ service = get_session_service()
+ try:
+ result = service.get_analysis_result(session_id, owner_id=user.id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="Analysis result not found") from exc
+ except SessionAccessError as exc:
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
+
+ buffer = StringIO()
+ buffer.write("source_sentence_idx,target_sentence_idx,score\n")
+ for edge in result.top_edges:
+ buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n")
+ return StreamingResponse(
+ iter([buffer.getvalue()]),
+ media_type="text/csv",
+ headers={"content-disposition": f'attachment; filename="{session_id}.csv"'},
+ )
diff --git a/app/cli/__init__.py b/app/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..67651742c4d69c795deddd211606da6f12a020a3
--- /dev/null
+++ b/app/cli/__init__.py
@@ -0,0 +1 @@
+"""CLI entrypoints."""
diff --git a/app/cli/run_api.py b/app/cli/run_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ae1d032c56ec9a51ebe767070e647d23a11293
--- /dev/null
+++ b/app/cli/run_api.py
@@ -0,0 +1,19 @@
+from __future__ import annotations
+
+import uvicorn
+
+from app.core.config import get_settings
+
+
+def main() -> None:
+ settings = get_settings()
+ uvicorn.run(
+ "app.api.main:app",
+ host=settings.api_host,
+ port=settings.api_port,
+ reload=False,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/app/cli/run_prototype.py b/app/cli/run_prototype.py
new file mode 100644
index 0000000000000000000000000000000000000000..1536f50ca0f8086ebd11f748e261196b26bd1933
--- /dev/null
+++ b/app/cli/run_prototype.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import json
+from datetime import UTC, datetime
+from pathlib import Path
+
+import typer
+
+from app.core.runtime_pipeline import compute_attribution_analysis
+
+cli = typer.Typer(add_completion=False)
+
+
+def _write_heatmap(path: Path, matrix: list[list[float]]) -> None:
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError as exc: # pragma: no cover - optional dependency
+ raise RuntimeError(
+ "Heatmap output requires the optional viz dependency: pip install .[viz]"
+ ) from exc
+
+ figure, axis = plt.subplots(figsize=(8, 6))
+ image = axis.imshow(matrix, aspect="auto", cmap="viridis")
+ axis.set_xlabel("Source sentence")
+ axis.set_ylabel("Target sentence")
+ axis.set_title("Sentence influence matrix")
+ figure.colorbar(image, ax=axis)
+ figure.tight_layout()
+ figure.savefig(path)
+ plt.close(figure)
+
+
+@cli.command()
+def main(
+ question: str,
+ model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
+ output_json: Path | None = None,
+ output_heatmap: Path | None = None,
+ max_new_tokens: int = 512,
+ max_trace_tokens: int = 1024,
+ max_sentences: int = 40,
+ take_log: bool = True,
+ validate_top_k: int = 3,
+ temperature: float = 0.6,
+ top_p: float = 0.95,
+ device_preference: str = "auto",
+ dtype_preference: str = "auto",
+ attn_implementation: str = "eager",
+ trust_remote_code: bool = True,
+ low_cpu_mem_usage: bool = True,
+) -> None:
+ result = compute_attribution_analysis(
+ question=question,
+ model_name=model_name,
+ take_log=take_log,
+ max_sentences=max_sentences,
+ max_trace_tokens=max_trace_tokens,
+ validate_top_k=validate_top_k,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ device_preference=device_preference,
+ dtype_preference=dtype_preference,
+ attn_implementation=attn_implementation,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ output_dir = Path("outputs")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
+
+ json_path = output_json or output_dir / f"analysis_{timestamp}.json"
+ json_path.write_text(json.dumps(result.model_dump(), indent=2), encoding="utf-8")
+
+ if output_heatmap is not None:
+ _write_heatmap(output_heatmap, result.suppression_matrix)
+
+ typer.echo(f"Wrote analysis JSON to {json_path}")
+ typer.echo(f"Sentences: {len(result.sentences)}")
+ typer.echo(f"Top edge count: {len(result.top_edges)}")
+ if result.validation_metadata and result.validation_metadata.enabled:
+ typer.echo(
+ "Validation:"
+ f" pearson={result.validation_metadata.pearson}"
+ f" spearman={result.validation_metadata.spearman}"
+ )
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/app/core/__init__.py b/app/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb889f7aec0bdce7f283cdd62173ac724201e81
--- /dev/null
+++ b/app/core/__init__.py
@@ -0,0 +1 @@
+"""Core runtime, config, and schema definitions."""
diff --git a/app/core/config.py b/app/core/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a95c2e534fa165a8c4f0d91d134756c52165f201
--- /dev/null
+++ b/app/core/config.py
@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+import os
+from functools import lru_cache
+from dataclasses import dataclass
+from typing import Literal
+
+
+@dataclass(frozen=True, slots=True)
+class Settings:
+ model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
+ max_trace_tokens: int = 1024
+ max_sentences: int = 40
+ take_log: bool = True
+ device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto"
+ dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto"
+ attn_implementation: str = "eager"
+ trust_remote_code: bool = True
+ low_cpu_mem_usage: bool = True
+ preload_model: bool = False
+ api_host: str = "0.0.0.0"
+ api_port: int = 7860
+ database_path: str = "data/app.db"
+ job_workers: int = 1
+ max_queued_jobs: int = 8
+ max_active_jobs_per_user: int = 2
+ require_auth: bool = True
+ public_api_enabled: bool = True
+
+
+DEFAULT_SETTINGS = Settings()
+
+
+@lru_cache(maxsize=1)
+def get_settings() -> Settings:
+ take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"}
+ trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"}
+ low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"}
+ require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"}
+ public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"}
+ return Settings(
+ model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name),
+ max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)),
+ max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)),
+ take_log=take_log,
+ device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type]
+ dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type]
+ attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation),
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"},
+ api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host),
+ api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)),
+ database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path),
+ job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)),
+ max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)),
+ max_active_jobs_per_user=int(
+ os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user)
+ ),
+ require_auth=require_auth,
+ public_api_enabled=public_api_enabled,
+ )
diff --git a/app/core/model_support.py b/app/core/model_support.py
new file mode 100644
index 0000000000000000000000000000000000000000..df0769cb13f23a1af10c4891c224b3f2a0e6c268
--- /dev/null
+++ b/app/core/model_support.py
@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+
+@dataclass(frozen=True, slots=True)
+class ModelSupport:
+ supports_attribution: bool
+ reason: str | None
+ layer_path: str | None
+ attention_attr: str | None
+ layer_count: int
+ attention_impl: str | None
+
+
+_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
+ ("model", "layers"),
+ ("model", "model", "layers"),
+ ("transformer", "h"),
+ ("gpt_neox", "layers"),
+)
+_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")
+
+
+def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
+ current = obj
+ for segment in path:
+ current = getattr(current, segment, None)
+ if current is None:
+ return None
+ return current
+
+
+def _get_attention_impl(model: Any) -> str | None:
+ config = getattr(model, "config", None)
+ if config is None:
+ return None
+ return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)
+
+
+def describe_model_support(model: Any) -> ModelSupport:
+ attn_impl = _get_attention_impl(model)
+ layers = None
+ layer_path = None
+ for candidate in _LAYER_PATH_CANDIDATES:
+ maybe_layers = _resolve_attr_chain(model, candidate)
+ if maybe_layers is not None:
+ layers = list(maybe_layers)
+ layer_path = ".".join(candidate)
+ break
+
+ if not layers:
+ return ModelSupport(
+ supports_attribution=False,
+ reason="Unsupported model structure: unable to locate decoder layers.",
+ layer_path=layer_path,
+ attention_attr=None,
+ layer_count=0,
+ attention_impl=attn_impl,
+ )
+
+ for attention_attr in _ATTENTION_ATTR_CANDIDATES:
+ if all(getattr(layer, attention_attr, None) is not None for layer in layers):
+ if attn_impl != "eager":
+ return ModelSupport(
+ supports_attribution=False,
+ reason="Attention gradients require attn_implementation='eager'.",
+ layer_path=layer_path,
+ attention_attr=attention_attr,
+ layer_count=len(layers),
+ attention_impl=attn_impl,
+ )
+ return ModelSupport(
+ supports_attribution=True,
+ reason=None,
+ layer_path=layer_path,
+ attention_attr=attention_attr,
+ layer_count=len(layers),
+ attention_impl=attn_impl,
+ )
+
+ return ModelSupport(
+ supports_attribution=False,
+ reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
+ layer_path=layer_path,
+ attention_attr=None,
+ layer_count=len(layers),
+ attention_impl=attn_impl,
+ )
+
+
+def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
+ support = describe_model_support(model)
+ if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
+ reason = support.reason or "Model does not support attribution analysis."
+ raise RuntimeError(reason)
+
+ layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
+ if layers is None:
+ raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
+ return list(layers), support.layer_path, support.attention_attr
diff --git a/app/core/runtime.py b/app/core/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b511211f0204881f17cd13558a5305497a0089d
--- /dev/null
+++ b/app/core/runtime.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from functools import lru_cache
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
+
+from app.core.model_support import ModelSupport, describe_model_support
+
+
+@dataclass(slots=True)
+class ModelBundle:
+ model_name: str
+ model: PreTrainedModel
+ tokenizer: PreTrainedTokenizerBase
+ device: torch.device
+ dtype: torch.dtype
+ capability: ModelSupport
+
+
+def resolve_dtype(preference: str, device: torch.device) -> torch.dtype:
+ if preference == "float32":
+ return torch.float32
+ if preference == "float16":
+ return torch.float16
+ if preference == "bfloat16":
+ return torch.bfloat16
+ if device.type == "cuda":
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ if device.type == "mps":
+ return torch.float16
+ return torch.float32
+
+
+def resolve_device(preference: str = "auto") -> torch.device:
+ if preference == "cuda":
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA requested but not available.")
+ return torch.device("cuda")
+ if preference == "mps":
+ if not torch.backends.mps.is_available():
+ raise RuntimeError("MPS requested but not available.")
+ return torch.device("mps")
+ if preference == "cpu":
+ return torch.device("cpu")
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ if torch.backends.mps.is_available():
+ return torch.device("mps")
+ return torch.device("cpu")
+
+
+@lru_cache(maxsize=2)
+def load_model_bundle(
+ model_name: str,
+ device_preference: str = "auto",
+ dtype_preference: str = "auto",
+ attn_implementation: str = "eager",
+ trust_remote_code: bool = True,
+ low_cpu_mem_usage: bool = True,
+) -> ModelBundle:
+ device = resolve_device(device_preference)
+ dtype = resolve_dtype(dtype_preference, device)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ trust_remote_code=trust_remote_code,
+ attn_implementation=attn_implementation,
+ torch_dtype=dtype,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+ model.to(device)
+ model.eval()
+ capability = describe_model_support(model)
+
+ return ModelBundle(
+ model_name=model_name,
+ model=model,
+ tokenizer=tokenizer,
+ device=device,
+ dtype=dtype,
+ capability=capability,
+ )
+
+
+def compute_attribution_analysis(**kwargs):
+ from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis
+
+ return _compute_attribution_analysis(**kwargs)
diff --git a/app/core/runtime_pipeline.py b/app/core/runtime_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..700e3c20059144b292668bc4de6229c652f28f7b
--- /dev/null
+++ b/app/core/runtime_pipeline.py
@@ -0,0 +1,175 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+
+from app.analysis.sentence_split import split_sentences
+from app.analysis.summaries import (
+ compute_incoming_importance,
+ compute_outgoing_importance,
+ compute_top_edges,
+)
+from app.analysis.suppression import compute_attribution_matrix
+from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit
+from app.analysis.validation import validate_top_edges
+from app.core.config import get_settings
+from app.core.runtime import load_model_bundle
+from app.core.schemas import AnalysisResult, GenerationResult
+from app.generation.service import generate_answer_and_trace
+
+
+def compute_attribution_analysis(
+ *,
+ question: str,
+ model_name: str | None = None,
+ take_log: bool | None = None,
+ max_sentences: int | None = None,
+ max_trace_tokens: int | None = None,
+ validate_top_k: int = 0,
+ max_new_tokens: int = 512,
+ temperature: float = 0.6,
+ top_p: float = 0.95,
+ device_preference: str | None = None,
+ dtype_preference: str | None = None,
+ attn_implementation: str | None = None,
+ trust_remote_code: bool | None = None,
+ low_cpu_mem_usage: bool | None = None,
+) -> AnalysisResult:
+ generation = None
+ return analyze_generation_result(
+ question=question,
+ generation=generation,
+ model_name=model_name,
+ take_log=take_log,
+ max_sentences=max_sentences,
+ max_trace_tokens=max_trace_tokens,
+ validate_top_k=validate_top_k,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ device_preference=device_preference,
+ dtype_preference=dtype_preference,
+ attn_implementation=attn_implementation,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+
+def analyze_generation_result(
+ *,
+ question: str,
+ generation: GenerationResult | None = None,
+ model_name: str | None = None,
+ take_log: bool | None = None,
+ max_sentences: int | None = None,
+ max_trace_tokens: int | None = None,
+ validate_top_k: int = 0,
+ max_new_tokens: int = 512,
+ temperature: float = 0.6,
+ top_p: float = 0.95,
+ device_preference: str | None = None,
+ dtype_preference: str | None = None,
+ attn_implementation: str | None = None,
+ trust_remote_code: bool | None = None,
+ low_cpu_mem_usage: bool | None = None,
+) -> AnalysisResult:
+ settings = get_settings()
+ resolved_model_name = model_name or settings.model_name
+ resolved_take_log = settings.take_log if take_log is None else take_log
+ resolved_max_sentences = max_sentences or settings.max_sentences
+ resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens
+ resolved_device = device_preference or settings.device_preference
+ resolved_dtype = dtype_preference or settings.dtype_preference
+ resolved_attn_implementation = attn_implementation or settings.attn_implementation
+ resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code
+ resolved_low_cpu_mem_usage = (
+ settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage
+ )
+
+ bundle = load_model_bundle(
+ resolved_model_name,
+ device_preference=resolved_device,
+ dtype_preference=resolved_dtype,
+ attn_implementation=resolved_attn_implementation,
+ trust_remote_code=resolved_trust_remote_code,
+ low_cpu_mem_usage=resolved_low_cpu_mem_usage,
+ )
+ if not bundle.capability.supports_attribution:
+ reason = bundle.capability.reason or "Model does not support attribution analysis."
+ raise RuntimeError(reason)
+ if generation is None:
+ generation = generate_answer_and_trace(
+ question=question,
+ model_name=resolved_model_name,
+ model=bundle.model,
+ tokenizer=bundle.tokenizer,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ )
+
+ truncated_text = truncate_text_to_token_limit(
+ generation.normalized_trace_text,
+ bundle.tokenizer,
+ resolved_max_trace_tokens,
+ )
+ sentence_spans = split_sentences(truncated_text)
+ if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences:
+ sentence_spans = sentence_spans[:resolved_max_sentences]
+ truncated_text = truncated_text[: sentence_spans[-1].end_char]
+ sentence_spans = split_sentences(truncated_text)
+
+ if not sentence_spans:
+ raise RuntimeError("Trace normalization produced no analyzable sentences.")
+
+ mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer)
+ input_ids = mapping.input_ids.to(bundle.device)
+ computation = compute_attribution_matrix(
+ input_ids=input_ids,
+ token_ranges=mapping.token_ranges,
+ model=bundle.model,
+ take_log=resolved_take_log,
+ max_trace_tokens=resolved_max_trace_tokens,
+ max_sentences=resolved_max_sentences,
+ )
+
+ outgoing = compute_outgoing_importance(computation.raw_matrix)
+ incoming = compute_incoming_importance(computation.raw_matrix)
+ top_edges = compute_top_edges(computation.raw_matrix, top_k=10)
+ validation = validate_top_edges(
+ model=bundle.model,
+ input_ids=input_ids,
+ token_ranges=mapping.token_ranges,
+ top_edges=top_edges,
+ baseline_token_nll=computation.token_nll,
+ top_k=validate_top_k,
+ )
+
+ return AnalysisResult(
+ question=question,
+ model_name=resolved_model_name,
+ answer=generation.answer,
+ raw_trace_text=generation.raw_trace_text,
+ normalized_trace_text=truncated_text,
+ sentences=[span.text for span in sentence_spans],
+ sentence_token_ranges=mapping.token_ranges,
+ suppression_matrix=computation.matrix.tolist(),
+ raw_suppression_matrix=computation.raw_matrix.tolist(),
+ outgoing_importance=outgoing,
+ incoming_importance=incoming,
+ top_edges=top_edges,
+ runtime_metadata=computation.runtime_metadata,
+ validation_metadata=validation,
+ extra_metadata={
+ "raw_generation_text": generation.raw_generation_text,
+ "generation_metadata": generation.generation_metadata.model_dump(),
+ "effective_runtime": {
+ "device_preference": resolved_device,
+ "dtype_preference": resolved_dtype,
+ "attn_implementation": resolved_attn_implementation,
+ "trust_remote_code": resolved_trust_remote_code,
+ "low_cpu_mem_usage": resolved_low_cpu_mem_usage,
+ },
+ },
+ )
diff --git a/app/core/schemas.py b/app/core/schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..a900d77eb63eeaba691c6583cdb6cea47025e1ba
--- /dev/null
+++ b/app/core/schemas.py
@@ -0,0 +1,167 @@
+from __future__ import annotations
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+
+class TopEdge(BaseModel):
+ source_sentence_idx: int
+ target_sentence_idx: int
+ score: float
+
+
+class ModelCapability(BaseModel):
+ supports_attribution: bool
+ reason: str | None = None
+ layer_path: str | None = None
+ attention_attr: str | None = None
+ layer_count: int = 0
+ attention_impl: str | None = None
+
+
+class RuntimeMetadata(BaseModel):
+ forward_pass_ms: float = 0.0
+ backward_pass_ms: float = 0.0
+ matrix_computation_ms: float = 0.0
+ total_analysis_ms: float = 0.0
+ num_layers: int = 0
+ num_heads: int = 0
+ sequence_length_tokens: int = 0
+ sentence_count: int = 0
+ device: str = "unknown"
+ dtype: str = "unknown"
+ attention_impl: str = "unknown"
+ max_trace_tokens: int = 0
+ max_sentences: int = 0
+ capability: ModelCapability = Field(default_factory=lambda: ModelCapability(supports_attribution=False))
+
+
+class ValidationMetadata(BaseModel):
+ enabled: bool = False
+ top_k: int = 0
+ pearson: float | None = None
+ spearman: float | None = None
+ compared_edges: list[TopEdge] = Field(default_factory=list)
+ notes: str | None = None
+
+
+class GenerationMetadata(BaseModel):
+ max_new_tokens: int
+ temperature: float
+ top_p: float
+ do_sample: bool
+
+
+class GenerationResult(BaseModel):
+ question: str
+ model_name: str
+ answer: str
+ raw_generation_text: str
+ raw_trace_text: str
+ normalized_trace_text: str
+ generation_metadata: GenerationMetadata
+
+
+class AnalysisResult(BaseModel):
+ question: str
+ model_name: str
+ answer: str
+ raw_trace_text: str
+ normalized_trace_text: str
+ sentences: list[str]
+ sentence_token_ranges: list[tuple[int, int]]
+ suppression_matrix: list[list[float]]
+ raw_suppression_matrix: list[list[float]] | None = None
+ outgoing_importance: list[float]
+ incoming_importance: list[float]
+ top_edges: list[TopEdge]
+ runtime_metadata: RuntimeMetadata
+ validation_metadata: ValidationMetadata | None = None
+ extra_metadata: dict[str, Any] = Field(default_factory=dict)
+
+
+class AnalysisRequest(BaseModel):
+ question: str
+ model_name: str | None = None
+ take_log: bool | None = None
+ max_sentences: int | None = None
+ max_trace_tokens: int | None = None
+ validate_top_k: int = 0
+ max_new_tokens: int = 256
+ temperature: float = 0.6
+ top_p: float = 0.95
+ device_preference: str | None = None
+ dtype_preference: str | None = None
+ attn_implementation: str | None = None
+ trust_remote_code: bool | None = None
+ low_cpu_mem_usage: bool | None = None
+
+
+class HealthResponse(BaseModel):
+ status: str
+ model_name: str
+ device_preference: str
+ dtype_preference: str
+ preload_model: bool
+ cuda_available: bool
+ mps_available: bool
+ require_auth: bool
+ public_api_enabled: bool
+ max_queued_jobs: int
+ max_active_jobs_per_user: int
+
+
+class WarmupResponse(BaseModel):
+ status: str
+ model_name: str
+ device: str
+ dtype: str
+ capability: ModelCapability
+
+
+class CurrentUserResponse(BaseModel):
+ authenticated: bool
+ auth_required: bool
+ username: str | None = None
+ full_name: str | None = None
+ avatar_url: str | None = None
+ login_url: str = "/oauth/huggingface/login"
+ logout_url: str = "/oauth/huggingface/logout"
+
+
+class SessionCreateRequest(BaseModel):
+ question: str
+ model_name: str | None = None
+ take_log: bool | None = None
+ max_sentences: int | None = None
+ max_trace_tokens: int | None = None
+ validate_top_k: int = 0
+ max_new_tokens: int = 256
+ temperature: float = 0.6
+ top_p: float = 0.95
+ device_preference: str | None = None
+ dtype_preference: str | None = None
+ attn_implementation: str | None = None
+ trust_remote_code: bool | None = None
+ low_cpu_mem_usage: bool | None = None
+
+
+class SessionResponse(BaseModel):
+ id: str
+ status: str
+ question: str
+ model_name: str
+ error: str | None = None
+ created_at: str
+ updated_at: str
+ answer: str | None = None
+ raw_trace_text: str | None = None
+ normalized_trace_text: str | None = None
+ sentences: list[str] | None = None
+ generation_metadata: dict[str, Any] | None = None
+
+
+class SessionResultResponse(BaseModel):
+ session: SessionResponse
+ analysis: AnalysisResult | None = None
diff --git a/app/frontend/app.js b/app/frontend/app.js
new file mode 100644
index 0000000000000000000000000000000000000000..473ef1254a1d56a84a9caeb970281c9b4206cf06
--- /dev/null
+++ b/app/frontend/app.js
@@ -0,0 +1,340 @@
+const form = document.querySelector("#question-form");
+const submitButton = document.querySelector("#submit-button");
+const statusChip = document.querySelector("#status-chip");
+const statusDetail = document.querySelector("#status-detail");
+const authSummary = document.querySelector("#auth-summary");
+const loginLink = document.querySelector("#login-link");
+const logoutLink = document.querySelector("#logout-link");
+const recentSessions = document.querySelector("#recent-sessions");
+const answerOutput = document.querySelector("#answer-output");
+const sessionIdOutput = document.querySelector("#session-id");
+const sentenceList = document.querySelector("#sentence-list");
+const traceCount = document.querySelector("#trace-count");
+const heatmap = document.querySelector("#heatmap");
+const selectionDetails = document.querySelector("#selection-details");
+const metrics = document.querySelector("#metrics");
+const topEdges = document.querySelector("#top-edges");
+const exportJson = document.querySelector("#export-json");
+const exportCsv = document.querySelector("#export-csv");
+
+let activeSessionId = null;
+let pollHandle = null;
+let activePayload = null;
+let selectedSentenceIdx = null;
+let selectedCell = null;
+let currentUser = null;
+
+function currentSentences() {
+ return activePayload?.analysis?.sentences || activePayload?.session?.sentences || [];
+}
+
+function setStatus(label, detail) {
+ statusChip.textContent = label;
+ statusDetail.textContent = detail;
+}
+
+function clearSelection() {
+ selectedSentenceIdx = null;
+ selectedCell = null;
+ selectionDetails.textContent = "Choose a sentence or heatmap cell.";
+ metrics.innerHTML = "";
+}
+
+function setExportLinks(sessionId, enabled) {
+ exportJson.href = enabled ? `/api/sessions/${sessionId}/export.json` : "#";
+ exportCsv.href = enabled ? `/api/sessions/${sessionId}/export.csv` : "#";
+ exportJson.classList.toggle("is-disabled", !enabled);
+ exportCsv.classList.toggle("is-disabled", !enabled);
+ exportJson.setAttribute("aria-disabled", String(!enabled));
+ exportCsv.setAttribute("aria-disabled", String(!enabled));
+}
+
+function renderRecentSessions(sessions) {
+ recentSessions.innerHTML = "";
+ if (!currentUser?.authenticated) {
+ recentSessions.textContent = currentUser?.auth_required
+ ? "Sign in to view your jobs."
+ : "Anonymous mode is enabled.";
+ return;
+ }
+ if (!sessions.length) {
+ recentSessions.textContent = "No jobs yet on this running Space instance.";
+ return;
+ }
+ sessions.forEach((session) => {
+ const item = document.createElement("button");
+ item.type = "button";
+ item.className = "session-card";
+ if (session.id === activeSessionId) {
+ item.classList.add("is-active");
+ }
+ item.innerHTML = `${session.status} ${session.question}`;
+ item.addEventListener("click", () => {
+ startPolling(session.id);
+ });
+ recentSessions.appendChild(item);
+ });
+}
+
+async function loadRecentSessions() {
+ if (!currentUser?.authenticated && currentUser?.auth_required) {
+ renderRecentSessions([]);
+ return;
+ }
+ const response = await fetch("/api/sessions?limit=8");
+ if (!response.ok) {
+ renderRecentSessions([]);
+ return;
+ }
+ renderRecentSessions(await response.json());
+}
+
+function renderAuth() {
+ if (!currentUser) {
+ authSummary.textContent = "Checking sign-in status.";
+ return;
+ }
+ loginLink.hidden = currentUser.authenticated;
+ logoutLink.hidden = !currentUser.authenticated;
+ if (currentUser.authenticated) {
+ authSummary.textContent = `Signed in as ${currentUser.full_name || currentUser.username}.`;
+ return;
+ }
+ authSummary.textContent = currentUser.auth_required
+ ? "Sign in with Hugging Face to run analysis jobs."
+ : "Anonymous access is enabled for this Space.";
+}
+
+function renderSession(session) {
+ sessionIdOutput.textContent = session.id ? `Session ${session.id.slice(0, 8)}` : "";
+ answerOutput.textContent = session.answer || "Waiting for generated answer.";
+ const sentences = currentSentences();
+ traceCount.textContent = sentences.length ? `${sentences.length} sentences` : "";
+
+ sentenceList.innerHTML = "";
+ sentences.forEach((sentence, index) => {
+ const item = document.createElement("li");
+ item.className = "sentence-item";
+ item.innerHTML = `${index} ${sentence}`;
+ item.addEventListener("click", () => {
+ selectedSentenceIdx = index;
+ selectedCell = null;
+ renderSelection();
+ });
+ sentenceList.appendChild(item);
+ });
+}
+
+function colorForValue(value, maxAbs) {
+ if (!maxAbs) {
+ return "rgba(224, 223, 218, 0.8)";
+ }
+ const normalized = Math.max(-1, Math.min(1, value / maxAbs));
+ if (normalized >= 0) {
+ const alpha = 0.12 + normalized * 0.88;
+ return `rgba(14, 90, 138, ${alpha.toFixed(3)})`;
+ }
+ const alpha = 0.12 + Math.abs(normalized) * 0.88;
+ return `rgba(215, 106, 52, ${alpha.toFixed(3)})`;
+}
+
+function renderHeatmap(result) {
+ const matrix = result?.suppression_matrix;
+ if (!matrix || !matrix.length) {
+ heatmap.className = "heatmap placeholder-box";
+ heatmap.textContent = "Analysis pending.";
+ return;
+ }
+
+ const flatValues = matrix.flat();
+ const maxAbs = Math.max(...flatValues.map((value) => Math.abs(value)));
+ heatmap.className = "heatmap";
+ heatmap.innerHTML = "";
+ const grid = document.createElement("div");
+ grid.className = "heatmap-grid";
+ grid.style.gridTemplateColumns = `repeat(${matrix.length}, 32px)`;
+
+ matrix.forEach((row, rowIndex) => {
+ row.forEach((value, colIndex) => {
+ const cell = document.createElement("button");
+ cell.type = "button";
+ cell.className = "heatmap-cell";
+ cell.style.background = colorForValue(value, maxAbs);
+ cell.title = `target ${rowIndex} ā source ${colIndex}: ${value.toFixed(4)}`;
+ if (selectedCell && selectedCell.row === rowIndex && selectedCell.col === colIndex) {
+ cell.classList.add("is-selected");
+ }
+ cell.addEventListener("click", () => {
+ selectedSentenceIdx = null;
+ selectedCell = { row: rowIndex, col: colIndex, value };
+ renderSelection();
+ });
+ grid.appendChild(cell);
+ });
+ });
+ heatmap.appendChild(grid);
+}
+
+function renderTopEdges(result) {
+ topEdges.innerHTML = "";
+ const edges = result?.top_edges || [];
+ if (!edges.length) {
+ return;
+ }
+ edges.slice(0, 5).forEach((edge) => {
+ const item = document.createElement("div");
+ item.className = "edge-card";
+ item.innerHTML = `${edge.source_sentence_idx} ā ${edge.target_sentence_idx} ${edge.score.toFixed(4)}`;
+ topEdges.appendChild(item);
+ });
+}
+
+function renderSelection() {
+ Array.from(sentenceList.children).forEach((item, index) => {
+ item.classList.toggle("is-active", selectedSentenceIdx === index);
+ });
+
+ const result = activePayload?.analysis;
+ const session = activePayload?.session;
+ if (!session) {
+ clearSelection();
+ return;
+ }
+
+ metrics.innerHTML = "";
+ if (selectedSentenceIdx != null && result) {
+ const outgoing = result.outgoing_importance[selectedSentenceIdx] ?? 0;
+ const incoming = result.incoming_importance[selectedSentenceIdx] ?? 0;
+ selectionDetails.innerHTML = `Sentence ${selectedSentenceIdx} ${currentSentences()[selectedSentenceIdx] || ""}`;
+ metrics.innerHTML = `
+
Outgoing impact ${outgoing.toFixed(4)}
+ Incoming dependence ${incoming.toFixed(4)}
+ `;
+ return;
+ }
+
+ if (selectedCell) {
+ selectionDetails.innerHTML = `Edge ${selectedCell.col} ā ${selectedCell.row} Influence score ${selectedCell.value.toFixed(4)}`;
+ metrics.innerHTML = `
+ Source sentence ${selectedCell.col}
+ Target sentence ${selectedCell.row}
+ `;
+ return;
+ }
+
+ selectionDetails.textContent = "Choose a sentence or heatmap cell.";
+}
+
+async function fetchSession(sessionId) {
+ const response = await fetch(`/api/sessions/${sessionId}/result`);
+ if (!response.ok) {
+ throw new Error(`Failed to fetch session ${sessionId}`);
+ }
+ return response.json();
+}
+
+function updateFromPayload(payload) {
+ activePayload = payload;
+ activeSessionId = payload.session.id;
+ renderSession(payload.session);
+ renderHeatmap(payload.analysis);
+ renderTopEdges(payload.analysis);
+ renderSelection();
+ setExportLinks(payload.session.id, payload.session.status === "completed");
+
+ const { status, error } = payload.session;
+ if (status === "queued") setStatus("Queued", "Waiting for a worker slot.");
+ if (status === "generating") setStatus("Generating", "The model is producing an answer and visible trace.");
+ if (status === "answer_ready") setStatus("Analysis pending", "Answer is ready. Attribution analysis is starting.");
+ if (status === "analyzing") setStatus("Analyzing", "Running forward and backward passes for sentence influence.");
+ if (status === "completed") setStatus("Completed", "Analysis finished.");
+ if (status === "failed") setStatus("Failed", error || "The session failed.");
+
+ if (["completed", "failed"].includes(status) && pollHandle) {
+ window.clearInterval(pollHandle);
+ pollHandle = null;
+ }
+
+ loadRecentSessions().catch(() => {});
+}
+
+async function startPolling(sessionId) {
+ activeSessionId = sessionId;
+ if (pollHandle) {
+ window.clearInterval(pollHandle);
+ }
+ const tick = async () => {
+ try {
+ const payload = await fetchSession(sessionId);
+ updateFromPayload(payload);
+ } catch (error) {
+ setStatus("Error", String(error));
+ window.clearInterval(pollHandle);
+ pollHandle = null;
+ }
+ };
+ await tick();
+ pollHandle = window.setInterval(tick, 2500);
+}
+
+form.addEventListener("submit", async (event) => {
+ event.preventDefault();
+ if (!currentUser?.authenticated && currentUser?.auth_required) {
+ window.location.href = loginLink.href;
+ return;
+ }
+ submitButton.disabled = true;
+ clearSelection();
+ sentenceList.innerHTML = "";
+ heatmap.className = "heatmap placeholder-box";
+ heatmap.textContent = "Analysis pending.";
+ topEdges.innerHTML = "";
+ answerOutput.textContent = "Waiting for generated answer.";
+ sessionIdOutput.textContent = "";
+ traceCount.textContent = "";
+ setExportLinks("", false);
+ setStatus("Submitting", "Creating session.");
+
+ const payload = {
+ question: document.querySelector("#question").value,
+ max_new_tokens: Number(document.querySelector("#max-new-tokens").value),
+ max_trace_tokens: Number(document.querySelector("#max-trace-tokens").value),
+ max_sentences: Number(document.querySelector("#max-sentences").value),
+ };
+
+ try {
+ const response = await fetch("/api/sessions", {
+ method: "POST",
+ headers: { "content-type": "application/json" },
+ body: JSON.stringify(payload),
+ });
+ if (!response.ok) {
+ throw new Error(await response.text());
+ }
+ const session = await response.json();
+ await startPolling(session.id);
+ } catch (error) {
+ setStatus("Error", String(error));
+ } finally {
+ submitButton.disabled = false;
+ }
+});
+
+async function initialize() {
+ setExportLinks("", false);
+ try {
+ const response = await fetch("/api/me");
+ currentUser = await response.json();
+ } catch (_error) {
+ currentUser = {
+ authenticated: false,
+ auth_required: true,
+ login_url: "/oauth/huggingface/login",
+ logout_url: "/oauth/huggingface/logout",
+ };
+ }
+ renderAuth();
+ await loadRecentSessions();
+}
+
+initialize();
diff --git a/app/frontend/index.html b/app/frontend/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..ac4c978cdedd7c8a0c4be401c81bcd1e4294a435
--- /dev/null
+++ b/app/frontend/index.html
@@ -0,0 +1,108 @@
+
+
+
+
+
+ Thought Anchors
+
+
+
+
+
+
+
Public Hugging Face Space
+
Checking sign-in status.
+
+
+
+
+
+ Online Chain-of-Thought Analysis
+ Reasoning traces with white-box sentence influence.
+
+ Submit a question, watch the answer appear first, then inspect the sentence-to-sentence
+ attribution matrix once analysis completes. Results can be exported as JSON or CSV.
+
+
+
+
+
+
+ Idle
+ Submit a question to create a session.
+
+
+
+
+
+
Your Sessions
+ Ephemeral instance history
+
+ Sign in to view your jobs.
+
+
+
+
+
Answer
+
+
+ No answer yet.
+
+
+
+
+
Reasoning Trace
+
+
+
+
+
+
+
+
Sentence Influence Matrix
+ Rows: targets, columns: sources
+
+ Analysis pending.
+
+
+
+
+
Selection
+
+ Choose a sentence or heatmap cell.
+
+
+
+
+
+
+
+
+
+
diff --git a/app/frontend/styles.css b/app/frontend/styles.css
new file mode 100644
index 0000000000000000000000000000000000000000..c4f8f94e7dcf9e3d1cc2a6d5f34072b5c5a98db5
--- /dev/null
+++ b/app/frontend/styles.css
@@ -0,0 +1,335 @@
+:root {
+ --bg: #f3efe6;
+ --panel: rgba(255, 252, 247, 0.9);
+ --ink: #1e1c19;
+ --muted: #6f655c;
+ --accent: #0e5a8a;
+ --accent-soft: #d2ecff;
+ --line: rgba(30, 28, 25, 0.12);
+ --warm: #d76a34;
+ --shadow: 0 20px 60px rgba(34, 25, 16, 0.08);
+}
+
+* {
+ box-sizing: border-box;
+}
+
+body {
+ margin: 0;
+ font-family: "Iowan Old Style", "Palatino Linotype", serif;
+ color: var(--ink);
+ background:
+ radial-gradient(circle at top left, rgba(14, 90, 138, 0.12), transparent 32%),
+ radial-gradient(circle at top right, rgba(215, 106, 52, 0.18), transparent 26%),
+ linear-gradient(180deg, #f8f4ec 0%, var(--bg) 100%);
+}
+
+.page {
+ max-width: 1440px;
+ margin: 0 auto;
+ padding: 40px 24px 48px;
+}
+
+.hero {
+ max-width: 760px;
+ margin-bottom: 28px;
+}
+
+.compact {
+ margin: 0;
+}
+
+.auth-strip {
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ gap: 16px;
+ margin-bottom: 18px;
+}
+
+.auth-actions,
+.export-actions {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 10px;
+}
+
+.eyebrow {
+ margin: 0 0 8px;
+ text-transform: uppercase;
+ letter-spacing: 0.12em;
+ font-size: 0.78rem;
+ color: var(--accent);
+}
+
+h1, h2 {
+ margin: 0;
+ font-weight: 600;
+}
+
+h1 {
+ font-size: clamp(2.4rem, 5vw, 4.4rem);
+ line-height: 0.95;
+}
+
+.lede {
+ margin: 18px 0 0;
+ font-size: 1.08rem;
+ color: var(--muted);
+ max-width: 60ch;
+}
+
+.panel {
+ background: var(--panel);
+ border: 1px solid var(--line);
+ border-radius: 24px;
+ padding: 20px;
+ box-shadow: var(--shadow);
+ backdrop-filter: blur(18px);
+}
+
+.composer {
+ margin-bottom: 16px;
+}
+
+.label,
+.controls span {
+ display: block;
+ margin-bottom: 8px;
+ font-size: 0.92rem;
+ color: var(--muted);
+}
+
+textarea,
+input {
+ width: 100%;
+ border: 1px solid rgba(30, 28, 25, 0.16);
+ border-radius: 14px;
+ padding: 14px 16px;
+ font: inherit;
+ background: rgba(255, 255, 255, 0.72);
+}
+
+textarea {
+ resize: vertical;
+ min-height: 132px;
+}
+
+.controls {
+ display: grid;
+ grid-template-columns: repeat(3, minmax(0, 1fr));
+ gap: 12px;
+ margin: 16px 0;
+}
+
+button {
+ appearance: none;
+ border: none;
+ border-radius: 999px;
+ padding: 14px 20px;
+ font: inherit;
+ font-weight: 600;
+ color: white;
+ background: linear-gradient(135deg, var(--accent), #123954);
+ cursor: pointer;
+}
+
+.pill-link {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ border-radius: 999px;
+ padding: 12px 18px;
+ text-decoration: none;
+ font-weight: 600;
+ color: white;
+ background: linear-gradient(135deg, var(--accent), #123954);
+}
+
+.pill-link.secondary {
+ color: var(--ink);
+ background: rgba(255, 255, 255, 0.78);
+ border: 1px solid var(--line);
+}
+
+.pill-link.is-disabled {
+ pointer-events: none;
+ opacity: 0.45;
+}
+
+button:disabled {
+ opacity: 0.5;
+ cursor: wait;
+}
+
+.status-row {
+ display: flex;
+ align-items: center;
+ gap: 12px;
+ margin-bottom: 16px;
+}
+
+.status-chip {
+ display: inline-flex;
+ align-items: center;
+ border-radius: 999px;
+ padding: 8px 14px;
+ background: var(--accent-soft);
+ color: var(--accent);
+ font-size: 0.9rem;
+}
+
+.status-detail,
+.muted,
+.placeholder {
+ color: var(--muted);
+}
+
+.grid {
+ display: grid;
+ grid-template-columns: 1.2fr 1fr;
+ gap: 16px;
+}
+
+.jobs-panel {
+ grid-column: 1 / -1;
+}
+
+.answer-panel,
+.trace-panel,
+.heatmap-panel,
+.details-panel {
+ min-height: 260px;
+}
+
+.answer-panel,
+.trace-panel {
+ grid-column: span 1;
+}
+
+.heatmap-panel,
+.details-panel {
+ grid-column: span 1;
+}
+
+.panel-heading {
+ display: flex;
+ justify-content: space-between;
+ gap: 12px;
+ align-items: baseline;
+ margin-bottom: 16px;
+}
+
+.sentence-list {
+ list-style: none;
+ margin: 0;
+ padding: 0;
+ max-height: 420px;
+ overflow: auto;
+}
+
+.sentence-item {
+ padding: 10px 12px;
+ border-radius: 14px;
+ border: 1px solid transparent;
+ margin-bottom: 8px;
+ background: rgba(255,255,255,0.45);
+ cursor: pointer;
+}
+
+.sentence-item:hover,
+.sentence-item.is-active {
+ border-color: rgba(14, 90, 138, 0.25);
+ background: rgba(210, 236, 255, 0.5);
+}
+
+.sentence-item strong {
+ display: inline-block;
+ min-width: 2.5rem;
+ color: var(--accent);
+}
+
+.heatmap {
+ display: grid;
+ gap: 4px;
+ overflow: auto;
+ min-height: 300px;
+}
+
+.heatmap-grid {
+ display: grid;
+ gap: 4px;
+}
+
+.heatmap-cell {
+ width: 32px;
+ height: 32px;
+ border-radius: 8px;
+ border: 1px solid rgba(255,255,255,0.4);
+ cursor: pointer;
+ position: relative;
+}
+
+.heatmap-cell.is-selected {
+ outline: 2px solid var(--ink);
+}
+
+.placeholder-box {
+ display: grid;
+ place-items: center;
+ border: 1px dashed var(--line);
+ border-radius: 18px;
+}
+
+.metric-card,
+.edge-card {
+ padding: 12px 14px;
+ border-radius: 16px;
+ background: rgba(255,255,255,0.45);
+ border: 1px solid var(--line);
+ margin-top: 10px;
+}
+
+.recent-sessions {
+ display: grid;
+ gap: 10px;
+}
+
+.session-card {
+ color: var(--ink);
+ border: 1px solid var(--line);
+ border-radius: 16px;
+ background: rgba(255, 255, 255, 0.48);
+ padding: 14px;
+ cursor: pointer;
+ text-align: left;
+}
+
+.session-card.is-active {
+ border-color: rgba(14, 90, 138, 0.25);
+ background: rgba(210, 236, 255, 0.42);
+}
+
+.session-card strong {
+ display: block;
+ margin-bottom: 4px;
+ color: var(--accent);
+}
+
+.metric-card strong,
+.edge-card strong {
+ display: block;
+ color: var(--accent);
+}
+
+@media (max-width: 960px) {
+ .auth-strip,
+ .grid,
+ .controls {
+ grid-template-columns: 1fr;
+ }
+
+ .auth-strip {
+ align-items: flex-start;
+ }
+}
diff --git a/app/generation/__init__.py b/app/generation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e50e4a879de95703f25268152e15662957f5050b
--- /dev/null
+++ b/app/generation/__init__.py
@@ -0,0 +1 @@
+"""Generation utilities for trace-producing model calls."""
diff --git a/app/generation/prompting.py b/app/generation/prompting.py
new file mode 100644
index 0000000000000000000000000000000000000000..29e12964c333740fddad3b0ced214f01961a028e
--- /dev/null
+++ b/app/generation/prompting.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from typing import Any
+
+
+SYSTEM_PROMPT = (
+ "You are a careful reasoning assistant. Respond with your full reasoning inside "
+ "... and then provide a concise final answer after the closing tag."
+)
+
+
+def build_messages(question: str) -> list[dict[str, str]]:
+ return [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": question},
+ ]
+
+
+def render_prompt(tokenizer: Any, question: str) -> str:
+ messages = build_messages(question)
+ if hasattr(tokenizer, "apply_chat_template"):
+ return tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+
+ return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n"
diff --git a/app/generation/service.py b/app/generation/service.py
new file mode 100644
index 0000000000000000000000000000000000000000..923e48e7a1cb59e82348e6b28b8a61d9b2eda509
--- /dev/null
+++ b/app/generation/service.py
@@ -0,0 +1,83 @@
+from __future__ import annotations
+
+import re
+from typing import Any
+
+import torch
+
+from app.analysis.sentence_split import normalize_trace_text
+from app.core.schemas import GenerationMetadata, GenerationResult
+from app.generation.prompting import render_prompt
+
+THINK_BLOCK_RE = re.compile(r"(.*?) ", re.IGNORECASE | re.DOTALL)
+ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE)
+
+
+def _extract_trace_and_answer(text: str) -> tuple[str, str]:
+ match = THINK_BLOCK_RE.search(text)
+ if match:
+ raw_trace = match.group(0)
+ answer = text[match.end() :].strip()
+ if not answer:
+ answer = match.group(1).strip()
+ return raw_trace, answer
+
+ raw_trace = text.strip()
+ answer_match = ANSWER_MARKER_RE.search(text)
+ if answer_match:
+ answer = text[answer_match.end() :].strip()
+ else:
+ paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()]
+ answer = paragraphs[-1] if paragraphs else raw_trace
+ return raw_trace, answer
+
+
+def generate_answer_and_trace(
+ *,
+ question: str,
+ model_name: str,
+ model: Any,
+ tokenizer: Any,
+ max_new_tokens: int = 512,
+ temperature: float = 0.6,
+ top_p: float = 0.95,
+) -> GenerationResult:
+ prompt_text = render_prompt(tokenizer, question)
+ encoded = tokenizer(prompt_text, return_tensors="pt")
+ model_device = next(model.parameters()).device
+ encoded = {key: value.to(model_device) for key, value in encoded.items()}
+ input_length = int(encoded["input_ids"].shape[-1])
+ do_sample = temperature > 0.0
+
+ generation_kwargs: dict[str, Any] = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "pad_token_id": tokenizer.pad_token_id,
+ "eos_token_id": tokenizer.eos_token_id,
+ }
+ if do_sample:
+ generation_kwargs["temperature"] = temperature
+
+ with torch.no_grad():
+ output_ids = model.generate(**encoded, **generation_kwargs)
+
+ generated_ids = output_ids[0, input_length:]
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
+ raw_trace_text, answer = _extract_trace_and_answer(generated_text)
+ normalized_trace_text = normalize_trace_text(raw_trace_text)
+
+ return GenerationResult(
+ question=question,
+ model_name=model_name,
+ answer=answer,
+ raw_generation_text=generated_text,
+ raw_trace_text=raw_trace_text,
+ normalized_trace_text=normalized_trace_text,
+ generation_metadata=GenerationMetadata(
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ do_sample=do_sample,
+ ),
+ )
diff --git a/app/services/__init__.py b/app/services/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..df6f31315514e5aac192dc34aaeb4e48a44f8c4d
--- /dev/null
+++ b/app/services/__init__.py
@@ -0,0 +1 @@
+"""Application services for session orchestration."""
diff --git a/app/services/sessions.py b/app/services/sessions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e06155db33880dbdb5a568016f5695c95949679
--- /dev/null
+++ b/app/services/sessions.py
@@ -0,0 +1,157 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from app.analysis.sentence_split import split_sentences
+from app.core.config import Settings
+from app.core.runtime import load_model_bundle
+from app.core.runtime_pipeline import analyze_generation_result
+from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult
+from app.generation.service import generate_answer_and_trace
+from app.storage.repository import SessionRecord, SessionRepository
+from app.workers.jobs import JobRunner
+
+
+class SessionLimitError(RuntimeError):
+ pass
+
+
+class SessionAccessError(PermissionError):
+ pass
+
+
+@dataclass(slots=True)
+class SessionService:
+ settings: Settings
+ repository: SessionRepository
+ jobs: JobRunner
+
+ def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord:
+ if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs:
+ raise SessionLimitError("The service queue is full. Try again after a few minutes.")
+ if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user:
+ raise SessionLimitError("You already have the maximum number of active analysis jobs.")
+ model_name = request.model_name or self.settings.model_name
+ session = self.repository.create_session(
+ question=request.question,
+ model_name=model_name,
+ owner_id=owner_id,
+ owner_name=owner_name,
+ )
+ self.jobs.submit(self._run_session_pipeline, session.id, request)
+ return session
+
+ def start_analysis(
+ self,
+ session_id: str,
+ *,
+ owner_id: str,
+ request: AnalysisRequest | None = None,
+ ) -> SessionRecord:
+ session = self.repository.get_session(session_id)
+ self._assert_owner(session, owner_id)
+ effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name)
+ self.jobs.submit(self._run_analysis_only, session_id, effective_request)
+ return session
+
+ def get_session_payload(self, session_id: str, *, owner_id: str) -> dict:
+ session = self.repository.get_session(session_id)
+ self._assert_owner(session, owner_id)
+ return self.repository.list_session_payload(session_id)
+
+ def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]:
+ return self.repository.list_sessions_for_owner(owner_id, limit=limit)
+
+ def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult:
+ payload = self.get_session_payload(session_id, owner_id=owner_id)
+ analysis = payload.get("analysis")
+ if analysis is None:
+ raise KeyError(session_id)
+ return AnalysisResult.model_validate(analysis)
+
+ @staticmethod
+ def _assert_owner(session: SessionRecord, owner_id: str) -> None:
+ if session.owner_id != owner_id:
+ raise SessionAccessError("Session belongs to a different user.")
+
+ def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None:
+ try:
+ self.repository.update_status(session_id, status="generating")
+ bundle = load_model_bundle(
+ request.model_name or self.settings.model_name,
+ device_preference=request.device_preference or self.settings.device_preference,
+ dtype_preference=request.dtype_preference or self.settings.dtype_preference,
+ attn_implementation=request.attn_implementation or self.settings.attn_implementation,
+ trust_remote_code=(
+ self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
+ ),
+ low_cpu_mem_usage=(
+ self.settings.low_cpu_mem_usage
+ if request.low_cpu_mem_usage is None
+ else request.low_cpu_mem_usage
+ ),
+ )
+ generation = generate_answer_and_trace(
+ question=request.question,
+ model_name=bundle.model_name,
+ model=bundle.model,
+ tokenizer=bundle.tokenizer,
+ max_new_tokens=request.max_new_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ )
+ sentences = [span.text for span in split_sentences(generation.normalized_trace_text)]
+ self.repository.save_generation_result(session_id, generation, sentences)
+ self.repository.update_status(session_id, status="answer_ready")
+ self._run_analysis_only(session_id, request, generation=generation)
+ except Exception as exc:
+ self.repository.update_status(session_id, status="failed", error=str(exc))
+
+ def _run_analysis_only(
+ self,
+ session_id: str,
+ request: AnalysisRequest,
+ *,
+ generation=None,
+ ) -> None:
+ try:
+ self.repository.update_status(session_id, status="analyzing")
+ if generation is None:
+ payload = self.repository.list_session_payload(session_id)
+ if payload.get("generation_metadata") is not None:
+ generation = GenerationResult(
+ question=payload["question"],
+ model_name=payload["model_name"],
+ answer=payload["answer"],
+ raw_generation_text=payload.get("raw_generation_text", ""),
+ raw_trace_text=payload["raw_trace_text"],
+ normalized_trace_text=payload["normalized_trace_text"],
+ generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]),
+ )
+ result = analyze_generation_result(
+ question=request.question,
+ generation=generation,
+ model_name=request.model_name or self.settings.model_name,
+ take_log=request.take_log,
+ max_sentences=request.max_sentences,
+ max_trace_tokens=request.max_trace_tokens,
+ validate_top_k=request.validate_top_k,
+ device_preference=request.device_preference or self.settings.device_preference,
+ dtype_preference=request.dtype_preference or self.settings.dtype_preference,
+ attn_implementation=request.attn_implementation or self.settings.attn_implementation,
+ trust_remote_code=(
+ self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
+ ),
+ low_cpu_mem_usage=(
+ self.settings.low_cpu_mem_usage
+ if request.low_cpu_mem_usage is None
+ else request.low_cpu_mem_usage
+ ),
+ max_new_tokens=request.max_new_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ )
+ self.repository.save_analysis_result(session_id, result)
+ self.repository.update_status(session_id, status="completed")
+ except Exception as exc:
+ self.repository.update_status(session_id, status="failed", error=str(exc))
diff --git a/app/storage/__init__.py b/app/storage/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc5cba5510dc993738cdcd0e74c464a43755fbb
--- /dev/null
+++ b/app/storage/__init__.py
@@ -0,0 +1 @@
+"""Persistence layer for sessions and analysis results."""
diff --git a/app/storage/db.py b/app/storage/db.py
new file mode 100644
index 0000000000000000000000000000000000000000..312ca13d854196e94d0d72f6ac30dc3868111dea
--- /dev/null
+++ b/app/storage/db.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import sqlite3
+from pathlib import Path
+
+
+def connect(database_path: str) -> sqlite3.Connection:
+ path = Path(database_path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ connection = sqlite3.connect(path, check_same_thread=False)
+ connection.row_factory = sqlite3.Row
+ connection.execute("PRAGMA journal_mode=WAL")
+ connection.execute("PRAGMA foreign_keys=ON")
+ return connection
+
+
+def initialize_schema(connection: sqlite3.Connection) -> None:
+ connection.executescript(
+ """
+ CREATE TABLE IF NOT EXISTS sessions (
+ id TEXT PRIMARY KEY,
+ status TEXT NOT NULL,
+ question TEXT NOT NULL,
+ model_name TEXT NOT NULL,
+ owner_id TEXT NOT NULL,
+ owner_name TEXT,
+ error TEXT,
+ created_at TEXT NOT NULL,
+ updated_at TEXT NOT NULL
+ );
+
+ CREATE TABLE IF NOT EXISTS generation_results (
+ session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
+ answer TEXT NOT NULL,
+ raw_generation_text TEXT NOT NULL,
+ raw_trace_text TEXT NOT NULL,
+ normalized_trace_text TEXT NOT NULL,
+ sentences_json TEXT NOT NULL,
+ generation_metadata_json TEXT NOT NULL
+ );
+
+ CREATE TABLE IF NOT EXISTS analysis_results (
+ session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
+ result_json TEXT NOT NULL
+ );
+ """
+ )
+ columns = {
+ row["name"] for row in connection.execute("PRAGMA table_info(generation_results)").fetchall()
+ }
+ if "raw_generation_text" not in columns:
+ connection.execute(
+ "ALTER TABLE generation_results ADD COLUMN raw_generation_text TEXT NOT NULL DEFAULT ''"
+ )
+ session_columns = {row["name"] for row in connection.execute("PRAGMA table_info(sessions)").fetchall()}
+ if "owner_id" not in session_columns:
+ connection.execute("ALTER TABLE sessions ADD COLUMN owner_id TEXT NOT NULL DEFAULT 'legacy-user'")
+ if "owner_name" not in session_columns:
+ connection.execute("ALTER TABLE sessions ADD COLUMN owner_name TEXT")
+ connection.commit()
diff --git a/app/storage/repository.py b/app/storage/repository.py
new file mode 100644
index 0000000000000000000000000000000000000000..53e20d809e4a937fef38f0d3c50257b180069cf0
--- /dev/null
+++ b/app/storage/repository.py
@@ -0,0 +1,182 @@
+from __future__ import annotations
+
+import json
+import threading
+import uuid
+from dataclasses import dataclass
+from datetime import UTC, datetime
+from typing import Any
+
+from app.core.schemas import AnalysisResult, GenerationResult
+from app.storage.db import connect, initialize_schema
+
+
+def _utc_now() -> str:
+ return datetime.now(UTC).isoformat()
+
+
+@dataclass(slots=True)
+class SessionRecord:
+ id: str
+ status: str
+ question: str
+ model_name: str
+ owner_id: str
+ owner_name: str | None
+ error: str | None
+ created_at: str
+ updated_at: str
+
+
+class SessionRepository:
+ def __init__(self, database_path: str) -> None:
+ self.connection = connect(database_path)
+ initialize_schema(self.connection)
+ self.lock = threading.Lock()
+
+ def create_session(self, *, question: str, model_name: str, owner_id: str, owner_name: str | None) -> SessionRecord:
+ session_id = str(uuid.uuid4())
+ now = _utc_now()
+ with self.lock:
+ self.connection.execute(
+ """
+ INSERT INTO sessions (id, status, question, model_name, owner_id, owner_name, error, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (session_id, "queued", question, model_name, owner_id, owner_name, None, now, now),
+ )
+ self.connection.commit()
+ return self.get_session(session_id)
+
+ def get_session(self, session_id: str) -> SessionRecord:
+ row = self.connection.execute(
+ "SELECT * FROM sessions WHERE id = ?",
+ (session_id,),
+ ).fetchone()
+ if row is None:
+ raise KeyError(session_id)
+ return SessionRecord(**dict(row))
+
+ def list_session_payload(self, session_id: str) -> dict[str, Any]:
+ session = self.get_session(session_id)
+ payload: dict[str, Any] = {
+ "id": session.id,
+ "status": session.status,
+ "question": session.question,
+ "model_name": session.model_name,
+ "owner_id": session.owner_id,
+ "owner_name": session.owner_name,
+ "error": session.error,
+ "created_at": session.created_at,
+ "updated_at": session.updated_at,
+ }
+ generation_row = self.connection.execute(
+ "SELECT * FROM generation_results WHERE session_id = ?",
+ (session_id,),
+ ).fetchone()
+ if generation_row is not None:
+ payload["answer"] = generation_row["answer"]
+ payload["raw_generation_text"] = generation_row["raw_generation_text"]
+ payload["raw_trace_text"] = generation_row["raw_trace_text"]
+ payload["normalized_trace_text"] = generation_row["normalized_trace_text"]
+ payload["sentences"] = json.loads(generation_row["sentences_json"])
+ payload["generation_metadata"] = json.loads(generation_row["generation_metadata_json"])
+ analysis_row = self.connection.execute(
+ "SELECT result_json FROM analysis_results WHERE session_id = ?",
+ (session_id,),
+ ).fetchone()
+ if analysis_row is not None:
+ payload["analysis"] = json.loads(analysis_row["result_json"])
+ return payload
+
+ def list_sessions_for_owner(self, owner_id: str, *, limit: int = 20) -> list[dict[str, Any]]:
+ rows = self.connection.execute(
+ """
+ SELECT id
+ FROM sessions
+ WHERE owner_id = ?
+ ORDER BY updated_at DESC
+ LIMIT ?
+ """,
+ (owner_id, limit),
+ ).fetchall()
+ return [self.list_session_payload(row["id"]) for row in rows]
+
+ def update_status(self, session_id: str, *, status: str, error: str | None = None) -> None:
+ now = _utc_now()
+ with self.lock:
+ self.connection.execute(
+ "UPDATE sessions SET status = ?, error = ?, updated_at = ? WHERE id = ?",
+ (status, error, now, session_id),
+ )
+ self.connection.commit()
+
+ def count_incomplete_sessions(self) -> int:
+ row = self.connection.execute(
+ "SELECT COUNT(*) AS count FROM sessions WHERE status NOT IN ('completed', 'failed')"
+ ).fetchone()
+ return int(row["count"])
+
+ def count_incomplete_sessions_for_owner(self, owner_id: str) -> int:
+ row = self.connection.execute(
+ """
+ SELECT COUNT(*) AS count
+ FROM sessions
+ WHERE owner_id = ? AND status NOT IN ('completed', 'failed')
+ """,
+ (owner_id,),
+ ).fetchone()
+ return int(row["count"])
+
+ def save_generation_result(self, session_id: str, generation: GenerationResult, sentences: list[str]) -> None:
+ with self.lock:
+ self.connection.execute(
+ """
+ INSERT INTO generation_results (
+ session_id,
+ answer,
+ raw_generation_text,
+ raw_trace_text,
+ normalized_trace_text,
+ sentences_json,
+ generation_metadata_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
+ ON CONFLICT(session_id) DO UPDATE SET
+ answer = excluded.answer,
+ raw_generation_text = excluded.raw_generation_text,
+ raw_trace_text = excluded.raw_trace_text,
+ normalized_trace_text = excluded.normalized_trace_text,
+ sentences_json = excluded.sentences_json,
+ generation_metadata_json = excluded.generation_metadata_json
+ """,
+ (
+ session_id,
+ generation.answer,
+ generation.raw_generation_text,
+ generation.raw_trace_text,
+ generation.normalized_trace_text,
+ json.dumps(sentences),
+ json.dumps(generation.generation_metadata.model_dump()),
+ ),
+ )
+ self.connection.execute(
+ "UPDATE sessions SET updated_at = ? WHERE id = ?",
+ (_utc_now(), session_id),
+ )
+ self.connection.commit()
+
+ def save_analysis_result(self, session_id: str, result: AnalysisResult) -> None:
+ with self.lock:
+ self.connection.execute(
+ """
+ INSERT INTO analysis_results (session_id, result_json)
+ VALUES (?, ?)
+ ON CONFLICT(session_id) DO UPDATE SET result_json = excluded.result_json
+ """,
+ (session_id, result.model_dump_json()),
+ )
+ self.connection.execute(
+ "UPDATE sessions SET updated_at = ? WHERE id = ?",
+ (_utc_now(), session_id),
+ )
+ self.connection.commit()
diff --git a/app/workers/__init__.py b/app/workers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c667ea692b3eb86f2af714fd0215f23805d3fd50
--- /dev/null
+++ b/app/workers/__init__.py
@@ -0,0 +1 @@
+"""Background job runner."""
diff --git a/app/workers/jobs.py b/app/workers/jobs.py
new file mode 100644
index 0000000000000000000000000000000000000000..df0fb3cc64026cedcee7481b0af6bd116e0cf24a
--- /dev/null
+++ b/app/workers/jobs.py
@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+from concurrent.futures import Future, ThreadPoolExecutor
+from dataclasses import dataclass
+from typing import Callable
+
+
+@dataclass(slots=True)
+class JobRunner:
+ executor: ThreadPoolExecutor
+
+ def submit(self, fn: Callable[..., None], *args, **kwargs) -> Future:
+ return self.executor.submit(fn, *args, **kwargs)
+
+ def shutdown(self) -> None:
+ self.executor.shutdown(wait=False, cancel_futures=True)
+
+
+def build_job_runner(max_workers: int) -> JobRunner:
+ return JobRunner(executor=ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="cot-anc"))
diff --git a/cot_anc.egg-info/PKG-INFO b/cot_anc.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..0495e7caeedf7cd9f6d462fa0628ccc46b0d2ce8
--- /dev/null
+++ b/cot_anc.egg-info/PKG-INFO
@@ -0,0 +1,147 @@
+Metadata-Version: 2.4
+Name: cot-anc
+Version: 0.1.0
+Summary: Online chain-of-thought analysis with attribution patching
+Requires-Python: >=3.11
+Description-Content-Type: text/markdown
+Requires-Dist: fastapi>=0.115.0
+Requires-Dist: huggingface_hub[oauth]>=0.33.0
+Requires-Dist: numpy>=2.0.0
+Requires-Dist: pydantic>=2.7.0
+Requires-Dist: scipy>=1.13.0
+Requires-Dist: torch>=2.2.0
+Requires-Dist: transformers>=4.44.0
+Requires-Dist: typer>=0.12.3
+Requires-Dist: uvicorn>=0.30.0
+Provides-Extra: dev
+Requires-Dist: pytest>=8.2.0; extra == "dev"
+Provides-Extra: viz
+Requires-Dist: matplotlib>=3.8.0; extra == "viz"
+
+---
+title: Thought Anchors
+emoji: š§
+colorFrom: blue
+colorTo: orange
+sdk: docker
+app_port: 7860
+hf_oauth: true
+pinned: false
+models:
+ - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+---
+
+# Thought Anchors
+
+Public-facing FastAPI service for generating a visible reasoning trace and computing
+sentence-to-sentence attribution on open-weight reasoning models.
+
+The app is now shaped for deployment as a Hugging Face `Docker Space`:
+
+- Hugging Face OAuth sign-in for end users
+- browser UI plus programmatic API
+- per-user ephemeral sessions on the running instance
+- JSON and CSV export for completed analyses
+- adaptive device and dtype loading for CPU, MPS, CUDA, Colab, Kaggle, and cloud GPUs
+
+## What It Can Do
+
+- generate an answer plus visible `... ` trace from a supported causal LM
+- normalize and split the trace into sentences
+- compute a sentence influence matrix with gradient x attention attribution
+- summarize incoming / outgoing importance and top edges
+- expose the workflow through:
+ - CLI
+ - FastAPI
+ - web UI
+ - async session queue
+
+## Current Deployment Target
+
+The primary deployment target is a Hugging Face `Docker Space` running on upgraded GPU
+hardware. The same app can also be run locally or on other cloud GPU hosts.
+
+Important runtime constraints:
+
+- attribution requires `attn_implementation="eager"`
+- the model must expose usable attention tensors and a supported decoder-layer layout
+- long traces are intentionally capped because the analysis path uses a full backward pass
+
+## Local Development
+
+Install dependencies:
+
+```bash
+uv sync
+```
+
+Run the API:
+
+```bash
+uv run python -m app.cli.run_api
+```
+
+Run the CLI:
+
+```bash
+uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
+```
+
+## API
+
+Main endpoints:
+
+- `GET /healthz`
+- `GET /api/me`
+- `POST /api/warmup`
+- `POST /api/analyze`
+- `GET /api/sessions`
+- `POST /api/sessions`
+- `GET /api/sessions/{id}`
+- `GET /api/sessions/{id}/result`
+- `GET /api/sessions/{id}/export.json`
+- `GET /api/sessions/{id}/export.csv`
+
+Example:
+
+```bash
+curl -X POST http://localhost:7860/api/analyze \
+ -H 'content-type: application/json' \
+ -d '{
+ "question": "Explain why the derivative of x^2 is 2x",
+ "max_new_tokens": 128,
+ "max_trace_tokens": 256,
+ "max_sentences": 16,
+ "validate_top_k": 0
+ }'
+```
+
+## Hugging Face Space Setup
+
+Recommended environment variables:
+
+- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
+- `DEVICE_PREFERENCE=auto`
+- `DTYPE_PREFERENCE=auto`
+- `ATTN_IMPLEMENTATION=eager`
+- `LOW_CPU_MEM_USAGE=true`
+- `TRUST_REMOTE_CODE=true`
+- `PRELOAD_MODEL=true`
+- `MAX_TRACE_TOKENS=256`
+- `MAX_SENTENCES=16`
+- `JOB_WORKERS=1`
+- `MAX_QUEUED_JOBS=8`
+- `MAX_ACTIVE_JOBS_PER_USER=2`
+- `REQUIRE_AUTH=true`
+
+Notes:
+
+- local disk is ephemeral; users should export results they want to keep
+- use upgraded GPU hardware for real attribution runs
+- keep trace limits conservative for public traffic
+
+## Notebook
+
+Use [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
+for Colab or Kaggle smoke testing. It installs dependencies, warms the model, runs one
+short attribution analysis, prints the top edges, and renders a simple heatmap.
diff --git a/cot_anc.egg-info/SOURCES.txt b/cot_anc.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..df4de4b428753501cbfab92d1cd5500a55564e64
--- /dev/null
+++ b/cot_anc.egg-info/SOURCES.txt
@@ -0,0 +1,44 @@
+README.md
+pyproject.toml
+app/__init__.py
+app/analysis/__init__.py
+app/analysis/hooks.py
+app/analysis/sentence_split.py
+app/analysis/summaries.py
+app/analysis/suppression.py
+app/analysis/token_boundaries.py
+app/analysis/validation.py
+app/api/__init__.py
+app/api/auth.py
+app/api/main.py
+app/cli/__init__.py
+app/cli/run_api.py
+app/cli/run_prototype.py
+app/core/__init__.py
+app/core/config.py
+app/core/model_support.py
+app/core/runtime.py
+app/core/runtime_pipeline.py
+app/core/schemas.py
+app/generation/__init__.py
+app/generation/prompting.py
+app/generation/service.py
+app/services/__init__.py
+app/services/sessions.py
+app/storage/__init__.py
+app/storage/db.py
+app/storage/repository.py
+app/workers/__init__.py
+app/workers/jobs.py
+cot_anc.egg-info/PKG-INFO
+cot_anc.egg-info/SOURCES.txt
+cot_anc.egg-info/dependency_links.txt
+cot_anc.egg-info/requires.txt
+cot_anc.egg-info/top_level.txt
+tests/test_api.py
+tests/test_model_support.py
+tests/test_runtime_pipeline.py
+tests/test_sentence_split.py
+tests/test_summaries.py
+tests/test_suppression_shape.py
+tests/test_token_boundaries.py
\ No newline at end of file
diff --git a/cot_anc.egg-info/dependency_links.txt b/cot_anc.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/cot_anc.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/cot_anc.egg-info/requires.txt b/cot_anc.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9d7d47c9cfcdbea0a30307aae3d9c1fb1d6b8f49
--- /dev/null
+++ b/cot_anc.egg-info/requires.txt
@@ -0,0 +1,15 @@
+fastapi>=0.115.0
+huggingface_hub[oauth]>=0.33.0
+numpy>=2.0.0
+pydantic>=2.7.0
+scipy>=1.13.0
+torch>=2.2.0
+transformers>=4.44.0
+typer>=0.12.3
+uvicorn>=0.30.0
+
+[dev]
+pytest>=8.2.0
+
+[viz]
+matplotlib>=3.8.0
diff --git a/cot_anc.egg-info/top_level.txt b/cot_anc.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b80f0bd60822d4fa4893de455958ef32f6c521bf
--- /dev/null
+++ b/cot_anc.egg-info/top_level.txt
@@ -0,0 +1 @@
+app
diff --git a/docs/api.md b/docs/api.md
new file mode 100644
index 0000000000000000000000000000000000000000..387813ee7d50df1ada46c496aa861d0d92398283
--- /dev/null
+++ b/docs/api.md
@@ -0,0 +1,83 @@
+# API And Product Behavior
+
+## Auth Model
+
+- If `REQUIRE_AUTH=true`, analysis endpoints require Hugging Face sign-in.
+- Sessions are scoped to user identity.
+- One user cannot fetch another userās session or export.
+
+## Session Model
+
+Statuses:
+
+- `queued`
+- `generating`
+- `answer_ready`
+- `analyzing`
+- `completed`
+- `failed`
+
+Sessions are ephemeral for current running instance.
+
+## Endpoints
+
+### `GET /healthz`
+
+Returns:
+
+- model name
+- device preference
+- dtype preference
+- auth requirement
+- queue limits
+- CUDA / MPS availability
+
+### `GET /api/me`
+
+Returns current auth state plus login/logout URLs.
+
+### `POST /api/warmup`
+
+Loads model with current runtime policy and returns:
+
+- resolved device
+- resolved dtype
+- model attribution capability
+
+### `POST /api/analyze`
+
+Direct synchronous analysis call. Good for trusted programmatic use, not best path for UI.
+
+### `GET /api/sessions`
+
+List current userās recent sessions on current instance.
+
+### `POST /api/sessions`
+
+Create async session job.
+
+Queue protections:
+
+- global queue cap
+- per-user active-job cap
+
+### `GET /api/sessions/{id}`
+
+Return session summary.
+
+### `GET /api/sessions/{id}/result`
+
+Return session summary + full analysis if ready.
+
+### Export
+
+- `GET /api/sessions/{id}/export.json`
+- `GET /api/sessions/{id}/export.csv`
+
+JSON contains session + analysis payload.
+
+CSV contains top edges:
+
+- `source_sentence_idx`
+- `target_sentence_idx`
+- `score`
diff --git a/docs/deploy-huggingface.md b/docs/deploy-huggingface.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad342b5483eb281849f96816dae6de6cc3ee606e
--- /dev/null
+++ b/docs/deploy-huggingface.md
@@ -0,0 +1,69 @@
+# Hugging Face Deployment
+
+Primary target: Hugging Face `Docker Space` on upgraded GPU hardware.
+
+## What Gets Deployed
+
+- FastAPI backend
+- static web frontend
+- Hugging Face OAuth routes
+- ephemeral SQLite-backed session queue
+
+## Required Space Settings
+
+- SDK: `Docker`
+- Port: `7860`
+- OAuth: enabled via README metadata
+- Hardware: upgraded GPU recommended
+
+## Recommended Runtime Variables
+
+Core:
+
+- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
+- `DEVICE_PREFERENCE=auto`
+- `DTYPE_PREFERENCE=auto`
+- `ATTN_IMPLEMENTATION=eager`
+- `LOW_CPU_MEM_USAGE=true`
+- `TRUST_REMOTE_CODE=true`
+- `PRELOAD_MODEL=true`
+
+Traffic limits:
+
+- `MAX_TRACE_TOKENS=256`
+- `MAX_SENTENCES=16`
+- `JOB_WORKERS=1`
+- `MAX_QUEUED_JOBS=8`
+- `MAX_ACTIVE_JOBS_PER_USER=2`
+- `REQUIRE_AUTH=true`
+
+## Deploy Flow
+
+1. Create new Hugging Face Space with `Docker` SDK.
+2. Push repo contents.
+3. Set runtime variables in Space settings.
+4. Upgrade hardware.
+5. Wait for build.
+6. Verify:
+ - `GET /healthz`
+ - sign-in works
+ - one short analysis completes
+ - JSON / CSV export works
+
+## Operational Notes
+
+- Local disk is ephemeral. Session history disappears on restart.
+- OAuth helper is mocked locally but real inside Space.
+- Keep public defaults conservative. Long traces can OOM small GPUs.
+- If queue pressure grows, lower token caps before increasing worker count.
+
+## Common Failure Modes
+
+- `attn_implementation` not eager:
+ - attribution disabled for model
+- unsupported model layout:
+ - generation may work, attribution fails early with clear error
+- OOM:
+ - reduce `MAX_TRACE_TOKENS`, `MAX_SENTENCES`, or choose larger GPU
+- cold start slow:
+ - keep `PRELOAD_MODEL=true`
diff --git a/docs/notebook.md b/docs/notebook.md
new file mode 100644
index 0000000000000000000000000000000000000000..c9df71bf6c96f78f1b3f1c08fbaa593b652c7244
--- /dev/null
+++ b/docs/notebook.md
@@ -0,0 +1,38 @@
+# Notebook Usage
+
+Notebook path:
+
+- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
+
+## Purpose
+
+Short smoke test for:
+
+- Colab GPU
+- Kaggle GPU
+- quick local validation
+
+## What Notebook Does
+
+1. Install runtime deps.
+2. Set conservative env vars.
+3. Import project pipeline.
+4. Run one short attribution job.
+5. Print answer, runtime metadata, top edges.
+6. Render heatmap.
+
+## Best Use
+
+- validate model availability
+- validate driver / torch / transformers stack
+- sanity-check latency before public deploy
+
+## If Notebook Fails
+
+Check:
+
+- GPU available
+- model access permissions
+- enough VRAM
+- eager attention enabled
+- trace limits not too high
diff --git a/docs/runtime.md b/docs/runtime.md
new file mode 100644
index 0000000000000000000000000000000000000000..b55208ebc23888f8fd266d62e89869ed9cca9c00
--- /dev/null
+++ b/docs/runtime.md
@@ -0,0 +1,77 @@
+# Runtime And Model Support
+
+## Execution Model
+
+Pipeline:
+
+1. Generate visible reasoning trace.
+2. Normalize and split trace into sentences.
+3. Map sentence spans to token spans.
+4. Run forward + backward pass.
+5. Build sentence influence matrix from gradient x attention.
+6. Summarize top edges and importance scores.
+
+## Device And Dtype Policy
+
+Default policy:
+
+- CUDA:
+ - `bfloat16` if supported
+ - else `float16`
+- MPS:
+ - `float16`
+- CPU:
+ - `float32`
+
+Override with:
+
+- `DTYPE_PREFERENCE`
+- request `dtype_preference`
+
+## Model Requirements
+
+Model must support all of:
+
+- causal LM generation
+- `output_attentions=True`
+- eager attention
+- supported decoder layer layout
+- supported attention module attribute
+
+Supported layer paths:
+
+- `model.layers`
+- `model.model.layers`
+- `transformer.h`
+- `gpt_neox.layers`
+
+Supported attention attrs:
+
+- `self_attn`
+- `attn`
+- `attention`
+
+## Why Trace Limits Exist
+
+Attribution path uses full backward pass over attention tensors. Cost grows with:
+
+- sequence length
+- layer count
+- head count
+- sentence count
+
+Public defaults stay small to protect uptime.
+
+## Good First Runtime Settings
+
+For public demo:
+
+- `max_new_tokens=128`
+- `max_trace_tokens=256`
+- `max_sentences=16`
+- `validate_top_k=0`
+
+For deeper analysis on bigger GPU:
+
+- raise trace tokens slowly
+- watch latency and memory first
diff --git a/notebooks/hf_space_demo.ipynb b/notebooks/hf_space_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..104b532bf3b4df279509ef862509efe2066e8845
--- /dev/null
+++ b/notebooks/hf_space_demo.ipynb
@@ -0,0 +1,119 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Thought Anchors Demo\n",
+ "\n",
+ "Colab/Kaggle notebook for a short end-to-end attribution smoke test."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -q uv\n",
+ "!uv pip install --system fastapi huggingface_hub matplotlib numpy pydantic scipy torch transformers typer uvicorn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "if not Path('app').exists():\n",
+ " raise RuntimeError('Upload or clone the repository before running the notebook.')\n",
+ "\n",
+ "os.environ.setdefault('MODEL_NAME', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')\n",
+ "os.environ.setdefault('DEVICE_PREFERENCE', 'auto')\n",
+ "os.environ.setdefault('DTYPE_PREFERENCE', 'auto')\n",
+ "os.environ.setdefault('ATTN_IMPLEMENTATION', 'eager')\n",
+ "os.environ.setdefault('LOW_CPU_MEM_USAGE', 'true')\n",
+ "os.environ.setdefault('TRUST_REMOTE_CODE', 'true')\n",
+ "os.environ.setdefault('MAX_TRACE_TOKENS', '256')\n",
+ "os.environ.setdefault('MAX_SENTENCES', '16')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from app.core.runtime_pipeline import compute_attribution_analysis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "question = 'Explain why the derivative of x^2 is 2x.'\n",
+ "\n",
+ "result = compute_attribution_analysis(\n",
+ " question=question,\n",
+ " max_new_tokens=128,\n",
+ " max_trace_tokens=256,\n",
+ " max_sentences=16,\n",
+ " validate_top_k=0,\n",
+ " temperature=0.0,\n",
+ " top_p=1.0,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print('Answer:\\n', result.answer)\n",
+ "print('\\nSentences:', len(result.sentences))\n",
+ "print('Runtime:', result.runtime_metadata.model_dump())\n",
+ "print('\\nTop edges:')\n",
+ "for edge in result.top_edges[:10]:\n",
+ " print(edge)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "matrix = np.array(result.suppression_matrix)\n",
+ "figure, axis = plt.subplots(figsize=(7, 5))\n",
+ "image = axis.imshow(matrix, aspect='auto', cmap='viridis')\n",
+ "axis.set_title('Sentence Influence Matrix')\n",
+ "axis.set_xlabel('Source sentence')\n",
+ "axis.set_ylabel('Target sentence')\n",
+ "figure.colorbar(image, ax=axis)\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..86da7f3af083f7169c8bb511a1a420f487f1d937
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,36 @@
+[project]
+name = "cot-anc"
+version = "0.1.0"
+description = "Online chain-of-thought analysis with attribution patching"
+readme = "README.md"
+requires-python = ">=3.11"
+dependencies = [
+ "fastapi>=0.115.0",
+ "huggingface_hub[oauth]>=0.33.0",
+ "numpy>=2.0.0",
+ "pydantic>=2.7.0",
+ "scipy>=1.13.0",
+ "torch>=2.2.0",
+ "transformers>=4.44.0",
+ "typer>=0.12.3",
+ "uvicorn>=0.30.0",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=8.2.0",
+]
+viz = [
+ "matplotlib>=3.8.0",
+]
+
+[build-system]
+requires = ["setuptools>=68.0"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+include = ["app*"]
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+pythonpath = ["."]
diff --git a/runpod_start.sh b/runpod_start.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee8e47a5ac27f90fbfcb1cd24f577f2a118c8a
--- /dev/null
+++ b/runpod_start.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+export API_HOST="${API_HOST:-0.0.0.0}"
+export API_PORT="${API_PORT:-7860}"
+export DEVICE_PREFERENCE="${DEVICE_PREFERENCE:-auto}"
+export DTYPE_PREFERENCE="${DTYPE_PREFERENCE:-auto}"
+export ATTN_IMPLEMENTATION="${ATTN_IMPLEMENTATION:-eager}"
+export PRELOAD_MODEL="${PRELOAD_MODEL:-true}"
+
+exec uv run python -m app.cli.run_api
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f73d112e061885503c9912f69061645e398801
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import importlib.util
+import re
+from types import SimpleNamespace
+
+import pytest
+
+HAS_TORCH = importlib.util.find_spec("torch") is not None
+if HAS_TORCH:
+ import torch
+else: # pragma: no cover - environment dependent
+ torch = None
+
+
+class MockTokenizer:
+ def __call__(
+ self,
+ text: str,
+ *,
+ add_special_tokens: bool = False,
+ return_offsets_mapping: bool = False,
+ return_tensors: str | None = None,
+ ):
+ if torch is None: # pragma: no cover - environment dependent
+ raise RuntimeError("MockTokenizer requires torch.")
+ matches = list(re.finditer(r"\S+", text))
+ input_ids = [index + 1 for index, _match in enumerate(matches)]
+ result = {
+ "input_ids": torch.tensor([input_ids], dtype=torch.long)
+ if return_tensors == "pt"
+ else [input_ids]
+ }
+ if return_offsets_mapping:
+ offsets = [[match.start(), match.end()] for match in matches]
+ result["offset_mapping"] = (
+ torch.tensor([offsets], dtype=torch.long)
+ if return_tensors == "pt"
+ else offsets
+ )
+ return result
+
+
+if torch is not None:
+ class FakeSelfAttention(torch.nn.Module):
+ def __init__(self, hidden_size: int, heads: int) -> None:
+ super().__init__()
+ self.heads = heads
+ self.scale = hidden_size ** -0.5
+
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs):
+ scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale
+ scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1)
+ causal_mask = torch.triu(
+ torch.ones(
+ scores.shape[-2:],
+ dtype=torch.bool,
+ device=hidden_states.device,
+ ),
+ diagonal=1,
+ )
+ scores = scores.masked_fill(
+ causal_mask.unsqueeze(0).unsqueeze(0),
+ torch.finfo(scores.dtype).min,
+ )
+ if attention_mask is not None and attention_mask.dim() == 4:
+ scores = scores + attention_mask
+ attention = torch.softmax(scores, dim=-1)
+ context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1)
+ return hidden_states + context, attention
+
+
+ class FakeLayer(torch.nn.Module):
+ def __init__(self, hidden_size: int, heads: int) -> None:
+ super().__init__()
+ self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads)
+
+
+ class FakeCausalLM(torch.nn.Module):
+ def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2):
+ super().__init__()
+ self.config = SimpleNamespace(_attn_implementation="eager")
+ self.embed = torch.nn.Embedding(vocab_size, hidden_size)
+ self.model = SimpleNamespace(
+ layers=torch.nn.ModuleList(
+ [FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)]
+ )
+ )
+ self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
+
+ def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs):
+ hidden_states = self.embed(input_ids)
+ attentions = []
+ for layer in self.model.layers:
+ hidden_states, attention = layer.self_attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ attentions.append(attention)
+ logits = self.lm_head(hidden_states)
+ if return_dict:
+ return SimpleNamespace(logits=logits, attentions=tuple(attentions))
+ return logits, tuple(attentions)
+
+
+@pytest.fixture()
+def mock_tokenizer() -> MockTokenizer:
+ if torch is None: # pragma: no cover - environment dependent
+ pytest.skip("torch not installed")
+ return MockTokenizer()
diff --git a/tests/test_api.py b/tests/test_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..44af1161c76d6d393f2a4c876ec3741f4c11aa48
--- /dev/null
+++ b/tests/test_api.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+
+from fastapi.testclient import TestClient
+import pytest
+
+from app.api.main import app
+from app.core.schemas import (
+ AnalysisResult,
+ CurrentUserResponse,
+ ModelCapability,
+ RuntimeMetadata,
+ SessionResponse,
+ ValidationMetadata,
+)
+
+
+@pytest.fixture(autouse=True)
+def auth_override(monkeypatch):
+ monkeypatch.setattr(
+ "app.api.main.require_user",
+ lambda _request, _settings: type(
+ "User",
+ (),
+ {"id": "user-123", "display_name": "Test User", "authenticated": True},
+ )(),
+ )
+ monkeypatch.setattr(
+ "app.api.main.get_optional_user",
+ lambda _request: type(
+ "User",
+ (),
+ {
+ "id": "user-123",
+ "username": "tester",
+ "display_name": "Test User",
+ "avatar_url": "https://example.com/avatar.png",
+ "authenticated": True,
+ },
+ )(),
+ )
+
+
+def test_healthz_returns_runtime_flags() -> None:
+ with TestClient(app) as client:
+ response = client.get("/healthz")
+
+ assert response.status_code == 200
+ payload = response.json()
+ assert payload["status"] == "ok"
+ assert "cuda_available" in payload
+ assert "dtype_preference" in payload
+
+
+def test_analyze_delegates_to_runtime(monkeypatch) -> None:
+ def fake_compute_attribution_analysis(**_kwargs):
+ return AnalysisResult(
+ question="Why?",
+ model_name="fake-model",
+ answer="Because.",
+ raw_trace_text="Alpha. ",
+ normalized_trace_text="Alpha.",
+ sentences=["Alpha."],
+ sentence_token_ranges=[(0, 1)],
+ suppression_matrix=[[0.0]],
+ raw_suppression_matrix=[[0.0]],
+ outgoing_importance=[0.0],
+ incoming_importance=[0.0],
+ top_edges=[],
+ runtime_metadata=RuntimeMetadata(
+ device="cpu",
+ capability=ModelCapability(supports_attribution=True, layer_count=2, attention_impl="eager"),
+ ),
+ validation_metadata=ValidationMetadata(enabled=False, top_k=0),
+ )
+
+ monkeypatch.setattr("app.api.main.compute_attribution_analysis", fake_compute_attribution_analysis)
+
+ with TestClient(app) as client:
+ response = client.post(
+ "/api/analyze",
+ json={
+ "question": "Why?",
+ "max_new_tokens": 8,
+ "validate_top_k": 0,
+ },
+ )
+
+ assert response.status_code == 200
+ payload = response.json()
+ assert payload["answer"] == "Because."
+ assert payload["model_name"] == "fake-model"
+
+
+def test_me_reports_current_user() -> None:
+ with TestClient(app) as client:
+ response = client.get("/api/me")
+
+ assert response.status_code == 200
+ payload = CurrentUserResponse.model_validate(response.json())
+ assert payload.authenticated is True
+ assert payload.username == "tester"
+
+
+def test_root_serves_frontend() -> None:
+ with TestClient(app) as client:
+ response = client.get("/")
+
+ assert response.status_code == 200
+ assert "Thought Anchors" in response.text
+
+
+def test_session_routes_use_service(monkeypatch) -> None:
+ class FakeSessionService:
+ def __init__(self) -> None:
+ self.payload = {
+ "id": "session-123",
+ "status": "completed",
+ "question": "Why?",
+ "model_name": "fake-model",
+ "error": None,
+ "created_at": "2026-04-06T00:00:00+00:00",
+ "updated_at": "2026-04-06T00:00:05+00:00",
+ "answer": "Because.",
+ "raw_trace_text": "Alpha. ",
+ "normalized_trace_text": "Alpha.",
+ "sentences": ["Alpha."],
+ "generation_metadata": {"max_new_tokens": 8},
+ "analysis": AnalysisResult(
+ question="Why?",
+ model_name="fake-model",
+ answer="Because.",
+ raw_trace_text="Alpha. ",
+ normalized_trace_text="Alpha.",
+ sentences=["Alpha."],
+ sentence_token_ranges=[(0, 1)],
+ suppression_matrix=[[0.0]],
+ raw_suppression_matrix=[[0.0]],
+ outgoing_importance=[0.0],
+ incoming_importance=[0.0],
+ top_edges=[],
+ runtime_metadata=RuntimeMetadata(
+ device="cpu",
+ capability=ModelCapability(
+ supports_attribution=True,
+ layer_count=2,
+ attention_impl="eager",
+ ),
+ ),
+ validation_metadata=ValidationMetadata(enabled=False, top_k=0),
+ ).model_dump(),
+ }
+
+ def create_session(self, _request, **_kwargs):
+ return SessionResponse.model_validate(self.payload)
+
+ def get_session_payload(self, _session_id: str, **_kwargs):
+ return self.payload
+
+ def start_analysis(self, _session_id: str, **_kwargs):
+ return SessionResponse.model_validate(self.payload)
+
+ def list_sessions(self, _owner_id: str, **_kwargs):
+ return [self.payload]
+
+ def get_analysis_result(self, _session_id: str, **_kwargs):
+ return AnalysisResult.model_validate(self.payload["analysis"])
+
+ monkeypatch.setattr("app.api.main.get_session_service", lambda: FakeSessionService())
+
+ with TestClient(app) as client:
+ listing = client.get("/api/sessions")
+ created = client.post("/api/sessions", json={"question": "Why?"})
+ session = client.get("/api/sessions/session-123")
+ result = client.get("/api/sessions/session-123/result")
+ exported_json = client.get("/api/sessions/session-123/export.json")
+ exported_csv = client.get("/api/sessions/session-123/export.csv")
+
+ assert listing.status_code == 200
+ assert created.status_code == 200
+ assert session.status_code == 200
+ assert result.status_code == 200
+ assert exported_json.status_code == 200
+ assert exported_csv.status_code == 200
+ assert created.json()["id"] == "session-123"
+ assert session.json()["answer"] == "Because."
+ assert result.json()["analysis"]["model_name"] == "fake-model"
diff --git a/tests/test_model_support.py b/tests/test_model_support.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08e85992d90a7f5eeba3457b5942e75c50411ce
--- /dev/null
+++ b/tests/test_model_support.py
@@ -0,0 +1,36 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+pytest.importorskip("torch")
+
+from app.core.model_support import describe_model_support
+
+
+def test_model_support_accepts_transformer_h_layout() -> None:
+ layer = SimpleNamespace(attn=object())
+ model = SimpleNamespace(
+ transformer=SimpleNamespace(h=[layer, layer]),
+ config=SimpleNamespace(_attn_implementation="eager"),
+ )
+
+ support = describe_model_support(model)
+
+ assert support.supports_attribution is True
+ assert support.layer_path == "transformer.h"
+ assert support.attention_attr == "attn"
+
+
+def test_model_support_rejects_non_eager_attention() -> None:
+ layer = SimpleNamespace(self_attn=object())
+ model = SimpleNamespace(
+ model=SimpleNamespace(layers=[layer]),
+ config=SimpleNamespace(_attn_implementation="sdpa"),
+ )
+
+ support = describe_model_support(model)
+
+ assert support.supports_attribution is False
+ assert "eager" in (support.reason or "")
diff --git a/tests/test_runtime_pipeline.py b/tests/test_runtime_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..c492ee6c9c246012af7af7b3f37f8c2f02c5d689
--- /dev/null
+++ b/tests/test_runtime_pipeline.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+pytest.importorskip("torch")
+
+from app.core.schemas import GenerationMetadata, GenerationResult
+from app.core.runtime_pipeline import compute_attribution_analysis
+
+
+def test_runtime_pipeline_with_mocked_generation(monkeypatch, mock_tokenizer) -> None:
+ from tests.conftest import FakeCausalLM
+
+ fake_model = FakeCausalLM()
+ fake_bundle = SimpleNamespace(
+ model=fake_model,
+ tokenizer=mock_tokenizer,
+ device=next(fake_model.parameters()).device,
+ dtype=next(fake_model.parameters()).dtype,
+ capability=SimpleNamespace(
+ supports_attribution=True,
+ reason=None,
+ layer_path="model.layers",
+ attention_attr="self_attn",
+ layer_count=len(fake_model.model.layers),
+ attention_impl="eager",
+ ),
+ )
+
+ def fake_load_model_bundle(_model_name: str, **_kwargs):
+ return fake_bundle
+
+ def fake_generate_answer_and_trace(**_kwargs):
+ return GenerationResult(
+ question="Why?",
+ model_name="fake-model",
+ answer="Because.",
+ raw_generation_text="Alpha beta. Gamma delta. Epsilon zeta. Because.",
+ raw_trace_text="Alpha beta. Gamma delta. Epsilon zeta. ",
+ normalized_trace_text="Alpha beta. Gamma delta. Epsilon zeta.",
+ generation_metadata=GenerationMetadata(
+ max_new_tokens=32,
+ temperature=0.0,
+ top_p=1.0,
+ do_sample=False,
+ ),
+ )
+
+ monkeypatch.setattr("app.core.runtime_pipeline.load_model_bundle", fake_load_model_bundle)
+ monkeypatch.setattr("app.core.runtime_pipeline.generate_answer_and_trace", fake_generate_answer_and_trace)
+
+ result = compute_attribution_analysis(
+ question="Why?",
+ model_name="fake-model",
+ validate_top_k=0,
+ max_trace_tokens=32,
+ max_sentences=5,
+ take_log=False,
+ )
+
+ assert result.answer == "Because."
+ assert len(result.sentences) == 3
+ assert len(result.suppression_matrix) == 3
+ assert result.validation_metadata is not None
+ assert result.validation_metadata.enabled is False
diff --git a/tests/test_sentence_split.py b/tests/test_sentence_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..f23e563bcff4c3b326b2a9344ade5c92ead76b2d
--- /dev/null
+++ b/tests/test_sentence_split.py
@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from app.analysis.sentence_split import normalize_trace_text, split_sentences
+
+
+def test_sentence_split_preserves_spacing_and_newlines() -> None:
+ raw = "First sentence. Second sentence?\n\nThird sentence! "
+ normalized = normalize_trace_text(raw)
+ spans = split_sentences(normalized)
+
+ assert normalized == "First sentence. Second sentence?\n\nThird sentence!"
+ assert [span.text for span in spans] == [
+ "First sentence. ",
+ "Second sentence?\n\n",
+ "Third sentence!",
+ ]
+ assert spans[1].start_char == len("First sentence. ")
+
diff --git a/tests/test_summaries.py b/tests/test_summaries.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf76bb61cb746138403e22111c885ee4b2049d9d
--- /dev/null
+++ b/tests/test_summaries.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import numpy as np
+
+from app.analysis.summaries import (
+ compute_incoming_importance,
+ compute_outgoing_importance,
+ compute_top_edges,
+)
+
+
+def test_summary_metrics_follow_upper_triangle() -> None:
+ matrix = np.array(
+ [
+ [0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [2.0, 3.0, 0.0],
+ ],
+ dtype=np.float32,
+ )
+
+ outgoing = compute_outgoing_importance(matrix)
+ incoming = compute_incoming_importance(matrix)
+ top_edges = compute_top_edges(matrix, top_k=2)
+
+ assert outgoing == [1.5, 3.0, 0.0]
+ assert incoming == [0.0, 1.0, 2.5]
+ assert [(edge.source_sentence_idx, edge.target_sentence_idx) for edge in top_edges] == [
+ (1, 2),
+ (0, 2),
+ ]
+
diff --git a/tests/test_suppression_shape.py b/tests/test_suppression_shape.py
new file mode 100644
index 0000000000000000000000000000000000000000..34b829bb053fd1e98872d20fb1821dc66df6d03b
--- /dev/null
+++ b/tests/test_suppression_shape.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+torch = pytest.importorskip("torch")
+
+from app.analysis.hooks import get_stored_attentions
+from app.analysis.suppression import compute_attribution_matrix
+
+
+def test_attribution_matrix_shape_and_cleanup() -> None:
+ from tests.conftest import FakeCausalLM
+
+ model = FakeCausalLM()
+ input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
+ token_ranges = [(0, 2), (2, 4), (4, 5)]
+
+ result = compute_attribution_matrix(
+ input_ids=input_ids,
+ token_ranges=token_ranges,
+ model=model,
+ take_log=False,
+ )
+
+ assert result.matrix.shape == (3, 3)
+ assert np.allclose(np.diag(result.raw_matrix), 0.0)
+ assert np.allclose(np.triu(result.raw_matrix), 0.0)
+ assert np.isfinite(result.matrix).all()
+ assert get_stored_attentions() == {}
diff --git a/tests/test_token_boundaries.py b/tests/test_token_boundaries.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d4a4f227a084c4c3ebf5cb04bc6dc40d3e1427
--- /dev/null
+++ b/tests/test_token_boundaries.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import pytest
+
+pytest.importorskip("torch")
+
+from app.analysis.sentence_split import split_sentences
+from app.analysis.token_boundaries import tokenize_with_sentence_ranges
+
+
+def test_token_boundaries_cover_full_sequence(mock_tokenizer) -> None:
+ text = "Alpha beta. Gamma delta."
+ spans = split_sentences(text)
+ mapping = tokenize_with_sentence_ranges(text, spans, mock_tokenizer)
+
+ assert mapping.token_ranges == [(0, 2), (2, 4)]
+ assert mapping.input_ids.shape == (1, 4)