BART-ender commited on
Commit
fda8fb3
·
verified ·
1 Parent(s): fca1de6

Deploy Thought Anchors

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .venv
3
+ .pytest_cache
4
+ __pycache__
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ outputs
9
+ tests
10
+ cot_anc.egg-info
.env.example ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
2
+ DEVICE_PREFERENCE=auto
3
+ DTYPE_PREFERENCE=auto
4
+ ATTN_IMPLEMENTATION=eager
5
+ TRUST_REMOTE_CODE=true
6
+ LOW_CPU_MEM_USAGE=true
7
+ MAX_TRACE_TOKENS=256
8
+ MAX_SENTENCES=16
9
+ TAKE_LOG=true
10
+ PRELOAD_MODEL=true
11
+ REQUIRE_AUTH=true
12
+ MAX_QUEUED_JOBS=8
13
+ MAX_ACTIVE_JOBS_PER_USER=2
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.venv
2
+ /.pytest_cache
3
+ /__pycache__
4
+ /data/app.db
5
+ .DS_Store
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ cot_anc.egg-info/
10
+ outputs/
11
+ uv.lock
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV UV_LINK_MODE=copy
7
+ ENV HOME=/home/user
8
+ ENV PATH=/home/user/.local/bin:$PATH
9
+ ENV HF_HOME=/home/user/.cache/huggingface
10
+ ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
11
+ ENV API_HOST=0.0.0.0
12
+ ENV API_PORT=7860
13
+ ENV DEVICE_PREFERENCE=auto
14
+ ENV DTYPE_PREFERENCE=auto
15
+ ENV ATTN_IMPLEMENTATION=eager
16
+ ENV LOW_CPU_MEM_USAGE=true
17
+ ENV TRUST_REMOTE_CODE=true
18
+ ENV PRELOAD_MODEL=true
19
+ ENV REQUIRE_AUTH=true
20
+
21
+ RUN apt-get update && apt-get install -y --no-install-recommends \
22
+ ca-certificates \
23
+ curl \
24
+ git \
25
+ python3 \
26
+ python3-pip \
27
+ python3-venv \
28
+ && rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ WORKDIR $HOME/app
33
+
34
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh
35
+
36
+ COPY --chown=user pyproject.toml uv.lock README.md .env.example ./
37
+ COPY --chown=user app ./app
38
+ COPY --chown=user notebooks ./notebooks
39
+ COPY --chown=user tests ./tests
40
+
41
+ RUN uv sync --frozen
42
+
43
+ EXPOSE 7860
44
+
45
+ CMD ["uv", "run", "python", "-m", "app.cli.run_api"]
README.md CHANGED
@@ -1,10 +1,85 @@
1
  ---
2
- title: Cot Anc
3
- emoji: 👁
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: docker
 
 
7
  pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Thought Anchors
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
+ app_port: 7860
8
+ hf_oauth: true
9
  pinned: false
10
+ models:
11
+ - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
12
  ---
13
 
14
+ # Thought Anchors
15
+
16
+ Thought Anchors generates visible reasoning traces from open-weight models and
17
+ computes sentence-to-sentence influence with gradient x attention attribution.
18
+
19
+ Current product shape:
20
+
21
+ - Hugging Face `Docker Space` first
22
+ - Hugging Face OAuth sign-in
23
+ - web UI + API
24
+ - per-user ephemeral sessions
25
+ - JSON / CSV export
26
+ - adaptive CPU / MPS / CUDA loading
27
+
28
+ ## Quick Start
29
+
30
+ Install deps:
31
+
32
+ ```bash
33
+ uv sync --extra dev
34
+ ```
35
+
36
+ Run API:
37
+
38
+ ```bash
39
+ uv run python -m app.cli.run_api
40
+ ```
41
+
42
+ Run CLI:
43
+
44
+ ```bash
45
+ uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
46
+ ```
47
+
48
+ Run tests:
49
+
50
+ ```bash
51
+ uv run python -m pytest -q
52
+ ```
53
+
54
+ ## Main Endpoints
55
+
56
+ - `GET /healthz`
57
+ - `GET /api/me`
58
+ - `POST /api/warmup`
59
+ - `POST /api/analyze`
60
+ - `GET /api/sessions`
61
+ - `POST /api/sessions`
62
+ - `GET /api/sessions/{id}`
63
+ - `GET /api/sessions/{id}/result`
64
+ - `GET /api/sessions/{id}/export.json`
65
+ - `GET /api/sessions/{id}/export.csv`
66
+
67
+ ## Docs
68
+
69
+ - [Hugging Face deployment](./docs/deploy-huggingface.md)
70
+ - [Runtime and model support](./docs/runtime.md)
71
+ - [API and product behavior](./docs/api.md)
72
+ - [Notebook usage](./docs/notebook.md)
73
+
74
+ ## Notebook
75
+
76
+ Colab / Kaggle smoke-test notebook:
77
+
78
+ - [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
79
+
80
+ ## Key Constraints
81
+
82
+ - Attribution needs `attn_implementation="eager"`.
83
+ - Model must expose supported decoder layers and attention modules.
84
+ - Long traces stay capped because analysis uses full backward pass.
85
+ - Space disk is ephemeral; export results you want to keep.
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Application package for chain-of-thought attribution analysis."""
app/analysis/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Analysis utilities for attribution patching."""
app/analysis/hooks.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from app.core.model_support import get_decoder_layers
9
+
10
+ _ATTENTION_STORE: dict[int, torch.Tensor] = {}
11
+
12
+
13
+ def clear_stored_attentions() -> None:
14
+ _ATTENTION_STORE.clear()
15
+
16
+
17
+ def get_stored_attentions() -> dict[int, torch.Tensor]:
18
+ return dict(_ATTENTION_STORE)
19
+
20
+
21
+ def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
22
+ if isinstance(output, torch.Tensor):
23
+ return output if output.dim() == 4 else None
24
+
25
+ if isinstance(output, dict):
26
+ for value in output.values():
27
+ if isinstance(value, torch.Tensor) and value.dim() == 4:
28
+ return value
29
+
30
+ if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
31
+ for item in output:
32
+ if isinstance(item, torch.Tensor) and item.dim() == 4:
33
+ return item
34
+
35
+ return None
36
+
37
+
38
+ def _get_attention_impl(model: Any) -> str | None:
39
+ config = getattr(model, "config", None)
40
+ if config is None:
41
+ return None
42
+ return getattr(config, "_attn_implementation", None) or getattr(
43
+ config,
44
+ "attn_implementation",
45
+ None,
46
+ )
47
+
48
+
49
+ def make_attn_hook(layer_idx: int):
50
+ def hook(_module: Any, _inputs: Any, output: Any) -> None:
51
+ attn = _extract_attention_tensor(output)
52
+ if attn is None:
53
+ return
54
+ if attn.dim() != 4:
55
+ raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
56
+ attn.retain_grad()
57
+ _ATTENTION_STORE[layer_idx] = attn
58
+
59
+ return hook
60
+
61
+
62
+ def register_hooks(model: Any) -> list[Any]:
63
+ clear_stored_attentions()
64
+ layers, _layer_path, attention_attr = get_decoder_layers(model)
65
+
66
+ handles: list[Any] = []
67
+ for layer_idx, layer in enumerate(layers):
68
+ self_attn = getattr(layer, attention_attr, None)
69
+ if self_attn is None:
70
+ raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
71
+ handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
72
+ return handles
73
+
74
+
75
+ def remove_hooks(handles: list[Any]) -> None:
76
+ for handle in handles:
77
+ handle.remove()
78
+ clear_stored_attentions()
app/analysis/sentence_split.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+
6
+ THINK_TAG_RE = re.compile(r"</?think>", re.IGNORECASE)
7
+
8
+
9
+ @dataclass(slots=True)
10
+ class SentenceSpan:
11
+ text: str
12
+ start_char: int
13
+ end_char: int
14
+
15
+
16
+ def normalize_trace_text(raw_trace_text: str) -> str:
17
+ return THINK_TAG_RE.sub("", raw_trace_text)
18
+
19
+
20
+ def _non_whitespace_token_count(text: str) -> int:
21
+ return len(re.findall(r"\S+", text))
22
+
23
+
24
+ def split_sentences(
25
+ text: str,
26
+ *,
27
+ min_token_like_units: int = 2,
28
+ ) -> list[SentenceSpan]:
29
+ if not text:
30
+ return []
31
+
32
+ raw_spans: list[tuple[int, int]] = []
33
+ start = 0
34
+ index = 0
35
+ text_length = len(text)
36
+
37
+ while index < text_length:
38
+ if text[index : index + 2] == "\n\n":
39
+ end = index + 2
40
+ raw_spans.append((start, end))
41
+ start = end
42
+ index = end
43
+ continue
44
+
45
+ if text[index] in ".!?":
46
+ end = index + 1
47
+ while end < text_length and text[end] in "\"')]}":
48
+ end += 1
49
+ while end < text_length and text[end].isspace() and text[end : end + 2] != "\n\n":
50
+ end += 1
51
+ raw_spans.append((start, end))
52
+ start = end
53
+ index = end
54
+ continue
55
+
56
+ index += 1
57
+
58
+ if start < text_length:
59
+ raw_spans.append((start, text_length))
60
+
61
+ merged: list[tuple[int, int]] = []
62
+ for span_start, span_end in raw_spans:
63
+ fragment = text[span_start:span_end]
64
+ if not fragment:
65
+ continue
66
+ if merged and _non_whitespace_token_count(fragment) < min_token_like_units:
67
+ previous_start, _ = merged[-1]
68
+ merged[-1] = (previous_start, span_end)
69
+ continue
70
+ merged.append((span_start, span_end))
71
+
72
+ if len(merged) > 1:
73
+ last_start, last_end = merged[-1]
74
+ if _non_whitespace_token_count(text[last_start:last_end]) < min_token_like_units:
75
+ prev_start, _ = merged[-2]
76
+ merged[-2] = (prev_start, last_end)
77
+ merged.pop()
78
+
79
+ return [
80
+ SentenceSpan(
81
+ text=text[span_start:span_end],
82
+ start_char=span_start,
83
+ end_char=span_end,
84
+ )
85
+ for span_start, span_end in merged
86
+ if text[span_start:span_end]
87
+ ]
app/analysis/summaries.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from app.core.schemas import TopEdge
6
+
7
+
8
+ def compute_outgoing_importance(matrix: np.ndarray) -> list[float]:
9
+ sentence_count = matrix.shape[0]
10
+ scores: list[float] = []
11
+ for source_idx in range(sentence_count):
12
+ column = matrix[source_idx + 1 :, source_idx]
13
+ scores.append(float(column.mean()) if column.size else 0.0)
14
+ return scores
15
+
16
+
17
+ def compute_incoming_importance(matrix: np.ndarray) -> list[float]:
18
+ sentence_count = matrix.shape[0]
19
+ scores: list[float] = []
20
+ for target_idx in range(sentence_count):
21
+ row = matrix[target_idx, :target_idx]
22
+ scores.append(float(row.mean()) if row.size else 0.0)
23
+ return scores
24
+
25
+
26
+ def compute_top_edges(matrix: np.ndarray, top_k: int = 10) -> list[TopEdge]:
27
+ sentence_count = matrix.shape[0]
28
+ candidates: list[TopEdge] = []
29
+ for target_idx in range(sentence_count):
30
+ for source_idx in range(target_idx):
31
+ candidates.append(
32
+ TopEdge(
33
+ source_sentence_idx=source_idx,
34
+ target_sentence_idx=target_idx,
35
+ score=float(matrix[target_idx, source_idx]),
36
+ )
37
+ )
38
+ candidates.sort(key=lambda edge: edge.score, reverse=True)
39
+ return candidates[:top_k]
app/analysis/suppression.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import asdict
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
12
+ from app.core.model_support import describe_model_support
13
+ from app.core.schemas import ModelCapability, RuntimeMetadata
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class AttributionMatrixComputation:
18
+ matrix: np.ndarray
19
+ raw_matrix: np.ndarray
20
+ token_nll: np.ndarray
21
+ runtime_metadata: RuntimeMetadata
22
+
23
+
24
+ def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
25
+ if logits.ndim != 3 or input_ids.ndim != 2:
26
+ raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
27
+ if logits.shape[0] != 1 or input_ids.shape[0] != 1:
28
+ raise ValueError("Only batch size 1 is supported for the prototype.")
29
+ if input_ids.shape[1] < 2:
30
+ raise ValueError("Need at least two tokens to compute next-token loss.")
31
+
32
+ shifted_logits = logits[:, :-1, :]
33
+ shifted_targets = input_ids[:, 1:]
34
+ log_probs = torch.log_softmax(shifted_logits, dim=-1)
35
+ gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
36
+ return -gathered[0]
37
+
38
+
39
+ def _current_memory_mb(device: torch.device) -> float | None:
40
+ if device.type != "cuda":
41
+ return None
42
+ return float(torch.cuda.memory_allocated(device) / (1024 * 1024))
43
+
44
+
45
+ def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
46
+ if not take_log:
47
+ return raw_matrix.copy()
48
+ presentation = np.zeros_like(raw_matrix)
49
+ positive = raw_matrix > 0
50
+ presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
51
+ return presentation
52
+
53
+
54
+ def compute_attribution_matrix(
55
+ input_ids: torch.Tensor,
56
+ token_ranges: list[tuple[int, int]],
57
+ model: Any,
58
+ take_log: bool = True,
59
+ max_trace_tokens: int = 0,
60
+ max_sentences: int = 0,
61
+ ) -> AttributionMatrixComputation:
62
+ device = input_ids.device
63
+ handles = register_hooks(model)
64
+ model.zero_grad(set_to_none=True)
65
+ forward_start = time.perf_counter()
66
+ memory_before_mb = _current_memory_mb(device)
67
+
68
+ try:
69
+ with torch.enable_grad():
70
+ outputs = model(
71
+ input_ids=input_ids,
72
+ output_attentions=True,
73
+ return_dict=True,
74
+ )
75
+ forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0
76
+
77
+ logits = outputs.logits
78
+ token_nll = compute_self_token_nll(logits, input_ids)
79
+ loss = token_nll.sum()
80
+
81
+ backward_start = time.perf_counter()
82
+ loss.backward()
83
+ backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0
84
+
85
+ attentions = get_stored_attentions()
86
+ if not attentions:
87
+ raise RuntimeError("No attention tensors were captured. Check eager attention mode.")
88
+
89
+ matrix_start = time.perf_counter()
90
+ sentence_count = len(token_ranges)
91
+ raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)
92
+
93
+ ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
94
+ first_attention = ordered_layers[0]
95
+ num_layers = len(ordered_layers)
96
+ num_heads = int(first_attention.shape[1])
97
+
98
+ for source_idx, (source_start, source_end) in enumerate(token_ranges):
99
+ for target_idx, (target_start, target_end) in enumerate(token_ranges):
100
+ if target_idx <= source_idx:
101
+ continue
102
+
103
+ total = 0.0
104
+ for attention in ordered_layers:
105
+ grad = attention.grad
106
+ if grad is None:
107
+ raise RuntimeError("Attention gradient was not retained for one or more layers.")
108
+ total += -(
109
+ grad[0, :, target_start:target_end, source_start:source_end]
110
+ * attention[0, :, target_start:target_end, source_start:source_end]
111
+ ).sum().item()
112
+
113
+ denominator = max(1, target_end - target_start)
114
+ raw_matrix[target_idx, source_idx] = total / denominator
115
+
116
+ matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
117
+ total_analysis_ms = (
118
+ forward_pass_ms + backward_pass_ms + matrix_computation_ms
119
+ )
120
+ presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)
121
+
122
+ attention_impl = getattr(model.config, "_attn_implementation", "unknown")
123
+ capability = describe_model_support(model)
124
+ runtime_metadata = RuntimeMetadata(
125
+ forward_pass_ms=forward_pass_ms,
126
+ backward_pass_ms=backward_pass_ms,
127
+ matrix_computation_ms=matrix_computation_ms,
128
+ total_analysis_ms=total_analysis_ms,
129
+ num_layers=num_layers,
130
+ num_heads=num_heads,
131
+ sequence_length_tokens=int(input_ids.shape[1]),
132
+ sentence_count=sentence_count,
133
+ device=str(device),
134
+ dtype=str(first_attention.dtype),
135
+ attention_impl=str(attention_impl),
136
+ max_trace_tokens=max_trace_tokens,
137
+ max_sentences=max_sentences,
138
+ capability=ModelCapability.model_validate(asdict(capability)),
139
+ )
140
+
141
+ memory_after_mb = _current_memory_mb(device)
142
+ if memory_before_mb is not None and memory_after_mb is not None:
143
+ runtime_metadata = runtime_metadata.model_copy(
144
+ update={
145
+ "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
146
+ }
147
+ )
148
+
149
+ return AttributionMatrixComputation(
150
+ matrix=presentation_matrix,
151
+ raw_matrix=raw_matrix,
152
+ token_nll=token_nll.detach().cpu().numpy(),
153
+ runtime_metadata=runtime_metadata,
154
+ )
155
+ finally:
156
+ for attention in get_stored_attentions().values():
157
+ attention.grad = None
158
+ remove_hooks(handles)
159
+ model.zero_grad(set_to_none=True)
160
+ if device.type == "cuda":
161
+ torch.cuda.empty_cache()
app/analysis/token_boundaries.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from app.analysis.sentence_split import SentenceSpan
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class TokenizedSentenceMapping:
13
+ input_ids: torch.Tensor
14
+ token_ranges: list[tuple[int, int]]
15
+ offsets: list[tuple[int, int]]
16
+ text: str
17
+
18
+
19
+ def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str:
20
+ if max_tokens <= 0:
21
+ raise ValueError("max_tokens must be positive.")
22
+ encoded = tokenizer(
23
+ text,
24
+ add_special_tokens=False,
25
+ return_offsets_mapping=True,
26
+ )
27
+ offsets = encoded["offset_mapping"]
28
+ if len(offsets) <= max_tokens:
29
+ return text
30
+ end_char = offsets[max_tokens - 1][1]
31
+ return text[:end_char]
32
+
33
+
34
+ def tokenize_with_sentence_ranges(
35
+ text: str,
36
+ sentence_spans: list[SentenceSpan],
37
+ tokenizer: Any,
38
+ ) -> TokenizedSentenceMapping:
39
+ encoded = tokenizer(
40
+ text,
41
+ add_special_tokens=False,
42
+ return_offsets_mapping=True,
43
+ return_tensors="pt",
44
+ )
45
+
46
+ input_ids = encoded["input_ids"]
47
+ raw_offsets = encoded["offset_mapping"][0].tolist()
48
+ offsets = [(int(start), int(end)) for start, end in raw_offsets]
49
+ token_ranges: list[tuple[int, int]] = []
50
+
51
+ for span in sentence_spans:
52
+ overlapping = [
53
+ token_index
54
+ for token_index, (token_start, token_end) in enumerate(offsets)
55
+ if token_end > span.start_char and token_start < span.end_char
56
+ ]
57
+ if not overlapping:
58
+ raise ValueError(
59
+ f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens."
60
+ )
61
+ token_ranges.append((overlapping[0], overlapping[-1] + 1))
62
+
63
+ if token_ranges:
64
+ if token_ranges[0][0] != 0 or token_ranges[-1][1] != len(offsets):
65
+ raise ValueError("Sentence token ranges do not cover the full analyzed sequence.")
66
+ for previous, current in zip(token_ranges, token_ranges[1:]):
67
+ if previous[1] != current[0]:
68
+ raise ValueError("Sentence token ranges are not contiguous.")
69
+
70
+ return TokenizedSentenceMapping(
71
+ input_ids=input_ids,
72
+ token_ranges=token_ranges,
73
+ offsets=offsets,
74
+ text=text,
75
+ )
app/analysis/validation.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ from scipy.stats import pearsonr, spearmanr
8
+
9
+ from app.analysis.suppression import compute_self_token_nll
10
+ from app.core.schemas import TopEdge, ValidationMetadata
11
+
12
+
13
+ def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice:
14
+ start, end = token_range
15
+ return slice(max(0, start - 1), max(0, end - 1))
16
+
17
+
18
+ def build_exact_suppression_mask(
19
+ *,
20
+ sequence_length: int,
21
+ source_range: tuple[int, int],
22
+ target_range: tuple[int, int],
23
+ device: torch.device,
24
+ dtype: torch.dtype,
25
+ ) -> torch.Tensor:
26
+ fill_value = torch.finfo(dtype).min
27
+ mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype)
28
+ future_positions = torch.triu(
29
+ torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool),
30
+ diagonal=1,
31
+ )
32
+ mask = mask.masked_fill(future_positions, fill_value)
33
+ source_start, source_end = source_range
34
+ target_start, target_end = target_range
35
+ mask[target_start:target_end, source_start:source_end] = fill_value
36
+ return mask.unsqueeze(0).unsqueeze(0)
37
+
38
+
39
+ def compute_exact_edge_score(
40
+ *,
41
+ model: Any,
42
+ input_ids: torch.Tensor,
43
+ source_range: tuple[int, int],
44
+ target_range: tuple[int, int],
45
+ baseline_token_nll: np.ndarray,
46
+ ) -> float:
47
+ model_dtype = next(model.parameters()).dtype
48
+ attention_mask = build_exact_suppression_mask(
49
+ sequence_length=int(input_ids.shape[1]),
50
+ source_range=source_range,
51
+ target_range=target_range,
52
+ device=input_ids.device,
53
+ dtype=model_dtype,
54
+ )
55
+ with torch.no_grad():
56
+ outputs = model(
57
+ input_ids=input_ids,
58
+ attention_mask=attention_mask,
59
+ output_attentions=False,
60
+ return_dict=True,
61
+ )
62
+ suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy()
63
+ nll_slice = _nll_slice_for_token_range(target_range)
64
+ return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum())
65
+
66
+
67
+ def validate_top_edges(
68
+ *,
69
+ model: Any,
70
+ input_ids: torch.Tensor,
71
+ token_ranges: list[tuple[int, int]],
72
+ top_edges: list[TopEdge],
73
+ baseline_token_nll: np.ndarray,
74
+ top_k: int,
75
+ ) -> ValidationMetadata:
76
+ if top_k <= 0 or not top_edges:
77
+ return ValidationMetadata(enabled=False, top_k=0)
78
+
79
+ selected_edges = top_edges[:top_k]
80
+ exact_scores: list[float] = []
81
+ attributed_scores: list[float] = []
82
+ compared_edges: list[TopEdge] = []
83
+
84
+ try:
85
+ for edge in selected_edges:
86
+ exact_score = compute_exact_edge_score(
87
+ model=model,
88
+ input_ids=input_ids,
89
+ source_range=token_ranges[edge.source_sentence_idx],
90
+ target_range=token_ranges[edge.target_sentence_idx],
91
+ baseline_token_nll=baseline_token_nll,
92
+ )
93
+ exact_scores.append(exact_score)
94
+ attributed_scores.append(edge.score)
95
+ compared_edges.append(
96
+ TopEdge(
97
+ source_sentence_idx=edge.source_sentence_idx,
98
+ target_sentence_idx=edge.target_sentence_idx,
99
+ score=exact_score,
100
+ )
101
+ )
102
+ except Exception as exc: # pragma: no cover - environment/model dependent
103
+ return ValidationMetadata(
104
+ enabled=True,
105
+ top_k=top_k,
106
+ compared_edges=[],
107
+ notes=f"Exact suppression validation failed: {exc}",
108
+ )
109
+
110
+ pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
111
+ spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
112
+
113
+ return ValidationMetadata(
114
+ enabled=True,
115
+ top_k=top_k,
116
+ pearson=pearson,
117
+ spearman=spearman,
118
+ compared_edges=compared_edges,
119
+ notes="Exact suppression compares sentence-level NLL deltas for selected edges.",
120
+ )
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API package for GPU-hosted analysis service."""
app/api/auth.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from fastapi import HTTPException, Request
6
+ from huggingface_hub import parse_huggingface_oauth
7
+
8
+ from app.core.config import Settings
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class UserContext:
13
+ id: str
14
+ username: str
15
+ display_name: str | None
16
+ avatar_url: str | None
17
+ authenticated: bool
18
+
19
+
20
+ def get_optional_user(request: Request) -> UserContext | None:
21
+ if "session" not in request.scope:
22
+ return None
23
+ try:
24
+ oauth_info = parse_huggingface_oauth(request)
25
+ except AssertionError:
26
+ return None
27
+ if oauth_info is None:
28
+ return None
29
+
30
+ user_info = oauth_info.user_info
31
+ username = user_info.preferred_username or user_info.sub or "hf-user"
32
+ display_name = user_info.name or username
33
+ return UserContext(
34
+ id=user_info.sub or username,
35
+ username=username,
36
+ display_name=display_name,
37
+ avatar_url=user_info.picture,
38
+ authenticated=True,
39
+ )
40
+
41
+
42
+ def require_user(request: Request, settings: Settings) -> UserContext:
43
+ user = get_optional_user(request)
44
+ if user is not None:
45
+ return user
46
+ if settings.require_auth:
47
+ raise HTTPException(status_code=401, detail="Sign in with Hugging Face to use this service.")
48
+ return UserContext(
49
+ id="anonymous",
50
+ username="anonymous",
51
+ display_name="Anonymous",
52
+ avatar_url=None,
53
+ authenticated=False,
54
+ )
app/api/main.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import asynccontextmanager
5
+ from io import StringIO
6
+ from pathlib import Path
7
+ from dataclasses import asdict
8
+ import os
9
+
10
+ import torch
11
+ from fastapi import FastAPI, HTTPException, Request
12
+ from fastapi.responses import FileResponse
13
+ from fastapi.responses import StreamingResponse
14
+ from fastapi.staticfiles import StaticFiles
15
+ from huggingface_hub import attach_huggingface_oauth
16
+
17
+ from app.api.auth import get_optional_user, require_user
18
+ from app.core.config import get_settings
19
+ from app.core.runtime import load_model_bundle
20
+ from app.core.runtime_pipeline import compute_attribution_analysis
21
+ from app.core.schemas import (
22
+ AnalysisRequest,
23
+ AnalysisResult,
24
+ CurrentUserResponse,
25
+ HealthResponse,
26
+ SessionCreateRequest,
27
+ SessionResponse,
28
+ SessionResultResponse,
29
+ WarmupResponse,
30
+ )
31
+ from app.services.sessions import SessionAccessError, SessionLimitError, SessionService
32
+ from app.storage.repository import SessionRepository
33
+ from app.workers.jobs import build_job_runner
34
+
35
+ logger = logging.getLogger(__name__)
36
+ FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
37
+
38
+
39
+ @asynccontextmanager
40
+ async def lifespan(_app: FastAPI):
41
+ settings = get_settings()
42
+ repository = SessionRepository(settings.database_path)
43
+ jobs = build_job_runner(settings.job_workers)
44
+ _app.state.repository = repository
45
+ _app.state.jobs = jobs
46
+ _app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs)
47
+ if settings.preload_model:
48
+ logger.info(
49
+ "Preloading model '%s' on device preference '%s'.",
50
+ settings.model_name,
51
+ settings.device_preference,
52
+ )
53
+ load_model_bundle(
54
+ settings.model_name,
55
+ device_preference=settings.device_preference,
56
+ dtype_preference=settings.dtype_preference,
57
+ attn_implementation=settings.attn_implementation,
58
+ trust_remote_code=settings.trust_remote_code,
59
+ low_cpu_mem_usage=settings.low_cpu_mem_usage,
60
+ )
61
+ yield
62
+ jobs.shutdown()
63
+
64
+
65
+ app = FastAPI(
66
+ title="CoT Attribution Analysis API",
67
+ version="0.1.0",
68
+ lifespan=lifespan,
69
+ )
70
+ if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"):
71
+ attach_huggingface_oauth(app)
72
+ app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
73
+
74
+
75
+ def get_session_service() -> SessionService:
76
+ return app.state.session_service
77
+
78
+
79
+ def _to_session_response(payload: dict) -> SessionResponse:
80
+ return SessionResponse(
81
+ id=payload["id"],
82
+ status=payload["status"],
83
+ question=payload["question"],
84
+ model_name=payload["model_name"],
85
+ error=payload.get("error"),
86
+ created_at=payload["created_at"],
87
+ updated_at=payload["updated_at"],
88
+ answer=payload.get("answer"),
89
+ raw_trace_text=payload.get("raw_trace_text"),
90
+ normalized_trace_text=payload.get("normalized_trace_text"),
91
+ sentences=payload.get("sentences"),
92
+ generation_metadata=payload.get("generation_metadata"),
93
+ )
94
+
95
+
96
+ @app.get("/", include_in_schema=False)
97
+ def index() -> FileResponse:
98
+ return FileResponse(FRONTEND_DIR / "index.html")
99
+
100
+
101
+ @app.get("/healthz", response_model=HealthResponse)
102
+ def healthz() -> HealthResponse:
103
+ settings = get_settings()
104
+ return HealthResponse(
105
+ status="ok",
106
+ model_name=settings.model_name,
107
+ device_preference=settings.device_preference,
108
+ dtype_preference=settings.dtype_preference,
109
+ preload_model=settings.preload_model,
110
+ cuda_available=torch.cuda.is_available(),
111
+ mps_available=torch.backends.mps.is_available(),
112
+ require_auth=settings.require_auth,
113
+ public_api_enabled=settings.public_api_enabled,
114
+ max_queued_jobs=settings.max_queued_jobs,
115
+ max_active_jobs_per_user=settings.max_active_jobs_per_user,
116
+ )
117
+
118
+
119
+ @app.get("/api/me", response_model=CurrentUserResponse)
120
+ def me(request: Request) -> CurrentUserResponse:
121
+ settings = get_settings()
122
+ user = get_optional_user(request)
123
+ return CurrentUserResponse(
124
+ authenticated=user is not None,
125
+ auth_required=settings.require_auth,
126
+ username=user.username if user else None,
127
+ full_name=user.display_name if user else None,
128
+ avatar_url=user.avatar_url if user else None,
129
+ )
130
+
131
+
132
+ @app.post("/api/warmup", response_model=WarmupResponse)
133
+ def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse:
134
+ settings = get_settings()
135
+ bundle = load_model_bundle(
136
+ model_name or settings.model_name,
137
+ device_preference=device_preference or settings.device_preference,
138
+ dtype_preference=settings.dtype_preference,
139
+ attn_implementation=settings.attn_implementation,
140
+ trust_remote_code=settings.trust_remote_code,
141
+ low_cpu_mem_usage=settings.low_cpu_mem_usage,
142
+ )
143
+ return WarmupResponse(
144
+ status="ready",
145
+ model_name=bundle.model_name,
146
+ device=str(bundle.device),
147
+ dtype=str(bundle.dtype),
148
+ capability=asdict(bundle.capability),
149
+ )
150
+
151
+
152
+ @app.post("/api/analyze", response_model=AnalysisResult)
153
+ def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult:
154
+ settings = get_settings()
155
+ require_user(http_request, settings)
156
+ try:
157
+ return compute_attribution_analysis(
158
+ question=request.question,
159
+ model_name=request.model_name,
160
+ take_log=request.take_log,
161
+ max_sentences=request.max_sentences,
162
+ max_trace_tokens=request.max_trace_tokens,
163
+ validate_top_k=request.validate_top_k,
164
+ max_new_tokens=request.max_new_tokens,
165
+ temperature=request.temperature,
166
+ top_p=request.top_p,
167
+ device_preference=request.device_preference,
168
+ dtype_preference=request.dtype_preference,
169
+ attn_implementation=request.attn_implementation,
170
+ trust_remote_code=request.trust_remote_code,
171
+ low_cpu_mem_usage=request.low_cpu_mem_usage,
172
+ )
173
+ except Exception as exc: # pragma: no cover - runtime path
174
+ logger.exception("Analysis request failed")
175
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
176
+
177
+
178
+ @app.get("/api/sessions", response_model=list[SessionResponse])
179
+ def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]:
180
+ settings = get_settings()
181
+ user = require_user(request, settings)
182
+ service = get_session_service()
183
+ payloads = service.list_sessions(user.id, limit=limit)
184
+ return [_to_session_response(payload) for payload in payloads]
185
+
186
+
187
+ @app.post("/api/sessions", response_model=SessionResponse)
188
+ def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse:
189
+ settings = get_settings()
190
+ user = require_user(http_request, settings)
191
+ service = get_session_service()
192
+ try:
193
+ session = service.create_session(
194
+ AnalysisRequest(
195
+ question=request.question,
196
+ model_name=request.model_name,
197
+ take_log=request.take_log,
198
+ max_sentences=request.max_sentences,
199
+ max_trace_tokens=request.max_trace_tokens,
200
+ validate_top_k=request.validate_top_k,
201
+ max_new_tokens=request.max_new_tokens,
202
+ temperature=request.temperature,
203
+ top_p=request.top_p,
204
+ device_preference=request.device_preference,
205
+ dtype_preference=request.dtype_preference,
206
+ attn_implementation=request.attn_implementation,
207
+ trust_remote_code=request.trust_remote_code,
208
+ low_cpu_mem_usage=request.low_cpu_mem_usage,
209
+ ),
210
+ owner_id=user.id,
211
+ owner_name=user.display_name,
212
+ )
213
+ except SessionLimitError as exc:
214
+ raise HTTPException(status_code=429, detail=str(exc)) from exc
215
+ payload = service.get_session_payload(session.id, owner_id=user.id)
216
+ return _to_session_response(payload)
217
+
218
+
219
+ @app.get("/api/sessions/{session_id}", response_model=SessionResponse)
220
+ def get_session(session_id: str, request: Request) -> SessionResponse:
221
+ settings = get_settings()
222
+ user = require_user(request, settings)
223
+ service = get_session_service()
224
+ try:
225
+ payload = service.get_session_payload(session_id, owner_id=user.id)
226
+ except KeyError as exc:
227
+ raise HTTPException(status_code=404, detail="Session not found") from exc
228
+ except SessionAccessError as exc:
229
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
230
+ return _to_session_response(payload)
231
+
232
+
233
+ @app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse)
234
+ def analyze_session(session_id: str, request: Request) -> SessionResponse:
235
+ settings = get_settings()
236
+ user = require_user(request, settings)
237
+ service = get_session_service()
238
+ try:
239
+ session = service.start_analysis(session_id, owner_id=user.id)
240
+ payload = service.get_session_payload(session.id, owner_id=user.id)
241
+ except KeyError as exc:
242
+ raise HTTPException(status_code=404, detail="Session not found") from exc
243
+ except SessionAccessError as exc:
244
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
245
+ return _to_session_response(payload)
246
+
247
+
248
+ @app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse)
249
+ def get_session_result(session_id: str, request: Request) -> SessionResultResponse:
250
+ settings = get_settings()
251
+ user = require_user(request, settings)
252
+ service = get_session_service()
253
+ try:
254
+ payload = service.get_session_payload(session_id, owner_id=user.id)
255
+ except KeyError as exc:
256
+ raise HTTPException(status_code=404, detail="Session not found") from exc
257
+ except SessionAccessError as exc:
258
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
259
+
260
+ session_response = _to_session_response(payload)
261
+ analysis_payload = payload.get("analysis")
262
+ return SessionResultResponse(
263
+ session=session_response,
264
+ analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None,
265
+ )
266
+
267
+
268
+ @app.get("/api/sessions/{session_id}/export.json")
269
+ def export_session_json(session_id: str, request: Request) -> StreamingResponse:
270
+ settings = get_settings()
271
+ user = require_user(request, settings)
272
+ service = get_session_service()
273
+ try:
274
+ payload = service.get_session_payload(session_id, owner_id=user.id)
275
+ except KeyError as exc:
276
+ raise HTTPException(status_code=404, detail="Session not found") from exc
277
+ except SessionAccessError as exc:
278
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
279
+ result = SessionResultResponse(
280
+ session=_to_session_response(payload),
281
+ analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None,
282
+ )
283
+ return StreamingResponse(
284
+ iter([result.model_dump_json(indent=2)]),
285
+ media_type="application/json",
286
+ headers={"content-disposition": f'attachment; filename="{session_id}.json"'},
287
+ )
288
+
289
+
290
+ @app.get("/api/sessions/{session_id}/export.csv")
291
+ def export_session_csv(session_id: str, request: Request) -> StreamingResponse:
292
+ settings = get_settings()
293
+ user = require_user(request, settings)
294
+ service = get_session_service()
295
+ try:
296
+ result = service.get_analysis_result(session_id, owner_id=user.id)
297
+ except KeyError as exc:
298
+ raise HTTPException(status_code=404, detail="Analysis result not found") from exc
299
+ except SessionAccessError as exc:
300
+ raise HTTPException(status_code=403, detail=str(exc)) from exc
301
+
302
+ buffer = StringIO()
303
+ buffer.write("source_sentence_idx,target_sentence_idx,score\n")
304
+ for edge in result.top_edges:
305
+ buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n")
306
+ return StreamingResponse(
307
+ iter([buffer.getvalue()]),
308
+ media_type="text/csv",
309
+ headers={"content-disposition": f'attachment; filename="{session_id}.csv"'},
310
+ )
app/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """CLI entrypoints."""
app/cli/run_api.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import uvicorn
4
+
5
+ from app.core.config import get_settings
6
+
7
+
8
+ def main() -> None:
9
+ settings = get_settings()
10
+ uvicorn.run(
11
+ "app.api.main:app",
12
+ host=settings.api_host,
13
+ port=settings.api_port,
14
+ reload=False,
15
+ )
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()
app/cli/run_prototype.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from datetime import UTC, datetime
5
+ from pathlib import Path
6
+
7
+ import typer
8
+
9
+ from app.core.runtime_pipeline import compute_attribution_analysis
10
+
11
+ cli = typer.Typer(add_completion=False)
12
+
13
+
14
+ def _write_heatmap(path: Path, matrix: list[list[float]]) -> None:
15
+ try:
16
+ import matplotlib.pyplot as plt
17
+ except ImportError as exc: # pragma: no cover - optional dependency
18
+ raise RuntimeError(
19
+ "Heatmap output requires the optional viz dependency: pip install .[viz]"
20
+ ) from exc
21
+
22
+ figure, axis = plt.subplots(figsize=(8, 6))
23
+ image = axis.imshow(matrix, aspect="auto", cmap="viridis")
24
+ axis.set_xlabel("Source sentence")
25
+ axis.set_ylabel("Target sentence")
26
+ axis.set_title("Sentence influence matrix")
27
+ figure.colorbar(image, ax=axis)
28
+ figure.tight_layout()
29
+ figure.savefig(path)
30
+ plt.close(figure)
31
+
32
+
33
+ @cli.command()
34
+ def main(
35
+ question: str,
36
+ model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
37
+ output_json: Path | None = None,
38
+ output_heatmap: Path | None = None,
39
+ max_new_tokens: int = 512,
40
+ max_trace_tokens: int = 1024,
41
+ max_sentences: int = 40,
42
+ take_log: bool = True,
43
+ validate_top_k: int = 3,
44
+ temperature: float = 0.6,
45
+ top_p: float = 0.95,
46
+ device_preference: str = "auto",
47
+ dtype_preference: str = "auto",
48
+ attn_implementation: str = "eager",
49
+ trust_remote_code: bool = True,
50
+ low_cpu_mem_usage: bool = True,
51
+ ) -> None:
52
+ result = compute_attribution_analysis(
53
+ question=question,
54
+ model_name=model_name,
55
+ take_log=take_log,
56
+ max_sentences=max_sentences,
57
+ max_trace_tokens=max_trace_tokens,
58
+ validate_top_k=validate_top_k,
59
+ max_new_tokens=max_new_tokens,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ device_preference=device_preference,
63
+ dtype_preference=dtype_preference,
64
+ attn_implementation=attn_implementation,
65
+ trust_remote_code=trust_remote_code,
66
+ low_cpu_mem_usage=low_cpu_mem_usage,
67
+ )
68
+
69
+ output_dir = Path("outputs")
70
+ output_dir.mkdir(parents=True, exist_ok=True)
71
+ timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
72
+
73
+ json_path = output_json or output_dir / f"analysis_{timestamp}.json"
74
+ json_path.write_text(json.dumps(result.model_dump(), indent=2), encoding="utf-8")
75
+
76
+ if output_heatmap is not None:
77
+ _write_heatmap(output_heatmap, result.suppression_matrix)
78
+
79
+ typer.echo(f"Wrote analysis JSON to {json_path}")
80
+ typer.echo(f"Sentences: {len(result.sentences)}")
81
+ typer.echo(f"Top edge count: {len(result.top_edges)}")
82
+ if result.validation_metadata and result.validation_metadata.enabled:
83
+ typer.echo(
84
+ "Validation:"
85
+ f" pearson={result.validation_metadata.pearson}"
86
+ f" spearman={result.validation_metadata.spearman}"
87
+ )
88
+
89
+
90
+ if __name__ == "__main__":
91
+ cli()
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Core runtime, config, and schema definitions."""
app/core/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from functools import lru_cache
5
+ from dataclasses import dataclass
6
+ from typing import Literal
7
+
8
+
9
+ @dataclass(frozen=True, slots=True)
10
+ class Settings:
11
+ model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
12
+ max_trace_tokens: int = 1024
13
+ max_sentences: int = 40
14
+ take_log: bool = True
15
+ device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto"
16
+ dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto"
17
+ attn_implementation: str = "eager"
18
+ trust_remote_code: bool = True
19
+ low_cpu_mem_usage: bool = True
20
+ preload_model: bool = False
21
+ api_host: str = "0.0.0.0"
22
+ api_port: int = 7860
23
+ database_path: str = "data/app.db"
24
+ job_workers: int = 1
25
+ max_queued_jobs: int = 8
26
+ max_active_jobs_per_user: int = 2
27
+ require_auth: bool = True
28
+ public_api_enabled: bool = True
29
+
30
+
31
+ DEFAULT_SETTINGS = Settings()
32
+
33
+
34
+ @lru_cache(maxsize=1)
35
+ def get_settings() -> Settings:
36
+ take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"}
37
+ trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"}
38
+ low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"}
39
+ require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"}
40
+ public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"}
41
+ return Settings(
42
+ model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name),
43
+ max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)),
44
+ max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)),
45
+ take_log=take_log,
46
+ device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type]
47
+ dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type]
48
+ attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation),
49
+ trust_remote_code=trust_remote_code,
50
+ low_cpu_mem_usage=low_cpu_mem_usage,
51
+ preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"},
52
+ api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host),
53
+ api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)),
54
+ database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path),
55
+ job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)),
56
+ max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)),
57
+ max_active_jobs_per_user=int(
58
+ os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user)
59
+ ),
60
+ require_auth=require_auth,
61
+ public_api_enabled=public_api_enabled,
62
+ )
app/core/model_support.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class ModelSupport:
9
+ supports_attribution: bool
10
+ reason: str | None
11
+ layer_path: str | None
12
+ attention_attr: str | None
13
+ layer_count: int
14
+ attention_impl: str | None
15
+
16
+
17
+ _LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
18
+ ("model", "layers"),
19
+ ("model", "model", "layers"),
20
+ ("transformer", "h"),
21
+ ("gpt_neox", "layers"),
22
+ )
23
+ _ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")
24
+
25
+
26
+ def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
27
+ current = obj
28
+ for segment in path:
29
+ current = getattr(current, segment, None)
30
+ if current is None:
31
+ return None
32
+ return current
33
+
34
+
35
+ def _get_attention_impl(model: Any) -> str | None:
36
+ config = getattr(model, "config", None)
37
+ if config is None:
38
+ return None
39
+ return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)
40
+
41
+
42
+ def describe_model_support(model: Any) -> ModelSupport:
43
+ attn_impl = _get_attention_impl(model)
44
+ layers = None
45
+ layer_path = None
46
+ for candidate in _LAYER_PATH_CANDIDATES:
47
+ maybe_layers = _resolve_attr_chain(model, candidate)
48
+ if maybe_layers is not None:
49
+ layers = list(maybe_layers)
50
+ layer_path = ".".join(candidate)
51
+ break
52
+
53
+ if not layers:
54
+ return ModelSupport(
55
+ supports_attribution=False,
56
+ reason="Unsupported model structure: unable to locate decoder layers.",
57
+ layer_path=layer_path,
58
+ attention_attr=None,
59
+ layer_count=0,
60
+ attention_impl=attn_impl,
61
+ )
62
+
63
+ for attention_attr in _ATTENTION_ATTR_CANDIDATES:
64
+ if all(getattr(layer, attention_attr, None) is not None for layer in layers):
65
+ if attn_impl != "eager":
66
+ return ModelSupport(
67
+ supports_attribution=False,
68
+ reason="Attention gradients require attn_implementation='eager'.",
69
+ layer_path=layer_path,
70
+ attention_attr=attention_attr,
71
+ layer_count=len(layers),
72
+ attention_impl=attn_impl,
73
+ )
74
+ return ModelSupport(
75
+ supports_attribution=True,
76
+ reason=None,
77
+ layer_path=layer_path,
78
+ attention_attr=attention_attr,
79
+ layer_count=len(layers),
80
+ attention_impl=attn_impl,
81
+ )
82
+
83
+ return ModelSupport(
84
+ supports_attribution=False,
85
+ reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
86
+ layer_path=layer_path,
87
+ attention_attr=None,
88
+ layer_count=len(layers),
89
+ attention_impl=attn_impl,
90
+ )
91
+
92
+
93
+ def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
94
+ support = describe_model_support(model)
95
+ if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
96
+ reason = support.reason or "Model does not support attribution analysis."
97
+ raise RuntimeError(reason)
98
+
99
+ layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
100
+ if layers is None:
101
+ raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
102
+ return list(layers), support.layer_path, support.attention_attr
app/core/runtime.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
8
+
9
+ from app.core.model_support import ModelSupport, describe_model_support
10
+
11
+
12
+ @dataclass(slots=True)
13
+ class ModelBundle:
14
+ model_name: str
15
+ model: PreTrainedModel
16
+ tokenizer: PreTrainedTokenizerBase
17
+ device: torch.device
18
+ dtype: torch.dtype
19
+ capability: ModelSupport
20
+
21
+
22
+ def resolve_dtype(preference: str, device: torch.device) -> torch.dtype:
23
+ if preference == "float32":
24
+ return torch.float32
25
+ if preference == "float16":
26
+ return torch.float16
27
+ if preference == "bfloat16":
28
+ return torch.bfloat16
29
+ if device.type == "cuda":
30
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
31
+ if device.type == "mps":
32
+ return torch.float16
33
+ return torch.float32
34
+
35
+
36
+ def resolve_device(preference: str = "auto") -> torch.device:
37
+ if preference == "cuda":
38
+ if not torch.cuda.is_available():
39
+ raise RuntimeError("CUDA requested but not available.")
40
+ return torch.device("cuda")
41
+ if preference == "mps":
42
+ if not torch.backends.mps.is_available():
43
+ raise RuntimeError("MPS requested but not available.")
44
+ return torch.device("mps")
45
+ if preference == "cpu":
46
+ return torch.device("cpu")
47
+ if torch.cuda.is_available():
48
+ return torch.device("cuda")
49
+ if torch.backends.mps.is_available():
50
+ return torch.device("mps")
51
+ return torch.device("cpu")
52
+
53
+
54
+ @lru_cache(maxsize=2)
55
+ def load_model_bundle(
56
+ model_name: str,
57
+ device_preference: str = "auto",
58
+ dtype_preference: str = "auto",
59
+ attn_implementation: str = "eager",
60
+ trust_remote_code: bool = True,
61
+ low_cpu_mem_usage: bool = True,
62
+ ) -> ModelBundle:
63
+ device = resolve_device(device_preference)
64
+ dtype = resolve_dtype(dtype_preference, device)
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ model_name,
71
+ trust_remote_code=trust_remote_code,
72
+ attn_implementation=attn_implementation,
73
+ torch_dtype=dtype,
74
+ low_cpu_mem_usage=low_cpu_mem_usage,
75
+ )
76
+ model.to(device)
77
+ model.eval()
78
+ capability = describe_model_support(model)
79
+
80
+ return ModelBundle(
81
+ model_name=model_name,
82
+ model=model,
83
+ tokenizer=tokenizer,
84
+ device=device,
85
+ dtype=dtype,
86
+ capability=capability,
87
+ )
88
+
89
+
90
+ def compute_attribution_analysis(**kwargs):
91
+ from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis
92
+
93
+ return _compute_attribution_analysis(**kwargs)
app/core/runtime_pipeline.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from app.analysis.sentence_split import split_sentences
8
+ from app.analysis.summaries import (
9
+ compute_incoming_importance,
10
+ compute_outgoing_importance,
11
+ compute_top_edges,
12
+ )
13
+ from app.analysis.suppression import compute_attribution_matrix
14
+ from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit
15
+ from app.analysis.validation import validate_top_edges
16
+ from app.core.config import get_settings
17
+ from app.core.runtime import load_model_bundle
18
+ from app.core.schemas import AnalysisResult, GenerationResult
19
+ from app.generation.service import generate_answer_and_trace
20
+
21
+
22
+ def compute_attribution_analysis(
23
+ *,
24
+ question: str,
25
+ model_name: str | None = None,
26
+ take_log: bool | None = None,
27
+ max_sentences: int | None = None,
28
+ max_trace_tokens: int | None = None,
29
+ validate_top_k: int = 0,
30
+ max_new_tokens: int = 512,
31
+ temperature: float = 0.6,
32
+ top_p: float = 0.95,
33
+ device_preference: str | None = None,
34
+ dtype_preference: str | None = None,
35
+ attn_implementation: str | None = None,
36
+ trust_remote_code: bool | None = None,
37
+ low_cpu_mem_usage: bool | None = None,
38
+ ) -> AnalysisResult:
39
+ generation = None
40
+ return analyze_generation_result(
41
+ question=question,
42
+ generation=generation,
43
+ model_name=model_name,
44
+ take_log=take_log,
45
+ max_sentences=max_sentences,
46
+ max_trace_tokens=max_trace_tokens,
47
+ validate_top_k=validate_top_k,
48
+ max_new_tokens=max_new_tokens,
49
+ temperature=temperature,
50
+ top_p=top_p,
51
+ device_preference=device_preference,
52
+ dtype_preference=dtype_preference,
53
+ attn_implementation=attn_implementation,
54
+ trust_remote_code=trust_remote_code,
55
+ low_cpu_mem_usage=low_cpu_mem_usage,
56
+ )
57
+
58
+
59
+ def analyze_generation_result(
60
+ *,
61
+ question: str,
62
+ generation: GenerationResult | None = None,
63
+ model_name: str | None = None,
64
+ take_log: bool | None = None,
65
+ max_sentences: int | None = None,
66
+ max_trace_tokens: int | None = None,
67
+ validate_top_k: int = 0,
68
+ max_new_tokens: int = 512,
69
+ temperature: float = 0.6,
70
+ top_p: float = 0.95,
71
+ device_preference: str | None = None,
72
+ dtype_preference: str | None = None,
73
+ attn_implementation: str | None = None,
74
+ trust_remote_code: bool | None = None,
75
+ low_cpu_mem_usage: bool | None = None,
76
+ ) -> AnalysisResult:
77
+ settings = get_settings()
78
+ resolved_model_name = model_name or settings.model_name
79
+ resolved_take_log = settings.take_log if take_log is None else take_log
80
+ resolved_max_sentences = max_sentences or settings.max_sentences
81
+ resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens
82
+ resolved_device = device_preference or settings.device_preference
83
+ resolved_dtype = dtype_preference or settings.dtype_preference
84
+ resolved_attn_implementation = attn_implementation or settings.attn_implementation
85
+ resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code
86
+ resolved_low_cpu_mem_usage = (
87
+ settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage
88
+ )
89
+
90
+ bundle = load_model_bundle(
91
+ resolved_model_name,
92
+ device_preference=resolved_device,
93
+ dtype_preference=resolved_dtype,
94
+ attn_implementation=resolved_attn_implementation,
95
+ trust_remote_code=resolved_trust_remote_code,
96
+ low_cpu_mem_usage=resolved_low_cpu_mem_usage,
97
+ )
98
+ if not bundle.capability.supports_attribution:
99
+ reason = bundle.capability.reason or "Model does not support attribution analysis."
100
+ raise RuntimeError(reason)
101
+ if generation is None:
102
+ generation = generate_answer_and_trace(
103
+ question=question,
104
+ model_name=resolved_model_name,
105
+ model=bundle.model,
106
+ tokenizer=bundle.tokenizer,
107
+ max_new_tokens=max_new_tokens,
108
+ temperature=temperature,
109
+ top_p=top_p,
110
+ )
111
+
112
+ truncated_text = truncate_text_to_token_limit(
113
+ generation.normalized_trace_text,
114
+ bundle.tokenizer,
115
+ resolved_max_trace_tokens,
116
+ )
117
+ sentence_spans = split_sentences(truncated_text)
118
+ if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences:
119
+ sentence_spans = sentence_spans[:resolved_max_sentences]
120
+ truncated_text = truncated_text[: sentence_spans[-1].end_char]
121
+ sentence_spans = split_sentences(truncated_text)
122
+
123
+ if not sentence_spans:
124
+ raise RuntimeError("Trace normalization produced no analyzable sentences.")
125
+
126
+ mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer)
127
+ input_ids = mapping.input_ids.to(bundle.device)
128
+ computation = compute_attribution_matrix(
129
+ input_ids=input_ids,
130
+ token_ranges=mapping.token_ranges,
131
+ model=bundle.model,
132
+ take_log=resolved_take_log,
133
+ max_trace_tokens=resolved_max_trace_tokens,
134
+ max_sentences=resolved_max_sentences,
135
+ )
136
+
137
+ outgoing = compute_outgoing_importance(computation.raw_matrix)
138
+ incoming = compute_incoming_importance(computation.raw_matrix)
139
+ top_edges = compute_top_edges(computation.raw_matrix, top_k=10)
140
+ validation = validate_top_edges(
141
+ model=bundle.model,
142
+ input_ids=input_ids,
143
+ token_ranges=mapping.token_ranges,
144
+ top_edges=top_edges,
145
+ baseline_token_nll=computation.token_nll,
146
+ top_k=validate_top_k,
147
+ )
148
+
149
+ return AnalysisResult(
150
+ question=question,
151
+ model_name=resolved_model_name,
152
+ answer=generation.answer,
153
+ raw_trace_text=generation.raw_trace_text,
154
+ normalized_trace_text=truncated_text,
155
+ sentences=[span.text for span in sentence_spans],
156
+ sentence_token_ranges=mapping.token_ranges,
157
+ suppression_matrix=computation.matrix.tolist(),
158
+ raw_suppression_matrix=computation.raw_matrix.tolist(),
159
+ outgoing_importance=outgoing,
160
+ incoming_importance=incoming,
161
+ top_edges=top_edges,
162
+ runtime_metadata=computation.runtime_metadata,
163
+ validation_metadata=validation,
164
+ extra_metadata={
165
+ "raw_generation_text": generation.raw_generation_text,
166
+ "generation_metadata": generation.generation_metadata.model_dump(),
167
+ "effective_runtime": {
168
+ "device_preference": resolved_device,
169
+ "dtype_preference": resolved_dtype,
170
+ "attn_implementation": resolved_attn_implementation,
171
+ "trust_remote_code": resolved_trust_remote_code,
172
+ "low_cpu_mem_usage": resolved_low_cpu_mem_usage,
173
+ },
174
+ },
175
+ )
app/core/schemas.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class TopEdge(BaseModel):
9
+ source_sentence_idx: int
10
+ target_sentence_idx: int
11
+ score: float
12
+
13
+
14
+ class ModelCapability(BaseModel):
15
+ supports_attribution: bool
16
+ reason: str | None = None
17
+ layer_path: str | None = None
18
+ attention_attr: str | None = None
19
+ layer_count: int = 0
20
+ attention_impl: str | None = None
21
+
22
+
23
+ class RuntimeMetadata(BaseModel):
24
+ forward_pass_ms: float = 0.0
25
+ backward_pass_ms: float = 0.0
26
+ matrix_computation_ms: float = 0.0
27
+ total_analysis_ms: float = 0.0
28
+ num_layers: int = 0
29
+ num_heads: int = 0
30
+ sequence_length_tokens: int = 0
31
+ sentence_count: int = 0
32
+ device: str = "unknown"
33
+ dtype: str = "unknown"
34
+ attention_impl: str = "unknown"
35
+ max_trace_tokens: int = 0
36
+ max_sentences: int = 0
37
+ capability: ModelCapability = Field(default_factory=lambda: ModelCapability(supports_attribution=False))
38
+
39
+
40
+ class ValidationMetadata(BaseModel):
41
+ enabled: bool = False
42
+ top_k: int = 0
43
+ pearson: float | None = None
44
+ spearman: float | None = None
45
+ compared_edges: list[TopEdge] = Field(default_factory=list)
46
+ notes: str | None = None
47
+
48
+
49
+ class GenerationMetadata(BaseModel):
50
+ max_new_tokens: int
51
+ temperature: float
52
+ top_p: float
53
+ do_sample: bool
54
+
55
+
56
+ class GenerationResult(BaseModel):
57
+ question: str
58
+ model_name: str
59
+ answer: str
60
+ raw_generation_text: str
61
+ raw_trace_text: str
62
+ normalized_trace_text: str
63
+ generation_metadata: GenerationMetadata
64
+
65
+
66
+ class AnalysisResult(BaseModel):
67
+ question: str
68
+ model_name: str
69
+ answer: str
70
+ raw_trace_text: str
71
+ normalized_trace_text: str
72
+ sentences: list[str]
73
+ sentence_token_ranges: list[tuple[int, int]]
74
+ suppression_matrix: list[list[float]]
75
+ raw_suppression_matrix: list[list[float]] | None = None
76
+ outgoing_importance: list[float]
77
+ incoming_importance: list[float]
78
+ top_edges: list[TopEdge]
79
+ runtime_metadata: RuntimeMetadata
80
+ validation_metadata: ValidationMetadata | None = None
81
+ extra_metadata: dict[str, Any] = Field(default_factory=dict)
82
+
83
+
84
+ class AnalysisRequest(BaseModel):
85
+ question: str
86
+ model_name: str | None = None
87
+ take_log: bool | None = None
88
+ max_sentences: int | None = None
89
+ max_trace_tokens: int | None = None
90
+ validate_top_k: int = 0
91
+ max_new_tokens: int = 256
92
+ temperature: float = 0.6
93
+ top_p: float = 0.95
94
+ device_preference: str | None = None
95
+ dtype_preference: str | None = None
96
+ attn_implementation: str | None = None
97
+ trust_remote_code: bool | None = None
98
+ low_cpu_mem_usage: bool | None = None
99
+
100
+
101
+ class HealthResponse(BaseModel):
102
+ status: str
103
+ model_name: str
104
+ device_preference: str
105
+ dtype_preference: str
106
+ preload_model: bool
107
+ cuda_available: bool
108
+ mps_available: bool
109
+ require_auth: bool
110
+ public_api_enabled: bool
111
+ max_queued_jobs: int
112
+ max_active_jobs_per_user: int
113
+
114
+
115
+ class WarmupResponse(BaseModel):
116
+ status: str
117
+ model_name: str
118
+ device: str
119
+ dtype: str
120
+ capability: ModelCapability
121
+
122
+
123
+ class CurrentUserResponse(BaseModel):
124
+ authenticated: bool
125
+ auth_required: bool
126
+ username: str | None = None
127
+ full_name: str | None = None
128
+ avatar_url: str | None = None
129
+ login_url: str = "/oauth/huggingface/login"
130
+ logout_url: str = "/oauth/huggingface/logout"
131
+
132
+
133
+ class SessionCreateRequest(BaseModel):
134
+ question: str
135
+ model_name: str | None = None
136
+ take_log: bool | None = None
137
+ max_sentences: int | None = None
138
+ max_trace_tokens: int | None = None
139
+ validate_top_k: int = 0
140
+ max_new_tokens: int = 256
141
+ temperature: float = 0.6
142
+ top_p: float = 0.95
143
+ device_preference: str | None = None
144
+ dtype_preference: str | None = None
145
+ attn_implementation: str | None = None
146
+ trust_remote_code: bool | None = None
147
+ low_cpu_mem_usage: bool | None = None
148
+
149
+
150
+ class SessionResponse(BaseModel):
151
+ id: str
152
+ status: str
153
+ question: str
154
+ model_name: str
155
+ error: str | None = None
156
+ created_at: str
157
+ updated_at: str
158
+ answer: str | None = None
159
+ raw_trace_text: str | None = None
160
+ normalized_trace_text: str | None = None
161
+ sentences: list[str] | None = None
162
+ generation_metadata: dict[str, Any] | None = None
163
+
164
+
165
+ class SessionResultResponse(BaseModel):
166
+ session: SessionResponse
167
+ analysis: AnalysisResult | None = None
app/frontend/app.js ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const form = document.querySelector("#question-form");
2
+ const submitButton = document.querySelector("#submit-button");
3
+ const statusChip = document.querySelector("#status-chip");
4
+ const statusDetail = document.querySelector("#status-detail");
5
+ const authSummary = document.querySelector("#auth-summary");
6
+ const loginLink = document.querySelector("#login-link");
7
+ const logoutLink = document.querySelector("#logout-link");
8
+ const recentSessions = document.querySelector("#recent-sessions");
9
+ const answerOutput = document.querySelector("#answer-output");
10
+ const sessionIdOutput = document.querySelector("#session-id");
11
+ const sentenceList = document.querySelector("#sentence-list");
12
+ const traceCount = document.querySelector("#trace-count");
13
+ const heatmap = document.querySelector("#heatmap");
14
+ const selectionDetails = document.querySelector("#selection-details");
15
+ const metrics = document.querySelector("#metrics");
16
+ const topEdges = document.querySelector("#top-edges");
17
+ const exportJson = document.querySelector("#export-json");
18
+ const exportCsv = document.querySelector("#export-csv");
19
+
20
+ let activeSessionId = null;
21
+ let pollHandle = null;
22
+ let activePayload = null;
23
+ let selectedSentenceIdx = null;
24
+ let selectedCell = null;
25
+ let currentUser = null;
26
+
27
+ function currentSentences() {
28
+ return activePayload?.analysis?.sentences || activePayload?.session?.sentences || [];
29
+ }
30
+
31
+ function setStatus(label, detail) {
32
+ statusChip.textContent = label;
33
+ statusDetail.textContent = detail;
34
+ }
35
+
36
+ function clearSelection() {
37
+ selectedSentenceIdx = null;
38
+ selectedCell = null;
39
+ selectionDetails.textContent = "Choose a sentence or heatmap cell.";
40
+ metrics.innerHTML = "";
41
+ }
42
+
43
+ function setExportLinks(sessionId, enabled) {
44
+ exportJson.href = enabled ? `/api/sessions/${sessionId}/export.json` : "#";
45
+ exportCsv.href = enabled ? `/api/sessions/${sessionId}/export.csv` : "#";
46
+ exportJson.classList.toggle("is-disabled", !enabled);
47
+ exportCsv.classList.toggle("is-disabled", !enabled);
48
+ exportJson.setAttribute("aria-disabled", String(!enabled));
49
+ exportCsv.setAttribute("aria-disabled", String(!enabled));
50
+ }
51
+
52
+ function renderRecentSessions(sessions) {
53
+ recentSessions.innerHTML = "";
54
+ if (!currentUser?.authenticated) {
55
+ recentSessions.textContent = currentUser?.auth_required
56
+ ? "Sign in to view your jobs."
57
+ : "Anonymous mode is enabled.";
58
+ return;
59
+ }
60
+ if (!sessions.length) {
61
+ recentSessions.textContent = "No jobs yet on this running Space instance.";
62
+ return;
63
+ }
64
+ sessions.forEach((session) => {
65
+ const item = document.createElement("button");
66
+ item.type = "button";
67
+ item.className = "session-card";
68
+ if (session.id === activeSessionId) {
69
+ item.classList.add("is-active");
70
+ }
71
+ item.innerHTML = `<strong>${session.status}</strong>${session.question}`;
72
+ item.addEventListener("click", () => {
73
+ startPolling(session.id);
74
+ });
75
+ recentSessions.appendChild(item);
76
+ });
77
+ }
78
+
79
+ async function loadRecentSessions() {
80
+ if (!currentUser?.authenticated && currentUser?.auth_required) {
81
+ renderRecentSessions([]);
82
+ return;
83
+ }
84
+ const response = await fetch("/api/sessions?limit=8");
85
+ if (!response.ok) {
86
+ renderRecentSessions([]);
87
+ return;
88
+ }
89
+ renderRecentSessions(await response.json());
90
+ }
91
+
92
+ function renderAuth() {
93
+ if (!currentUser) {
94
+ authSummary.textContent = "Checking sign-in status.";
95
+ return;
96
+ }
97
+ loginLink.hidden = currentUser.authenticated;
98
+ logoutLink.hidden = !currentUser.authenticated;
99
+ if (currentUser.authenticated) {
100
+ authSummary.textContent = `Signed in as ${currentUser.full_name || currentUser.username}.`;
101
+ return;
102
+ }
103
+ authSummary.textContent = currentUser.auth_required
104
+ ? "Sign in with Hugging Face to run analysis jobs."
105
+ : "Anonymous access is enabled for this Space.";
106
+ }
107
+
108
+ function renderSession(session) {
109
+ sessionIdOutput.textContent = session.id ? `Session ${session.id.slice(0, 8)}` : "";
110
+ answerOutput.textContent = session.answer || "Waiting for generated answer.";
111
+ const sentences = currentSentences();
112
+ traceCount.textContent = sentences.length ? `${sentences.length} sentences` : "";
113
+
114
+ sentenceList.innerHTML = "";
115
+ sentences.forEach((sentence, index) => {
116
+ const item = document.createElement("li");
117
+ item.className = "sentence-item";
118
+ item.innerHTML = `<strong>${index}</strong> ${sentence}`;
119
+ item.addEventListener("click", () => {
120
+ selectedSentenceIdx = index;
121
+ selectedCell = null;
122
+ renderSelection();
123
+ });
124
+ sentenceList.appendChild(item);
125
+ });
126
+ }
127
+
128
+ function colorForValue(value, maxAbs) {
129
+ if (!maxAbs) {
130
+ return "rgba(224, 223, 218, 0.8)";
131
+ }
132
+ const normalized = Math.max(-1, Math.min(1, value / maxAbs));
133
+ if (normalized >= 0) {
134
+ const alpha = 0.12 + normalized * 0.88;
135
+ return `rgba(14, 90, 138, ${alpha.toFixed(3)})`;
136
+ }
137
+ const alpha = 0.12 + Math.abs(normalized) * 0.88;
138
+ return `rgba(215, 106, 52, ${alpha.toFixed(3)})`;
139
+ }
140
+
141
+ function renderHeatmap(result) {
142
+ const matrix = result?.suppression_matrix;
143
+ if (!matrix || !matrix.length) {
144
+ heatmap.className = "heatmap placeholder-box";
145
+ heatmap.textContent = "Analysis pending.";
146
+ return;
147
+ }
148
+
149
+ const flatValues = matrix.flat();
150
+ const maxAbs = Math.max(...flatValues.map((value) => Math.abs(value)));
151
+ heatmap.className = "heatmap";
152
+ heatmap.innerHTML = "";
153
+ const grid = document.createElement("div");
154
+ grid.className = "heatmap-grid";
155
+ grid.style.gridTemplateColumns = `repeat(${matrix.length}, 32px)`;
156
+
157
+ matrix.forEach((row, rowIndex) => {
158
+ row.forEach((value, colIndex) => {
159
+ const cell = document.createElement("button");
160
+ cell.type = "button";
161
+ cell.className = "heatmap-cell";
162
+ cell.style.background = colorForValue(value, maxAbs);
163
+ cell.title = `target ${rowIndex} ← source ${colIndex}: ${value.toFixed(4)}`;
164
+ if (selectedCell && selectedCell.row === rowIndex && selectedCell.col === colIndex) {
165
+ cell.classList.add("is-selected");
166
+ }
167
+ cell.addEventListener("click", () => {
168
+ selectedSentenceIdx = null;
169
+ selectedCell = { row: rowIndex, col: colIndex, value };
170
+ renderSelection();
171
+ });
172
+ grid.appendChild(cell);
173
+ });
174
+ });
175
+ heatmap.appendChild(grid);
176
+ }
177
+
178
+ function renderTopEdges(result) {
179
+ topEdges.innerHTML = "";
180
+ const edges = result?.top_edges || [];
181
+ if (!edges.length) {
182
+ return;
183
+ }
184
+ edges.slice(0, 5).forEach((edge) => {
185
+ const item = document.createElement("div");
186
+ item.className = "edge-card";
187
+ item.innerHTML = `<strong>${edge.source_sentence_idx} → ${edge.target_sentence_idx}</strong>${edge.score.toFixed(4)}`;
188
+ topEdges.appendChild(item);
189
+ });
190
+ }
191
+
192
+ function renderSelection() {
193
+ Array.from(sentenceList.children).forEach((item, index) => {
194
+ item.classList.toggle("is-active", selectedSentenceIdx === index);
195
+ });
196
+
197
+ const result = activePayload?.analysis;
198
+ const session = activePayload?.session;
199
+ if (!session) {
200
+ clearSelection();
201
+ return;
202
+ }
203
+
204
+ metrics.innerHTML = "";
205
+ if (selectedSentenceIdx != null && result) {
206
+ const outgoing = result.outgoing_importance[selectedSentenceIdx] ?? 0;
207
+ const incoming = result.incoming_importance[selectedSentenceIdx] ?? 0;
208
+ selectionDetails.innerHTML = `<strong>Sentence ${selectedSentenceIdx}</strong><br>${currentSentences()[selectedSentenceIdx] || ""}`;
209
+ metrics.innerHTML = `
210
+ <div class="metric-card"><strong>Outgoing impact</strong>${outgoing.toFixed(4)}</div>
211
+ <div class="metric-card"><strong>Incoming dependence</strong>${incoming.toFixed(4)}</div>
212
+ `;
213
+ return;
214
+ }
215
+
216
+ if (selectedCell) {
217
+ selectionDetails.innerHTML = `<strong>Edge ${selectedCell.col} → ${selectedCell.row}</strong><br>Influence score ${selectedCell.value.toFixed(4)}`;
218
+ metrics.innerHTML = `
219
+ <div class="metric-card"><strong>Source sentence</strong>${selectedCell.col}</div>
220
+ <div class="metric-card"><strong>Target sentence</strong>${selectedCell.row}</div>
221
+ `;
222
+ return;
223
+ }
224
+
225
+ selectionDetails.textContent = "Choose a sentence or heatmap cell.";
226
+ }
227
+
228
+ async function fetchSession(sessionId) {
229
+ const response = await fetch(`/api/sessions/${sessionId}/result`);
230
+ if (!response.ok) {
231
+ throw new Error(`Failed to fetch session ${sessionId}`);
232
+ }
233
+ return response.json();
234
+ }
235
+
236
+ function updateFromPayload(payload) {
237
+ activePayload = payload;
238
+ activeSessionId = payload.session.id;
239
+ renderSession(payload.session);
240
+ renderHeatmap(payload.analysis);
241
+ renderTopEdges(payload.analysis);
242
+ renderSelection();
243
+ setExportLinks(payload.session.id, payload.session.status === "completed");
244
+
245
+ const { status, error } = payload.session;
246
+ if (status === "queued") setStatus("Queued", "Waiting for a worker slot.");
247
+ if (status === "generating") setStatus("Generating", "The model is producing an answer and visible trace.");
248
+ if (status === "answer_ready") setStatus("Analysis pending", "Answer is ready. Attribution analysis is starting.");
249
+ if (status === "analyzing") setStatus("Analyzing", "Running forward and backward passes for sentence influence.");
250
+ if (status === "completed") setStatus("Completed", "Analysis finished.");
251
+ if (status === "failed") setStatus("Failed", error || "The session failed.");
252
+
253
+ if (["completed", "failed"].includes(status) && pollHandle) {
254
+ window.clearInterval(pollHandle);
255
+ pollHandle = null;
256
+ }
257
+
258
+ loadRecentSessions().catch(() => {});
259
+ }
260
+
261
+ async function startPolling(sessionId) {
262
+ activeSessionId = sessionId;
263
+ if (pollHandle) {
264
+ window.clearInterval(pollHandle);
265
+ }
266
+ const tick = async () => {
267
+ try {
268
+ const payload = await fetchSession(sessionId);
269
+ updateFromPayload(payload);
270
+ } catch (error) {
271
+ setStatus("Error", String(error));
272
+ window.clearInterval(pollHandle);
273
+ pollHandle = null;
274
+ }
275
+ };
276
+ await tick();
277
+ pollHandle = window.setInterval(tick, 2500);
278
+ }
279
+
280
+ form.addEventListener("submit", async (event) => {
281
+ event.preventDefault();
282
+ if (!currentUser?.authenticated && currentUser?.auth_required) {
283
+ window.location.href = loginLink.href;
284
+ return;
285
+ }
286
+ submitButton.disabled = true;
287
+ clearSelection();
288
+ sentenceList.innerHTML = "";
289
+ heatmap.className = "heatmap placeholder-box";
290
+ heatmap.textContent = "Analysis pending.";
291
+ topEdges.innerHTML = "";
292
+ answerOutput.textContent = "Waiting for generated answer.";
293
+ sessionIdOutput.textContent = "";
294
+ traceCount.textContent = "";
295
+ setExportLinks("", false);
296
+ setStatus("Submitting", "Creating session.");
297
+
298
+ const payload = {
299
+ question: document.querySelector("#question").value,
300
+ max_new_tokens: Number(document.querySelector("#max-new-tokens").value),
301
+ max_trace_tokens: Number(document.querySelector("#max-trace-tokens").value),
302
+ max_sentences: Number(document.querySelector("#max-sentences").value),
303
+ };
304
+
305
+ try {
306
+ const response = await fetch("/api/sessions", {
307
+ method: "POST",
308
+ headers: { "content-type": "application/json" },
309
+ body: JSON.stringify(payload),
310
+ });
311
+ if (!response.ok) {
312
+ throw new Error(await response.text());
313
+ }
314
+ const session = await response.json();
315
+ await startPolling(session.id);
316
+ } catch (error) {
317
+ setStatus("Error", String(error));
318
+ } finally {
319
+ submitButton.disabled = false;
320
+ }
321
+ });
322
+
323
+ async function initialize() {
324
+ setExportLinks("", false);
325
+ try {
326
+ const response = await fetch("/api/me");
327
+ currentUser = await response.json();
328
+ } catch (_error) {
329
+ currentUser = {
330
+ authenticated: false,
331
+ auth_required: true,
332
+ login_url: "/oauth/huggingface/login",
333
+ logout_url: "/oauth/huggingface/logout",
334
+ };
335
+ }
336
+ renderAuth();
337
+ await loadRecentSessions();
338
+ }
339
+
340
+ initialize();
app/frontend/index.html ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>Thought Anchors</title>
7
+ <link rel="stylesheet" href="/static/styles.css" />
8
+ </head>
9
+ <body>
10
+ <main class="page">
11
+ <section class="auth-strip panel">
12
+ <div>
13
+ <p class="eyebrow">Public Hugging Face Space</p>
14
+ <p id="auth-summary" class="lede compact">Checking sign-in status.</p>
15
+ </div>
16
+ <div class="auth-actions">
17
+ <a id="login-link" class="pill-link" href="/oauth/huggingface/login">Sign in with Hugging Face</a>
18
+ <a id="logout-link" class="pill-link secondary" href="/oauth/huggingface/logout">Sign out</a>
19
+ </div>
20
+ </section>
21
+
22
+ <section class="hero">
23
+ <p class="eyebrow">Online Chain-of-Thought Analysis</p>
24
+ <h1>Reasoning traces with white-box sentence influence.</h1>
25
+ <p class="lede">
26
+ Submit a question, watch the answer appear first, then inspect the sentence-to-sentence
27
+ attribution matrix once analysis completes. Results can be exported as JSON or CSV.
28
+ </p>
29
+ </section>
30
+
31
+ <section class="panel composer">
32
+ <form id="question-form">
33
+ <label class="label" for="question">Question</label>
34
+ <textarea id="question" name="question" rows="5" placeholder="Explain why the derivative of x^2 is 2x" required></textarea>
35
+ <div class="controls">
36
+ <label>
37
+ <span>Max new tokens</span>
38
+ <input id="max-new-tokens" type="number" min="16" max="512" value="128" />
39
+ </label>
40
+ <label>
41
+ <span>Max trace tokens</span>
42
+ <input id="max-trace-tokens" type="number" min="64" max="1024" value="256" />
43
+ </label>
44
+ <label>
45
+ <span>Max sentences</span>
46
+ <input id="max-sentences" type="number" min="4" max="40" value="16" />
47
+ </label>
48
+ </div>
49
+ <button id="submit-button" type="submit">Start analysis</button>
50
+ </form>
51
+ </section>
52
+
53
+ <section class="status-row">
54
+ <div class="status-chip" id="status-chip">Idle</div>
55
+ <div class="status-detail" id="status-detail">Submit a question to create a session.</div>
56
+ </section>
57
+
58
+ <section class="grid">
59
+ <article class="panel jobs-panel">
60
+ <div class="panel-heading">
61
+ <h2>Your Sessions</h2>
62
+ <span class="muted">Ephemeral instance history</span>
63
+ </div>
64
+ <div id="recent-sessions" class="recent-sessions placeholder">Sign in to view your jobs.</div>
65
+ </article>
66
+
67
+ <article class="panel answer-panel">
68
+ <div class="panel-heading">
69
+ <h2>Answer</h2>
70
+ <span id="session-id" class="muted"></span>
71
+ </div>
72
+ <p id="answer-output" class="placeholder">No answer yet.</p>
73
+ </article>
74
+
75
+ <article class="panel trace-panel">
76
+ <div class="panel-heading">
77
+ <h2>Reasoning Trace</h2>
78
+ <span class="muted" id="trace-count"></span>
79
+ </div>
80
+ <ol id="sentence-list" class="sentence-list"></ol>
81
+ </article>
82
+
83
+ <article class="panel heatmap-panel">
84
+ <div class="panel-heading">
85
+ <h2>Sentence Influence Matrix</h2>
86
+ <span class="muted">Rows: targets, columns: sources</span>
87
+ </div>
88
+ <div id="heatmap" class="heatmap placeholder-box">Analysis pending.</div>
89
+ </article>
90
+
91
+ <article class="panel details-panel">
92
+ <div class="panel-heading">
93
+ <h2>Selection</h2>
94
+ </div>
95
+ <div id="selection-details" class="placeholder">Choose a sentence or heatmap cell.</div>
96
+ <div class="metrics" id="metrics"></div>
97
+ <div class="edges" id="top-edges"></div>
98
+ <div class="export-actions">
99
+ <a id="export-json" class="pill-link is-disabled" href="#" aria-disabled="true">Export JSON</a>
100
+ <a id="export-csv" class="pill-link secondary is-disabled" href="#" aria-disabled="true">Export CSV</a>
101
+ </div>
102
+ </article>
103
+ </section>
104
+ </main>
105
+
106
+ <script src="/static/app.js" defer></script>
107
+ </body>
108
+ </html>
app/frontend/styles.css ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg: #f3efe6;
3
+ --panel: rgba(255, 252, 247, 0.9);
4
+ --ink: #1e1c19;
5
+ --muted: #6f655c;
6
+ --accent: #0e5a8a;
7
+ --accent-soft: #d2ecff;
8
+ --line: rgba(30, 28, 25, 0.12);
9
+ --warm: #d76a34;
10
+ --shadow: 0 20px 60px rgba(34, 25, 16, 0.08);
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ body {
18
+ margin: 0;
19
+ font-family: "Iowan Old Style", "Palatino Linotype", serif;
20
+ color: var(--ink);
21
+ background:
22
+ radial-gradient(circle at top left, rgba(14, 90, 138, 0.12), transparent 32%),
23
+ radial-gradient(circle at top right, rgba(215, 106, 52, 0.18), transparent 26%),
24
+ linear-gradient(180deg, #f8f4ec 0%, var(--bg) 100%);
25
+ }
26
+
27
+ .page {
28
+ max-width: 1440px;
29
+ margin: 0 auto;
30
+ padding: 40px 24px 48px;
31
+ }
32
+
33
+ .hero {
34
+ max-width: 760px;
35
+ margin-bottom: 28px;
36
+ }
37
+
38
+ .compact {
39
+ margin: 0;
40
+ }
41
+
42
+ .auth-strip {
43
+ display: flex;
44
+ align-items: center;
45
+ justify-content: space-between;
46
+ gap: 16px;
47
+ margin-bottom: 18px;
48
+ }
49
+
50
+ .auth-actions,
51
+ .export-actions {
52
+ display: flex;
53
+ flex-wrap: wrap;
54
+ gap: 10px;
55
+ }
56
+
57
+ .eyebrow {
58
+ margin: 0 0 8px;
59
+ text-transform: uppercase;
60
+ letter-spacing: 0.12em;
61
+ font-size: 0.78rem;
62
+ color: var(--accent);
63
+ }
64
+
65
+ h1, h2 {
66
+ margin: 0;
67
+ font-weight: 600;
68
+ }
69
+
70
+ h1 {
71
+ font-size: clamp(2.4rem, 5vw, 4.4rem);
72
+ line-height: 0.95;
73
+ }
74
+
75
+ .lede {
76
+ margin: 18px 0 0;
77
+ font-size: 1.08rem;
78
+ color: var(--muted);
79
+ max-width: 60ch;
80
+ }
81
+
82
+ .panel {
83
+ background: var(--panel);
84
+ border: 1px solid var(--line);
85
+ border-radius: 24px;
86
+ padding: 20px;
87
+ box-shadow: var(--shadow);
88
+ backdrop-filter: blur(18px);
89
+ }
90
+
91
+ .composer {
92
+ margin-bottom: 16px;
93
+ }
94
+
95
+ .label,
96
+ .controls span {
97
+ display: block;
98
+ margin-bottom: 8px;
99
+ font-size: 0.92rem;
100
+ color: var(--muted);
101
+ }
102
+
103
+ textarea,
104
+ input {
105
+ width: 100%;
106
+ border: 1px solid rgba(30, 28, 25, 0.16);
107
+ border-radius: 14px;
108
+ padding: 14px 16px;
109
+ font: inherit;
110
+ background: rgba(255, 255, 255, 0.72);
111
+ }
112
+
113
+ textarea {
114
+ resize: vertical;
115
+ min-height: 132px;
116
+ }
117
+
118
+ .controls {
119
+ display: grid;
120
+ grid-template-columns: repeat(3, minmax(0, 1fr));
121
+ gap: 12px;
122
+ margin: 16px 0;
123
+ }
124
+
125
+ button {
126
+ appearance: none;
127
+ border: none;
128
+ border-radius: 999px;
129
+ padding: 14px 20px;
130
+ font: inherit;
131
+ font-weight: 600;
132
+ color: white;
133
+ background: linear-gradient(135deg, var(--accent), #123954);
134
+ cursor: pointer;
135
+ }
136
+
137
+ .pill-link {
138
+ display: inline-flex;
139
+ align-items: center;
140
+ justify-content: center;
141
+ border-radius: 999px;
142
+ padding: 12px 18px;
143
+ text-decoration: none;
144
+ font-weight: 600;
145
+ color: white;
146
+ background: linear-gradient(135deg, var(--accent), #123954);
147
+ }
148
+
149
+ .pill-link.secondary {
150
+ color: var(--ink);
151
+ background: rgba(255, 255, 255, 0.78);
152
+ border: 1px solid var(--line);
153
+ }
154
+
155
+ .pill-link.is-disabled {
156
+ pointer-events: none;
157
+ opacity: 0.45;
158
+ }
159
+
160
+ button:disabled {
161
+ opacity: 0.5;
162
+ cursor: wait;
163
+ }
164
+
165
+ .status-row {
166
+ display: flex;
167
+ align-items: center;
168
+ gap: 12px;
169
+ margin-bottom: 16px;
170
+ }
171
+
172
+ .status-chip {
173
+ display: inline-flex;
174
+ align-items: center;
175
+ border-radius: 999px;
176
+ padding: 8px 14px;
177
+ background: var(--accent-soft);
178
+ color: var(--accent);
179
+ font-size: 0.9rem;
180
+ }
181
+
182
+ .status-detail,
183
+ .muted,
184
+ .placeholder {
185
+ color: var(--muted);
186
+ }
187
+
188
+ .grid {
189
+ display: grid;
190
+ grid-template-columns: 1.2fr 1fr;
191
+ gap: 16px;
192
+ }
193
+
194
+ .jobs-panel {
195
+ grid-column: 1 / -1;
196
+ }
197
+
198
+ .answer-panel,
199
+ .trace-panel,
200
+ .heatmap-panel,
201
+ .details-panel {
202
+ min-height: 260px;
203
+ }
204
+
205
+ .answer-panel,
206
+ .trace-panel {
207
+ grid-column: span 1;
208
+ }
209
+
210
+ .heatmap-panel,
211
+ .details-panel {
212
+ grid-column: span 1;
213
+ }
214
+
215
+ .panel-heading {
216
+ display: flex;
217
+ justify-content: space-between;
218
+ gap: 12px;
219
+ align-items: baseline;
220
+ margin-bottom: 16px;
221
+ }
222
+
223
+ .sentence-list {
224
+ list-style: none;
225
+ margin: 0;
226
+ padding: 0;
227
+ max-height: 420px;
228
+ overflow: auto;
229
+ }
230
+
231
+ .sentence-item {
232
+ padding: 10px 12px;
233
+ border-radius: 14px;
234
+ border: 1px solid transparent;
235
+ margin-bottom: 8px;
236
+ background: rgba(255,255,255,0.45);
237
+ cursor: pointer;
238
+ }
239
+
240
+ .sentence-item:hover,
241
+ .sentence-item.is-active {
242
+ border-color: rgba(14, 90, 138, 0.25);
243
+ background: rgba(210, 236, 255, 0.5);
244
+ }
245
+
246
+ .sentence-item strong {
247
+ display: inline-block;
248
+ min-width: 2.5rem;
249
+ color: var(--accent);
250
+ }
251
+
252
+ .heatmap {
253
+ display: grid;
254
+ gap: 4px;
255
+ overflow: auto;
256
+ min-height: 300px;
257
+ }
258
+
259
+ .heatmap-grid {
260
+ display: grid;
261
+ gap: 4px;
262
+ }
263
+
264
+ .heatmap-cell {
265
+ width: 32px;
266
+ height: 32px;
267
+ border-radius: 8px;
268
+ border: 1px solid rgba(255,255,255,0.4);
269
+ cursor: pointer;
270
+ position: relative;
271
+ }
272
+
273
+ .heatmap-cell.is-selected {
274
+ outline: 2px solid var(--ink);
275
+ }
276
+
277
+ .placeholder-box {
278
+ display: grid;
279
+ place-items: center;
280
+ border: 1px dashed var(--line);
281
+ border-radius: 18px;
282
+ }
283
+
284
+ .metric-card,
285
+ .edge-card {
286
+ padding: 12px 14px;
287
+ border-radius: 16px;
288
+ background: rgba(255,255,255,0.45);
289
+ border: 1px solid var(--line);
290
+ margin-top: 10px;
291
+ }
292
+
293
+ .recent-sessions {
294
+ display: grid;
295
+ gap: 10px;
296
+ }
297
+
298
+ .session-card {
299
+ color: var(--ink);
300
+ border: 1px solid var(--line);
301
+ border-radius: 16px;
302
+ background: rgba(255, 255, 255, 0.48);
303
+ padding: 14px;
304
+ cursor: pointer;
305
+ text-align: left;
306
+ }
307
+
308
+ .session-card.is-active {
309
+ border-color: rgba(14, 90, 138, 0.25);
310
+ background: rgba(210, 236, 255, 0.42);
311
+ }
312
+
313
+ .session-card strong {
314
+ display: block;
315
+ margin-bottom: 4px;
316
+ color: var(--accent);
317
+ }
318
+
319
+ .metric-card strong,
320
+ .edge-card strong {
321
+ display: block;
322
+ color: var(--accent);
323
+ }
324
+
325
+ @media (max-width: 960px) {
326
+ .auth-strip,
327
+ .grid,
328
+ .controls {
329
+ grid-template-columns: 1fr;
330
+ }
331
+
332
+ .auth-strip {
333
+ align-items: flex-start;
334
+ }
335
+ }
app/generation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Generation utilities for trace-producing model calls."""
app/generation/prompting.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ SYSTEM_PROMPT = (
7
+ "You are a careful reasoning assistant. Respond with your full reasoning inside "
8
+ "<think>...</think> and then provide a concise final answer after the closing tag."
9
+ )
10
+
11
+
12
+ def build_messages(question: str) -> list[dict[str, str]]:
13
+ return [
14
+ {"role": "system", "content": SYSTEM_PROMPT},
15
+ {"role": "user", "content": question},
16
+ ]
17
+
18
+
19
+ def render_prompt(tokenizer: Any, question: str) -> str:
20
+ messages = build_messages(question)
21
+ if hasattr(tokenizer, "apply_chat_template"):
22
+ return tokenizer.apply_chat_template(
23
+ messages,
24
+ tokenize=False,
25
+ add_generation_prompt=True,
26
+ )
27
+
28
+ return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n"
app/generation/service.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from app.analysis.sentence_split import normalize_trace_text
9
+ from app.core.schemas import GenerationMetadata, GenerationResult
10
+ from app.generation.prompting import render_prompt
11
+
12
+ THINK_BLOCK_RE = re.compile(r"<think>(.*?)</think>", re.IGNORECASE | re.DOTALL)
13
+ ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE)
14
+
15
+
16
+ def _extract_trace_and_answer(text: str) -> tuple[str, str]:
17
+ match = THINK_BLOCK_RE.search(text)
18
+ if match:
19
+ raw_trace = match.group(0)
20
+ answer = text[match.end() :].strip()
21
+ if not answer:
22
+ answer = match.group(1).strip()
23
+ return raw_trace, answer
24
+
25
+ raw_trace = text.strip()
26
+ answer_match = ANSWER_MARKER_RE.search(text)
27
+ if answer_match:
28
+ answer = text[answer_match.end() :].strip()
29
+ else:
30
+ paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()]
31
+ answer = paragraphs[-1] if paragraphs else raw_trace
32
+ return raw_trace, answer
33
+
34
+
35
+ def generate_answer_and_trace(
36
+ *,
37
+ question: str,
38
+ model_name: str,
39
+ model: Any,
40
+ tokenizer: Any,
41
+ max_new_tokens: int = 512,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.95,
44
+ ) -> GenerationResult:
45
+ prompt_text = render_prompt(tokenizer, question)
46
+ encoded = tokenizer(prompt_text, return_tensors="pt")
47
+ model_device = next(model.parameters()).device
48
+ encoded = {key: value.to(model_device) for key, value in encoded.items()}
49
+ input_length = int(encoded["input_ids"].shape[-1])
50
+ do_sample = temperature > 0.0
51
+
52
+ generation_kwargs: dict[str, Any] = {
53
+ "max_new_tokens": max_new_tokens,
54
+ "do_sample": do_sample,
55
+ "top_p": top_p,
56
+ "pad_token_id": tokenizer.pad_token_id,
57
+ "eos_token_id": tokenizer.eos_token_id,
58
+ }
59
+ if do_sample:
60
+ generation_kwargs["temperature"] = temperature
61
+
62
+ with torch.no_grad():
63
+ output_ids = model.generate(**encoded, **generation_kwargs)
64
+
65
+ generated_ids = output_ids[0, input_length:]
66
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
67
+ raw_trace_text, answer = _extract_trace_and_answer(generated_text)
68
+ normalized_trace_text = normalize_trace_text(raw_trace_text)
69
+
70
+ return GenerationResult(
71
+ question=question,
72
+ model_name=model_name,
73
+ answer=answer,
74
+ raw_generation_text=generated_text,
75
+ raw_trace_text=raw_trace_text,
76
+ normalized_trace_text=normalized_trace_text,
77
+ generation_metadata=GenerationMetadata(
78
+ max_new_tokens=max_new_tokens,
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ do_sample=do_sample,
82
+ ),
83
+ )
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Application services for session orchestration."""
app/services/sessions.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from app.analysis.sentence_split import split_sentences
6
+ from app.core.config import Settings
7
+ from app.core.runtime import load_model_bundle
8
+ from app.core.runtime_pipeline import analyze_generation_result
9
+ from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult
10
+ from app.generation.service import generate_answer_and_trace
11
+ from app.storage.repository import SessionRecord, SessionRepository
12
+ from app.workers.jobs import JobRunner
13
+
14
+
15
+ class SessionLimitError(RuntimeError):
16
+ pass
17
+
18
+
19
+ class SessionAccessError(PermissionError):
20
+ pass
21
+
22
+
23
+ @dataclass(slots=True)
24
+ class SessionService:
25
+ settings: Settings
26
+ repository: SessionRepository
27
+ jobs: JobRunner
28
+
29
+ def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord:
30
+ if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs:
31
+ raise SessionLimitError("The service queue is full. Try again after a few minutes.")
32
+ if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user:
33
+ raise SessionLimitError("You already have the maximum number of active analysis jobs.")
34
+ model_name = request.model_name or self.settings.model_name
35
+ session = self.repository.create_session(
36
+ question=request.question,
37
+ model_name=model_name,
38
+ owner_id=owner_id,
39
+ owner_name=owner_name,
40
+ )
41
+ self.jobs.submit(self._run_session_pipeline, session.id, request)
42
+ return session
43
+
44
+ def start_analysis(
45
+ self,
46
+ session_id: str,
47
+ *,
48
+ owner_id: str,
49
+ request: AnalysisRequest | None = None,
50
+ ) -> SessionRecord:
51
+ session = self.repository.get_session(session_id)
52
+ self._assert_owner(session, owner_id)
53
+ effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name)
54
+ self.jobs.submit(self._run_analysis_only, session_id, effective_request)
55
+ return session
56
+
57
+ def get_session_payload(self, session_id: str, *, owner_id: str) -> dict:
58
+ session = self.repository.get_session(session_id)
59
+ self._assert_owner(session, owner_id)
60
+ return self.repository.list_session_payload(session_id)
61
+
62
+ def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]:
63
+ return self.repository.list_sessions_for_owner(owner_id, limit=limit)
64
+
65
+ def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult:
66
+ payload = self.get_session_payload(session_id, owner_id=owner_id)
67
+ analysis = payload.get("analysis")
68
+ if analysis is None:
69
+ raise KeyError(session_id)
70
+ return AnalysisResult.model_validate(analysis)
71
+
72
+ @staticmethod
73
+ def _assert_owner(session: SessionRecord, owner_id: str) -> None:
74
+ if session.owner_id != owner_id:
75
+ raise SessionAccessError("Session belongs to a different user.")
76
+
77
+ def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None:
78
+ try:
79
+ self.repository.update_status(session_id, status="generating")
80
+ bundle = load_model_bundle(
81
+ request.model_name or self.settings.model_name,
82
+ device_preference=request.device_preference or self.settings.device_preference,
83
+ dtype_preference=request.dtype_preference or self.settings.dtype_preference,
84
+ attn_implementation=request.attn_implementation or self.settings.attn_implementation,
85
+ trust_remote_code=(
86
+ self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
87
+ ),
88
+ low_cpu_mem_usage=(
89
+ self.settings.low_cpu_mem_usage
90
+ if request.low_cpu_mem_usage is None
91
+ else request.low_cpu_mem_usage
92
+ ),
93
+ )
94
+ generation = generate_answer_and_trace(
95
+ question=request.question,
96
+ model_name=bundle.model_name,
97
+ model=bundle.model,
98
+ tokenizer=bundle.tokenizer,
99
+ max_new_tokens=request.max_new_tokens,
100
+ temperature=request.temperature,
101
+ top_p=request.top_p,
102
+ )
103
+ sentences = [span.text for span in split_sentences(generation.normalized_trace_text)]
104
+ self.repository.save_generation_result(session_id, generation, sentences)
105
+ self.repository.update_status(session_id, status="answer_ready")
106
+ self._run_analysis_only(session_id, request, generation=generation)
107
+ except Exception as exc:
108
+ self.repository.update_status(session_id, status="failed", error=str(exc))
109
+
110
+ def _run_analysis_only(
111
+ self,
112
+ session_id: str,
113
+ request: AnalysisRequest,
114
+ *,
115
+ generation=None,
116
+ ) -> None:
117
+ try:
118
+ self.repository.update_status(session_id, status="analyzing")
119
+ if generation is None:
120
+ payload = self.repository.list_session_payload(session_id)
121
+ if payload.get("generation_metadata") is not None:
122
+ generation = GenerationResult(
123
+ question=payload["question"],
124
+ model_name=payload["model_name"],
125
+ answer=payload["answer"],
126
+ raw_generation_text=payload.get("raw_generation_text", ""),
127
+ raw_trace_text=payload["raw_trace_text"],
128
+ normalized_trace_text=payload["normalized_trace_text"],
129
+ generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]),
130
+ )
131
+ result = analyze_generation_result(
132
+ question=request.question,
133
+ generation=generation,
134
+ model_name=request.model_name or self.settings.model_name,
135
+ take_log=request.take_log,
136
+ max_sentences=request.max_sentences,
137
+ max_trace_tokens=request.max_trace_tokens,
138
+ validate_top_k=request.validate_top_k,
139
+ device_preference=request.device_preference or self.settings.device_preference,
140
+ dtype_preference=request.dtype_preference or self.settings.dtype_preference,
141
+ attn_implementation=request.attn_implementation or self.settings.attn_implementation,
142
+ trust_remote_code=(
143
+ self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
144
+ ),
145
+ low_cpu_mem_usage=(
146
+ self.settings.low_cpu_mem_usage
147
+ if request.low_cpu_mem_usage is None
148
+ else request.low_cpu_mem_usage
149
+ ),
150
+ max_new_tokens=request.max_new_tokens,
151
+ temperature=request.temperature,
152
+ top_p=request.top_p,
153
+ )
154
+ self.repository.save_analysis_result(session_id, result)
155
+ self.repository.update_status(session_id, status="completed")
156
+ except Exception as exc:
157
+ self.repository.update_status(session_id, status="failed", error=str(exc))
app/storage/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Persistence layer for sessions and analysis results."""
app/storage/db.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sqlite3
4
+ from pathlib import Path
5
+
6
+
7
+ def connect(database_path: str) -> sqlite3.Connection:
8
+ path = Path(database_path)
9
+ path.parent.mkdir(parents=True, exist_ok=True)
10
+ connection = sqlite3.connect(path, check_same_thread=False)
11
+ connection.row_factory = sqlite3.Row
12
+ connection.execute("PRAGMA journal_mode=WAL")
13
+ connection.execute("PRAGMA foreign_keys=ON")
14
+ return connection
15
+
16
+
17
+ def initialize_schema(connection: sqlite3.Connection) -> None:
18
+ connection.executescript(
19
+ """
20
+ CREATE TABLE IF NOT EXISTS sessions (
21
+ id TEXT PRIMARY KEY,
22
+ status TEXT NOT NULL,
23
+ question TEXT NOT NULL,
24
+ model_name TEXT NOT NULL,
25
+ owner_id TEXT NOT NULL,
26
+ owner_name TEXT,
27
+ error TEXT,
28
+ created_at TEXT NOT NULL,
29
+ updated_at TEXT NOT NULL
30
+ );
31
+
32
+ CREATE TABLE IF NOT EXISTS generation_results (
33
+ session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
34
+ answer TEXT NOT NULL,
35
+ raw_generation_text TEXT NOT NULL,
36
+ raw_trace_text TEXT NOT NULL,
37
+ normalized_trace_text TEXT NOT NULL,
38
+ sentences_json TEXT NOT NULL,
39
+ generation_metadata_json TEXT NOT NULL
40
+ );
41
+
42
+ CREATE TABLE IF NOT EXISTS analysis_results (
43
+ session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
44
+ result_json TEXT NOT NULL
45
+ );
46
+ """
47
+ )
48
+ columns = {
49
+ row["name"] for row in connection.execute("PRAGMA table_info(generation_results)").fetchall()
50
+ }
51
+ if "raw_generation_text" not in columns:
52
+ connection.execute(
53
+ "ALTER TABLE generation_results ADD COLUMN raw_generation_text TEXT NOT NULL DEFAULT ''"
54
+ )
55
+ session_columns = {row["name"] for row in connection.execute("PRAGMA table_info(sessions)").fetchall()}
56
+ if "owner_id" not in session_columns:
57
+ connection.execute("ALTER TABLE sessions ADD COLUMN owner_id TEXT NOT NULL DEFAULT 'legacy-user'")
58
+ if "owner_name" not in session_columns:
59
+ connection.execute("ALTER TABLE sessions ADD COLUMN owner_name TEXT")
60
+ connection.commit()
app/storage/repository.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import threading
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from datetime import UTC, datetime
8
+ from typing import Any
9
+
10
+ from app.core.schemas import AnalysisResult, GenerationResult
11
+ from app.storage.db import connect, initialize_schema
12
+
13
+
14
+ def _utc_now() -> str:
15
+ return datetime.now(UTC).isoformat()
16
+
17
+
18
+ @dataclass(slots=True)
19
+ class SessionRecord:
20
+ id: str
21
+ status: str
22
+ question: str
23
+ model_name: str
24
+ owner_id: str
25
+ owner_name: str | None
26
+ error: str | None
27
+ created_at: str
28
+ updated_at: str
29
+
30
+
31
+ class SessionRepository:
32
+ def __init__(self, database_path: str) -> None:
33
+ self.connection = connect(database_path)
34
+ initialize_schema(self.connection)
35
+ self.lock = threading.Lock()
36
+
37
+ def create_session(self, *, question: str, model_name: str, owner_id: str, owner_name: str | None) -> SessionRecord:
38
+ session_id = str(uuid.uuid4())
39
+ now = _utc_now()
40
+ with self.lock:
41
+ self.connection.execute(
42
+ """
43
+ INSERT INTO sessions (id, status, question, model_name, owner_id, owner_name, error, created_at, updated_at)
44
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
45
+ """,
46
+ (session_id, "queued", question, model_name, owner_id, owner_name, None, now, now),
47
+ )
48
+ self.connection.commit()
49
+ return self.get_session(session_id)
50
+
51
+ def get_session(self, session_id: str) -> SessionRecord:
52
+ row = self.connection.execute(
53
+ "SELECT * FROM sessions WHERE id = ?",
54
+ (session_id,),
55
+ ).fetchone()
56
+ if row is None:
57
+ raise KeyError(session_id)
58
+ return SessionRecord(**dict(row))
59
+
60
+ def list_session_payload(self, session_id: str) -> dict[str, Any]:
61
+ session = self.get_session(session_id)
62
+ payload: dict[str, Any] = {
63
+ "id": session.id,
64
+ "status": session.status,
65
+ "question": session.question,
66
+ "model_name": session.model_name,
67
+ "owner_id": session.owner_id,
68
+ "owner_name": session.owner_name,
69
+ "error": session.error,
70
+ "created_at": session.created_at,
71
+ "updated_at": session.updated_at,
72
+ }
73
+ generation_row = self.connection.execute(
74
+ "SELECT * FROM generation_results WHERE session_id = ?",
75
+ (session_id,),
76
+ ).fetchone()
77
+ if generation_row is not None:
78
+ payload["answer"] = generation_row["answer"]
79
+ payload["raw_generation_text"] = generation_row["raw_generation_text"]
80
+ payload["raw_trace_text"] = generation_row["raw_trace_text"]
81
+ payload["normalized_trace_text"] = generation_row["normalized_trace_text"]
82
+ payload["sentences"] = json.loads(generation_row["sentences_json"])
83
+ payload["generation_metadata"] = json.loads(generation_row["generation_metadata_json"])
84
+ analysis_row = self.connection.execute(
85
+ "SELECT result_json FROM analysis_results WHERE session_id = ?",
86
+ (session_id,),
87
+ ).fetchone()
88
+ if analysis_row is not None:
89
+ payload["analysis"] = json.loads(analysis_row["result_json"])
90
+ return payload
91
+
92
+ def list_sessions_for_owner(self, owner_id: str, *, limit: int = 20) -> list[dict[str, Any]]:
93
+ rows = self.connection.execute(
94
+ """
95
+ SELECT id
96
+ FROM sessions
97
+ WHERE owner_id = ?
98
+ ORDER BY updated_at DESC
99
+ LIMIT ?
100
+ """,
101
+ (owner_id, limit),
102
+ ).fetchall()
103
+ return [self.list_session_payload(row["id"]) for row in rows]
104
+
105
+ def update_status(self, session_id: str, *, status: str, error: str | None = None) -> None:
106
+ now = _utc_now()
107
+ with self.lock:
108
+ self.connection.execute(
109
+ "UPDATE sessions SET status = ?, error = ?, updated_at = ? WHERE id = ?",
110
+ (status, error, now, session_id),
111
+ )
112
+ self.connection.commit()
113
+
114
+ def count_incomplete_sessions(self) -> int:
115
+ row = self.connection.execute(
116
+ "SELECT COUNT(*) AS count FROM sessions WHERE status NOT IN ('completed', 'failed')"
117
+ ).fetchone()
118
+ return int(row["count"])
119
+
120
+ def count_incomplete_sessions_for_owner(self, owner_id: str) -> int:
121
+ row = self.connection.execute(
122
+ """
123
+ SELECT COUNT(*) AS count
124
+ FROM sessions
125
+ WHERE owner_id = ? AND status NOT IN ('completed', 'failed')
126
+ """,
127
+ (owner_id,),
128
+ ).fetchone()
129
+ return int(row["count"])
130
+
131
+ def save_generation_result(self, session_id: str, generation: GenerationResult, sentences: list[str]) -> None:
132
+ with self.lock:
133
+ self.connection.execute(
134
+ """
135
+ INSERT INTO generation_results (
136
+ session_id,
137
+ answer,
138
+ raw_generation_text,
139
+ raw_trace_text,
140
+ normalized_trace_text,
141
+ sentences_json,
142
+ generation_metadata_json
143
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
144
+ ON CONFLICT(session_id) DO UPDATE SET
145
+ answer = excluded.answer,
146
+ raw_generation_text = excluded.raw_generation_text,
147
+ raw_trace_text = excluded.raw_trace_text,
148
+ normalized_trace_text = excluded.normalized_trace_text,
149
+ sentences_json = excluded.sentences_json,
150
+ generation_metadata_json = excluded.generation_metadata_json
151
+ """,
152
+ (
153
+ session_id,
154
+ generation.answer,
155
+ generation.raw_generation_text,
156
+ generation.raw_trace_text,
157
+ generation.normalized_trace_text,
158
+ json.dumps(sentences),
159
+ json.dumps(generation.generation_metadata.model_dump()),
160
+ ),
161
+ )
162
+ self.connection.execute(
163
+ "UPDATE sessions SET updated_at = ? WHERE id = ?",
164
+ (_utc_now(), session_id),
165
+ )
166
+ self.connection.commit()
167
+
168
+ def save_analysis_result(self, session_id: str, result: AnalysisResult) -> None:
169
+ with self.lock:
170
+ self.connection.execute(
171
+ """
172
+ INSERT INTO analysis_results (session_id, result_json)
173
+ VALUES (?, ?)
174
+ ON CONFLICT(session_id) DO UPDATE SET result_json = excluded.result_json
175
+ """,
176
+ (session_id, result.model_dump_json()),
177
+ )
178
+ self.connection.execute(
179
+ "UPDATE sessions SET updated_at = ? WHERE id = ?",
180
+ (_utc_now(), session_id),
181
+ )
182
+ self.connection.commit()
app/workers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Background job runner."""
app/workers/jobs.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from concurrent.futures import Future, ThreadPoolExecutor
4
+ from dataclasses import dataclass
5
+ from typing import Callable
6
+
7
+
8
+ @dataclass(slots=True)
9
+ class JobRunner:
10
+ executor: ThreadPoolExecutor
11
+
12
+ def submit(self, fn: Callable[..., None], *args, **kwargs) -> Future:
13
+ return self.executor.submit(fn, *args, **kwargs)
14
+
15
+ def shutdown(self) -> None:
16
+ self.executor.shutdown(wait=False, cancel_futures=True)
17
+
18
+
19
+ def build_job_runner(max_workers: int) -> JobRunner:
20
+ return JobRunner(executor=ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="cot-anc"))
cot_anc.egg-info/PKG-INFO ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: cot-anc
3
+ Version: 0.1.0
4
+ Summary: Online chain-of-thought analysis with attribution patching
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: fastapi>=0.115.0
8
+ Requires-Dist: huggingface_hub[oauth]>=0.33.0
9
+ Requires-Dist: numpy>=2.0.0
10
+ Requires-Dist: pydantic>=2.7.0
11
+ Requires-Dist: scipy>=1.13.0
12
+ Requires-Dist: torch>=2.2.0
13
+ Requires-Dist: transformers>=4.44.0
14
+ Requires-Dist: typer>=0.12.3
15
+ Requires-Dist: uvicorn>=0.30.0
16
+ Provides-Extra: dev
17
+ Requires-Dist: pytest>=8.2.0; extra == "dev"
18
+ Provides-Extra: viz
19
+ Requires-Dist: matplotlib>=3.8.0; extra == "viz"
20
+
21
+ ---
22
+ title: Thought Anchors
23
+ emoji: 🧠
24
+ colorFrom: blue
25
+ colorTo: orange
26
+ sdk: docker
27
+ app_port: 7860
28
+ hf_oauth: true
29
+ pinned: false
30
+ models:
31
+ - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
32
+ ---
33
+
34
+ # Thought Anchors
35
+
36
+ Public-facing FastAPI service for generating a visible reasoning trace and computing
37
+ sentence-to-sentence attribution on open-weight reasoning models.
38
+
39
+ The app is now shaped for deployment as a Hugging Face `Docker Space`:
40
+
41
+ - Hugging Face OAuth sign-in for end users
42
+ - browser UI plus programmatic API
43
+ - per-user ephemeral sessions on the running instance
44
+ - JSON and CSV export for completed analyses
45
+ - adaptive device and dtype loading for CPU, MPS, CUDA, Colab, Kaggle, and cloud GPUs
46
+
47
+ ## What It Can Do
48
+
49
+ - generate an answer plus visible `<think>...</think>` trace from a supported causal LM
50
+ - normalize and split the trace into sentences
51
+ - compute a sentence influence matrix with gradient x attention attribution
52
+ - summarize incoming / outgoing importance and top edges
53
+ - expose the workflow through:
54
+ - CLI
55
+ - FastAPI
56
+ - web UI
57
+ - async session queue
58
+
59
+ ## Current Deployment Target
60
+
61
+ The primary deployment target is a Hugging Face `Docker Space` running on upgraded GPU
62
+ hardware. The same app can also be run locally or on other cloud GPU hosts.
63
+
64
+ Important runtime constraints:
65
+
66
+ - attribution requires `attn_implementation="eager"`
67
+ - the model must expose usable attention tensors and a supported decoder-layer layout
68
+ - long traces are intentionally capped because the analysis path uses a full backward pass
69
+
70
+ ## Local Development
71
+
72
+ Install dependencies:
73
+
74
+ ```bash
75
+ uv sync
76
+ ```
77
+
78
+ Run the API:
79
+
80
+ ```bash
81
+ uv run python -m app.cli.run_api
82
+ ```
83
+
84
+ Run the CLI:
85
+
86
+ ```bash
87
+ uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
88
+ ```
89
+
90
+ ## API
91
+
92
+ Main endpoints:
93
+
94
+ - `GET /healthz`
95
+ - `GET /api/me`
96
+ - `POST /api/warmup`
97
+ - `POST /api/analyze`
98
+ - `GET /api/sessions`
99
+ - `POST /api/sessions`
100
+ - `GET /api/sessions/{id}`
101
+ - `GET /api/sessions/{id}/result`
102
+ - `GET /api/sessions/{id}/export.json`
103
+ - `GET /api/sessions/{id}/export.csv`
104
+
105
+ Example:
106
+
107
+ ```bash
108
+ curl -X POST http://localhost:7860/api/analyze \
109
+ -H 'content-type: application/json' \
110
+ -d '{
111
+ "question": "Explain why the derivative of x^2 is 2x",
112
+ "max_new_tokens": 128,
113
+ "max_trace_tokens": 256,
114
+ "max_sentences": 16,
115
+ "validate_top_k": 0
116
+ }'
117
+ ```
118
+
119
+ ## Hugging Face Space Setup
120
+
121
+ Recommended environment variables:
122
+
123
+ - `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
124
+ - `DEVICE_PREFERENCE=auto`
125
+ - `DTYPE_PREFERENCE=auto`
126
+ - `ATTN_IMPLEMENTATION=eager`
127
+ - `LOW_CPU_MEM_USAGE=true`
128
+ - `TRUST_REMOTE_CODE=true`
129
+ - `PRELOAD_MODEL=true`
130
+ - `MAX_TRACE_TOKENS=256`
131
+ - `MAX_SENTENCES=16`
132
+ - `JOB_WORKERS=1`
133
+ - `MAX_QUEUED_JOBS=8`
134
+ - `MAX_ACTIVE_JOBS_PER_USER=2`
135
+ - `REQUIRE_AUTH=true`
136
+
137
+ Notes:
138
+
139
+ - local disk is ephemeral; users should export results they want to keep
140
+ - use upgraded GPU hardware for real attribution runs
141
+ - keep trace limits conservative for public traffic
142
+
143
+ ## Notebook
144
+
145
+ Use [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
146
+ for Colab or Kaggle smoke testing. It installs dependencies, warms the model, runs one
147
+ short attribution analysis, prints the top edges, and renders a simple heatmap.
cot_anc.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ app/__init__.py
4
+ app/analysis/__init__.py
5
+ app/analysis/hooks.py
6
+ app/analysis/sentence_split.py
7
+ app/analysis/summaries.py
8
+ app/analysis/suppression.py
9
+ app/analysis/token_boundaries.py
10
+ app/analysis/validation.py
11
+ app/api/__init__.py
12
+ app/api/auth.py
13
+ app/api/main.py
14
+ app/cli/__init__.py
15
+ app/cli/run_api.py
16
+ app/cli/run_prototype.py
17
+ app/core/__init__.py
18
+ app/core/config.py
19
+ app/core/model_support.py
20
+ app/core/runtime.py
21
+ app/core/runtime_pipeline.py
22
+ app/core/schemas.py
23
+ app/generation/__init__.py
24
+ app/generation/prompting.py
25
+ app/generation/service.py
26
+ app/services/__init__.py
27
+ app/services/sessions.py
28
+ app/storage/__init__.py
29
+ app/storage/db.py
30
+ app/storage/repository.py
31
+ app/workers/__init__.py
32
+ app/workers/jobs.py
33
+ cot_anc.egg-info/PKG-INFO
34
+ cot_anc.egg-info/SOURCES.txt
35
+ cot_anc.egg-info/dependency_links.txt
36
+ cot_anc.egg-info/requires.txt
37
+ cot_anc.egg-info/top_level.txt
38
+ tests/test_api.py
39
+ tests/test_model_support.py
40
+ tests/test_runtime_pipeline.py
41
+ tests/test_sentence_split.py
42
+ tests/test_summaries.py
43
+ tests/test_suppression_shape.py
44
+ tests/test_token_boundaries.py
cot_anc.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
cot_anc.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.115.0
2
+ huggingface_hub[oauth]>=0.33.0
3
+ numpy>=2.0.0
4
+ pydantic>=2.7.0
5
+ scipy>=1.13.0
6
+ torch>=2.2.0
7
+ transformers>=4.44.0
8
+ typer>=0.12.3
9
+ uvicorn>=0.30.0
10
+
11
+ [dev]
12
+ pytest>=8.2.0
13
+
14
+ [viz]
15
+ matplotlib>=3.8.0
cot_anc.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ app
docs/api.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API And Product Behavior
2
+
3
+ ## Auth Model
4
+
5
+ - If `REQUIRE_AUTH=true`, analysis endpoints require Hugging Face sign-in.
6
+ - Sessions are scoped to user identity.
7
+ - One user cannot fetch another user’s session or export.
8
+
9
+ ## Session Model
10
+
11
+ Statuses:
12
+
13
+ - `queued`
14
+ - `generating`
15
+ - `answer_ready`
16
+ - `analyzing`
17
+ - `completed`
18
+ - `failed`
19
+
20
+ Sessions are ephemeral for current running instance.
21
+
22
+ ## Endpoints
23
+
24
+ ### `GET /healthz`
25
+
26
+ Returns:
27
+
28
+ - model name
29
+ - device preference
30
+ - dtype preference
31
+ - auth requirement
32
+ - queue limits
33
+ - CUDA / MPS availability
34
+
35
+ ### `GET /api/me`
36
+
37
+ Returns current auth state plus login/logout URLs.
38
+
39
+ ### `POST /api/warmup`
40
+
41
+ Loads model with current runtime policy and returns:
42
+
43
+ - resolved device
44
+ - resolved dtype
45
+ - model attribution capability
46
+
47
+ ### `POST /api/analyze`
48
+
49
+ Direct synchronous analysis call. Good for trusted programmatic use, not best path for UI.
50
+
51
+ ### `GET /api/sessions`
52
+
53
+ List current user’s recent sessions on current instance.
54
+
55
+ ### `POST /api/sessions`
56
+
57
+ Create async session job.
58
+
59
+ Queue protections:
60
+
61
+ - global queue cap
62
+ - per-user active-job cap
63
+
64
+ ### `GET /api/sessions/{id}`
65
+
66
+ Return session summary.
67
+
68
+ ### `GET /api/sessions/{id}/result`
69
+
70
+ Return session summary + full analysis if ready.
71
+
72
+ ### Export
73
+
74
+ - `GET /api/sessions/{id}/export.json`
75
+ - `GET /api/sessions/{id}/export.csv`
76
+
77
+ JSON contains session + analysis payload.
78
+
79
+ CSV contains top edges:
80
+
81
+ - `source_sentence_idx`
82
+ - `target_sentence_idx`
83
+ - `score`
docs/deploy-huggingface.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Deployment
2
+
3
+ Primary target: Hugging Face `Docker Space` on upgraded GPU hardware.
4
+
5
+ ## What Gets Deployed
6
+
7
+ - FastAPI backend
8
+ - static web frontend
9
+ - Hugging Face OAuth routes
10
+ - ephemeral SQLite-backed session queue
11
+
12
+ ## Required Space Settings
13
+
14
+ - SDK: `Docker`
15
+ - Port: `7860`
16
+ - OAuth: enabled via README metadata
17
+ - Hardware: upgraded GPU recommended
18
+
19
+ ## Recommended Runtime Variables
20
+
21
+ Core:
22
+
23
+ - `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
24
+ - `DEVICE_PREFERENCE=auto`
25
+ - `DTYPE_PREFERENCE=auto`
26
+ - `ATTN_IMPLEMENTATION=eager`
27
+ - `LOW_CPU_MEM_USAGE=true`
28
+ - `TRUST_REMOTE_CODE=true`
29
+ - `PRELOAD_MODEL=true`
30
+
31
+ Traffic limits:
32
+
33
+ - `MAX_TRACE_TOKENS=256`
34
+ - `MAX_SENTENCES=16`
35
+ - `JOB_WORKERS=1`
36
+ - `MAX_QUEUED_JOBS=8`
37
+ - `MAX_ACTIVE_JOBS_PER_USER=2`
38
+ - `REQUIRE_AUTH=true`
39
+
40
+ ## Deploy Flow
41
+
42
+ 1. Create new Hugging Face Space with `Docker` SDK.
43
+ 2. Push repo contents.
44
+ 3. Set runtime variables in Space settings.
45
+ 4. Upgrade hardware.
46
+ 5. Wait for build.
47
+ 6. Verify:
48
+ - `GET /healthz`
49
+ - sign-in works
50
+ - one short analysis completes
51
+ - JSON / CSV export works
52
+
53
+ ## Operational Notes
54
+
55
+ - Local disk is ephemeral. Session history disappears on restart.
56
+ - OAuth helper is mocked locally but real inside Space.
57
+ - Keep public defaults conservative. Long traces can OOM small GPUs.
58
+ - If queue pressure grows, lower token caps before increasing worker count.
59
+
60
+ ## Common Failure Modes
61
+
62
+ - `attn_implementation` not eager:
63
+ - attribution disabled for model
64
+ - unsupported model layout:
65
+ - generation may work, attribution fails early with clear error
66
+ - OOM:
67
+ - reduce `MAX_TRACE_TOKENS`, `MAX_SENTENCES`, or choose larger GPU
68
+ - cold start slow:
69
+ - keep `PRELOAD_MODEL=true`
docs/notebook.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Notebook Usage
2
+
3
+ Notebook path:
4
+
5
+ - [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
6
+
7
+ ## Purpose
8
+
9
+ Short smoke test for:
10
+
11
+ - Colab GPU
12
+ - Kaggle GPU
13
+ - quick local validation
14
+
15
+ ## What Notebook Does
16
+
17
+ 1. Install runtime deps.
18
+ 2. Set conservative env vars.
19
+ 3. Import project pipeline.
20
+ 4. Run one short attribution job.
21
+ 5. Print answer, runtime metadata, top edges.
22
+ 6. Render heatmap.
23
+
24
+ ## Best Use
25
+
26
+ - validate model availability
27
+ - validate driver / torch / transformers stack
28
+ - sanity-check latency before public deploy
29
+
30
+ ## If Notebook Fails
31
+
32
+ Check:
33
+
34
+ - GPU available
35
+ - model access permissions
36
+ - enough VRAM
37
+ - eager attention enabled
38
+ - trace limits not too high
docs/runtime.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runtime And Model Support
2
+
3
+ ## Execution Model
4
+
5
+ Pipeline:
6
+
7
+ 1. Generate visible reasoning trace.
8
+ 2. Normalize and split trace into sentences.
9
+ 3. Map sentence spans to token spans.
10
+ 4. Run forward + backward pass.
11
+ 5. Build sentence influence matrix from gradient x attention.
12
+ 6. Summarize top edges and importance scores.
13
+
14
+ ## Device And Dtype Policy
15
+
16
+ Default policy:
17
+
18
+ - CUDA:
19
+ - `bfloat16` if supported
20
+ - else `float16`
21
+ - MPS:
22
+ - `float16`
23
+ - CPU:
24
+ - `float32`
25
+
26
+ Override with:
27
+
28
+ - `DTYPE_PREFERENCE`
29
+ - request `dtype_preference`
30
+
31
+ ## Model Requirements
32
+
33
+ Model must support all of:
34
+
35
+ - causal LM generation
36
+ - `output_attentions=True`
37
+ - eager attention
38
+ - supported decoder layer layout
39
+ - supported attention module attribute
40
+
41
+ Supported layer paths:
42
+
43
+ - `model.layers`
44
+ - `model.model.layers`
45
+ - `transformer.h`
46
+ - `gpt_neox.layers`
47
+
48
+ Supported attention attrs:
49
+
50
+ - `self_attn`
51
+ - `attn`
52
+ - `attention`
53
+
54
+ ## Why Trace Limits Exist
55
+
56
+ Attribution path uses full backward pass over attention tensors. Cost grows with:
57
+
58
+ - sequence length
59
+ - layer count
60
+ - head count
61
+ - sentence count
62
+
63
+ Public defaults stay small to protect uptime.
64
+
65
+ ## Good First Runtime Settings
66
+
67
+ For public demo:
68
+
69
+ - `max_new_tokens=128`
70
+ - `max_trace_tokens=256`
71
+ - `max_sentences=16`
72
+ - `validate_top_k=0`
73
+
74
+ For deeper analysis on bigger GPU:
75
+
76
+ - raise trace tokens slowly
77
+ - watch latency and memory first
notebooks/hf_space_demo.ipynb ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Thought Anchors Demo\n",
8
+ "\n",
9
+ "Colab/Kaggle notebook for a short end-to-end attribution smoke test."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "!pip install -q uv\n",
19
+ "!uv pip install --system fastapi huggingface_hub matplotlib numpy pydantic scipy torch transformers typer uvicorn"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import os\n",
29
+ "from pathlib import Path\n",
30
+ "\n",
31
+ "if not Path('app').exists():\n",
32
+ " raise RuntimeError('Upload or clone the repository before running the notebook.')\n",
33
+ "\n",
34
+ "os.environ.setdefault('MODEL_NAME', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')\n",
35
+ "os.environ.setdefault('DEVICE_PREFERENCE', 'auto')\n",
36
+ "os.environ.setdefault('DTYPE_PREFERENCE', 'auto')\n",
37
+ "os.environ.setdefault('ATTN_IMPLEMENTATION', 'eager')\n",
38
+ "os.environ.setdefault('LOW_CPU_MEM_USAGE', 'true')\n",
39
+ "os.environ.setdefault('TRUST_REMOTE_CODE', 'true')\n",
40
+ "os.environ.setdefault('MAX_TRACE_TOKENS', '256')\n",
41
+ "os.environ.setdefault('MAX_SENTENCES', '16')"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "from app.core.runtime_pipeline import compute_attribution_analysis"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "question = 'Explain why the derivative of x^2 is 2x.'\n",
60
+ "\n",
61
+ "result = compute_attribution_analysis(\n",
62
+ " question=question,\n",
63
+ " max_new_tokens=128,\n",
64
+ " max_trace_tokens=256,\n",
65
+ " max_sentences=16,\n",
66
+ " validate_top_k=0,\n",
67
+ " temperature=0.0,\n",
68
+ " top_p=1.0,\n",
69
+ ")"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "print('Answer:\\n', result.answer)\n",
79
+ "print('\\nSentences:', len(result.sentences))\n",
80
+ "print('Runtime:', result.runtime_metadata.model_dump())\n",
81
+ "print('\\nTop edges:')\n",
82
+ "for edge in result.top_edges[:10]:\n",
83
+ " print(edge)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "import matplotlib.pyplot as plt\n",
93
+ "import numpy as np\n",
94
+ "\n",
95
+ "matrix = np.array(result.suppression_matrix)\n",
96
+ "figure, axis = plt.subplots(figsize=(7, 5))\n",
97
+ "image = axis.imshow(matrix, aspect='auto', cmap='viridis')\n",
98
+ "axis.set_title('Sentence Influence Matrix')\n",
99
+ "axis.set_xlabel('Source sentence')\n",
100
+ "axis.set_ylabel('Target sentence')\n",
101
+ "figure.colorbar(image, ax=axis)\n",
102
+ "plt.show()"
103
+ ]
104
+ }
105
+ ],
106
+ "metadata": {
107
+ "kernelspec": {
108
+ "display_name": "Python 3",
109
+ "language": "python",
110
+ "name": "python3"
111
+ },
112
+ "language_info": {
113
+ "name": "python",
114
+ "version": "3.11"
115
+ }
116
+ },
117
+ "nbformat": 4,
118
+ "nbformat_minor": 5
119
+ }
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "cot-anc"
3
+ version = "0.1.0"
4
+ description = "Online chain-of-thought analysis with attribution patching"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "fastapi>=0.115.0",
9
+ "huggingface_hub[oauth]>=0.33.0",
10
+ "numpy>=2.0.0",
11
+ "pydantic>=2.7.0",
12
+ "scipy>=1.13.0",
13
+ "torch>=2.2.0",
14
+ "transformers>=4.44.0",
15
+ "typer>=0.12.3",
16
+ "uvicorn>=0.30.0",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ dev = [
21
+ "pytest>=8.2.0",
22
+ ]
23
+ viz = [
24
+ "matplotlib>=3.8.0",
25
+ ]
26
+
27
+ [build-system]
28
+ requires = ["setuptools>=68.0"]
29
+ build-backend = "setuptools.build_meta"
30
+
31
+ [tool.setuptools.packages.find]
32
+ include = ["app*"]
33
+
34
+ [tool.pytest.ini_options]
35
+ testpaths = ["tests"]
36
+ pythonpath = ["."]
runpod_start.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export API_HOST="${API_HOST:-0.0.0.0}"
5
+ export API_PORT="${API_PORT:-7860}"
6
+ export DEVICE_PREFERENCE="${DEVICE_PREFERENCE:-auto}"
7
+ export DTYPE_PREFERENCE="${DTYPE_PREFERENCE:-auto}"
8
+ export ATTN_IMPLEMENTATION="${ATTN_IMPLEMENTATION:-eager}"
9
+ export PRELOAD_MODEL="${PRELOAD_MODEL:-true}"
10
+
11
+ exec uv run python -m app.cli.run_api