Spaces:
Sleeping
Sleeping
Deploy Thought Anchors
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +10 -0
- .env.example +13 -0
- .gitignore +11 -0
- Dockerfile +45 -0
- README.md +80 -5
- app/__init__.py +1 -0
- app/analysis/__init__.py +1 -0
- app/analysis/hooks.py +78 -0
- app/analysis/sentence_split.py +87 -0
- app/analysis/summaries.py +39 -0
- app/analysis/suppression.py +161 -0
- app/analysis/token_boundaries.py +75 -0
- app/analysis/validation.py +120 -0
- app/api/__init__.py +1 -0
- app/api/auth.py +54 -0
- app/api/main.py +310 -0
- app/cli/__init__.py +1 -0
- app/cli/run_api.py +19 -0
- app/cli/run_prototype.py +91 -0
- app/core/__init__.py +1 -0
- app/core/config.py +62 -0
- app/core/model_support.py +102 -0
- app/core/runtime.py +93 -0
- app/core/runtime_pipeline.py +175 -0
- app/core/schemas.py +167 -0
- app/frontend/app.js +340 -0
- app/frontend/index.html +108 -0
- app/frontend/styles.css +335 -0
- app/generation/__init__.py +1 -0
- app/generation/prompting.py +28 -0
- app/generation/service.py +83 -0
- app/services/__init__.py +1 -0
- app/services/sessions.py +157 -0
- app/storage/__init__.py +1 -0
- app/storage/db.py +60 -0
- app/storage/repository.py +182 -0
- app/workers/__init__.py +1 -0
- app/workers/jobs.py +20 -0
- cot_anc.egg-info/PKG-INFO +147 -0
- cot_anc.egg-info/SOURCES.txt +44 -0
- cot_anc.egg-info/dependency_links.txt +1 -0
- cot_anc.egg-info/requires.txt +15 -0
- cot_anc.egg-info/top_level.txt +1 -0
- docs/api.md +83 -0
- docs/deploy-huggingface.md +69 -0
- docs/notebook.md +38 -0
- docs/runtime.md +77 -0
- notebooks/hf_space_demo.ipynb +119 -0
- pyproject.toml +36 -0
- runpod_start.sh +11 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
.pytest_cache
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
outputs
|
| 9 |
+
tests
|
| 10 |
+
cot_anc.egg-info
|
.env.example
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 2 |
+
DEVICE_PREFERENCE=auto
|
| 3 |
+
DTYPE_PREFERENCE=auto
|
| 4 |
+
ATTN_IMPLEMENTATION=eager
|
| 5 |
+
TRUST_REMOTE_CODE=true
|
| 6 |
+
LOW_CPU_MEM_USAGE=true
|
| 7 |
+
MAX_TRACE_TOKENS=256
|
| 8 |
+
MAX_SENTENCES=16
|
| 9 |
+
TAKE_LOG=true
|
| 10 |
+
PRELOAD_MODEL=true
|
| 11 |
+
REQUIRE_AUTH=true
|
| 12 |
+
MAX_QUEUED_JOBS=8
|
| 13 |
+
MAX_ACTIVE_JOBS_PER_USER=2
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.venv
|
| 2 |
+
/.pytest_cache
|
| 3 |
+
/__pycache__
|
| 4 |
+
/data/app.db
|
| 5 |
+
.DS_Store
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
cot_anc.egg-info/
|
| 10 |
+
outputs/
|
| 11 |
+
uv.lock
|
Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV UV_LINK_MODE=copy
|
| 7 |
+
ENV HOME=/home/user
|
| 8 |
+
ENV PATH=/home/user/.local/bin:$PATH
|
| 9 |
+
ENV HF_HOME=/home/user/.cache/huggingface
|
| 10 |
+
ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
|
| 11 |
+
ENV API_HOST=0.0.0.0
|
| 12 |
+
ENV API_PORT=7860
|
| 13 |
+
ENV DEVICE_PREFERENCE=auto
|
| 14 |
+
ENV DTYPE_PREFERENCE=auto
|
| 15 |
+
ENV ATTN_IMPLEMENTATION=eager
|
| 16 |
+
ENV LOW_CPU_MEM_USAGE=true
|
| 17 |
+
ENV TRUST_REMOTE_CODE=true
|
| 18 |
+
ENV PRELOAD_MODEL=true
|
| 19 |
+
ENV REQUIRE_AUTH=true
|
| 20 |
+
|
| 21 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 22 |
+
ca-certificates \
|
| 23 |
+
curl \
|
| 24 |
+
git \
|
| 25 |
+
python3 \
|
| 26 |
+
python3-pip \
|
| 27 |
+
python3-venv \
|
| 28 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 29 |
+
|
| 30 |
+
RUN useradd -m -u 1000 user
|
| 31 |
+
USER user
|
| 32 |
+
WORKDIR $HOME/app
|
| 33 |
+
|
| 34 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 35 |
+
|
| 36 |
+
COPY --chown=user pyproject.toml uv.lock README.md .env.example ./
|
| 37 |
+
COPY --chown=user app ./app
|
| 38 |
+
COPY --chown=user notebooks ./notebooks
|
| 39 |
+
COPY --chown=user tests ./tests
|
| 40 |
+
|
| 41 |
+
RUN uv sync --frozen
|
| 42 |
+
|
| 43 |
+
EXPOSE 7860
|
| 44 |
+
|
| 45 |
+
CMD ["uv", "run", "python", "-m", "app.cli.run_api"]
|
README.md
CHANGED
|
@@ -1,10 +1,85 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Thought Anchors
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
hf_oauth: true
|
| 9 |
pinned: false
|
| 10 |
+
models:
|
| 11 |
+
- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Thought Anchors
|
| 15 |
+
|
| 16 |
+
Thought Anchors generates visible reasoning traces from open-weight models and
|
| 17 |
+
computes sentence-to-sentence influence with gradient x attention attribution.
|
| 18 |
+
|
| 19 |
+
Current product shape:
|
| 20 |
+
|
| 21 |
+
- Hugging Face `Docker Space` first
|
| 22 |
+
- Hugging Face OAuth sign-in
|
| 23 |
+
- web UI + API
|
| 24 |
+
- per-user ephemeral sessions
|
| 25 |
+
- JSON / CSV export
|
| 26 |
+
- adaptive CPU / MPS / CUDA loading
|
| 27 |
+
|
| 28 |
+
## Quick Start
|
| 29 |
+
|
| 30 |
+
Install deps:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
uv sync --extra dev
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Run API:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
uv run python -m app.cli.run_api
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Run CLI:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Run tests:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
uv run python -m pytest -q
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Main Endpoints
|
| 55 |
+
|
| 56 |
+
- `GET /healthz`
|
| 57 |
+
- `GET /api/me`
|
| 58 |
+
- `POST /api/warmup`
|
| 59 |
+
- `POST /api/analyze`
|
| 60 |
+
- `GET /api/sessions`
|
| 61 |
+
- `POST /api/sessions`
|
| 62 |
+
- `GET /api/sessions/{id}`
|
| 63 |
+
- `GET /api/sessions/{id}/result`
|
| 64 |
+
- `GET /api/sessions/{id}/export.json`
|
| 65 |
+
- `GET /api/sessions/{id}/export.csv`
|
| 66 |
+
|
| 67 |
+
## Docs
|
| 68 |
+
|
| 69 |
+
- [Hugging Face deployment](./docs/deploy-huggingface.md)
|
| 70 |
+
- [Runtime and model support](./docs/runtime.md)
|
| 71 |
+
- [API and product behavior](./docs/api.md)
|
| 72 |
+
- [Notebook usage](./docs/notebook.md)
|
| 73 |
+
|
| 74 |
+
## Notebook
|
| 75 |
+
|
| 76 |
+
Colab / Kaggle smoke-test notebook:
|
| 77 |
+
|
| 78 |
+
- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
|
| 79 |
+
|
| 80 |
+
## Key Constraints
|
| 81 |
+
|
| 82 |
+
- Attribution needs `attn_implementation="eager"`.
|
| 83 |
+
- Model must expose supported decoder layers and attention modules.
|
| 84 |
+
- Long traces stay capped because analysis uses full backward pass.
|
| 85 |
+
- Space disk is ephemeral; export results you want to keep.
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Application package for chain-of-thought attribution analysis."""
|
app/analysis/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Analysis utilities for attribution patching."""
|
app/analysis/hooks.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from app.core.model_support import get_decoder_layers
|
| 9 |
+
|
| 10 |
+
_ATTENTION_STORE: dict[int, torch.Tensor] = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def clear_stored_attentions() -> None:
|
| 14 |
+
_ATTENTION_STORE.clear()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_stored_attentions() -> dict[int, torch.Tensor]:
|
| 18 |
+
return dict(_ATTENTION_STORE)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
|
| 22 |
+
if isinstance(output, torch.Tensor):
|
| 23 |
+
return output if output.dim() == 4 else None
|
| 24 |
+
|
| 25 |
+
if isinstance(output, dict):
|
| 26 |
+
for value in output.values():
|
| 27 |
+
if isinstance(value, torch.Tensor) and value.dim() == 4:
|
| 28 |
+
return value
|
| 29 |
+
|
| 30 |
+
if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
|
| 31 |
+
for item in output:
|
| 32 |
+
if isinstance(item, torch.Tensor) and item.dim() == 4:
|
| 33 |
+
return item
|
| 34 |
+
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_attention_impl(model: Any) -> str | None:
|
| 39 |
+
config = getattr(model, "config", None)
|
| 40 |
+
if config is None:
|
| 41 |
+
return None
|
| 42 |
+
return getattr(config, "_attn_implementation", None) or getattr(
|
| 43 |
+
config,
|
| 44 |
+
"attn_implementation",
|
| 45 |
+
None,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def make_attn_hook(layer_idx: int):
|
| 50 |
+
def hook(_module: Any, _inputs: Any, output: Any) -> None:
|
| 51 |
+
attn = _extract_attention_tensor(output)
|
| 52 |
+
if attn is None:
|
| 53 |
+
return
|
| 54 |
+
if attn.dim() != 4:
|
| 55 |
+
raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
|
| 56 |
+
attn.retain_grad()
|
| 57 |
+
_ATTENTION_STORE[layer_idx] = attn
|
| 58 |
+
|
| 59 |
+
return hook
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def register_hooks(model: Any) -> list[Any]:
|
| 63 |
+
clear_stored_attentions()
|
| 64 |
+
layers, _layer_path, attention_attr = get_decoder_layers(model)
|
| 65 |
+
|
| 66 |
+
handles: list[Any] = []
|
| 67 |
+
for layer_idx, layer in enumerate(layers):
|
| 68 |
+
self_attn = getattr(layer, attention_attr, None)
|
| 69 |
+
if self_attn is None:
|
| 70 |
+
raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
|
| 71 |
+
handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
|
| 72 |
+
return handles
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def remove_hooks(handles: list[Any]) -> None:
|
| 76 |
+
for handle in handles:
|
| 77 |
+
handle.remove()
|
| 78 |
+
clear_stored_attentions()
|
app/analysis/sentence_split.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
THINK_TAG_RE = re.compile(r"</?think>", re.IGNORECASE)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(slots=True)
|
| 10 |
+
class SentenceSpan:
|
| 11 |
+
text: str
|
| 12 |
+
start_char: int
|
| 13 |
+
end_char: int
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def normalize_trace_text(raw_trace_text: str) -> str:
|
| 17 |
+
return THINK_TAG_RE.sub("", raw_trace_text)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _non_whitespace_token_count(text: str) -> int:
|
| 21 |
+
return len(re.findall(r"\S+", text))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def split_sentences(
|
| 25 |
+
text: str,
|
| 26 |
+
*,
|
| 27 |
+
min_token_like_units: int = 2,
|
| 28 |
+
) -> list[SentenceSpan]:
|
| 29 |
+
if not text:
|
| 30 |
+
return []
|
| 31 |
+
|
| 32 |
+
raw_spans: list[tuple[int, int]] = []
|
| 33 |
+
start = 0
|
| 34 |
+
index = 0
|
| 35 |
+
text_length = len(text)
|
| 36 |
+
|
| 37 |
+
while index < text_length:
|
| 38 |
+
if text[index : index + 2] == "\n\n":
|
| 39 |
+
end = index + 2
|
| 40 |
+
raw_spans.append((start, end))
|
| 41 |
+
start = end
|
| 42 |
+
index = end
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
if text[index] in ".!?":
|
| 46 |
+
end = index + 1
|
| 47 |
+
while end < text_length and text[end] in "\"')]}":
|
| 48 |
+
end += 1
|
| 49 |
+
while end < text_length and text[end].isspace() and text[end : end + 2] != "\n\n":
|
| 50 |
+
end += 1
|
| 51 |
+
raw_spans.append((start, end))
|
| 52 |
+
start = end
|
| 53 |
+
index = end
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
index += 1
|
| 57 |
+
|
| 58 |
+
if start < text_length:
|
| 59 |
+
raw_spans.append((start, text_length))
|
| 60 |
+
|
| 61 |
+
merged: list[tuple[int, int]] = []
|
| 62 |
+
for span_start, span_end in raw_spans:
|
| 63 |
+
fragment = text[span_start:span_end]
|
| 64 |
+
if not fragment:
|
| 65 |
+
continue
|
| 66 |
+
if merged and _non_whitespace_token_count(fragment) < min_token_like_units:
|
| 67 |
+
previous_start, _ = merged[-1]
|
| 68 |
+
merged[-1] = (previous_start, span_end)
|
| 69 |
+
continue
|
| 70 |
+
merged.append((span_start, span_end))
|
| 71 |
+
|
| 72 |
+
if len(merged) > 1:
|
| 73 |
+
last_start, last_end = merged[-1]
|
| 74 |
+
if _non_whitespace_token_count(text[last_start:last_end]) < min_token_like_units:
|
| 75 |
+
prev_start, _ = merged[-2]
|
| 76 |
+
merged[-2] = (prev_start, last_end)
|
| 77 |
+
merged.pop()
|
| 78 |
+
|
| 79 |
+
return [
|
| 80 |
+
SentenceSpan(
|
| 81 |
+
text=text[span_start:span_end],
|
| 82 |
+
start_char=span_start,
|
| 83 |
+
end_char=span_end,
|
| 84 |
+
)
|
| 85 |
+
for span_start, span_end in merged
|
| 86 |
+
if text[span_start:span_end]
|
| 87 |
+
]
|
app/analysis/summaries.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from app.core.schemas import TopEdge
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compute_outgoing_importance(matrix: np.ndarray) -> list[float]:
|
| 9 |
+
sentence_count = matrix.shape[0]
|
| 10 |
+
scores: list[float] = []
|
| 11 |
+
for source_idx in range(sentence_count):
|
| 12 |
+
column = matrix[source_idx + 1 :, source_idx]
|
| 13 |
+
scores.append(float(column.mean()) if column.size else 0.0)
|
| 14 |
+
return scores
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_incoming_importance(matrix: np.ndarray) -> list[float]:
|
| 18 |
+
sentence_count = matrix.shape[0]
|
| 19 |
+
scores: list[float] = []
|
| 20 |
+
for target_idx in range(sentence_count):
|
| 21 |
+
row = matrix[target_idx, :target_idx]
|
| 22 |
+
scores.append(float(row.mean()) if row.size else 0.0)
|
| 23 |
+
return scores
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compute_top_edges(matrix: np.ndarray, top_k: int = 10) -> list[TopEdge]:
|
| 27 |
+
sentence_count = matrix.shape[0]
|
| 28 |
+
candidates: list[TopEdge] = []
|
| 29 |
+
for target_idx in range(sentence_count):
|
| 30 |
+
for source_idx in range(target_idx):
|
| 31 |
+
candidates.append(
|
| 32 |
+
TopEdge(
|
| 33 |
+
source_sentence_idx=source_idx,
|
| 34 |
+
target_sentence_idx=target_idx,
|
| 35 |
+
score=float(matrix[target_idx, source_idx]),
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
candidates.sort(key=lambda edge: edge.score, reverse=True)
|
| 39 |
+
return candidates[:top_k]
|
app/analysis/suppression.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import asdict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
|
| 12 |
+
from app.core.model_support import describe_model_support
|
| 13 |
+
from app.core.schemas import ModelCapability, RuntimeMetadata
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(slots=True)
|
| 17 |
+
class AttributionMatrixComputation:
|
| 18 |
+
matrix: np.ndarray
|
| 19 |
+
raw_matrix: np.ndarray
|
| 20 |
+
token_nll: np.ndarray
|
| 21 |
+
runtime_metadata: RuntimeMetadata
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
if logits.ndim != 3 or input_ids.ndim != 2:
|
| 26 |
+
raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
|
| 27 |
+
if logits.shape[0] != 1 or input_ids.shape[0] != 1:
|
| 28 |
+
raise ValueError("Only batch size 1 is supported for the prototype.")
|
| 29 |
+
if input_ids.shape[1] < 2:
|
| 30 |
+
raise ValueError("Need at least two tokens to compute next-token loss.")
|
| 31 |
+
|
| 32 |
+
shifted_logits = logits[:, :-1, :]
|
| 33 |
+
shifted_targets = input_ids[:, 1:]
|
| 34 |
+
log_probs = torch.log_softmax(shifted_logits, dim=-1)
|
| 35 |
+
gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
|
| 36 |
+
return -gathered[0]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _current_memory_mb(device: torch.device) -> float | None:
|
| 40 |
+
if device.type != "cuda":
|
| 41 |
+
return None
|
| 42 |
+
return float(torch.cuda.memory_allocated(device) / (1024 * 1024))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
|
| 46 |
+
if not take_log:
|
| 47 |
+
return raw_matrix.copy()
|
| 48 |
+
presentation = np.zeros_like(raw_matrix)
|
| 49 |
+
positive = raw_matrix > 0
|
| 50 |
+
presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
|
| 51 |
+
return presentation
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compute_attribution_matrix(
|
| 55 |
+
input_ids: torch.Tensor,
|
| 56 |
+
token_ranges: list[tuple[int, int]],
|
| 57 |
+
model: Any,
|
| 58 |
+
take_log: bool = True,
|
| 59 |
+
max_trace_tokens: int = 0,
|
| 60 |
+
max_sentences: int = 0,
|
| 61 |
+
) -> AttributionMatrixComputation:
|
| 62 |
+
device = input_ids.device
|
| 63 |
+
handles = register_hooks(model)
|
| 64 |
+
model.zero_grad(set_to_none=True)
|
| 65 |
+
forward_start = time.perf_counter()
|
| 66 |
+
memory_before_mb = _current_memory_mb(device)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
with torch.enable_grad():
|
| 70 |
+
outputs = model(
|
| 71 |
+
input_ids=input_ids,
|
| 72 |
+
output_attentions=True,
|
| 73 |
+
return_dict=True,
|
| 74 |
+
)
|
| 75 |
+
forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0
|
| 76 |
+
|
| 77 |
+
logits = outputs.logits
|
| 78 |
+
token_nll = compute_self_token_nll(logits, input_ids)
|
| 79 |
+
loss = token_nll.sum()
|
| 80 |
+
|
| 81 |
+
backward_start = time.perf_counter()
|
| 82 |
+
loss.backward()
|
| 83 |
+
backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0
|
| 84 |
+
|
| 85 |
+
attentions = get_stored_attentions()
|
| 86 |
+
if not attentions:
|
| 87 |
+
raise RuntimeError("No attention tensors were captured. Check eager attention mode.")
|
| 88 |
+
|
| 89 |
+
matrix_start = time.perf_counter()
|
| 90 |
+
sentence_count = len(token_ranges)
|
| 91 |
+
raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)
|
| 92 |
+
|
| 93 |
+
ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
|
| 94 |
+
first_attention = ordered_layers[0]
|
| 95 |
+
num_layers = len(ordered_layers)
|
| 96 |
+
num_heads = int(first_attention.shape[1])
|
| 97 |
+
|
| 98 |
+
for source_idx, (source_start, source_end) in enumerate(token_ranges):
|
| 99 |
+
for target_idx, (target_start, target_end) in enumerate(token_ranges):
|
| 100 |
+
if target_idx <= source_idx:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
total = 0.0
|
| 104 |
+
for attention in ordered_layers:
|
| 105 |
+
grad = attention.grad
|
| 106 |
+
if grad is None:
|
| 107 |
+
raise RuntimeError("Attention gradient was not retained for one or more layers.")
|
| 108 |
+
total += -(
|
| 109 |
+
grad[0, :, target_start:target_end, source_start:source_end]
|
| 110 |
+
* attention[0, :, target_start:target_end, source_start:source_end]
|
| 111 |
+
).sum().item()
|
| 112 |
+
|
| 113 |
+
denominator = max(1, target_end - target_start)
|
| 114 |
+
raw_matrix[target_idx, source_idx] = total / denominator
|
| 115 |
+
|
| 116 |
+
matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
|
| 117 |
+
total_analysis_ms = (
|
| 118 |
+
forward_pass_ms + backward_pass_ms + matrix_computation_ms
|
| 119 |
+
)
|
| 120 |
+
presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)
|
| 121 |
+
|
| 122 |
+
attention_impl = getattr(model.config, "_attn_implementation", "unknown")
|
| 123 |
+
capability = describe_model_support(model)
|
| 124 |
+
runtime_metadata = RuntimeMetadata(
|
| 125 |
+
forward_pass_ms=forward_pass_ms,
|
| 126 |
+
backward_pass_ms=backward_pass_ms,
|
| 127 |
+
matrix_computation_ms=matrix_computation_ms,
|
| 128 |
+
total_analysis_ms=total_analysis_ms,
|
| 129 |
+
num_layers=num_layers,
|
| 130 |
+
num_heads=num_heads,
|
| 131 |
+
sequence_length_tokens=int(input_ids.shape[1]),
|
| 132 |
+
sentence_count=sentence_count,
|
| 133 |
+
device=str(device),
|
| 134 |
+
dtype=str(first_attention.dtype),
|
| 135 |
+
attention_impl=str(attention_impl),
|
| 136 |
+
max_trace_tokens=max_trace_tokens,
|
| 137 |
+
max_sentences=max_sentences,
|
| 138 |
+
capability=ModelCapability.model_validate(asdict(capability)),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
memory_after_mb = _current_memory_mb(device)
|
| 142 |
+
if memory_before_mb is not None and memory_after_mb is not None:
|
| 143 |
+
runtime_metadata = runtime_metadata.model_copy(
|
| 144 |
+
update={
|
| 145 |
+
"device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return AttributionMatrixComputation(
|
| 150 |
+
matrix=presentation_matrix,
|
| 151 |
+
raw_matrix=raw_matrix,
|
| 152 |
+
token_nll=token_nll.detach().cpu().numpy(),
|
| 153 |
+
runtime_metadata=runtime_metadata,
|
| 154 |
+
)
|
| 155 |
+
finally:
|
| 156 |
+
for attention in get_stored_attentions().values():
|
| 157 |
+
attention.grad = None
|
| 158 |
+
remove_hooks(handles)
|
| 159 |
+
model.zero_grad(set_to_none=True)
|
| 160 |
+
if device.type == "cuda":
|
| 161 |
+
torch.cuda.empty_cache()
|
app/analysis/token_boundaries.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from app.analysis.sentence_split import SentenceSpan
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(slots=True)
|
| 12 |
+
class TokenizedSentenceMapping:
|
| 13 |
+
input_ids: torch.Tensor
|
| 14 |
+
token_ranges: list[tuple[int, int]]
|
| 15 |
+
offsets: list[tuple[int, int]]
|
| 16 |
+
text: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str:
|
| 20 |
+
if max_tokens <= 0:
|
| 21 |
+
raise ValueError("max_tokens must be positive.")
|
| 22 |
+
encoded = tokenizer(
|
| 23 |
+
text,
|
| 24 |
+
add_special_tokens=False,
|
| 25 |
+
return_offsets_mapping=True,
|
| 26 |
+
)
|
| 27 |
+
offsets = encoded["offset_mapping"]
|
| 28 |
+
if len(offsets) <= max_tokens:
|
| 29 |
+
return text
|
| 30 |
+
end_char = offsets[max_tokens - 1][1]
|
| 31 |
+
return text[:end_char]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def tokenize_with_sentence_ranges(
|
| 35 |
+
text: str,
|
| 36 |
+
sentence_spans: list[SentenceSpan],
|
| 37 |
+
tokenizer: Any,
|
| 38 |
+
) -> TokenizedSentenceMapping:
|
| 39 |
+
encoded = tokenizer(
|
| 40 |
+
text,
|
| 41 |
+
add_special_tokens=False,
|
| 42 |
+
return_offsets_mapping=True,
|
| 43 |
+
return_tensors="pt",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
input_ids = encoded["input_ids"]
|
| 47 |
+
raw_offsets = encoded["offset_mapping"][0].tolist()
|
| 48 |
+
offsets = [(int(start), int(end)) for start, end in raw_offsets]
|
| 49 |
+
token_ranges: list[tuple[int, int]] = []
|
| 50 |
+
|
| 51 |
+
for span in sentence_spans:
|
| 52 |
+
overlapping = [
|
| 53 |
+
token_index
|
| 54 |
+
for token_index, (token_start, token_end) in enumerate(offsets)
|
| 55 |
+
if token_end > span.start_char and token_start < span.end_char
|
| 56 |
+
]
|
| 57 |
+
if not overlapping:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens."
|
| 60 |
+
)
|
| 61 |
+
token_ranges.append((overlapping[0], overlapping[-1] + 1))
|
| 62 |
+
|
| 63 |
+
if token_ranges:
|
| 64 |
+
if token_ranges[0][0] != 0 or token_ranges[-1][1] != len(offsets):
|
| 65 |
+
raise ValueError("Sentence token ranges do not cover the full analyzed sequence.")
|
| 66 |
+
for previous, current in zip(token_ranges, token_ranges[1:]):
|
| 67 |
+
if previous[1] != current[0]:
|
| 68 |
+
raise ValueError("Sentence token ranges are not contiguous.")
|
| 69 |
+
|
| 70 |
+
return TokenizedSentenceMapping(
|
| 71 |
+
input_ids=input_ids,
|
| 72 |
+
token_ranges=token_ranges,
|
| 73 |
+
offsets=offsets,
|
| 74 |
+
text=text,
|
| 75 |
+
)
|
app/analysis/validation.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from scipy.stats import pearsonr, spearmanr
|
| 8 |
+
|
| 9 |
+
from app.analysis.suppression import compute_self_token_nll
|
| 10 |
+
from app.core.schemas import TopEdge, ValidationMetadata
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice:
|
| 14 |
+
start, end = token_range
|
| 15 |
+
return slice(max(0, start - 1), max(0, end - 1))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_exact_suppression_mask(
|
| 19 |
+
*,
|
| 20 |
+
sequence_length: int,
|
| 21 |
+
source_range: tuple[int, int],
|
| 22 |
+
target_range: tuple[int, int],
|
| 23 |
+
device: torch.device,
|
| 24 |
+
dtype: torch.dtype,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
fill_value = torch.finfo(dtype).min
|
| 27 |
+
mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype)
|
| 28 |
+
future_positions = torch.triu(
|
| 29 |
+
torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool),
|
| 30 |
+
diagonal=1,
|
| 31 |
+
)
|
| 32 |
+
mask = mask.masked_fill(future_positions, fill_value)
|
| 33 |
+
source_start, source_end = source_range
|
| 34 |
+
target_start, target_end = target_range
|
| 35 |
+
mask[target_start:target_end, source_start:source_end] = fill_value
|
| 36 |
+
return mask.unsqueeze(0).unsqueeze(0)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def compute_exact_edge_score(
|
| 40 |
+
*,
|
| 41 |
+
model: Any,
|
| 42 |
+
input_ids: torch.Tensor,
|
| 43 |
+
source_range: tuple[int, int],
|
| 44 |
+
target_range: tuple[int, int],
|
| 45 |
+
baseline_token_nll: np.ndarray,
|
| 46 |
+
) -> float:
|
| 47 |
+
model_dtype = next(model.parameters()).dtype
|
| 48 |
+
attention_mask = build_exact_suppression_mask(
|
| 49 |
+
sequence_length=int(input_ids.shape[1]),
|
| 50 |
+
source_range=source_range,
|
| 51 |
+
target_range=target_range,
|
| 52 |
+
device=input_ids.device,
|
| 53 |
+
dtype=model_dtype,
|
| 54 |
+
)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
outputs = model(
|
| 57 |
+
input_ids=input_ids,
|
| 58 |
+
attention_mask=attention_mask,
|
| 59 |
+
output_attentions=False,
|
| 60 |
+
return_dict=True,
|
| 61 |
+
)
|
| 62 |
+
suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy()
|
| 63 |
+
nll_slice = _nll_slice_for_token_range(target_range)
|
| 64 |
+
return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum())
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def validate_top_edges(
|
| 68 |
+
*,
|
| 69 |
+
model: Any,
|
| 70 |
+
input_ids: torch.Tensor,
|
| 71 |
+
token_ranges: list[tuple[int, int]],
|
| 72 |
+
top_edges: list[TopEdge],
|
| 73 |
+
baseline_token_nll: np.ndarray,
|
| 74 |
+
top_k: int,
|
| 75 |
+
) -> ValidationMetadata:
|
| 76 |
+
if top_k <= 0 or not top_edges:
|
| 77 |
+
return ValidationMetadata(enabled=False, top_k=0)
|
| 78 |
+
|
| 79 |
+
selected_edges = top_edges[:top_k]
|
| 80 |
+
exact_scores: list[float] = []
|
| 81 |
+
attributed_scores: list[float] = []
|
| 82 |
+
compared_edges: list[TopEdge] = []
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
for edge in selected_edges:
|
| 86 |
+
exact_score = compute_exact_edge_score(
|
| 87 |
+
model=model,
|
| 88 |
+
input_ids=input_ids,
|
| 89 |
+
source_range=token_ranges[edge.source_sentence_idx],
|
| 90 |
+
target_range=token_ranges[edge.target_sentence_idx],
|
| 91 |
+
baseline_token_nll=baseline_token_nll,
|
| 92 |
+
)
|
| 93 |
+
exact_scores.append(exact_score)
|
| 94 |
+
attributed_scores.append(edge.score)
|
| 95 |
+
compared_edges.append(
|
| 96 |
+
TopEdge(
|
| 97 |
+
source_sentence_idx=edge.source_sentence_idx,
|
| 98 |
+
target_sentence_idx=edge.target_sentence_idx,
|
| 99 |
+
score=exact_score,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
except Exception as exc: # pragma: no cover - environment/model dependent
|
| 103 |
+
return ValidationMetadata(
|
| 104 |
+
enabled=True,
|
| 105 |
+
top_k=top_k,
|
| 106 |
+
compared_edges=[],
|
| 107 |
+
notes=f"Exact suppression validation failed: {exc}",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
|
| 111 |
+
spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
|
| 112 |
+
|
| 113 |
+
return ValidationMetadata(
|
| 114 |
+
enabled=True,
|
| 115 |
+
top_k=top_k,
|
| 116 |
+
pearson=pearson,
|
| 117 |
+
spearman=spearman,
|
| 118 |
+
compared_edges=compared_edges,
|
| 119 |
+
notes="Exact suppression compares sentence-level NLL deltas for selected edges.",
|
| 120 |
+
)
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API package for GPU-hosted analysis service."""
|
app/api/auth.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from fastapi import HTTPException, Request
|
| 6 |
+
from huggingface_hub import parse_huggingface_oauth
|
| 7 |
+
|
| 8 |
+
from app.core.config import Settings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True, slots=True)
|
| 12 |
+
class UserContext:
|
| 13 |
+
id: str
|
| 14 |
+
username: str
|
| 15 |
+
display_name: str | None
|
| 16 |
+
avatar_url: str | None
|
| 17 |
+
authenticated: bool
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_optional_user(request: Request) -> UserContext | None:
|
| 21 |
+
if "session" not in request.scope:
|
| 22 |
+
return None
|
| 23 |
+
try:
|
| 24 |
+
oauth_info = parse_huggingface_oauth(request)
|
| 25 |
+
except AssertionError:
|
| 26 |
+
return None
|
| 27 |
+
if oauth_info is None:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
user_info = oauth_info.user_info
|
| 31 |
+
username = user_info.preferred_username or user_info.sub or "hf-user"
|
| 32 |
+
display_name = user_info.name or username
|
| 33 |
+
return UserContext(
|
| 34 |
+
id=user_info.sub or username,
|
| 35 |
+
username=username,
|
| 36 |
+
display_name=display_name,
|
| 37 |
+
avatar_url=user_info.picture,
|
| 38 |
+
authenticated=True,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def require_user(request: Request, settings: Settings) -> UserContext:
|
| 43 |
+
user = get_optional_user(request)
|
| 44 |
+
if user is not None:
|
| 45 |
+
return user
|
| 46 |
+
if settings.require_auth:
|
| 47 |
+
raise HTTPException(status_code=401, detail="Sign in with Hugging Face to use this service.")
|
| 48 |
+
return UserContext(
|
| 49 |
+
id="anonymous",
|
| 50 |
+
username="anonymous",
|
| 51 |
+
display_name="Anonymous",
|
| 52 |
+
avatar_url=None,
|
| 53 |
+
authenticated=False,
|
| 54 |
+
)
|
app/api/main.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from io import StringIO
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 12 |
+
from fastapi.responses import FileResponse
|
| 13 |
+
from fastapi.responses import StreamingResponse
|
| 14 |
+
from fastapi.staticfiles import StaticFiles
|
| 15 |
+
from huggingface_hub import attach_huggingface_oauth
|
| 16 |
+
|
| 17 |
+
from app.api.auth import get_optional_user, require_user
|
| 18 |
+
from app.core.config import get_settings
|
| 19 |
+
from app.core.runtime import load_model_bundle
|
| 20 |
+
from app.core.runtime_pipeline import compute_attribution_analysis
|
| 21 |
+
from app.core.schemas import (
|
| 22 |
+
AnalysisRequest,
|
| 23 |
+
AnalysisResult,
|
| 24 |
+
CurrentUserResponse,
|
| 25 |
+
HealthResponse,
|
| 26 |
+
SessionCreateRequest,
|
| 27 |
+
SessionResponse,
|
| 28 |
+
SessionResultResponse,
|
| 29 |
+
WarmupResponse,
|
| 30 |
+
)
|
| 31 |
+
from app.services.sessions import SessionAccessError, SessionLimitError, SessionService
|
| 32 |
+
from app.storage.repository import SessionRepository
|
| 33 |
+
from app.workers.jobs import build_job_runner
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@asynccontextmanager
|
| 40 |
+
async def lifespan(_app: FastAPI):
|
| 41 |
+
settings = get_settings()
|
| 42 |
+
repository = SessionRepository(settings.database_path)
|
| 43 |
+
jobs = build_job_runner(settings.job_workers)
|
| 44 |
+
_app.state.repository = repository
|
| 45 |
+
_app.state.jobs = jobs
|
| 46 |
+
_app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs)
|
| 47 |
+
if settings.preload_model:
|
| 48 |
+
logger.info(
|
| 49 |
+
"Preloading model '%s' on device preference '%s'.",
|
| 50 |
+
settings.model_name,
|
| 51 |
+
settings.device_preference,
|
| 52 |
+
)
|
| 53 |
+
load_model_bundle(
|
| 54 |
+
settings.model_name,
|
| 55 |
+
device_preference=settings.device_preference,
|
| 56 |
+
dtype_preference=settings.dtype_preference,
|
| 57 |
+
attn_implementation=settings.attn_implementation,
|
| 58 |
+
trust_remote_code=settings.trust_remote_code,
|
| 59 |
+
low_cpu_mem_usage=settings.low_cpu_mem_usage,
|
| 60 |
+
)
|
| 61 |
+
yield
|
| 62 |
+
jobs.shutdown()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
app = FastAPI(
|
| 66 |
+
title="CoT Attribution Analysis API",
|
| 67 |
+
version="0.1.0",
|
| 68 |
+
lifespan=lifespan,
|
| 69 |
+
)
|
| 70 |
+
if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"):
|
| 71 |
+
attach_huggingface_oauth(app)
|
| 72 |
+
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_session_service() -> SessionService:
|
| 76 |
+
return app.state.session_service
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _to_session_response(payload: dict) -> SessionResponse:
|
| 80 |
+
return SessionResponse(
|
| 81 |
+
id=payload["id"],
|
| 82 |
+
status=payload["status"],
|
| 83 |
+
question=payload["question"],
|
| 84 |
+
model_name=payload["model_name"],
|
| 85 |
+
error=payload.get("error"),
|
| 86 |
+
created_at=payload["created_at"],
|
| 87 |
+
updated_at=payload["updated_at"],
|
| 88 |
+
answer=payload.get("answer"),
|
| 89 |
+
raw_trace_text=payload.get("raw_trace_text"),
|
| 90 |
+
normalized_trace_text=payload.get("normalized_trace_text"),
|
| 91 |
+
sentences=payload.get("sentences"),
|
| 92 |
+
generation_metadata=payload.get("generation_metadata"),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.get("/", include_in_schema=False)
|
| 97 |
+
def index() -> FileResponse:
|
| 98 |
+
return FileResponse(FRONTEND_DIR / "index.html")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@app.get("/healthz", response_model=HealthResponse)
|
| 102 |
+
def healthz() -> HealthResponse:
|
| 103 |
+
settings = get_settings()
|
| 104 |
+
return HealthResponse(
|
| 105 |
+
status="ok",
|
| 106 |
+
model_name=settings.model_name,
|
| 107 |
+
device_preference=settings.device_preference,
|
| 108 |
+
dtype_preference=settings.dtype_preference,
|
| 109 |
+
preload_model=settings.preload_model,
|
| 110 |
+
cuda_available=torch.cuda.is_available(),
|
| 111 |
+
mps_available=torch.backends.mps.is_available(),
|
| 112 |
+
require_auth=settings.require_auth,
|
| 113 |
+
public_api_enabled=settings.public_api_enabled,
|
| 114 |
+
max_queued_jobs=settings.max_queued_jobs,
|
| 115 |
+
max_active_jobs_per_user=settings.max_active_jobs_per_user,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@app.get("/api/me", response_model=CurrentUserResponse)
|
| 120 |
+
def me(request: Request) -> CurrentUserResponse:
|
| 121 |
+
settings = get_settings()
|
| 122 |
+
user = get_optional_user(request)
|
| 123 |
+
return CurrentUserResponse(
|
| 124 |
+
authenticated=user is not None,
|
| 125 |
+
auth_required=settings.require_auth,
|
| 126 |
+
username=user.username if user else None,
|
| 127 |
+
full_name=user.display_name if user else None,
|
| 128 |
+
avatar_url=user.avatar_url if user else None,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.post("/api/warmup", response_model=WarmupResponse)
|
| 133 |
+
def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse:
|
| 134 |
+
settings = get_settings()
|
| 135 |
+
bundle = load_model_bundle(
|
| 136 |
+
model_name or settings.model_name,
|
| 137 |
+
device_preference=device_preference or settings.device_preference,
|
| 138 |
+
dtype_preference=settings.dtype_preference,
|
| 139 |
+
attn_implementation=settings.attn_implementation,
|
| 140 |
+
trust_remote_code=settings.trust_remote_code,
|
| 141 |
+
low_cpu_mem_usage=settings.low_cpu_mem_usage,
|
| 142 |
+
)
|
| 143 |
+
return WarmupResponse(
|
| 144 |
+
status="ready",
|
| 145 |
+
model_name=bundle.model_name,
|
| 146 |
+
device=str(bundle.device),
|
| 147 |
+
dtype=str(bundle.dtype),
|
| 148 |
+
capability=asdict(bundle.capability),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@app.post("/api/analyze", response_model=AnalysisResult)
|
| 153 |
+
def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult:
|
| 154 |
+
settings = get_settings()
|
| 155 |
+
require_user(http_request, settings)
|
| 156 |
+
try:
|
| 157 |
+
return compute_attribution_analysis(
|
| 158 |
+
question=request.question,
|
| 159 |
+
model_name=request.model_name,
|
| 160 |
+
take_log=request.take_log,
|
| 161 |
+
max_sentences=request.max_sentences,
|
| 162 |
+
max_trace_tokens=request.max_trace_tokens,
|
| 163 |
+
validate_top_k=request.validate_top_k,
|
| 164 |
+
max_new_tokens=request.max_new_tokens,
|
| 165 |
+
temperature=request.temperature,
|
| 166 |
+
top_p=request.top_p,
|
| 167 |
+
device_preference=request.device_preference,
|
| 168 |
+
dtype_preference=request.dtype_preference,
|
| 169 |
+
attn_implementation=request.attn_implementation,
|
| 170 |
+
trust_remote_code=request.trust_remote_code,
|
| 171 |
+
low_cpu_mem_usage=request.low_cpu_mem_usage,
|
| 172 |
+
)
|
| 173 |
+
except Exception as exc: # pragma: no cover - runtime path
|
| 174 |
+
logger.exception("Analysis request failed")
|
| 175 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@app.get("/api/sessions", response_model=list[SessionResponse])
|
| 179 |
+
def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]:
|
| 180 |
+
settings = get_settings()
|
| 181 |
+
user = require_user(request, settings)
|
| 182 |
+
service = get_session_service()
|
| 183 |
+
payloads = service.list_sessions(user.id, limit=limit)
|
| 184 |
+
return [_to_session_response(payload) for payload in payloads]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@app.post("/api/sessions", response_model=SessionResponse)
|
| 188 |
+
def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse:
|
| 189 |
+
settings = get_settings()
|
| 190 |
+
user = require_user(http_request, settings)
|
| 191 |
+
service = get_session_service()
|
| 192 |
+
try:
|
| 193 |
+
session = service.create_session(
|
| 194 |
+
AnalysisRequest(
|
| 195 |
+
question=request.question,
|
| 196 |
+
model_name=request.model_name,
|
| 197 |
+
take_log=request.take_log,
|
| 198 |
+
max_sentences=request.max_sentences,
|
| 199 |
+
max_trace_tokens=request.max_trace_tokens,
|
| 200 |
+
validate_top_k=request.validate_top_k,
|
| 201 |
+
max_new_tokens=request.max_new_tokens,
|
| 202 |
+
temperature=request.temperature,
|
| 203 |
+
top_p=request.top_p,
|
| 204 |
+
device_preference=request.device_preference,
|
| 205 |
+
dtype_preference=request.dtype_preference,
|
| 206 |
+
attn_implementation=request.attn_implementation,
|
| 207 |
+
trust_remote_code=request.trust_remote_code,
|
| 208 |
+
low_cpu_mem_usage=request.low_cpu_mem_usage,
|
| 209 |
+
),
|
| 210 |
+
owner_id=user.id,
|
| 211 |
+
owner_name=user.display_name,
|
| 212 |
+
)
|
| 213 |
+
except SessionLimitError as exc:
|
| 214 |
+
raise HTTPException(status_code=429, detail=str(exc)) from exc
|
| 215 |
+
payload = service.get_session_payload(session.id, owner_id=user.id)
|
| 216 |
+
return _to_session_response(payload)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@app.get("/api/sessions/{session_id}", response_model=SessionResponse)
|
| 220 |
+
def get_session(session_id: str, request: Request) -> SessionResponse:
|
| 221 |
+
settings = get_settings()
|
| 222 |
+
user = require_user(request, settings)
|
| 223 |
+
service = get_session_service()
|
| 224 |
+
try:
|
| 225 |
+
payload = service.get_session_payload(session_id, owner_id=user.id)
|
| 226 |
+
except KeyError as exc:
|
| 227 |
+
raise HTTPException(status_code=404, detail="Session not found") from exc
|
| 228 |
+
except SessionAccessError as exc:
|
| 229 |
+
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
| 230 |
+
return _to_session_response(payload)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse)
|
| 234 |
+
def analyze_session(session_id: str, request: Request) -> SessionResponse:
|
| 235 |
+
settings = get_settings()
|
| 236 |
+
user = require_user(request, settings)
|
| 237 |
+
service = get_session_service()
|
| 238 |
+
try:
|
| 239 |
+
session = service.start_analysis(session_id, owner_id=user.id)
|
| 240 |
+
payload = service.get_session_payload(session.id, owner_id=user.id)
|
| 241 |
+
except KeyError as exc:
|
| 242 |
+
raise HTTPException(status_code=404, detail="Session not found") from exc
|
| 243 |
+
except SessionAccessError as exc:
|
| 244 |
+
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
| 245 |
+
return _to_session_response(payload)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse)
|
| 249 |
+
def get_session_result(session_id: str, request: Request) -> SessionResultResponse:
|
| 250 |
+
settings = get_settings()
|
| 251 |
+
user = require_user(request, settings)
|
| 252 |
+
service = get_session_service()
|
| 253 |
+
try:
|
| 254 |
+
payload = service.get_session_payload(session_id, owner_id=user.id)
|
| 255 |
+
except KeyError as exc:
|
| 256 |
+
raise HTTPException(status_code=404, detail="Session not found") from exc
|
| 257 |
+
except SessionAccessError as exc:
|
| 258 |
+
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
| 259 |
+
|
| 260 |
+
session_response = _to_session_response(payload)
|
| 261 |
+
analysis_payload = payload.get("analysis")
|
| 262 |
+
return SessionResultResponse(
|
| 263 |
+
session=session_response,
|
| 264 |
+
analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@app.get("/api/sessions/{session_id}/export.json")
|
| 269 |
+
def export_session_json(session_id: str, request: Request) -> StreamingResponse:
|
| 270 |
+
settings = get_settings()
|
| 271 |
+
user = require_user(request, settings)
|
| 272 |
+
service = get_session_service()
|
| 273 |
+
try:
|
| 274 |
+
payload = service.get_session_payload(session_id, owner_id=user.id)
|
| 275 |
+
except KeyError as exc:
|
| 276 |
+
raise HTTPException(status_code=404, detail="Session not found") from exc
|
| 277 |
+
except SessionAccessError as exc:
|
| 278 |
+
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
| 279 |
+
result = SessionResultResponse(
|
| 280 |
+
session=_to_session_response(payload),
|
| 281 |
+
analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None,
|
| 282 |
+
)
|
| 283 |
+
return StreamingResponse(
|
| 284 |
+
iter([result.model_dump_json(indent=2)]),
|
| 285 |
+
media_type="application/json",
|
| 286 |
+
headers={"content-disposition": f'attachment; filename="{session_id}.json"'},
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@app.get("/api/sessions/{session_id}/export.csv")
|
| 291 |
+
def export_session_csv(session_id: str, request: Request) -> StreamingResponse:
|
| 292 |
+
settings = get_settings()
|
| 293 |
+
user = require_user(request, settings)
|
| 294 |
+
service = get_session_service()
|
| 295 |
+
try:
|
| 296 |
+
result = service.get_analysis_result(session_id, owner_id=user.id)
|
| 297 |
+
except KeyError as exc:
|
| 298 |
+
raise HTTPException(status_code=404, detail="Analysis result not found") from exc
|
| 299 |
+
except SessionAccessError as exc:
|
| 300 |
+
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
| 301 |
+
|
| 302 |
+
buffer = StringIO()
|
| 303 |
+
buffer.write("source_sentence_idx,target_sentence_idx,score\n")
|
| 304 |
+
for edge in result.top_edges:
|
| 305 |
+
buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n")
|
| 306 |
+
return StreamingResponse(
|
| 307 |
+
iter([buffer.getvalue()]),
|
| 308 |
+
media_type="text/csv",
|
| 309 |
+
headers={"content-disposition": f'attachment; filename="{session_id}.csv"'},
|
| 310 |
+
)
|
app/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""CLI entrypoints."""
|
app/cli/run_api.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import uvicorn
|
| 4 |
+
|
| 5 |
+
from app.core.config import get_settings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main() -> None:
|
| 9 |
+
settings = get_settings()
|
| 10 |
+
uvicorn.run(
|
| 11 |
+
"app.api.main:app",
|
| 12 |
+
host=settings.api_host,
|
| 13 |
+
port=settings.api_port,
|
| 14 |
+
reload=False,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
main()
|
app/cli/run_prototype.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from datetime import UTC, datetime
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import typer
|
| 8 |
+
|
| 9 |
+
from app.core.runtime_pipeline import compute_attribution_analysis
|
| 10 |
+
|
| 11 |
+
cli = typer.Typer(add_completion=False)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _write_heatmap(path: Path, matrix: list[list[float]]) -> None:
|
| 15 |
+
try:
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
except ImportError as exc: # pragma: no cover - optional dependency
|
| 18 |
+
raise RuntimeError(
|
| 19 |
+
"Heatmap output requires the optional viz dependency: pip install .[viz]"
|
| 20 |
+
) from exc
|
| 21 |
+
|
| 22 |
+
figure, axis = plt.subplots(figsize=(8, 6))
|
| 23 |
+
image = axis.imshow(matrix, aspect="auto", cmap="viridis")
|
| 24 |
+
axis.set_xlabel("Source sentence")
|
| 25 |
+
axis.set_ylabel("Target sentence")
|
| 26 |
+
axis.set_title("Sentence influence matrix")
|
| 27 |
+
figure.colorbar(image, ax=axis)
|
| 28 |
+
figure.tight_layout()
|
| 29 |
+
figure.savefig(path)
|
| 30 |
+
plt.close(figure)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@cli.command()
|
| 34 |
+
def main(
|
| 35 |
+
question: str,
|
| 36 |
+
model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
| 37 |
+
output_json: Path | None = None,
|
| 38 |
+
output_heatmap: Path | None = None,
|
| 39 |
+
max_new_tokens: int = 512,
|
| 40 |
+
max_trace_tokens: int = 1024,
|
| 41 |
+
max_sentences: int = 40,
|
| 42 |
+
take_log: bool = True,
|
| 43 |
+
validate_top_k: int = 3,
|
| 44 |
+
temperature: float = 0.6,
|
| 45 |
+
top_p: float = 0.95,
|
| 46 |
+
device_preference: str = "auto",
|
| 47 |
+
dtype_preference: str = "auto",
|
| 48 |
+
attn_implementation: str = "eager",
|
| 49 |
+
trust_remote_code: bool = True,
|
| 50 |
+
low_cpu_mem_usage: bool = True,
|
| 51 |
+
) -> None:
|
| 52 |
+
result = compute_attribution_analysis(
|
| 53 |
+
question=question,
|
| 54 |
+
model_name=model_name,
|
| 55 |
+
take_log=take_log,
|
| 56 |
+
max_sentences=max_sentences,
|
| 57 |
+
max_trace_tokens=max_trace_tokens,
|
| 58 |
+
validate_top_k=validate_top_k,
|
| 59 |
+
max_new_tokens=max_new_tokens,
|
| 60 |
+
temperature=temperature,
|
| 61 |
+
top_p=top_p,
|
| 62 |
+
device_preference=device_preference,
|
| 63 |
+
dtype_preference=dtype_preference,
|
| 64 |
+
attn_implementation=attn_implementation,
|
| 65 |
+
trust_remote_code=trust_remote_code,
|
| 66 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
output_dir = Path("outputs")
|
| 70 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
|
| 72 |
+
|
| 73 |
+
json_path = output_json or output_dir / f"analysis_{timestamp}.json"
|
| 74 |
+
json_path.write_text(json.dumps(result.model_dump(), indent=2), encoding="utf-8")
|
| 75 |
+
|
| 76 |
+
if output_heatmap is not None:
|
| 77 |
+
_write_heatmap(output_heatmap, result.suppression_matrix)
|
| 78 |
+
|
| 79 |
+
typer.echo(f"Wrote analysis JSON to {json_path}")
|
| 80 |
+
typer.echo(f"Sentences: {len(result.sentences)}")
|
| 81 |
+
typer.echo(f"Top edge count: {len(result.top_edges)}")
|
| 82 |
+
if result.validation_metadata and result.validation_metadata.enabled:
|
| 83 |
+
typer.echo(
|
| 84 |
+
"Validation:"
|
| 85 |
+
f" pearson={result.validation_metadata.pearson}"
|
| 86 |
+
f" spearman={result.validation_metadata.spearman}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
cli()
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core runtime, config, and schema definitions."""
|
app/core/config.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True, slots=True)
|
| 10 |
+
class Settings:
|
| 11 |
+
model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
| 12 |
+
max_trace_tokens: int = 1024
|
| 13 |
+
max_sentences: int = 40
|
| 14 |
+
take_log: bool = True
|
| 15 |
+
device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto"
|
| 16 |
+
dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto"
|
| 17 |
+
attn_implementation: str = "eager"
|
| 18 |
+
trust_remote_code: bool = True
|
| 19 |
+
low_cpu_mem_usage: bool = True
|
| 20 |
+
preload_model: bool = False
|
| 21 |
+
api_host: str = "0.0.0.0"
|
| 22 |
+
api_port: int = 7860
|
| 23 |
+
database_path: str = "data/app.db"
|
| 24 |
+
job_workers: int = 1
|
| 25 |
+
max_queued_jobs: int = 8
|
| 26 |
+
max_active_jobs_per_user: int = 2
|
| 27 |
+
require_auth: bool = True
|
| 28 |
+
public_api_enabled: bool = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_SETTINGS = Settings()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@lru_cache(maxsize=1)
|
| 35 |
+
def get_settings() -> Settings:
|
| 36 |
+
take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"}
|
| 37 |
+
trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"}
|
| 38 |
+
low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"}
|
| 39 |
+
require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"}
|
| 40 |
+
public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"}
|
| 41 |
+
return Settings(
|
| 42 |
+
model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name),
|
| 43 |
+
max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)),
|
| 44 |
+
max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)),
|
| 45 |
+
take_log=take_log,
|
| 46 |
+
device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type]
|
| 47 |
+
dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type]
|
| 48 |
+
attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation),
|
| 49 |
+
trust_remote_code=trust_remote_code,
|
| 50 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 51 |
+
preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"},
|
| 52 |
+
api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host),
|
| 53 |
+
api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)),
|
| 54 |
+
database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path),
|
| 55 |
+
job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)),
|
| 56 |
+
max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)),
|
| 57 |
+
max_active_jobs_per_user=int(
|
| 58 |
+
os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user)
|
| 59 |
+
),
|
| 60 |
+
require_auth=require_auth,
|
| 61 |
+
public_api_enabled=public_api_enabled,
|
| 62 |
+
)
|
app/core/model_support.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True, slots=True)
|
| 8 |
+
class ModelSupport:
|
| 9 |
+
supports_attribution: bool
|
| 10 |
+
reason: str | None
|
| 11 |
+
layer_path: str | None
|
| 12 |
+
attention_attr: str | None
|
| 13 |
+
layer_count: int
|
| 14 |
+
attention_impl: str | None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
|
| 18 |
+
("model", "layers"),
|
| 19 |
+
("model", "model", "layers"),
|
| 20 |
+
("transformer", "h"),
|
| 21 |
+
("gpt_neox", "layers"),
|
| 22 |
+
)
|
| 23 |
+
_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
|
| 27 |
+
current = obj
|
| 28 |
+
for segment in path:
|
| 29 |
+
current = getattr(current, segment, None)
|
| 30 |
+
if current is None:
|
| 31 |
+
return None
|
| 32 |
+
return current
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_attention_impl(model: Any) -> str | None:
|
| 36 |
+
config = getattr(model, "config", None)
|
| 37 |
+
if config is None:
|
| 38 |
+
return None
|
| 39 |
+
return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def describe_model_support(model: Any) -> ModelSupport:
|
| 43 |
+
attn_impl = _get_attention_impl(model)
|
| 44 |
+
layers = None
|
| 45 |
+
layer_path = None
|
| 46 |
+
for candidate in _LAYER_PATH_CANDIDATES:
|
| 47 |
+
maybe_layers = _resolve_attr_chain(model, candidate)
|
| 48 |
+
if maybe_layers is not None:
|
| 49 |
+
layers = list(maybe_layers)
|
| 50 |
+
layer_path = ".".join(candidate)
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if not layers:
|
| 54 |
+
return ModelSupport(
|
| 55 |
+
supports_attribution=False,
|
| 56 |
+
reason="Unsupported model structure: unable to locate decoder layers.",
|
| 57 |
+
layer_path=layer_path,
|
| 58 |
+
attention_attr=None,
|
| 59 |
+
layer_count=0,
|
| 60 |
+
attention_impl=attn_impl,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
for attention_attr in _ATTENTION_ATTR_CANDIDATES:
|
| 64 |
+
if all(getattr(layer, attention_attr, None) is not None for layer in layers):
|
| 65 |
+
if attn_impl != "eager":
|
| 66 |
+
return ModelSupport(
|
| 67 |
+
supports_attribution=False,
|
| 68 |
+
reason="Attention gradients require attn_implementation='eager'.",
|
| 69 |
+
layer_path=layer_path,
|
| 70 |
+
attention_attr=attention_attr,
|
| 71 |
+
layer_count=len(layers),
|
| 72 |
+
attention_impl=attn_impl,
|
| 73 |
+
)
|
| 74 |
+
return ModelSupport(
|
| 75 |
+
supports_attribution=True,
|
| 76 |
+
reason=None,
|
| 77 |
+
layer_path=layer_path,
|
| 78 |
+
attention_attr=attention_attr,
|
| 79 |
+
layer_count=len(layers),
|
| 80 |
+
attention_impl=attn_impl,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return ModelSupport(
|
| 84 |
+
supports_attribution=False,
|
| 85 |
+
reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
|
| 86 |
+
layer_path=layer_path,
|
| 87 |
+
attention_attr=None,
|
| 88 |
+
layer_count=len(layers),
|
| 89 |
+
attention_impl=attn_impl,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
|
| 94 |
+
support = describe_model_support(model)
|
| 95 |
+
if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
|
| 96 |
+
reason = support.reason or "Model does not support attribution analysis."
|
| 97 |
+
raise RuntimeError(reason)
|
| 98 |
+
|
| 99 |
+
layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
|
| 100 |
+
if layers is None:
|
| 101 |
+
raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
|
| 102 |
+
return list(layers), support.layer_path, support.attention_attr
|
app/core/runtime.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
|
| 8 |
+
|
| 9 |
+
from app.core.model_support import ModelSupport, describe_model_support
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(slots=True)
|
| 13 |
+
class ModelBundle:
|
| 14 |
+
model_name: str
|
| 15 |
+
model: PreTrainedModel
|
| 16 |
+
tokenizer: PreTrainedTokenizerBase
|
| 17 |
+
device: torch.device
|
| 18 |
+
dtype: torch.dtype
|
| 19 |
+
capability: ModelSupport
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def resolve_dtype(preference: str, device: torch.device) -> torch.dtype:
|
| 23 |
+
if preference == "float32":
|
| 24 |
+
return torch.float32
|
| 25 |
+
if preference == "float16":
|
| 26 |
+
return torch.float16
|
| 27 |
+
if preference == "bfloat16":
|
| 28 |
+
return torch.bfloat16
|
| 29 |
+
if device.type == "cuda":
|
| 30 |
+
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 31 |
+
if device.type == "mps":
|
| 32 |
+
return torch.float16
|
| 33 |
+
return torch.float32
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def resolve_device(preference: str = "auto") -> torch.device:
|
| 37 |
+
if preference == "cuda":
|
| 38 |
+
if not torch.cuda.is_available():
|
| 39 |
+
raise RuntimeError("CUDA requested but not available.")
|
| 40 |
+
return torch.device("cuda")
|
| 41 |
+
if preference == "mps":
|
| 42 |
+
if not torch.backends.mps.is_available():
|
| 43 |
+
raise RuntimeError("MPS requested but not available.")
|
| 44 |
+
return torch.device("mps")
|
| 45 |
+
if preference == "cpu":
|
| 46 |
+
return torch.device("cpu")
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
return torch.device("cuda")
|
| 49 |
+
if torch.backends.mps.is_available():
|
| 50 |
+
return torch.device("mps")
|
| 51 |
+
return torch.device("cpu")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@lru_cache(maxsize=2)
|
| 55 |
+
def load_model_bundle(
|
| 56 |
+
model_name: str,
|
| 57 |
+
device_preference: str = "auto",
|
| 58 |
+
dtype_preference: str = "auto",
|
| 59 |
+
attn_implementation: str = "eager",
|
| 60 |
+
trust_remote_code: bool = True,
|
| 61 |
+
low_cpu_mem_usage: bool = True,
|
| 62 |
+
) -> ModelBundle:
|
| 63 |
+
device = resolve_device(device_preference)
|
| 64 |
+
dtype = resolve_dtype(dtype_preference, device)
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
| 66 |
+
if tokenizer.pad_token is None:
|
| 67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 68 |
+
|
| 69 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
model_name,
|
| 71 |
+
trust_remote_code=trust_remote_code,
|
| 72 |
+
attn_implementation=attn_implementation,
|
| 73 |
+
torch_dtype=dtype,
|
| 74 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 75 |
+
)
|
| 76 |
+
model.to(device)
|
| 77 |
+
model.eval()
|
| 78 |
+
capability = describe_model_support(model)
|
| 79 |
+
|
| 80 |
+
return ModelBundle(
|
| 81 |
+
model_name=model_name,
|
| 82 |
+
model=model,
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
device=device,
|
| 85 |
+
dtype=dtype,
|
| 86 |
+
capability=capability,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def compute_attribution_analysis(**kwargs):
|
| 91 |
+
from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis
|
| 92 |
+
|
| 93 |
+
return _compute_attribution_analysis(**kwargs)
|
app/core/runtime_pipeline.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from app.analysis.sentence_split import split_sentences
|
| 8 |
+
from app.analysis.summaries import (
|
| 9 |
+
compute_incoming_importance,
|
| 10 |
+
compute_outgoing_importance,
|
| 11 |
+
compute_top_edges,
|
| 12 |
+
)
|
| 13 |
+
from app.analysis.suppression import compute_attribution_matrix
|
| 14 |
+
from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit
|
| 15 |
+
from app.analysis.validation import validate_top_edges
|
| 16 |
+
from app.core.config import get_settings
|
| 17 |
+
from app.core.runtime import load_model_bundle
|
| 18 |
+
from app.core.schemas import AnalysisResult, GenerationResult
|
| 19 |
+
from app.generation.service import generate_answer_and_trace
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_attribution_analysis(
|
| 23 |
+
*,
|
| 24 |
+
question: str,
|
| 25 |
+
model_name: str | None = None,
|
| 26 |
+
take_log: bool | None = None,
|
| 27 |
+
max_sentences: int | None = None,
|
| 28 |
+
max_trace_tokens: int | None = None,
|
| 29 |
+
validate_top_k: int = 0,
|
| 30 |
+
max_new_tokens: int = 512,
|
| 31 |
+
temperature: float = 0.6,
|
| 32 |
+
top_p: float = 0.95,
|
| 33 |
+
device_preference: str | None = None,
|
| 34 |
+
dtype_preference: str | None = None,
|
| 35 |
+
attn_implementation: str | None = None,
|
| 36 |
+
trust_remote_code: bool | None = None,
|
| 37 |
+
low_cpu_mem_usage: bool | None = None,
|
| 38 |
+
) -> AnalysisResult:
|
| 39 |
+
generation = None
|
| 40 |
+
return analyze_generation_result(
|
| 41 |
+
question=question,
|
| 42 |
+
generation=generation,
|
| 43 |
+
model_name=model_name,
|
| 44 |
+
take_log=take_log,
|
| 45 |
+
max_sentences=max_sentences,
|
| 46 |
+
max_trace_tokens=max_trace_tokens,
|
| 47 |
+
validate_top_k=validate_top_k,
|
| 48 |
+
max_new_tokens=max_new_tokens,
|
| 49 |
+
temperature=temperature,
|
| 50 |
+
top_p=top_p,
|
| 51 |
+
device_preference=device_preference,
|
| 52 |
+
dtype_preference=dtype_preference,
|
| 53 |
+
attn_implementation=attn_implementation,
|
| 54 |
+
trust_remote_code=trust_remote_code,
|
| 55 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def analyze_generation_result(
|
| 60 |
+
*,
|
| 61 |
+
question: str,
|
| 62 |
+
generation: GenerationResult | None = None,
|
| 63 |
+
model_name: str | None = None,
|
| 64 |
+
take_log: bool | None = None,
|
| 65 |
+
max_sentences: int | None = None,
|
| 66 |
+
max_trace_tokens: int | None = None,
|
| 67 |
+
validate_top_k: int = 0,
|
| 68 |
+
max_new_tokens: int = 512,
|
| 69 |
+
temperature: float = 0.6,
|
| 70 |
+
top_p: float = 0.95,
|
| 71 |
+
device_preference: str | None = None,
|
| 72 |
+
dtype_preference: str | None = None,
|
| 73 |
+
attn_implementation: str | None = None,
|
| 74 |
+
trust_remote_code: bool | None = None,
|
| 75 |
+
low_cpu_mem_usage: bool | None = None,
|
| 76 |
+
) -> AnalysisResult:
|
| 77 |
+
settings = get_settings()
|
| 78 |
+
resolved_model_name = model_name or settings.model_name
|
| 79 |
+
resolved_take_log = settings.take_log if take_log is None else take_log
|
| 80 |
+
resolved_max_sentences = max_sentences or settings.max_sentences
|
| 81 |
+
resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens
|
| 82 |
+
resolved_device = device_preference or settings.device_preference
|
| 83 |
+
resolved_dtype = dtype_preference or settings.dtype_preference
|
| 84 |
+
resolved_attn_implementation = attn_implementation or settings.attn_implementation
|
| 85 |
+
resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code
|
| 86 |
+
resolved_low_cpu_mem_usage = (
|
| 87 |
+
settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
bundle = load_model_bundle(
|
| 91 |
+
resolved_model_name,
|
| 92 |
+
device_preference=resolved_device,
|
| 93 |
+
dtype_preference=resolved_dtype,
|
| 94 |
+
attn_implementation=resolved_attn_implementation,
|
| 95 |
+
trust_remote_code=resolved_trust_remote_code,
|
| 96 |
+
low_cpu_mem_usage=resolved_low_cpu_mem_usage,
|
| 97 |
+
)
|
| 98 |
+
if not bundle.capability.supports_attribution:
|
| 99 |
+
reason = bundle.capability.reason or "Model does not support attribution analysis."
|
| 100 |
+
raise RuntimeError(reason)
|
| 101 |
+
if generation is None:
|
| 102 |
+
generation = generate_answer_and_trace(
|
| 103 |
+
question=question,
|
| 104 |
+
model_name=resolved_model_name,
|
| 105 |
+
model=bundle.model,
|
| 106 |
+
tokenizer=bundle.tokenizer,
|
| 107 |
+
max_new_tokens=max_new_tokens,
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
top_p=top_p,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
truncated_text = truncate_text_to_token_limit(
|
| 113 |
+
generation.normalized_trace_text,
|
| 114 |
+
bundle.tokenizer,
|
| 115 |
+
resolved_max_trace_tokens,
|
| 116 |
+
)
|
| 117 |
+
sentence_spans = split_sentences(truncated_text)
|
| 118 |
+
if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences:
|
| 119 |
+
sentence_spans = sentence_spans[:resolved_max_sentences]
|
| 120 |
+
truncated_text = truncated_text[: sentence_spans[-1].end_char]
|
| 121 |
+
sentence_spans = split_sentences(truncated_text)
|
| 122 |
+
|
| 123 |
+
if not sentence_spans:
|
| 124 |
+
raise RuntimeError("Trace normalization produced no analyzable sentences.")
|
| 125 |
+
|
| 126 |
+
mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer)
|
| 127 |
+
input_ids = mapping.input_ids.to(bundle.device)
|
| 128 |
+
computation = compute_attribution_matrix(
|
| 129 |
+
input_ids=input_ids,
|
| 130 |
+
token_ranges=mapping.token_ranges,
|
| 131 |
+
model=bundle.model,
|
| 132 |
+
take_log=resolved_take_log,
|
| 133 |
+
max_trace_tokens=resolved_max_trace_tokens,
|
| 134 |
+
max_sentences=resolved_max_sentences,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
outgoing = compute_outgoing_importance(computation.raw_matrix)
|
| 138 |
+
incoming = compute_incoming_importance(computation.raw_matrix)
|
| 139 |
+
top_edges = compute_top_edges(computation.raw_matrix, top_k=10)
|
| 140 |
+
validation = validate_top_edges(
|
| 141 |
+
model=bundle.model,
|
| 142 |
+
input_ids=input_ids,
|
| 143 |
+
token_ranges=mapping.token_ranges,
|
| 144 |
+
top_edges=top_edges,
|
| 145 |
+
baseline_token_nll=computation.token_nll,
|
| 146 |
+
top_k=validate_top_k,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return AnalysisResult(
|
| 150 |
+
question=question,
|
| 151 |
+
model_name=resolved_model_name,
|
| 152 |
+
answer=generation.answer,
|
| 153 |
+
raw_trace_text=generation.raw_trace_text,
|
| 154 |
+
normalized_trace_text=truncated_text,
|
| 155 |
+
sentences=[span.text for span in sentence_spans],
|
| 156 |
+
sentence_token_ranges=mapping.token_ranges,
|
| 157 |
+
suppression_matrix=computation.matrix.tolist(),
|
| 158 |
+
raw_suppression_matrix=computation.raw_matrix.tolist(),
|
| 159 |
+
outgoing_importance=outgoing,
|
| 160 |
+
incoming_importance=incoming,
|
| 161 |
+
top_edges=top_edges,
|
| 162 |
+
runtime_metadata=computation.runtime_metadata,
|
| 163 |
+
validation_metadata=validation,
|
| 164 |
+
extra_metadata={
|
| 165 |
+
"raw_generation_text": generation.raw_generation_text,
|
| 166 |
+
"generation_metadata": generation.generation_metadata.model_dump(),
|
| 167 |
+
"effective_runtime": {
|
| 168 |
+
"device_preference": resolved_device,
|
| 169 |
+
"dtype_preference": resolved_dtype,
|
| 170 |
+
"attn_implementation": resolved_attn_implementation,
|
| 171 |
+
"trust_remote_code": resolved_trust_remote_code,
|
| 172 |
+
"low_cpu_mem_usage": resolved_low_cpu_mem_usage,
|
| 173 |
+
},
|
| 174 |
+
},
|
| 175 |
+
)
|
app/core/schemas.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TopEdge(BaseModel):
|
| 9 |
+
source_sentence_idx: int
|
| 10 |
+
target_sentence_idx: int
|
| 11 |
+
score: float
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ModelCapability(BaseModel):
|
| 15 |
+
supports_attribution: bool
|
| 16 |
+
reason: str | None = None
|
| 17 |
+
layer_path: str | None = None
|
| 18 |
+
attention_attr: str | None = None
|
| 19 |
+
layer_count: int = 0
|
| 20 |
+
attention_impl: str | None = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RuntimeMetadata(BaseModel):
|
| 24 |
+
forward_pass_ms: float = 0.0
|
| 25 |
+
backward_pass_ms: float = 0.0
|
| 26 |
+
matrix_computation_ms: float = 0.0
|
| 27 |
+
total_analysis_ms: float = 0.0
|
| 28 |
+
num_layers: int = 0
|
| 29 |
+
num_heads: int = 0
|
| 30 |
+
sequence_length_tokens: int = 0
|
| 31 |
+
sentence_count: int = 0
|
| 32 |
+
device: str = "unknown"
|
| 33 |
+
dtype: str = "unknown"
|
| 34 |
+
attention_impl: str = "unknown"
|
| 35 |
+
max_trace_tokens: int = 0
|
| 36 |
+
max_sentences: int = 0
|
| 37 |
+
capability: ModelCapability = Field(default_factory=lambda: ModelCapability(supports_attribution=False))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ValidationMetadata(BaseModel):
|
| 41 |
+
enabled: bool = False
|
| 42 |
+
top_k: int = 0
|
| 43 |
+
pearson: float | None = None
|
| 44 |
+
spearman: float | None = None
|
| 45 |
+
compared_edges: list[TopEdge] = Field(default_factory=list)
|
| 46 |
+
notes: str | None = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GenerationMetadata(BaseModel):
|
| 50 |
+
max_new_tokens: int
|
| 51 |
+
temperature: float
|
| 52 |
+
top_p: float
|
| 53 |
+
do_sample: bool
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class GenerationResult(BaseModel):
|
| 57 |
+
question: str
|
| 58 |
+
model_name: str
|
| 59 |
+
answer: str
|
| 60 |
+
raw_generation_text: str
|
| 61 |
+
raw_trace_text: str
|
| 62 |
+
normalized_trace_text: str
|
| 63 |
+
generation_metadata: GenerationMetadata
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class AnalysisResult(BaseModel):
|
| 67 |
+
question: str
|
| 68 |
+
model_name: str
|
| 69 |
+
answer: str
|
| 70 |
+
raw_trace_text: str
|
| 71 |
+
normalized_trace_text: str
|
| 72 |
+
sentences: list[str]
|
| 73 |
+
sentence_token_ranges: list[tuple[int, int]]
|
| 74 |
+
suppression_matrix: list[list[float]]
|
| 75 |
+
raw_suppression_matrix: list[list[float]] | None = None
|
| 76 |
+
outgoing_importance: list[float]
|
| 77 |
+
incoming_importance: list[float]
|
| 78 |
+
top_edges: list[TopEdge]
|
| 79 |
+
runtime_metadata: RuntimeMetadata
|
| 80 |
+
validation_metadata: ValidationMetadata | None = None
|
| 81 |
+
extra_metadata: dict[str, Any] = Field(default_factory=dict)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AnalysisRequest(BaseModel):
|
| 85 |
+
question: str
|
| 86 |
+
model_name: str | None = None
|
| 87 |
+
take_log: bool | None = None
|
| 88 |
+
max_sentences: int | None = None
|
| 89 |
+
max_trace_tokens: int | None = None
|
| 90 |
+
validate_top_k: int = 0
|
| 91 |
+
max_new_tokens: int = 256
|
| 92 |
+
temperature: float = 0.6
|
| 93 |
+
top_p: float = 0.95
|
| 94 |
+
device_preference: str | None = None
|
| 95 |
+
dtype_preference: str | None = None
|
| 96 |
+
attn_implementation: str | None = None
|
| 97 |
+
trust_remote_code: bool | None = None
|
| 98 |
+
low_cpu_mem_usage: bool | None = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class HealthResponse(BaseModel):
|
| 102 |
+
status: str
|
| 103 |
+
model_name: str
|
| 104 |
+
device_preference: str
|
| 105 |
+
dtype_preference: str
|
| 106 |
+
preload_model: bool
|
| 107 |
+
cuda_available: bool
|
| 108 |
+
mps_available: bool
|
| 109 |
+
require_auth: bool
|
| 110 |
+
public_api_enabled: bool
|
| 111 |
+
max_queued_jobs: int
|
| 112 |
+
max_active_jobs_per_user: int
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class WarmupResponse(BaseModel):
|
| 116 |
+
status: str
|
| 117 |
+
model_name: str
|
| 118 |
+
device: str
|
| 119 |
+
dtype: str
|
| 120 |
+
capability: ModelCapability
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CurrentUserResponse(BaseModel):
|
| 124 |
+
authenticated: bool
|
| 125 |
+
auth_required: bool
|
| 126 |
+
username: str | None = None
|
| 127 |
+
full_name: str | None = None
|
| 128 |
+
avatar_url: str | None = None
|
| 129 |
+
login_url: str = "/oauth/huggingface/login"
|
| 130 |
+
logout_url: str = "/oauth/huggingface/logout"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SessionCreateRequest(BaseModel):
|
| 134 |
+
question: str
|
| 135 |
+
model_name: str | None = None
|
| 136 |
+
take_log: bool | None = None
|
| 137 |
+
max_sentences: int | None = None
|
| 138 |
+
max_trace_tokens: int | None = None
|
| 139 |
+
validate_top_k: int = 0
|
| 140 |
+
max_new_tokens: int = 256
|
| 141 |
+
temperature: float = 0.6
|
| 142 |
+
top_p: float = 0.95
|
| 143 |
+
device_preference: str | None = None
|
| 144 |
+
dtype_preference: str | None = None
|
| 145 |
+
attn_implementation: str | None = None
|
| 146 |
+
trust_remote_code: bool | None = None
|
| 147 |
+
low_cpu_mem_usage: bool | None = None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SessionResponse(BaseModel):
|
| 151 |
+
id: str
|
| 152 |
+
status: str
|
| 153 |
+
question: str
|
| 154 |
+
model_name: str
|
| 155 |
+
error: str | None = None
|
| 156 |
+
created_at: str
|
| 157 |
+
updated_at: str
|
| 158 |
+
answer: str | None = None
|
| 159 |
+
raw_trace_text: str | None = None
|
| 160 |
+
normalized_trace_text: str | None = None
|
| 161 |
+
sentences: list[str] | None = None
|
| 162 |
+
generation_metadata: dict[str, Any] | None = None
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class SessionResultResponse(BaseModel):
|
| 166 |
+
session: SessionResponse
|
| 167 |
+
analysis: AnalysisResult | None = None
|
app/frontend/app.js
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const form = document.querySelector("#question-form");
|
| 2 |
+
const submitButton = document.querySelector("#submit-button");
|
| 3 |
+
const statusChip = document.querySelector("#status-chip");
|
| 4 |
+
const statusDetail = document.querySelector("#status-detail");
|
| 5 |
+
const authSummary = document.querySelector("#auth-summary");
|
| 6 |
+
const loginLink = document.querySelector("#login-link");
|
| 7 |
+
const logoutLink = document.querySelector("#logout-link");
|
| 8 |
+
const recentSessions = document.querySelector("#recent-sessions");
|
| 9 |
+
const answerOutput = document.querySelector("#answer-output");
|
| 10 |
+
const sessionIdOutput = document.querySelector("#session-id");
|
| 11 |
+
const sentenceList = document.querySelector("#sentence-list");
|
| 12 |
+
const traceCount = document.querySelector("#trace-count");
|
| 13 |
+
const heatmap = document.querySelector("#heatmap");
|
| 14 |
+
const selectionDetails = document.querySelector("#selection-details");
|
| 15 |
+
const metrics = document.querySelector("#metrics");
|
| 16 |
+
const topEdges = document.querySelector("#top-edges");
|
| 17 |
+
const exportJson = document.querySelector("#export-json");
|
| 18 |
+
const exportCsv = document.querySelector("#export-csv");
|
| 19 |
+
|
| 20 |
+
let activeSessionId = null;
|
| 21 |
+
let pollHandle = null;
|
| 22 |
+
let activePayload = null;
|
| 23 |
+
let selectedSentenceIdx = null;
|
| 24 |
+
let selectedCell = null;
|
| 25 |
+
let currentUser = null;
|
| 26 |
+
|
| 27 |
+
function currentSentences() {
|
| 28 |
+
return activePayload?.analysis?.sentences || activePayload?.session?.sentences || [];
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
function setStatus(label, detail) {
|
| 32 |
+
statusChip.textContent = label;
|
| 33 |
+
statusDetail.textContent = detail;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
function clearSelection() {
|
| 37 |
+
selectedSentenceIdx = null;
|
| 38 |
+
selectedCell = null;
|
| 39 |
+
selectionDetails.textContent = "Choose a sentence or heatmap cell.";
|
| 40 |
+
metrics.innerHTML = "";
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
function setExportLinks(sessionId, enabled) {
|
| 44 |
+
exportJson.href = enabled ? `/api/sessions/${sessionId}/export.json` : "#";
|
| 45 |
+
exportCsv.href = enabled ? `/api/sessions/${sessionId}/export.csv` : "#";
|
| 46 |
+
exportJson.classList.toggle("is-disabled", !enabled);
|
| 47 |
+
exportCsv.classList.toggle("is-disabled", !enabled);
|
| 48 |
+
exportJson.setAttribute("aria-disabled", String(!enabled));
|
| 49 |
+
exportCsv.setAttribute("aria-disabled", String(!enabled));
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
function renderRecentSessions(sessions) {
|
| 53 |
+
recentSessions.innerHTML = "";
|
| 54 |
+
if (!currentUser?.authenticated) {
|
| 55 |
+
recentSessions.textContent = currentUser?.auth_required
|
| 56 |
+
? "Sign in to view your jobs."
|
| 57 |
+
: "Anonymous mode is enabled.";
|
| 58 |
+
return;
|
| 59 |
+
}
|
| 60 |
+
if (!sessions.length) {
|
| 61 |
+
recentSessions.textContent = "No jobs yet on this running Space instance.";
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
sessions.forEach((session) => {
|
| 65 |
+
const item = document.createElement("button");
|
| 66 |
+
item.type = "button";
|
| 67 |
+
item.className = "session-card";
|
| 68 |
+
if (session.id === activeSessionId) {
|
| 69 |
+
item.classList.add("is-active");
|
| 70 |
+
}
|
| 71 |
+
item.innerHTML = `<strong>${session.status}</strong>${session.question}`;
|
| 72 |
+
item.addEventListener("click", () => {
|
| 73 |
+
startPolling(session.id);
|
| 74 |
+
});
|
| 75 |
+
recentSessions.appendChild(item);
|
| 76 |
+
});
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
async function loadRecentSessions() {
|
| 80 |
+
if (!currentUser?.authenticated && currentUser?.auth_required) {
|
| 81 |
+
renderRecentSessions([]);
|
| 82 |
+
return;
|
| 83 |
+
}
|
| 84 |
+
const response = await fetch("/api/sessions?limit=8");
|
| 85 |
+
if (!response.ok) {
|
| 86 |
+
renderRecentSessions([]);
|
| 87 |
+
return;
|
| 88 |
+
}
|
| 89 |
+
renderRecentSessions(await response.json());
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
function renderAuth() {
|
| 93 |
+
if (!currentUser) {
|
| 94 |
+
authSummary.textContent = "Checking sign-in status.";
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
loginLink.hidden = currentUser.authenticated;
|
| 98 |
+
logoutLink.hidden = !currentUser.authenticated;
|
| 99 |
+
if (currentUser.authenticated) {
|
| 100 |
+
authSummary.textContent = `Signed in as ${currentUser.full_name || currentUser.username}.`;
|
| 101 |
+
return;
|
| 102 |
+
}
|
| 103 |
+
authSummary.textContent = currentUser.auth_required
|
| 104 |
+
? "Sign in with Hugging Face to run analysis jobs."
|
| 105 |
+
: "Anonymous access is enabled for this Space.";
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
function renderSession(session) {
|
| 109 |
+
sessionIdOutput.textContent = session.id ? `Session ${session.id.slice(0, 8)}` : "";
|
| 110 |
+
answerOutput.textContent = session.answer || "Waiting for generated answer.";
|
| 111 |
+
const sentences = currentSentences();
|
| 112 |
+
traceCount.textContent = sentences.length ? `${sentences.length} sentences` : "";
|
| 113 |
+
|
| 114 |
+
sentenceList.innerHTML = "";
|
| 115 |
+
sentences.forEach((sentence, index) => {
|
| 116 |
+
const item = document.createElement("li");
|
| 117 |
+
item.className = "sentence-item";
|
| 118 |
+
item.innerHTML = `<strong>${index}</strong> ${sentence}`;
|
| 119 |
+
item.addEventListener("click", () => {
|
| 120 |
+
selectedSentenceIdx = index;
|
| 121 |
+
selectedCell = null;
|
| 122 |
+
renderSelection();
|
| 123 |
+
});
|
| 124 |
+
sentenceList.appendChild(item);
|
| 125 |
+
});
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
function colorForValue(value, maxAbs) {
|
| 129 |
+
if (!maxAbs) {
|
| 130 |
+
return "rgba(224, 223, 218, 0.8)";
|
| 131 |
+
}
|
| 132 |
+
const normalized = Math.max(-1, Math.min(1, value / maxAbs));
|
| 133 |
+
if (normalized >= 0) {
|
| 134 |
+
const alpha = 0.12 + normalized * 0.88;
|
| 135 |
+
return `rgba(14, 90, 138, ${alpha.toFixed(3)})`;
|
| 136 |
+
}
|
| 137 |
+
const alpha = 0.12 + Math.abs(normalized) * 0.88;
|
| 138 |
+
return `rgba(215, 106, 52, ${alpha.toFixed(3)})`;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
function renderHeatmap(result) {
|
| 142 |
+
const matrix = result?.suppression_matrix;
|
| 143 |
+
if (!matrix || !matrix.length) {
|
| 144 |
+
heatmap.className = "heatmap placeholder-box";
|
| 145 |
+
heatmap.textContent = "Analysis pending.";
|
| 146 |
+
return;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
const flatValues = matrix.flat();
|
| 150 |
+
const maxAbs = Math.max(...flatValues.map((value) => Math.abs(value)));
|
| 151 |
+
heatmap.className = "heatmap";
|
| 152 |
+
heatmap.innerHTML = "";
|
| 153 |
+
const grid = document.createElement("div");
|
| 154 |
+
grid.className = "heatmap-grid";
|
| 155 |
+
grid.style.gridTemplateColumns = `repeat(${matrix.length}, 32px)`;
|
| 156 |
+
|
| 157 |
+
matrix.forEach((row, rowIndex) => {
|
| 158 |
+
row.forEach((value, colIndex) => {
|
| 159 |
+
const cell = document.createElement("button");
|
| 160 |
+
cell.type = "button";
|
| 161 |
+
cell.className = "heatmap-cell";
|
| 162 |
+
cell.style.background = colorForValue(value, maxAbs);
|
| 163 |
+
cell.title = `target ${rowIndex} ← source ${colIndex}: ${value.toFixed(4)}`;
|
| 164 |
+
if (selectedCell && selectedCell.row === rowIndex && selectedCell.col === colIndex) {
|
| 165 |
+
cell.classList.add("is-selected");
|
| 166 |
+
}
|
| 167 |
+
cell.addEventListener("click", () => {
|
| 168 |
+
selectedSentenceIdx = null;
|
| 169 |
+
selectedCell = { row: rowIndex, col: colIndex, value };
|
| 170 |
+
renderSelection();
|
| 171 |
+
});
|
| 172 |
+
grid.appendChild(cell);
|
| 173 |
+
});
|
| 174 |
+
});
|
| 175 |
+
heatmap.appendChild(grid);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
function renderTopEdges(result) {
|
| 179 |
+
topEdges.innerHTML = "";
|
| 180 |
+
const edges = result?.top_edges || [];
|
| 181 |
+
if (!edges.length) {
|
| 182 |
+
return;
|
| 183 |
+
}
|
| 184 |
+
edges.slice(0, 5).forEach((edge) => {
|
| 185 |
+
const item = document.createElement("div");
|
| 186 |
+
item.className = "edge-card";
|
| 187 |
+
item.innerHTML = `<strong>${edge.source_sentence_idx} → ${edge.target_sentence_idx}</strong>${edge.score.toFixed(4)}`;
|
| 188 |
+
topEdges.appendChild(item);
|
| 189 |
+
});
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
function renderSelection() {
|
| 193 |
+
Array.from(sentenceList.children).forEach((item, index) => {
|
| 194 |
+
item.classList.toggle("is-active", selectedSentenceIdx === index);
|
| 195 |
+
});
|
| 196 |
+
|
| 197 |
+
const result = activePayload?.analysis;
|
| 198 |
+
const session = activePayload?.session;
|
| 199 |
+
if (!session) {
|
| 200 |
+
clearSelection();
|
| 201 |
+
return;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
metrics.innerHTML = "";
|
| 205 |
+
if (selectedSentenceIdx != null && result) {
|
| 206 |
+
const outgoing = result.outgoing_importance[selectedSentenceIdx] ?? 0;
|
| 207 |
+
const incoming = result.incoming_importance[selectedSentenceIdx] ?? 0;
|
| 208 |
+
selectionDetails.innerHTML = `<strong>Sentence ${selectedSentenceIdx}</strong><br>${currentSentences()[selectedSentenceIdx] || ""}`;
|
| 209 |
+
metrics.innerHTML = `
|
| 210 |
+
<div class="metric-card"><strong>Outgoing impact</strong>${outgoing.toFixed(4)}</div>
|
| 211 |
+
<div class="metric-card"><strong>Incoming dependence</strong>${incoming.toFixed(4)}</div>
|
| 212 |
+
`;
|
| 213 |
+
return;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
if (selectedCell) {
|
| 217 |
+
selectionDetails.innerHTML = `<strong>Edge ${selectedCell.col} → ${selectedCell.row}</strong><br>Influence score ${selectedCell.value.toFixed(4)}`;
|
| 218 |
+
metrics.innerHTML = `
|
| 219 |
+
<div class="metric-card"><strong>Source sentence</strong>${selectedCell.col}</div>
|
| 220 |
+
<div class="metric-card"><strong>Target sentence</strong>${selectedCell.row}</div>
|
| 221 |
+
`;
|
| 222 |
+
return;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
selectionDetails.textContent = "Choose a sentence or heatmap cell.";
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
async function fetchSession(sessionId) {
|
| 229 |
+
const response = await fetch(`/api/sessions/${sessionId}/result`);
|
| 230 |
+
if (!response.ok) {
|
| 231 |
+
throw new Error(`Failed to fetch session ${sessionId}`);
|
| 232 |
+
}
|
| 233 |
+
return response.json();
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
function updateFromPayload(payload) {
|
| 237 |
+
activePayload = payload;
|
| 238 |
+
activeSessionId = payload.session.id;
|
| 239 |
+
renderSession(payload.session);
|
| 240 |
+
renderHeatmap(payload.analysis);
|
| 241 |
+
renderTopEdges(payload.analysis);
|
| 242 |
+
renderSelection();
|
| 243 |
+
setExportLinks(payload.session.id, payload.session.status === "completed");
|
| 244 |
+
|
| 245 |
+
const { status, error } = payload.session;
|
| 246 |
+
if (status === "queued") setStatus("Queued", "Waiting for a worker slot.");
|
| 247 |
+
if (status === "generating") setStatus("Generating", "The model is producing an answer and visible trace.");
|
| 248 |
+
if (status === "answer_ready") setStatus("Analysis pending", "Answer is ready. Attribution analysis is starting.");
|
| 249 |
+
if (status === "analyzing") setStatus("Analyzing", "Running forward and backward passes for sentence influence.");
|
| 250 |
+
if (status === "completed") setStatus("Completed", "Analysis finished.");
|
| 251 |
+
if (status === "failed") setStatus("Failed", error || "The session failed.");
|
| 252 |
+
|
| 253 |
+
if (["completed", "failed"].includes(status) && pollHandle) {
|
| 254 |
+
window.clearInterval(pollHandle);
|
| 255 |
+
pollHandle = null;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
loadRecentSessions().catch(() => {});
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
async function startPolling(sessionId) {
|
| 262 |
+
activeSessionId = sessionId;
|
| 263 |
+
if (pollHandle) {
|
| 264 |
+
window.clearInterval(pollHandle);
|
| 265 |
+
}
|
| 266 |
+
const tick = async () => {
|
| 267 |
+
try {
|
| 268 |
+
const payload = await fetchSession(sessionId);
|
| 269 |
+
updateFromPayload(payload);
|
| 270 |
+
} catch (error) {
|
| 271 |
+
setStatus("Error", String(error));
|
| 272 |
+
window.clearInterval(pollHandle);
|
| 273 |
+
pollHandle = null;
|
| 274 |
+
}
|
| 275 |
+
};
|
| 276 |
+
await tick();
|
| 277 |
+
pollHandle = window.setInterval(tick, 2500);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
form.addEventListener("submit", async (event) => {
|
| 281 |
+
event.preventDefault();
|
| 282 |
+
if (!currentUser?.authenticated && currentUser?.auth_required) {
|
| 283 |
+
window.location.href = loginLink.href;
|
| 284 |
+
return;
|
| 285 |
+
}
|
| 286 |
+
submitButton.disabled = true;
|
| 287 |
+
clearSelection();
|
| 288 |
+
sentenceList.innerHTML = "";
|
| 289 |
+
heatmap.className = "heatmap placeholder-box";
|
| 290 |
+
heatmap.textContent = "Analysis pending.";
|
| 291 |
+
topEdges.innerHTML = "";
|
| 292 |
+
answerOutput.textContent = "Waiting for generated answer.";
|
| 293 |
+
sessionIdOutput.textContent = "";
|
| 294 |
+
traceCount.textContent = "";
|
| 295 |
+
setExportLinks("", false);
|
| 296 |
+
setStatus("Submitting", "Creating session.");
|
| 297 |
+
|
| 298 |
+
const payload = {
|
| 299 |
+
question: document.querySelector("#question").value,
|
| 300 |
+
max_new_tokens: Number(document.querySelector("#max-new-tokens").value),
|
| 301 |
+
max_trace_tokens: Number(document.querySelector("#max-trace-tokens").value),
|
| 302 |
+
max_sentences: Number(document.querySelector("#max-sentences").value),
|
| 303 |
+
};
|
| 304 |
+
|
| 305 |
+
try {
|
| 306 |
+
const response = await fetch("/api/sessions", {
|
| 307 |
+
method: "POST",
|
| 308 |
+
headers: { "content-type": "application/json" },
|
| 309 |
+
body: JSON.stringify(payload),
|
| 310 |
+
});
|
| 311 |
+
if (!response.ok) {
|
| 312 |
+
throw new Error(await response.text());
|
| 313 |
+
}
|
| 314 |
+
const session = await response.json();
|
| 315 |
+
await startPolling(session.id);
|
| 316 |
+
} catch (error) {
|
| 317 |
+
setStatus("Error", String(error));
|
| 318 |
+
} finally {
|
| 319 |
+
submitButton.disabled = false;
|
| 320 |
+
}
|
| 321 |
+
});
|
| 322 |
+
|
| 323 |
+
async function initialize() {
|
| 324 |
+
setExportLinks("", false);
|
| 325 |
+
try {
|
| 326 |
+
const response = await fetch("/api/me");
|
| 327 |
+
currentUser = await response.json();
|
| 328 |
+
} catch (_error) {
|
| 329 |
+
currentUser = {
|
| 330 |
+
authenticated: false,
|
| 331 |
+
auth_required: true,
|
| 332 |
+
login_url: "/oauth/huggingface/login",
|
| 333 |
+
logout_url: "/oauth/huggingface/logout",
|
| 334 |
+
};
|
| 335 |
+
}
|
| 336 |
+
renderAuth();
|
| 337 |
+
await loadRecentSessions();
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
initialize();
|
app/frontend/index.html
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 6 |
+
<title>Thought Anchors</title>
|
| 7 |
+
<link rel="stylesheet" href="/static/styles.css" />
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<main class="page">
|
| 11 |
+
<section class="auth-strip panel">
|
| 12 |
+
<div>
|
| 13 |
+
<p class="eyebrow">Public Hugging Face Space</p>
|
| 14 |
+
<p id="auth-summary" class="lede compact">Checking sign-in status.</p>
|
| 15 |
+
</div>
|
| 16 |
+
<div class="auth-actions">
|
| 17 |
+
<a id="login-link" class="pill-link" href="/oauth/huggingface/login">Sign in with Hugging Face</a>
|
| 18 |
+
<a id="logout-link" class="pill-link secondary" href="/oauth/huggingface/logout">Sign out</a>
|
| 19 |
+
</div>
|
| 20 |
+
</section>
|
| 21 |
+
|
| 22 |
+
<section class="hero">
|
| 23 |
+
<p class="eyebrow">Online Chain-of-Thought Analysis</p>
|
| 24 |
+
<h1>Reasoning traces with white-box sentence influence.</h1>
|
| 25 |
+
<p class="lede">
|
| 26 |
+
Submit a question, watch the answer appear first, then inspect the sentence-to-sentence
|
| 27 |
+
attribution matrix once analysis completes. Results can be exported as JSON or CSV.
|
| 28 |
+
</p>
|
| 29 |
+
</section>
|
| 30 |
+
|
| 31 |
+
<section class="panel composer">
|
| 32 |
+
<form id="question-form">
|
| 33 |
+
<label class="label" for="question">Question</label>
|
| 34 |
+
<textarea id="question" name="question" rows="5" placeholder="Explain why the derivative of x^2 is 2x" required></textarea>
|
| 35 |
+
<div class="controls">
|
| 36 |
+
<label>
|
| 37 |
+
<span>Max new tokens</span>
|
| 38 |
+
<input id="max-new-tokens" type="number" min="16" max="512" value="128" />
|
| 39 |
+
</label>
|
| 40 |
+
<label>
|
| 41 |
+
<span>Max trace tokens</span>
|
| 42 |
+
<input id="max-trace-tokens" type="number" min="64" max="1024" value="256" />
|
| 43 |
+
</label>
|
| 44 |
+
<label>
|
| 45 |
+
<span>Max sentences</span>
|
| 46 |
+
<input id="max-sentences" type="number" min="4" max="40" value="16" />
|
| 47 |
+
</label>
|
| 48 |
+
</div>
|
| 49 |
+
<button id="submit-button" type="submit">Start analysis</button>
|
| 50 |
+
</form>
|
| 51 |
+
</section>
|
| 52 |
+
|
| 53 |
+
<section class="status-row">
|
| 54 |
+
<div class="status-chip" id="status-chip">Idle</div>
|
| 55 |
+
<div class="status-detail" id="status-detail">Submit a question to create a session.</div>
|
| 56 |
+
</section>
|
| 57 |
+
|
| 58 |
+
<section class="grid">
|
| 59 |
+
<article class="panel jobs-panel">
|
| 60 |
+
<div class="panel-heading">
|
| 61 |
+
<h2>Your Sessions</h2>
|
| 62 |
+
<span class="muted">Ephemeral instance history</span>
|
| 63 |
+
</div>
|
| 64 |
+
<div id="recent-sessions" class="recent-sessions placeholder">Sign in to view your jobs.</div>
|
| 65 |
+
</article>
|
| 66 |
+
|
| 67 |
+
<article class="panel answer-panel">
|
| 68 |
+
<div class="panel-heading">
|
| 69 |
+
<h2>Answer</h2>
|
| 70 |
+
<span id="session-id" class="muted"></span>
|
| 71 |
+
</div>
|
| 72 |
+
<p id="answer-output" class="placeholder">No answer yet.</p>
|
| 73 |
+
</article>
|
| 74 |
+
|
| 75 |
+
<article class="panel trace-panel">
|
| 76 |
+
<div class="panel-heading">
|
| 77 |
+
<h2>Reasoning Trace</h2>
|
| 78 |
+
<span class="muted" id="trace-count"></span>
|
| 79 |
+
</div>
|
| 80 |
+
<ol id="sentence-list" class="sentence-list"></ol>
|
| 81 |
+
</article>
|
| 82 |
+
|
| 83 |
+
<article class="panel heatmap-panel">
|
| 84 |
+
<div class="panel-heading">
|
| 85 |
+
<h2>Sentence Influence Matrix</h2>
|
| 86 |
+
<span class="muted">Rows: targets, columns: sources</span>
|
| 87 |
+
</div>
|
| 88 |
+
<div id="heatmap" class="heatmap placeholder-box">Analysis pending.</div>
|
| 89 |
+
</article>
|
| 90 |
+
|
| 91 |
+
<article class="panel details-panel">
|
| 92 |
+
<div class="panel-heading">
|
| 93 |
+
<h2>Selection</h2>
|
| 94 |
+
</div>
|
| 95 |
+
<div id="selection-details" class="placeholder">Choose a sentence or heatmap cell.</div>
|
| 96 |
+
<div class="metrics" id="metrics"></div>
|
| 97 |
+
<div class="edges" id="top-edges"></div>
|
| 98 |
+
<div class="export-actions">
|
| 99 |
+
<a id="export-json" class="pill-link is-disabled" href="#" aria-disabled="true">Export JSON</a>
|
| 100 |
+
<a id="export-csv" class="pill-link secondary is-disabled" href="#" aria-disabled="true">Export CSV</a>
|
| 101 |
+
</div>
|
| 102 |
+
</article>
|
| 103 |
+
</section>
|
| 104 |
+
</main>
|
| 105 |
+
|
| 106 |
+
<script src="/static/app.js" defer></script>
|
| 107 |
+
</body>
|
| 108 |
+
</html>
|
app/frontend/styles.css
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--bg: #f3efe6;
|
| 3 |
+
--panel: rgba(255, 252, 247, 0.9);
|
| 4 |
+
--ink: #1e1c19;
|
| 5 |
+
--muted: #6f655c;
|
| 6 |
+
--accent: #0e5a8a;
|
| 7 |
+
--accent-soft: #d2ecff;
|
| 8 |
+
--line: rgba(30, 28, 25, 0.12);
|
| 9 |
+
--warm: #d76a34;
|
| 10 |
+
--shadow: 0 20px 60px rgba(34, 25, 16, 0.08);
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
body {
|
| 18 |
+
margin: 0;
|
| 19 |
+
font-family: "Iowan Old Style", "Palatino Linotype", serif;
|
| 20 |
+
color: var(--ink);
|
| 21 |
+
background:
|
| 22 |
+
radial-gradient(circle at top left, rgba(14, 90, 138, 0.12), transparent 32%),
|
| 23 |
+
radial-gradient(circle at top right, rgba(215, 106, 52, 0.18), transparent 26%),
|
| 24 |
+
linear-gradient(180deg, #f8f4ec 0%, var(--bg) 100%);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.page {
|
| 28 |
+
max-width: 1440px;
|
| 29 |
+
margin: 0 auto;
|
| 30 |
+
padding: 40px 24px 48px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.hero {
|
| 34 |
+
max-width: 760px;
|
| 35 |
+
margin-bottom: 28px;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.compact {
|
| 39 |
+
margin: 0;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.auth-strip {
|
| 43 |
+
display: flex;
|
| 44 |
+
align-items: center;
|
| 45 |
+
justify-content: space-between;
|
| 46 |
+
gap: 16px;
|
| 47 |
+
margin-bottom: 18px;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.auth-actions,
|
| 51 |
+
.export-actions {
|
| 52 |
+
display: flex;
|
| 53 |
+
flex-wrap: wrap;
|
| 54 |
+
gap: 10px;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.eyebrow {
|
| 58 |
+
margin: 0 0 8px;
|
| 59 |
+
text-transform: uppercase;
|
| 60 |
+
letter-spacing: 0.12em;
|
| 61 |
+
font-size: 0.78rem;
|
| 62 |
+
color: var(--accent);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
h1, h2 {
|
| 66 |
+
margin: 0;
|
| 67 |
+
font-weight: 600;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
h1 {
|
| 71 |
+
font-size: clamp(2.4rem, 5vw, 4.4rem);
|
| 72 |
+
line-height: 0.95;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.lede {
|
| 76 |
+
margin: 18px 0 0;
|
| 77 |
+
font-size: 1.08rem;
|
| 78 |
+
color: var(--muted);
|
| 79 |
+
max-width: 60ch;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.panel {
|
| 83 |
+
background: var(--panel);
|
| 84 |
+
border: 1px solid var(--line);
|
| 85 |
+
border-radius: 24px;
|
| 86 |
+
padding: 20px;
|
| 87 |
+
box-shadow: var(--shadow);
|
| 88 |
+
backdrop-filter: blur(18px);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.composer {
|
| 92 |
+
margin-bottom: 16px;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.label,
|
| 96 |
+
.controls span {
|
| 97 |
+
display: block;
|
| 98 |
+
margin-bottom: 8px;
|
| 99 |
+
font-size: 0.92rem;
|
| 100 |
+
color: var(--muted);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
textarea,
|
| 104 |
+
input {
|
| 105 |
+
width: 100%;
|
| 106 |
+
border: 1px solid rgba(30, 28, 25, 0.16);
|
| 107 |
+
border-radius: 14px;
|
| 108 |
+
padding: 14px 16px;
|
| 109 |
+
font: inherit;
|
| 110 |
+
background: rgba(255, 255, 255, 0.72);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
textarea {
|
| 114 |
+
resize: vertical;
|
| 115 |
+
min-height: 132px;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.controls {
|
| 119 |
+
display: grid;
|
| 120 |
+
grid-template-columns: repeat(3, minmax(0, 1fr));
|
| 121 |
+
gap: 12px;
|
| 122 |
+
margin: 16px 0;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
button {
|
| 126 |
+
appearance: none;
|
| 127 |
+
border: none;
|
| 128 |
+
border-radius: 999px;
|
| 129 |
+
padding: 14px 20px;
|
| 130 |
+
font: inherit;
|
| 131 |
+
font-weight: 600;
|
| 132 |
+
color: white;
|
| 133 |
+
background: linear-gradient(135deg, var(--accent), #123954);
|
| 134 |
+
cursor: pointer;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.pill-link {
|
| 138 |
+
display: inline-flex;
|
| 139 |
+
align-items: center;
|
| 140 |
+
justify-content: center;
|
| 141 |
+
border-radius: 999px;
|
| 142 |
+
padding: 12px 18px;
|
| 143 |
+
text-decoration: none;
|
| 144 |
+
font-weight: 600;
|
| 145 |
+
color: white;
|
| 146 |
+
background: linear-gradient(135deg, var(--accent), #123954);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.pill-link.secondary {
|
| 150 |
+
color: var(--ink);
|
| 151 |
+
background: rgba(255, 255, 255, 0.78);
|
| 152 |
+
border: 1px solid var(--line);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.pill-link.is-disabled {
|
| 156 |
+
pointer-events: none;
|
| 157 |
+
opacity: 0.45;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
button:disabled {
|
| 161 |
+
opacity: 0.5;
|
| 162 |
+
cursor: wait;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.status-row {
|
| 166 |
+
display: flex;
|
| 167 |
+
align-items: center;
|
| 168 |
+
gap: 12px;
|
| 169 |
+
margin-bottom: 16px;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.status-chip {
|
| 173 |
+
display: inline-flex;
|
| 174 |
+
align-items: center;
|
| 175 |
+
border-radius: 999px;
|
| 176 |
+
padding: 8px 14px;
|
| 177 |
+
background: var(--accent-soft);
|
| 178 |
+
color: var(--accent);
|
| 179 |
+
font-size: 0.9rem;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.status-detail,
|
| 183 |
+
.muted,
|
| 184 |
+
.placeholder {
|
| 185 |
+
color: var(--muted);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
.grid {
|
| 189 |
+
display: grid;
|
| 190 |
+
grid-template-columns: 1.2fr 1fr;
|
| 191 |
+
gap: 16px;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.jobs-panel {
|
| 195 |
+
grid-column: 1 / -1;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
.answer-panel,
|
| 199 |
+
.trace-panel,
|
| 200 |
+
.heatmap-panel,
|
| 201 |
+
.details-panel {
|
| 202 |
+
min-height: 260px;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.answer-panel,
|
| 206 |
+
.trace-panel {
|
| 207 |
+
grid-column: span 1;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
.heatmap-panel,
|
| 211 |
+
.details-panel {
|
| 212 |
+
grid-column: span 1;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.panel-heading {
|
| 216 |
+
display: flex;
|
| 217 |
+
justify-content: space-between;
|
| 218 |
+
gap: 12px;
|
| 219 |
+
align-items: baseline;
|
| 220 |
+
margin-bottom: 16px;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.sentence-list {
|
| 224 |
+
list-style: none;
|
| 225 |
+
margin: 0;
|
| 226 |
+
padding: 0;
|
| 227 |
+
max-height: 420px;
|
| 228 |
+
overflow: auto;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.sentence-item {
|
| 232 |
+
padding: 10px 12px;
|
| 233 |
+
border-radius: 14px;
|
| 234 |
+
border: 1px solid transparent;
|
| 235 |
+
margin-bottom: 8px;
|
| 236 |
+
background: rgba(255,255,255,0.45);
|
| 237 |
+
cursor: pointer;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
.sentence-item:hover,
|
| 241 |
+
.sentence-item.is-active {
|
| 242 |
+
border-color: rgba(14, 90, 138, 0.25);
|
| 243 |
+
background: rgba(210, 236, 255, 0.5);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.sentence-item strong {
|
| 247 |
+
display: inline-block;
|
| 248 |
+
min-width: 2.5rem;
|
| 249 |
+
color: var(--accent);
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.heatmap {
|
| 253 |
+
display: grid;
|
| 254 |
+
gap: 4px;
|
| 255 |
+
overflow: auto;
|
| 256 |
+
min-height: 300px;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
.heatmap-grid {
|
| 260 |
+
display: grid;
|
| 261 |
+
gap: 4px;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.heatmap-cell {
|
| 265 |
+
width: 32px;
|
| 266 |
+
height: 32px;
|
| 267 |
+
border-radius: 8px;
|
| 268 |
+
border: 1px solid rgba(255,255,255,0.4);
|
| 269 |
+
cursor: pointer;
|
| 270 |
+
position: relative;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.heatmap-cell.is-selected {
|
| 274 |
+
outline: 2px solid var(--ink);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
.placeholder-box {
|
| 278 |
+
display: grid;
|
| 279 |
+
place-items: center;
|
| 280 |
+
border: 1px dashed var(--line);
|
| 281 |
+
border-radius: 18px;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.metric-card,
|
| 285 |
+
.edge-card {
|
| 286 |
+
padding: 12px 14px;
|
| 287 |
+
border-radius: 16px;
|
| 288 |
+
background: rgba(255,255,255,0.45);
|
| 289 |
+
border: 1px solid var(--line);
|
| 290 |
+
margin-top: 10px;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
.recent-sessions {
|
| 294 |
+
display: grid;
|
| 295 |
+
gap: 10px;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
.session-card {
|
| 299 |
+
color: var(--ink);
|
| 300 |
+
border: 1px solid var(--line);
|
| 301 |
+
border-radius: 16px;
|
| 302 |
+
background: rgba(255, 255, 255, 0.48);
|
| 303 |
+
padding: 14px;
|
| 304 |
+
cursor: pointer;
|
| 305 |
+
text-align: left;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
.session-card.is-active {
|
| 309 |
+
border-color: rgba(14, 90, 138, 0.25);
|
| 310 |
+
background: rgba(210, 236, 255, 0.42);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.session-card strong {
|
| 314 |
+
display: block;
|
| 315 |
+
margin-bottom: 4px;
|
| 316 |
+
color: var(--accent);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
.metric-card strong,
|
| 320 |
+
.edge-card strong {
|
| 321 |
+
display: block;
|
| 322 |
+
color: var(--accent);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
@media (max-width: 960px) {
|
| 326 |
+
.auth-strip,
|
| 327 |
+
.grid,
|
| 328 |
+
.controls {
|
| 329 |
+
grid-template-columns: 1fr;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
.auth-strip {
|
| 333 |
+
align-items: flex-start;
|
| 334 |
+
}
|
| 335 |
+
}
|
app/generation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Generation utilities for trace-producing model calls."""
|
app/generation/prompting.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
SYSTEM_PROMPT = (
|
| 7 |
+
"You are a careful reasoning assistant. Respond with your full reasoning inside "
|
| 8 |
+
"<think>...</think> and then provide a concise final answer after the closing tag."
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_messages(question: str) -> list[dict[str, str]]:
|
| 13 |
+
return [
|
| 14 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 15 |
+
{"role": "user", "content": question},
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def render_prompt(tokenizer: Any, question: str) -> str:
|
| 20 |
+
messages = build_messages(question)
|
| 21 |
+
if hasattr(tokenizer, "apply_chat_template"):
|
| 22 |
+
return tokenizer.apply_chat_template(
|
| 23 |
+
messages,
|
| 24 |
+
tokenize=False,
|
| 25 |
+
add_generation_prompt=True,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n"
|
app/generation/service.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from app.analysis.sentence_split import normalize_trace_text
|
| 9 |
+
from app.core.schemas import GenerationMetadata, GenerationResult
|
| 10 |
+
from app.generation.prompting import render_prompt
|
| 11 |
+
|
| 12 |
+
THINK_BLOCK_RE = re.compile(r"<think>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
| 13 |
+
ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _extract_trace_and_answer(text: str) -> tuple[str, str]:
|
| 17 |
+
match = THINK_BLOCK_RE.search(text)
|
| 18 |
+
if match:
|
| 19 |
+
raw_trace = match.group(0)
|
| 20 |
+
answer = text[match.end() :].strip()
|
| 21 |
+
if not answer:
|
| 22 |
+
answer = match.group(1).strip()
|
| 23 |
+
return raw_trace, answer
|
| 24 |
+
|
| 25 |
+
raw_trace = text.strip()
|
| 26 |
+
answer_match = ANSWER_MARKER_RE.search(text)
|
| 27 |
+
if answer_match:
|
| 28 |
+
answer = text[answer_match.end() :].strip()
|
| 29 |
+
else:
|
| 30 |
+
paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()]
|
| 31 |
+
answer = paragraphs[-1] if paragraphs else raw_trace
|
| 32 |
+
return raw_trace, answer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def generate_answer_and_trace(
|
| 36 |
+
*,
|
| 37 |
+
question: str,
|
| 38 |
+
model_name: str,
|
| 39 |
+
model: Any,
|
| 40 |
+
tokenizer: Any,
|
| 41 |
+
max_new_tokens: int = 512,
|
| 42 |
+
temperature: float = 0.6,
|
| 43 |
+
top_p: float = 0.95,
|
| 44 |
+
) -> GenerationResult:
|
| 45 |
+
prompt_text = render_prompt(tokenizer, question)
|
| 46 |
+
encoded = tokenizer(prompt_text, return_tensors="pt")
|
| 47 |
+
model_device = next(model.parameters()).device
|
| 48 |
+
encoded = {key: value.to(model_device) for key, value in encoded.items()}
|
| 49 |
+
input_length = int(encoded["input_ids"].shape[-1])
|
| 50 |
+
do_sample = temperature > 0.0
|
| 51 |
+
|
| 52 |
+
generation_kwargs: dict[str, Any] = {
|
| 53 |
+
"max_new_tokens": max_new_tokens,
|
| 54 |
+
"do_sample": do_sample,
|
| 55 |
+
"top_p": top_p,
|
| 56 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 57 |
+
"eos_token_id": tokenizer.eos_token_id,
|
| 58 |
+
}
|
| 59 |
+
if do_sample:
|
| 60 |
+
generation_kwargs["temperature"] = temperature
|
| 61 |
+
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
output_ids = model.generate(**encoded, **generation_kwargs)
|
| 64 |
+
|
| 65 |
+
generated_ids = output_ids[0, input_length:]
|
| 66 |
+
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 67 |
+
raw_trace_text, answer = _extract_trace_and_answer(generated_text)
|
| 68 |
+
normalized_trace_text = normalize_trace_text(raw_trace_text)
|
| 69 |
+
|
| 70 |
+
return GenerationResult(
|
| 71 |
+
question=question,
|
| 72 |
+
model_name=model_name,
|
| 73 |
+
answer=answer,
|
| 74 |
+
raw_generation_text=generated_text,
|
| 75 |
+
raw_trace_text=raw_trace_text,
|
| 76 |
+
normalized_trace_text=normalized_trace_text,
|
| 77 |
+
generation_metadata=GenerationMetadata(
|
| 78 |
+
max_new_tokens=max_new_tokens,
|
| 79 |
+
temperature=temperature,
|
| 80 |
+
top_p=top_p,
|
| 81 |
+
do_sample=do_sample,
|
| 82 |
+
),
|
| 83 |
+
)
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Application services for session orchestration."""
|
app/services/sessions.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from app.analysis.sentence_split import split_sentences
|
| 6 |
+
from app.core.config import Settings
|
| 7 |
+
from app.core.runtime import load_model_bundle
|
| 8 |
+
from app.core.runtime_pipeline import analyze_generation_result
|
| 9 |
+
from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult
|
| 10 |
+
from app.generation.service import generate_answer_and_trace
|
| 11 |
+
from app.storage.repository import SessionRecord, SessionRepository
|
| 12 |
+
from app.workers.jobs import JobRunner
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SessionLimitError(RuntimeError):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SessionAccessError(PermissionError):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(slots=True)
|
| 24 |
+
class SessionService:
|
| 25 |
+
settings: Settings
|
| 26 |
+
repository: SessionRepository
|
| 27 |
+
jobs: JobRunner
|
| 28 |
+
|
| 29 |
+
def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord:
|
| 30 |
+
if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs:
|
| 31 |
+
raise SessionLimitError("The service queue is full. Try again after a few minutes.")
|
| 32 |
+
if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user:
|
| 33 |
+
raise SessionLimitError("You already have the maximum number of active analysis jobs.")
|
| 34 |
+
model_name = request.model_name or self.settings.model_name
|
| 35 |
+
session = self.repository.create_session(
|
| 36 |
+
question=request.question,
|
| 37 |
+
model_name=model_name,
|
| 38 |
+
owner_id=owner_id,
|
| 39 |
+
owner_name=owner_name,
|
| 40 |
+
)
|
| 41 |
+
self.jobs.submit(self._run_session_pipeline, session.id, request)
|
| 42 |
+
return session
|
| 43 |
+
|
| 44 |
+
def start_analysis(
|
| 45 |
+
self,
|
| 46 |
+
session_id: str,
|
| 47 |
+
*,
|
| 48 |
+
owner_id: str,
|
| 49 |
+
request: AnalysisRequest | None = None,
|
| 50 |
+
) -> SessionRecord:
|
| 51 |
+
session = self.repository.get_session(session_id)
|
| 52 |
+
self._assert_owner(session, owner_id)
|
| 53 |
+
effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name)
|
| 54 |
+
self.jobs.submit(self._run_analysis_only, session_id, effective_request)
|
| 55 |
+
return session
|
| 56 |
+
|
| 57 |
+
def get_session_payload(self, session_id: str, *, owner_id: str) -> dict:
|
| 58 |
+
session = self.repository.get_session(session_id)
|
| 59 |
+
self._assert_owner(session, owner_id)
|
| 60 |
+
return self.repository.list_session_payload(session_id)
|
| 61 |
+
|
| 62 |
+
def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]:
|
| 63 |
+
return self.repository.list_sessions_for_owner(owner_id, limit=limit)
|
| 64 |
+
|
| 65 |
+
def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult:
|
| 66 |
+
payload = self.get_session_payload(session_id, owner_id=owner_id)
|
| 67 |
+
analysis = payload.get("analysis")
|
| 68 |
+
if analysis is None:
|
| 69 |
+
raise KeyError(session_id)
|
| 70 |
+
return AnalysisResult.model_validate(analysis)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _assert_owner(session: SessionRecord, owner_id: str) -> None:
|
| 74 |
+
if session.owner_id != owner_id:
|
| 75 |
+
raise SessionAccessError("Session belongs to a different user.")
|
| 76 |
+
|
| 77 |
+
def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None:
|
| 78 |
+
try:
|
| 79 |
+
self.repository.update_status(session_id, status="generating")
|
| 80 |
+
bundle = load_model_bundle(
|
| 81 |
+
request.model_name or self.settings.model_name,
|
| 82 |
+
device_preference=request.device_preference or self.settings.device_preference,
|
| 83 |
+
dtype_preference=request.dtype_preference or self.settings.dtype_preference,
|
| 84 |
+
attn_implementation=request.attn_implementation or self.settings.attn_implementation,
|
| 85 |
+
trust_remote_code=(
|
| 86 |
+
self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
|
| 87 |
+
),
|
| 88 |
+
low_cpu_mem_usage=(
|
| 89 |
+
self.settings.low_cpu_mem_usage
|
| 90 |
+
if request.low_cpu_mem_usage is None
|
| 91 |
+
else request.low_cpu_mem_usage
|
| 92 |
+
),
|
| 93 |
+
)
|
| 94 |
+
generation = generate_answer_and_trace(
|
| 95 |
+
question=request.question,
|
| 96 |
+
model_name=bundle.model_name,
|
| 97 |
+
model=bundle.model,
|
| 98 |
+
tokenizer=bundle.tokenizer,
|
| 99 |
+
max_new_tokens=request.max_new_tokens,
|
| 100 |
+
temperature=request.temperature,
|
| 101 |
+
top_p=request.top_p,
|
| 102 |
+
)
|
| 103 |
+
sentences = [span.text for span in split_sentences(generation.normalized_trace_text)]
|
| 104 |
+
self.repository.save_generation_result(session_id, generation, sentences)
|
| 105 |
+
self.repository.update_status(session_id, status="answer_ready")
|
| 106 |
+
self._run_analysis_only(session_id, request, generation=generation)
|
| 107 |
+
except Exception as exc:
|
| 108 |
+
self.repository.update_status(session_id, status="failed", error=str(exc))
|
| 109 |
+
|
| 110 |
+
def _run_analysis_only(
|
| 111 |
+
self,
|
| 112 |
+
session_id: str,
|
| 113 |
+
request: AnalysisRequest,
|
| 114 |
+
*,
|
| 115 |
+
generation=None,
|
| 116 |
+
) -> None:
|
| 117 |
+
try:
|
| 118 |
+
self.repository.update_status(session_id, status="analyzing")
|
| 119 |
+
if generation is None:
|
| 120 |
+
payload = self.repository.list_session_payload(session_id)
|
| 121 |
+
if payload.get("generation_metadata") is not None:
|
| 122 |
+
generation = GenerationResult(
|
| 123 |
+
question=payload["question"],
|
| 124 |
+
model_name=payload["model_name"],
|
| 125 |
+
answer=payload["answer"],
|
| 126 |
+
raw_generation_text=payload.get("raw_generation_text", ""),
|
| 127 |
+
raw_trace_text=payload["raw_trace_text"],
|
| 128 |
+
normalized_trace_text=payload["normalized_trace_text"],
|
| 129 |
+
generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]),
|
| 130 |
+
)
|
| 131 |
+
result = analyze_generation_result(
|
| 132 |
+
question=request.question,
|
| 133 |
+
generation=generation,
|
| 134 |
+
model_name=request.model_name or self.settings.model_name,
|
| 135 |
+
take_log=request.take_log,
|
| 136 |
+
max_sentences=request.max_sentences,
|
| 137 |
+
max_trace_tokens=request.max_trace_tokens,
|
| 138 |
+
validate_top_k=request.validate_top_k,
|
| 139 |
+
device_preference=request.device_preference or self.settings.device_preference,
|
| 140 |
+
dtype_preference=request.dtype_preference or self.settings.dtype_preference,
|
| 141 |
+
attn_implementation=request.attn_implementation or self.settings.attn_implementation,
|
| 142 |
+
trust_remote_code=(
|
| 143 |
+
self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
|
| 144 |
+
),
|
| 145 |
+
low_cpu_mem_usage=(
|
| 146 |
+
self.settings.low_cpu_mem_usage
|
| 147 |
+
if request.low_cpu_mem_usage is None
|
| 148 |
+
else request.low_cpu_mem_usage
|
| 149 |
+
),
|
| 150 |
+
max_new_tokens=request.max_new_tokens,
|
| 151 |
+
temperature=request.temperature,
|
| 152 |
+
top_p=request.top_p,
|
| 153 |
+
)
|
| 154 |
+
self.repository.save_analysis_result(session_id, result)
|
| 155 |
+
self.repository.update_status(session_id, status="completed")
|
| 156 |
+
except Exception as exc:
|
| 157 |
+
self.repository.update_status(session_id, status="failed", error=str(exc))
|
app/storage/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Persistence layer for sessions and analysis results."""
|
app/storage/db.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def connect(database_path: str) -> sqlite3.Connection:
|
| 8 |
+
path = Path(database_path)
|
| 9 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 10 |
+
connection = sqlite3.connect(path, check_same_thread=False)
|
| 11 |
+
connection.row_factory = sqlite3.Row
|
| 12 |
+
connection.execute("PRAGMA journal_mode=WAL")
|
| 13 |
+
connection.execute("PRAGMA foreign_keys=ON")
|
| 14 |
+
return connection
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def initialize_schema(connection: sqlite3.Connection) -> None:
|
| 18 |
+
connection.executescript(
|
| 19 |
+
"""
|
| 20 |
+
CREATE TABLE IF NOT EXISTS sessions (
|
| 21 |
+
id TEXT PRIMARY KEY,
|
| 22 |
+
status TEXT NOT NULL,
|
| 23 |
+
question TEXT NOT NULL,
|
| 24 |
+
model_name TEXT NOT NULL,
|
| 25 |
+
owner_id TEXT NOT NULL,
|
| 26 |
+
owner_name TEXT,
|
| 27 |
+
error TEXT,
|
| 28 |
+
created_at TEXT NOT NULL,
|
| 29 |
+
updated_at TEXT NOT NULL
|
| 30 |
+
);
|
| 31 |
+
|
| 32 |
+
CREATE TABLE IF NOT EXISTS generation_results (
|
| 33 |
+
session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
|
| 34 |
+
answer TEXT NOT NULL,
|
| 35 |
+
raw_generation_text TEXT NOT NULL,
|
| 36 |
+
raw_trace_text TEXT NOT NULL,
|
| 37 |
+
normalized_trace_text TEXT NOT NULL,
|
| 38 |
+
sentences_json TEXT NOT NULL,
|
| 39 |
+
generation_metadata_json TEXT NOT NULL
|
| 40 |
+
);
|
| 41 |
+
|
| 42 |
+
CREATE TABLE IF NOT EXISTS analysis_results (
|
| 43 |
+
session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE,
|
| 44 |
+
result_json TEXT NOT NULL
|
| 45 |
+
);
|
| 46 |
+
"""
|
| 47 |
+
)
|
| 48 |
+
columns = {
|
| 49 |
+
row["name"] for row in connection.execute("PRAGMA table_info(generation_results)").fetchall()
|
| 50 |
+
}
|
| 51 |
+
if "raw_generation_text" not in columns:
|
| 52 |
+
connection.execute(
|
| 53 |
+
"ALTER TABLE generation_results ADD COLUMN raw_generation_text TEXT NOT NULL DEFAULT ''"
|
| 54 |
+
)
|
| 55 |
+
session_columns = {row["name"] for row in connection.execute("PRAGMA table_info(sessions)").fetchall()}
|
| 56 |
+
if "owner_id" not in session_columns:
|
| 57 |
+
connection.execute("ALTER TABLE sessions ADD COLUMN owner_id TEXT NOT NULL DEFAULT 'legacy-user'")
|
| 58 |
+
if "owner_name" not in session_columns:
|
| 59 |
+
connection.execute("ALTER TABLE sessions ADD COLUMN owner_name TEXT")
|
| 60 |
+
connection.commit()
|
app/storage/repository.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import threading
|
| 5 |
+
import uuid
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datetime import UTC, datetime
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.core.schemas import AnalysisResult, GenerationResult
|
| 11 |
+
from app.storage.db import connect, initialize_schema
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _utc_now() -> str:
|
| 15 |
+
return datetime.now(UTC).isoformat()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(slots=True)
|
| 19 |
+
class SessionRecord:
|
| 20 |
+
id: str
|
| 21 |
+
status: str
|
| 22 |
+
question: str
|
| 23 |
+
model_name: str
|
| 24 |
+
owner_id: str
|
| 25 |
+
owner_name: str | None
|
| 26 |
+
error: str | None
|
| 27 |
+
created_at: str
|
| 28 |
+
updated_at: str
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SessionRepository:
|
| 32 |
+
def __init__(self, database_path: str) -> None:
|
| 33 |
+
self.connection = connect(database_path)
|
| 34 |
+
initialize_schema(self.connection)
|
| 35 |
+
self.lock = threading.Lock()
|
| 36 |
+
|
| 37 |
+
def create_session(self, *, question: str, model_name: str, owner_id: str, owner_name: str | None) -> SessionRecord:
|
| 38 |
+
session_id = str(uuid.uuid4())
|
| 39 |
+
now = _utc_now()
|
| 40 |
+
with self.lock:
|
| 41 |
+
self.connection.execute(
|
| 42 |
+
"""
|
| 43 |
+
INSERT INTO sessions (id, status, question, model_name, owner_id, owner_name, error, created_at, updated_at)
|
| 44 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 45 |
+
""",
|
| 46 |
+
(session_id, "queued", question, model_name, owner_id, owner_name, None, now, now),
|
| 47 |
+
)
|
| 48 |
+
self.connection.commit()
|
| 49 |
+
return self.get_session(session_id)
|
| 50 |
+
|
| 51 |
+
def get_session(self, session_id: str) -> SessionRecord:
|
| 52 |
+
row = self.connection.execute(
|
| 53 |
+
"SELECT * FROM sessions WHERE id = ?",
|
| 54 |
+
(session_id,),
|
| 55 |
+
).fetchone()
|
| 56 |
+
if row is None:
|
| 57 |
+
raise KeyError(session_id)
|
| 58 |
+
return SessionRecord(**dict(row))
|
| 59 |
+
|
| 60 |
+
def list_session_payload(self, session_id: str) -> dict[str, Any]:
|
| 61 |
+
session = self.get_session(session_id)
|
| 62 |
+
payload: dict[str, Any] = {
|
| 63 |
+
"id": session.id,
|
| 64 |
+
"status": session.status,
|
| 65 |
+
"question": session.question,
|
| 66 |
+
"model_name": session.model_name,
|
| 67 |
+
"owner_id": session.owner_id,
|
| 68 |
+
"owner_name": session.owner_name,
|
| 69 |
+
"error": session.error,
|
| 70 |
+
"created_at": session.created_at,
|
| 71 |
+
"updated_at": session.updated_at,
|
| 72 |
+
}
|
| 73 |
+
generation_row = self.connection.execute(
|
| 74 |
+
"SELECT * FROM generation_results WHERE session_id = ?",
|
| 75 |
+
(session_id,),
|
| 76 |
+
).fetchone()
|
| 77 |
+
if generation_row is not None:
|
| 78 |
+
payload["answer"] = generation_row["answer"]
|
| 79 |
+
payload["raw_generation_text"] = generation_row["raw_generation_text"]
|
| 80 |
+
payload["raw_trace_text"] = generation_row["raw_trace_text"]
|
| 81 |
+
payload["normalized_trace_text"] = generation_row["normalized_trace_text"]
|
| 82 |
+
payload["sentences"] = json.loads(generation_row["sentences_json"])
|
| 83 |
+
payload["generation_metadata"] = json.loads(generation_row["generation_metadata_json"])
|
| 84 |
+
analysis_row = self.connection.execute(
|
| 85 |
+
"SELECT result_json FROM analysis_results WHERE session_id = ?",
|
| 86 |
+
(session_id,),
|
| 87 |
+
).fetchone()
|
| 88 |
+
if analysis_row is not None:
|
| 89 |
+
payload["analysis"] = json.loads(analysis_row["result_json"])
|
| 90 |
+
return payload
|
| 91 |
+
|
| 92 |
+
def list_sessions_for_owner(self, owner_id: str, *, limit: int = 20) -> list[dict[str, Any]]:
|
| 93 |
+
rows = self.connection.execute(
|
| 94 |
+
"""
|
| 95 |
+
SELECT id
|
| 96 |
+
FROM sessions
|
| 97 |
+
WHERE owner_id = ?
|
| 98 |
+
ORDER BY updated_at DESC
|
| 99 |
+
LIMIT ?
|
| 100 |
+
""",
|
| 101 |
+
(owner_id, limit),
|
| 102 |
+
).fetchall()
|
| 103 |
+
return [self.list_session_payload(row["id"]) for row in rows]
|
| 104 |
+
|
| 105 |
+
def update_status(self, session_id: str, *, status: str, error: str | None = None) -> None:
|
| 106 |
+
now = _utc_now()
|
| 107 |
+
with self.lock:
|
| 108 |
+
self.connection.execute(
|
| 109 |
+
"UPDATE sessions SET status = ?, error = ?, updated_at = ? WHERE id = ?",
|
| 110 |
+
(status, error, now, session_id),
|
| 111 |
+
)
|
| 112 |
+
self.connection.commit()
|
| 113 |
+
|
| 114 |
+
def count_incomplete_sessions(self) -> int:
|
| 115 |
+
row = self.connection.execute(
|
| 116 |
+
"SELECT COUNT(*) AS count FROM sessions WHERE status NOT IN ('completed', 'failed')"
|
| 117 |
+
).fetchone()
|
| 118 |
+
return int(row["count"])
|
| 119 |
+
|
| 120 |
+
def count_incomplete_sessions_for_owner(self, owner_id: str) -> int:
|
| 121 |
+
row = self.connection.execute(
|
| 122 |
+
"""
|
| 123 |
+
SELECT COUNT(*) AS count
|
| 124 |
+
FROM sessions
|
| 125 |
+
WHERE owner_id = ? AND status NOT IN ('completed', 'failed')
|
| 126 |
+
""",
|
| 127 |
+
(owner_id,),
|
| 128 |
+
).fetchone()
|
| 129 |
+
return int(row["count"])
|
| 130 |
+
|
| 131 |
+
def save_generation_result(self, session_id: str, generation: GenerationResult, sentences: list[str]) -> None:
|
| 132 |
+
with self.lock:
|
| 133 |
+
self.connection.execute(
|
| 134 |
+
"""
|
| 135 |
+
INSERT INTO generation_results (
|
| 136 |
+
session_id,
|
| 137 |
+
answer,
|
| 138 |
+
raw_generation_text,
|
| 139 |
+
raw_trace_text,
|
| 140 |
+
normalized_trace_text,
|
| 141 |
+
sentences_json,
|
| 142 |
+
generation_metadata_json
|
| 143 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 144 |
+
ON CONFLICT(session_id) DO UPDATE SET
|
| 145 |
+
answer = excluded.answer,
|
| 146 |
+
raw_generation_text = excluded.raw_generation_text,
|
| 147 |
+
raw_trace_text = excluded.raw_trace_text,
|
| 148 |
+
normalized_trace_text = excluded.normalized_trace_text,
|
| 149 |
+
sentences_json = excluded.sentences_json,
|
| 150 |
+
generation_metadata_json = excluded.generation_metadata_json
|
| 151 |
+
""",
|
| 152 |
+
(
|
| 153 |
+
session_id,
|
| 154 |
+
generation.answer,
|
| 155 |
+
generation.raw_generation_text,
|
| 156 |
+
generation.raw_trace_text,
|
| 157 |
+
generation.normalized_trace_text,
|
| 158 |
+
json.dumps(sentences),
|
| 159 |
+
json.dumps(generation.generation_metadata.model_dump()),
|
| 160 |
+
),
|
| 161 |
+
)
|
| 162 |
+
self.connection.execute(
|
| 163 |
+
"UPDATE sessions SET updated_at = ? WHERE id = ?",
|
| 164 |
+
(_utc_now(), session_id),
|
| 165 |
+
)
|
| 166 |
+
self.connection.commit()
|
| 167 |
+
|
| 168 |
+
def save_analysis_result(self, session_id: str, result: AnalysisResult) -> None:
|
| 169 |
+
with self.lock:
|
| 170 |
+
self.connection.execute(
|
| 171 |
+
"""
|
| 172 |
+
INSERT INTO analysis_results (session_id, result_json)
|
| 173 |
+
VALUES (?, ?)
|
| 174 |
+
ON CONFLICT(session_id) DO UPDATE SET result_json = excluded.result_json
|
| 175 |
+
""",
|
| 176 |
+
(session_id, result.model_dump_json()),
|
| 177 |
+
)
|
| 178 |
+
self.connection.execute(
|
| 179 |
+
"UPDATE sessions SET updated_at = ? WHERE id = ?",
|
| 180 |
+
(_utc_now(), session_id),
|
| 181 |
+
)
|
| 182 |
+
self.connection.commit()
|
app/workers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Background job runner."""
|
app/workers/jobs.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from concurrent.futures import Future, ThreadPoolExecutor
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(slots=True)
|
| 9 |
+
class JobRunner:
|
| 10 |
+
executor: ThreadPoolExecutor
|
| 11 |
+
|
| 12 |
+
def submit(self, fn: Callable[..., None], *args, **kwargs) -> Future:
|
| 13 |
+
return self.executor.submit(fn, *args, **kwargs)
|
| 14 |
+
|
| 15 |
+
def shutdown(self) -> None:
|
| 16 |
+
self.executor.shutdown(wait=False, cancel_futures=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_job_runner(max_workers: int) -> JobRunner:
|
| 20 |
+
return JobRunner(executor=ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="cot-anc"))
|
cot_anc.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: cot-anc
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Online chain-of-thought analysis with attribution patching
|
| 5 |
+
Requires-Python: >=3.11
|
| 6 |
+
Description-Content-Type: text/markdown
|
| 7 |
+
Requires-Dist: fastapi>=0.115.0
|
| 8 |
+
Requires-Dist: huggingface_hub[oauth]>=0.33.0
|
| 9 |
+
Requires-Dist: numpy>=2.0.0
|
| 10 |
+
Requires-Dist: pydantic>=2.7.0
|
| 11 |
+
Requires-Dist: scipy>=1.13.0
|
| 12 |
+
Requires-Dist: torch>=2.2.0
|
| 13 |
+
Requires-Dist: transformers>=4.44.0
|
| 14 |
+
Requires-Dist: typer>=0.12.3
|
| 15 |
+
Requires-Dist: uvicorn>=0.30.0
|
| 16 |
+
Provides-Extra: dev
|
| 17 |
+
Requires-Dist: pytest>=8.2.0; extra == "dev"
|
| 18 |
+
Provides-Extra: viz
|
| 19 |
+
Requires-Dist: matplotlib>=3.8.0; extra == "viz"
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
title: Thought Anchors
|
| 23 |
+
emoji: 🧠
|
| 24 |
+
colorFrom: blue
|
| 25 |
+
colorTo: orange
|
| 26 |
+
sdk: docker
|
| 27 |
+
app_port: 7860
|
| 28 |
+
hf_oauth: true
|
| 29 |
+
pinned: false
|
| 30 |
+
models:
|
| 31 |
+
- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
# Thought Anchors
|
| 35 |
+
|
| 36 |
+
Public-facing FastAPI service for generating a visible reasoning trace and computing
|
| 37 |
+
sentence-to-sentence attribution on open-weight reasoning models.
|
| 38 |
+
|
| 39 |
+
The app is now shaped for deployment as a Hugging Face `Docker Space`:
|
| 40 |
+
|
| 41 |
+
- Hugging Face OAuth sign-in for end users
|
| 42 |
+
- browser UI plus programmatic API
|
| 43 |
+
- per-user ephemeral sessions on the running instance
|
| 44 |
+
- JSON and CSV export for completed analyses
|
| 45 |
+
- adaptive device and dtype loading for CPU, MPS, CUDA, Colab, Kaggle, and cloud GPUs
|
| 46 |
+
|
| 47 |
+
## What It Can Do
|
| 48 |
+
|
| 49 |
+
- generate an answer plus visible `<think>...</think>` trace from a supported causal LM
|
| 50 |
+
- normalize and split the trace into sentences
|
| 51 |
+
- compute a sentence influence matrix with gradient x attention attribution
|
| 52 |
+
- summarize incoming / outgoing importance and top edges
|
| 53 |
+
- expose the workflow through:
|
| 54 |
+
- CLI
|
| 55 |
+
- FastAPI
|
| 56 |
+
- web UI
|
| 57 |
+
- async session queue
|
| 58 |
+
|
| 59 |
+
## Current Deployment Target
|
| 60 |
+
|
| 61 |
+
The primary deployment target is a Hugging Face `Docker Space` running on upgraded GPU
|
| 62 |
+
hardware. The same app can also be run locally or on other cloud GPU hosts.
|
| 63 |
+
|
| 64 |
+
Important runtime constraints:
|
| 65 |
+
|
| 66 |
+
- attribution requires `attn_implementation="eager"`
|
| 67 |
+
- the model must expose usable attention tensors and a supported decoder-layer layout
|
| 68 |
+
- long traces are intentionally capped because the analysis path uses a full backward pass
|
| 69 |
+
|
| 70 |
+
## Local Development
|
| 71 |
+
|
| 72 |
+
Install dependencies:
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
uv sync
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Run the API:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
uv run python -m app.cli.run_api
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
Run the CLI:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x"
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## API
|
| 91 |
+
|
| 92 |
+
Main endpoints:
|
| 93 |
+
|
| 94 |
+
- `GET /healthz`
|
| 95 |
+
- `GET /api/me`
|
| 96 |
+
- `POST /api/warmup`
|
| 97 |
+
- `POST /api/analyze`
|
| 98 |
+
- `GET /api/sessions`
|
| 99 |
+
- `POST /api/sessions`
|
| 100 |
+
- `GET /api/sessions/{id}`
|
| 101 |
+
- `GET /api/sessions/{id}/result`
|
| 102 |
+
- `GET /api/sessions/{id}/export.json`
|
| 103 |
+
- `GET /api/sessions/{id}/export.csv`
|
| 104 |
+
|
| 105 |
+
Example:
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
curl -X POST http://localhost:7860/api/analyze \
|
| 109 |
+
-H 'content-type: application/json' \
|
| 110 |
+
-d '{
|
| 111 |
+
"question": "Explain why the derivative of x^2 is 2x",
|
| 112 |
+
"max_new_tokens": 128,
|
| 113 |
+
"max_trace_tokens": 256,
|
| 114 |
+
"max_sentences": 16,
|
| 115 |
+
"validate_top_k": 0
|
| 116 |
+
}'
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## Hugging Face Space Setup
|
| 120 |
+
|
| 121 |
+
Recommended environment variables:
|
| 122 |
+
|
| 123 |
+
- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
|
| 124 |
+
- `DEVICE_PREFERENCE=auto`
|
| 125 |
+
- `DTYPE_PREFERENCE=auto`
|
| 126 |
+
- `ATTN_IMPLEMENTATION=eager`
|
| 127 |
+
- `LOW_CPU_MEM_USAGE=true`
|
| 128 |
+
- `TRUST_REMOTE_CODE=true`
|
| 129 |
+
- `PRELOAD_MODEL=true`
|
| 130 |
+
- `MAX_TRACE_TOKENS=256`
|
| 131 |
+
- `MAX_SENTENCES=16`
|
| 132 |
+
- `JOB_WORKERS=1`
|
| 133 |
+
- `MAX_QUEUED_JOBS=8`
|
| 134 |
+
- `MAX_ACTIVE_JOBS_PER_USER=2`
|
| 135 |
+
- `REQUIRE_AUTH=true`
|
| 136 |
+
|
| 137 |
+
Notes:
|
| 138 |
+
|
| 139 |
+
- local disk is ephemeral; users should export results they want to keep
|
| 140 |
+
- use upgraded GPU hardware for real attribution runs
|
| 141 |
+
- keep trace limits conservative for public traffic
|
| 142 |
+
|
| 143 |
+
## Notebook
|
| 144 |
+
|
| 145 |
+
Use [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
|
| 146 |
+
for Colab or Kaggle smoke testing. It installs dependencies, warms the model, runs one
|
| 147 |
+
short attribution analysis, prints the top edges, and renders a simple heatmap.
|
cot_anc.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
app/__init__.py
|
| 4 |
+
app/analysis/__init__.py
|
| 5 |
+
app/analysis/hooks.py
|
| 6 |
+
app/analysis/sentence_split.py
|
| 7 |
+
app/analysis/summaries.py
|
| 8 |
+
app/analysis/suppression.py
|
| 9 |
+
app/analysis/token_boundaries.py
|
| 10 |
+
app/analysis/validation.py
|
| 11 |
+
app/api/__init__.py
|
| 12 |
+
app/api/auth.py
|
| 13 |
+
app/api/main.py
|
| 14 |
+
app/cli/__init__.py
|
| 15 |
+
app/cli/run_api.py
|
| 16 |
+
app/cli/run_prototype.py
|
| 17 |
+
app/core/__init__.py
|
| 18 |
+
app/core/config.py
|
| 19 |
+
app/core/model_support.py
|
| 20 |
+
app/core/runtime.py
|
| 21 |
+
app/core/runtime_pipeline.py
|
| 22 |
+
app/core/schemas.py
|
| 23 |
+
app/generation/__init__.py
|
| 24 |
+
app/generation/prompting.py
|
| 25 |
+
app/generation/service.py
|
| 26 |
+
app/services/__init__.py
|
| 27 |
+
app/services/sessions.py
|
| 28 |
+
app/storage/__init__.py
|
| 29 |
+
app/storage/db.py
|
| 30 |
+
app/storage/repository.py
|
| 31 |
+
app/workers/__init__.py
|
| 32 |
+
app/workers/jobs.py
|
| 33 |
+
cot_anc.egg-info/PKG-INFO
|
| 34 |
+
cot_anc.egg-info/SOURCES.txt
|
| 35 |
+
cot_anc.egg-info/dependency_links.txt
|
| 36 |
+
cot_anc.egg-info/requires.txt
|
| 37 |
+
cot_anc.egg-info/top_level.txt
|
| 38 |
+
tests/test_api.py
|
| 39 |
+
tests/test_model_support.py
|
| 40 |
+
tests/test_runtime_pipeline.py
|
| 41 |
+
tests/test_sentence_split.py
|
| 42 |
+
tests/test_summaries.py
|
| 43 |
+
tests/test_suppression_shape.py
|
| 44 |
+
tests/test_token_boundaries.py
|
cot_anc.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
cot_anc.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
huggingface_hub[oauth]>=0.33.0
|
| 3 |
+
numpy>=2.0.0
|
| 4 |
+
pydantic>=2.7.0
|
| 5 |
+
scipy>=1.13.0
|
| 6 |
+
torch>=2.2.0
|
| 7 |
+
transformers>=4.44.0
|
| 8 |
+
typer>=0.12.3
|
| 9 |
+
uvicorn>=0.30.0
|
| 10 |
+
|
| 11 |
+
[dev]
|
| 12 |
+
pytest>=8.2.0
|
| 13 |
+
|
| 14 |
+
[viz]
|
| 15 |
+
matplotlib>=3.8.0
|
cot_anc.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
app
|
docs/api.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API And Product Behavior
|
| 2 |
+
|
| 3 |
+
## Auth Model
|
| 4 |
+
|
| 5 |
+
- If `REQUIRE_AUTH=true`, analysis endpoints require Hugging Face sign-in.
|
| 6 |
+
- Sessions are scoped to user identity.
|
| 7 |
+
- One user cannot fetch another user’s session or export.
|
| 8 |
+
|
| 9 |
+
## Session Model
|
| 10 |
+
|
| 11 |
+
Statuses:
|
| 12 |
+
|
| 13 |
+
- `queued`
|
| 14 |
+
- `generating`
|
| 15 |
+
- `answer_ready`
|
| 16 |
+
- `analyzing`
|
| 17 |
+
- `completed`
|
| 18 |
+
- `failed`
|
| 19 |
+
|
| 20 |
+
Sessions are ephemeral for current running instance.
|
| 21 |
+
|
| 22 |
+
## Endpoints
|
| 23 |
+
|
| 24 |
+
### `GET /healthz`
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
|
| 28 |
+
- model name
|
| 29 |
+
- device preference
|
| 30 |
+
- dtype preference
|
| 31 |
+
- auth requirement
|
| 32 |
+
- queue limits
|
| 33 |
+
- CUDA / MPS availability
|
| 34 |
+
|
| 35 |
+
### `GET /api/me`
|
| 36 |
+
|
| 37 |
+
Returns current auth state plus login/logout URLs.
|
| 38 |
+
|
| 39 |
+
### `POST /api/warmup`
|
| 40 |
+
|
| 41 |
+
Loads model with current runtime policy and returns:
|
| 42 |
+
|
| 43 |
+
- resolved device
|
| 44 |
+
- resolved dtype
|
| 45 |
+
- model attribution capability
|
| 46 |
+
|
| 47 |
+
### `POST /api/analyze`
|
| 48 |
+
|
| 49 |
+
Direct synchronous analysis call. Good for trusted programmatic use, not best path for UI.
|
| 50 |
+
|
| 51 |
+
### `GET /api/sessions`
|
| 52 |
+
|
| 53 |
+
List current user’s recent sessions on current instance.
|
| 54 |
+
|
| 55 |
+
### `POST /api/sessions`
|
| 56 |
+
|
| 57 |
+
Create async session job.
|
| 58 |
+
|
| 59 |
+
Queue protections:
|
| 60 |
+
|
| 61 |
+
- global queue cap
|
| 62 |
+
- per-user active-job cap
|
| 63 |
+
|
| 64 |
+
### `GET /api/sessions/{id}`
|
| 65 |
+
|
| 66 |
+
Return session summary.
|
| 67 |
+
|
| 68 |
+
### `GET /api/sessions/{id}/result`
|
| 69 |
+
|
| 70 |
+
Return session summary + full analysis if ready.
|
| 71 |
+
|
| 72 |
+
### Export
|
| 73 |
+
|
| 74 |
+
- `GET /api/sessions/{id}/export.json`
|
| 75 |
+
- `GET /api/sessions/{id}/export.csv`
|
| 76 |
+
|
| 77 |
+
JSON contains session + analysis payload.
|
| 78 |
+
|
| 79 |
+
CSV contains top edges:
|
| 80 |
+
|
| 81 |
+
- `source_sentence_idx`
|
| 82 |
+
- `target_sentence_idx`
|
| 83 |
+
- `score`
|
docs/deploy-huggingface.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Deployment
|
| 2 |
+
|
| 3 |
+
Primary target: Hugging Face `Docker Space` on upgraded GPU hardware.
|
| 4 |
+
|
| 5 |
+
## What Gets Deployed
|
| 6 |
+
|
| 7 |
+
- FastAPI backend
|
| 8 |
+
- static web frontend
|
| 9 |
+
- Hugging Face OAuth routes
|
| 10 |
+
- ephemeral SQLite-backed session queue
|
| 11 |
+
|
| 12 |
+
## Required Space Settings
|
| 13 |
+
|
| 14 |
+
- SDK: `Docker`
|
| 15 |
+
- Port: `7860`
|
| 16 |
+
- OAuth: enabled via README metadata
|
| 17 |
+
- Hardware: upgraded GPU recommended
|
| 18 |
+
|
| 19 |
+
## Recommended Runtime Variables
|
| 20 |
+
|
| 21 |
+
Core:
|
| 22 |
+
|
| 23 |
+
- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
|
| 24 |
+
- `DEVICE_PREFERENCE=auto`
|
| 25 |
+
- `DTYPE_PREFERENCE=auto`
|
| 26 |
+
- `ATTN_IMPLEMENTATION=eager`
|
| 27 |
+
- `LOW_CPU_MEM_USAGE=true`
|
| 28 |
+
- `TRUST_REMOTE_CODE=true`
|
| 29 |
+
- `PRELOAD_MODEL=true`
|
| 30 |
+
|
| 31 |
+
Traffic limits:
|
| 32 |
+
|
| 33 |
+
- `MAX_TRACE_TOKENS=256`
|
| 34 |
+
- `MAX_SENTENCES=16`
|
| 35 |
+
- `JOB_WORKERS=1`
|
| 36 |
+
- `MAX_QUEUED_JOBS=8`
|
| 37 |
+
- `MAX_ACTIVE_JOBS_PER_USER=2`
|
| 38 |
+
- `REQUIRE_AUTH=true`
|
| 39 |
+
|
| 40 |
+
## Deploy Flow
|
| 41 |
+
|
| 42 |
+
1. Create new Hugging Face Space with `Docker` SDK.
|
| 43 |
+
2. Push repo contents.
|
| 44 |
+
3. Set runtime variables in Space settings.
|
| 45 |
+
4. Upgrade hardware.
|
| 46 |
+
5. Wait for build.
|
| 47 |
+
6. Verify:
|
| 48 |
+
- `GET /healthz`
|
| 49 |
+
- sign-in works
|
| 50 |
+
- one short analysis completes
|
| 51 |
+
- JSON / CSV export works
|
| 52 |
+
|
| 53 |
+
## Operational Notes
|
| 54 |
+
|
| 55 |
+
- Local disk is ephemeral. Session history disappears on restart.
|
| 56 |
+
- OAuth helper is mocked locally but real inside Space.
|
| 57 |
+
- Keep public defaults conservative. Long traces can OOM small GPUs.
|
| 58 |
+
- If queue pressure grows, lower token caps before increasing worker count.
|
| 59 |
+
|
| 60 |
+
## Common Failure Modes
|
| 61 |
+
|
| 62 |
+
- `attn_implementation` not eager:
|
| 63 |
+
- attribution disabled for model
|
| 64 |
+
- unsupported model layout:
|
| 65 |
+
- generation may work, attribution fails early with clear error
|
| 66 |
+
- OOM:
|
| 67 |
+
- reduce `MAX_TRACE_TOKENS`, `MAX_SENTENCES`, or choose larger GPU
|
| 68 |
+
- cold start slow:
|
| 69 |
+
- keep `PRELOAD_MODEL=true`
|
docs/notebook.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Notebook Usage
|
| 2 |
+
|
| 3 |
+
Notebook path:
|
| 4 |
+
|
| 5 |
+
- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb)
|
| 6 |
+
|
| 7 |
+
## Purpose
|
| 8 |
+
|
| 9 |
+
Short smoke test for:
|
| 10 |
+
|
| 11 |
+
- Colab GPU
|
| 12 |
+
- Kaggle GPU
|
| 13 |
+
- quick local validation
|
| 14 |
+
|
| 15 |
+
## What Notebook Does
|
| 16 |
+
|
| 17 |
+
1. Install runtime deps.
|
| 18 |
+
2. Set conservative env vars.
|
| 19 |
+
3. Import project pipeline.
|
| 20 |
+
4. Run one short attribution job.
|
| 21 |
+
5. Print answer, runtime metadata, top edges.
|
| 22 |
+
6. Render heatmap.
|
| 23 |
+
|
| 24 |
+
## Best Use
|
| 25 |
+
|
| 26 |
+
- validate model availability
|
| 27 |
+
- validate driver / torch / transformers stack
|
| 28 |
+
- sanity-check latency before public deploy
|
| 29 |
+
|
| 30 |
+
## If Notebook Fails
|
| 31 |
+
|
| 32 |
+
Check:
|
| 33 |
+
|
| 34 |
+
- GPU available
|
| 35 |
+
- model access permissions
|
| 36 |
+
- enough VRAM
|
| 37 |
+
- eager attention enabled
|
| 38 |
+
- trace limits not too high
|
docs/runtime.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Runtime And Model Support
|
| 2 |
+
|
| 3 |
+
## Execution Model
|
| 4 |
+
|
| 5 |
+
Pipeline:
|
| 6 |
+
|
| 7 |
+
1. Generate visible reasoning trace.
|
| 8 |
+
2. Normalize and split trace into sentences.
|
| 9 |
+
3. Map sentence spans to token spans.
|
| 10 |
+
4. Run forward + backward pass.
|
| 11 |
+
5. Build sentence influence matrix from gradient x attention.
|
| 12 |
+
6. Summarize top edges and importance scores.
|
| 13 |
+
|
| 14 |
+
## Device And Dtype Policy
|
| 15 |
+
|
| 16 |
+
Default policy:
|
| 17 |
+
|
| 18 |
+
- CUDA:
|
| 19 |
+
- `bfloat16` if supported
|
| 20 |
+
- else `float16`
|
| 21 |
+
- MPS:
|
| 22 |
+
- `float16`
|
| 23 |
+
- CPU:
|
| 24 |
+
- `float32`
|
| 25 |
+
|
| 26 |
+
Override with:
|
| 27 |
+
|
| 28 |
+
- `DTYPE_PREFERENCE`
|
| 29 |
+
- request `dtype_preference`
|
| 30 |
+
|
| 31 |
+
## Model Requirements
|
| 32 |
+
|
| 33 |
+
Model must support all of:
|
| 34 |
+
|
| 35 |
+
- causal LM generation
|
| 36 |
+
- `output_attentions=True`
|
| 37 |
+
- eager attention
|
| 38 |
+
- supported decoder layer layout
|
| 39 |
+
- supported attention module attribute
|
| 40 |
+
|
| 41 |
+
Supported layer paths:
|
| 42 |
+
|
| 43 |
+
- `model.layers`
|
| 44 |
+
- `model.model.layers`
|
| 45 |
+
- `transformer.h`
|
| 46 |
+
- `gpt_neox.layers`
|
| 47 |
+
|
| 48 |
+
Supported attention attrs:
|
| 49 |
+
|
| 50 |
+
- `self_attn`
|
| 51 |
+
- `attn`
|
| 52 |
+
- `attention`
|
| 53 |
+
|
| 54 |
+
## Why Trace Limits Exist
|
| 55 |
+
|
| 56 |
+
Attribution path uses full backward pass over attention tensors. Cost grows with:
|
| 57 |
+
|
| 58 |
+
- sequence length
|
| 59 |
+
- layer count
|
| 60 |
+
- head count
|
| 61 |
+
- sentence count
|
| 62 |
+
|
| 63 |
+
Public defaults stay small to protect uptime.
|
| 64 |
+
|
| 65 |
+
## Good First Runtime Settings
|
| 66 |
+
|
| 67 |
+
For public demo:
|
| 68 |
+
|
| 69 |
+
- `max_new_tokens=128`
|
| 70 |
+
- `max_trace_tokens=256`
|
| 71 |
+
- `max_sentences=16`
|
| 72 |
+
- `validate_top_k=0`
|
| 73 |
+
|
| 74 |
+
For deeper analysis on bigger GPU:
|
| 75 |
+
|
| 76 |
+
- raise trace tokens slowly
|
| 77 |
+
- watch latency and memory first
|
notebooks/hf_space_demo.ipynb
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Thought Anchors Demo\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Colab/Kaggle notebook for a short end-to-end attribution smoke test."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"!pip install -q uv\n",
|
| 19 |
+
"!uv pip install --system fastapi huggingface_hub matplotlib numpy pydantic scipy torch transformers typer uvicorn"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"import os\n",
|
| 29 |
+
"from pathlib import Path\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"if not Path('app').exists():\n",
|
| 32 |
+
" raise RuntimeError('Upload or clone the repository before running the notebook.')\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"os.environ.setdefault('MODEL_NAME', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')\n",
|
| 35 |
+
"os.environ.setdefault('DEVICE_PREFERENCE', 'auto')\n",
|
| 36 |
+
"os.environ.setdefault('DTYPE_PREFERENCE', 'auto')\n",
|
| 37 |
+
"os.environ.setdefault('ATTN_IMPLEMENTATION', 'eager')\n",
|
| 38 |
+
"os.environ.setdefault('LOW_CPU_MEM_USAGE', 'true')\n",
|
| 39 |
+
"os.environ.setdefault('TRUST_REMOTE_CODE', 'true')\n",
|
| 40 |
+
"os.environ.setdefault('MAX_TRACE_TOKENS', '256')\n",
|
| 41 |
+
"os.environ.setdefault('MAX_SENTENCES', '16')"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"from app.core.runtime_pipeline import compute_attribution_analysis"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"question = 'Explain why the derivative of x^2 is 2x.'\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"result = compute_attribution_analysis(\n",
|
| 62 |
+
" question=question,\n",
|
| 63 |
+
" max_new_tokens=128,\n",
|
| 64 |
+
" max_trace_tokens=256,\n",
|
| 65 |
+
" max_sentences=16,\n",
|
| 66 |
+
" validate_top_k=0,\n",
|
| 67 |
+
" temperature=0.0,\n",
|
| 68 |
+
" top_p=1.0,\n",
|
| 69 |
+
")"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"print('Answer:\\n', result.answer)\n",
|
| 79 |
+
"print('\\nSentences:', len(result.sentences))\n",
|
| 80 |
+
"print('Runtime:', result.runtime_metadata.model_dump())\n",
|
| 81 |
+
"print('\\nTop edges:')\n",
|
| 82 |
+
"for edge in result.top_edges[:10]:\n",
|
| 83 |
+
" print(edge)"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [],
|
| 91 |
+
"source": [
|
| 92 |
+
"import matplotlib.pyplot as plt\n",
|
| 93 |
+
"import numpy as np\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"matrix = np.array(result.suppression_matrix)\n",
|
| 96 |
+
"figure, axis = plt.subplots(figsize=(7, 5))\n",
|
| 97 |
+
"image = axis.imshow(matrix, aspect='auto', cmap='viridis')\n",
|
| 98 |
+
"axis.set_title('Sentence Influence Matrix')\n",
|
| 99 |
+
"axis.set_xlabel('Source sentence')\n",
|
| 100 |
+
"axis.set_ylabel('Target sentence')\n",
|
| 101 |
+
"figure.colorbar(image, ax=axis)\n",
|
| 102 |
+
"plt.show()"
|
| 103 |
+
]
|
| 104 |
+
}
|
| 105 |
+
],
|
| 106 |
+
"metadata": {
|
| 107 |
+
"kernelspec": {
|
| 108 |
+
"display_name": "Python 3",
|
| 109 |
+
"language": "python",
|
| 110 |
+
"name": "python3"
|
| 111 |
+
},
|
| 112 |
+
"language_info": {
|
| 113 |
+
"name": "python",
|
| 114 |
+
"version": "3.11"
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"nbformat": 4,
|
| 118 |
+
"nbformat_minor": 5
|
| 119 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "cot-anc"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Online chain-of-thought analysis with attribution patching"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi>=0.115.0",
|
| 9 |
+
"huggingface_hub[oauth]>=0.33.0",
|
| 10 |
+
"numpy>=2.0.0",
|
| 11 |
+
"pydantic>=2.7.0",
|
| 12 |
+
"scipy>=1.13.0",
|
| 13 |
+
"torch>=2.2.0",
|
| 14 |
+
"transformers>=4.44.0",
|
| 15 |
+
"typer>=0.12.3",
|
| 16 |
+
"uvicorn>=0.30.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
dev = [
|
| 21 |
+
"pytest>=8.2.0",
|
| 22 |
+
]
|
| 23 |
+
viz = [
|
| 24 |
+
"matplotlib>=3.8.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[build-system]
|
| 28 |
+
requires = ["setuptools>=68.0"]
|
| 29 |
+
build-backend = "setuptools.build_meta"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools.packages.find]
|
| 32 |
+
include = ["app*"]
|
| 33 |
+
|
| 34 |
+
[tool.pytest.ini_options]
|
| 35 |
+
testpaths = ["tests"]
|
| 36 |
+
pythonpath = ["."]
|
runpod_start.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
export API_HOST="${API_HOST:-0.0.0.0}"
|
| 5 |
+
export API_PORT="${API_PORT:-7860}"
|
| 6 |
+
export DEVICE_PREFERENCE="${DEVICE_PREFERENCE:-auto}"
|
| 7 |
+
export DTYPE_PREFERENCE="${DTYPE_PREFERENCE:-auto}"
|
| 8 |
+
export ATTN_IMPLEMENTATION="${ATTN_IMPLEMENTATION:-eager}"
|
| 9 |
+
export PRELOAD_MODEL="${PRELOAD_MODEL:-true}"
|
| 10 |
+
|
| 11 |
+
exec uv run python -m app.cli.run_api
|