diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..e3547b7c994f25f7043895495003e1c136c5ffff --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +.git +.venv +.pytest_cache +__pycache__ +*.pyc +*.pyo +*.pyd +outputs +tests +cot_anc.egg-info diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..81b48eb2a2dfadbc477417d63ed25755c0d0cbe2 --- /dev/null +++ b/.env.example @@ -0,0 +1,13 @@ +MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +DEVICE_PREFERENCE=auto +DTYPE_PREFERENCE=auto +ATTN_IMPLEMENTATION=eager +TRUST_REMOTE_CODE=true +LOW_CPU_MEM_USAGE=true +MAX_TRACE_TOKENS=256 +MAX_SENTENCES=16 +TAKE_LOG=true +PRELOAD_MODEL=true +REQUIRE_AUTH=true +MAX_QUEUED_JOBS=8 +MAX_ACTIVE_JOBS_PER_USER=2 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9e2cf02e23713e8c4e41182df287739c27083ff0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +/.venv +/.pytest_cache +/__pycache__ +/data/app.db +.DS_Store +*.pyc +*.pyo +*.pyd +cot_anc.egg-info/ +outputs/ +uv.lock diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..c4a0c8f9234795d6d4ad54151b1c46149caddb83 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,45 @@ +FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV UV_LINK_MODE=copy +ENV HOME=/home/user +ENV PATH=/home/user/.local/bin:$PATH +ENV HF_HOME=/home/user/.cache/huggingface +ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface +ENV API_HOST=0.0.0.0 +ENV API_PORT=7860 +ENV DEVICE_PREFERENCE=auto +ENV DTYPE_PREFERENCE=auto +ENV ATTN_IMPLEMENTATION=eager +ENV LOW_CPU_MEM_USAGE=true +ENV TRUST_REMOTE_CODE=true +ENV PRELOAD_MODEL=true +ENV REQUIRE_AUTH=true + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + git \ + python3 \ + python3-pip \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* + +RUN useradd -m -u 1000 user +USER user +WORKDIR $HOME/app + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +COPY --chown=user pyproject.toml uv.lock README.md .env.example ./ +COPY --chown=user app ./app +COPY --chown=user notebooks ./notebooks +COPY --chown=user tests ./tests + +RUN uv sync --frozen + +EXPOSE 7860 + +CMD ["uv", "run", "python", "-m", "app.cli.run_api"] diff --git a/README.md b/README.md index a67b9f869e759e6a2d88942ddb32892e8a96e719..83ba1c28c3145af95059e8d725922b3cba4717d1 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,85 @@ --- -title: Cot Anc -emoji: šŸ‘ -colorFrom: indigo -colorTo: green +title: Thought Anchors +emoji: 🧠 +colorFrom: blue +colorTo: red sdk: docker +app_port: 7860 +hf_oauth: true pinned: false +models: + - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Thought Anchors + +Thought Anchors generates visible reasoning traces from open-weight models and +computes sentence-to-sentence influence with gradient x attention attribution. + +Current product shape: + +- Hugging Face `Docker Space` first +- Hugging Face OAuth sign-in +- web UI + API +- per-user ephemeral sessions +- JSON / CSV export +- adaptive CPU / MPS / CUDA loading + +## Quick Start + +Install deps: + +```bash +uv sync --extra dev +``` + +Run API: + +```bash +uv run python -m app.cli.run_api +``` + +Run CLI: + +```bash +uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x" +``` + +Run tests: + +```bash +uv run python -m pytest -q +``` + +## Main Endpoints + +- `GET /healthz` +- `GET /api/me` +- `POST /api/warmup` +- `POST /api/analyze` +- `GET /api/sessions` +- `POST /api/sessions` +- `GET /api/sessions/{id}` +- `GET /api/sessions/{id}/result` +- `GET /api/sessions/{id}/export.json` +- `GET /api/sessions/{id}/export.csv` + +## Docs + +- [Hugging Face deployment](./docs/deploy-huggingface.md) +- [Runtime and model support](./docs/runtime.md) +- [API and product behavior](./docs/api.md) +- [Notebook usage](./docs/notebook.md) + +## Notebook + +Colab / Kaggle smoke-test notebook: + +- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb) + +## Key Constraints + +- Attribution needs `attn_implementation="eager"`. +- Model must expose supported decoder layers and attention modules. +- Long traces stay capped because analysis uses full backward pass. +- Space disk is ephemeral; export results you want to keep. diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4814867a127241148d25839faacb615d6a89c1ba --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +"""Application package for chain-of-thought attribution analysis.""" diff --git a/app/analysis/__init__.py b/app/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0277200e78b87d2e8d336dbb5fd5b171a3114bdb --- /dev/null +++ b/app/analysis/__init__.py @@ -0,0 +1 @@ +"""Analysis utilities for attribution patching.""" diff --git a/app/analysis/hooks.py b/app/analysis/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..d0580c700238bf4a239c64fc6aeb0ed805f7187d --- /dev/null +++ b/app/analysis/hooks.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch + +from app.core.model_support import get_decoder_layers + +_ATTENTION_STORE: dict[int, torch.Tensor] = {} + + +def clear_stored_attentions() -> None: + _ATTENTION_STORE.clear() + + +def get_stored_attentions() -> dict[int, torch.Tensor]: + return dict(_ATTENTION_STORE) + + +def _extract_attention_tensor(output: Any) -> torch.Tensor | None: + if isinstance(output, torch.Tensor): + return output if output.dim() == 4 else None + + if isinstance(output, dict): + for value in output.values(): + if isinstance(value, torch.Tensor) and value.dim() == 4: + return value + + if isinstance(output, Iterable) and not isinstance(output, (str, bytes)): + for item in output: + if isinstance(item, torch.Tensor) and item.dim() == 4: + return item + + return None + + +def _get_attention_impl(model: Any) -> str | None: + config = getattr(model, "config", None) + if config is None: + return None + return getattr(config, "_attn_implementation", None) or getattr( + config, + "attn_implementation", + None, + ) + + +def make_attn_hook(layer_idx: int): + def hook(_module: Any, _inputs: Any, output: Any) -> None: + attn = _extract_attention_tensor(output) + if attn is None: + return + if attn.dim() != 4: + raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.") + attn.retain_grad() + _ATTENTION_STORE[layer_idx] = attn + + return hook + + +def register_hooks(model: Any) -> list[Any]: + clear_stored_attentions() + layers, _layer_path, attention_attr = get_decoder_layers(model) + + handles: list[Any] = [] + for layer_idx, layer in enumerate(layers): + self_attn = getattr(layer, attention_attr, None) + if self_attn is None: + raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.") + handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx))) + return handles + + +def remove_hooks(handles: list[Any]) -> None: + for handle in handles: + handle.remove() + clear_stored_attentions() diff --git a/app/analysis/sentence_split.py b/app/analysis/sentence_split.py new file mode 100644 index 0000000000000000000000000000000000000000..c34f234eda78e1eb728f3d3a65c747224865023c --- /dev/null +++ b/app/analysis/sentence_split.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +THINK_TAG_RE = re.compile(r"", re.IGNORECASE) + + +@dataclass(slots=True) +class SentenceSpan: + text: str + start_char: int + end_char: int + + +def normalize_trace_text(raw_trace_text: str) -> str: + return THINK_TAG_RE.sub("", raw_trace_text) + + +def _non_whitespace_token_count(text: str) -> int: + return len(re.findall(r"\S+", text)) + + +def split_sentences( + text: str, + *, + min_token_like_units: int = 2, +) -> list[SentenceSpan]: + if not text: + return [] + + raw_spans: list[tuple[int, int]] = [] + start = 0 + index = 0 + text_length = len(text) + + while index < text_length: + if text[index : index + 2] == "\n\n": + end = index + 2 + raw_spans.append((start, end)) + start = end + index = end + continue + + if text[index] in ".!?": + end = index + 1 + while end < text_length and text[end] in "\"')]}": + end += 1 + while end < text_length and text[end].isspace() and text[end : end + 2] != "\n\n": + end += 1 + raw_spans.append((start, end)) + start = end + index = end + continue + + index += 1 + + if start < text_length: + raw_spans.append((start, text_length)) + + merged: list[tuple[int, int]] = [] + for span_start, span_end in raw_spans: + fragment = text[span_start:span_end] + if not fragment: + continue + if merged and _non_whitespace_token_count(fragment) < min_token_like_units: + previous_start, _ = merged[-1] + merged[-1] = (previous_start, span_end) + continue + merged.append((span_start, span_end)) + + if len(merged) > 1: + last_start, last_end = merged[-1] + if _non_whitespace_token_count(text[last_start:last_end]) < min_token_like_units: + prev_start, _ = merged[-2] + merged[-2] = (prev_start, last_end) + merged.pop() + + return [ + SentenceSpan( + text=text[span_start:span_end], + start_char=span_start, + end_char=span_end, + ) + for span_start, span_end in merged + if text[span_start:span_end] + ] diff --git a/app/analysis/summaries.py b/app/analysis/summaries.py new file mode 100644 index 0000000000000000000000000000000000000000..0af4693dea6e613917bd04df19902bf0d8815799 --- /dev/null +++ b/app/analysis/summaries.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import numpy as np + +from app.core.schemas import TopEdge + + +def compute_outgoing_importance(matrix: np.ndarray) -> list[float]: + sentence_count = matrix.shape[0] + scores: list[float] = [] + for source_idx in range(sentence_count): + column = matrix[source_idx + 1 :, source_idx] + scores.append(float(column.mean()) if column.size else 0.0) + return scores + + +def compute_incoming_importance(matrix: np.ndarray) -> list[float]: + sentence_count = matrix.shape[0] + scores: list[float] = [] + for target_idx in range(sentence_count): + row = matrix[target_idx, :target_idx] + scores.append(float(row.mean()) if row.size else 0.0) + return scores + + +def compute_top_edges(matrix: np.ndarray, top_k: int = 10) -> list[TopEdge]: + sentence_count = matrix.shape[0] + candidates: list[TopEdge] = [] + for target_idx in range(sentence_count): + for source_idx in range(target_idx): + candidates.append( + TopEdge( + source_sentence_idx=source_idx, + target_sentence_idx=target_idx, + score=float(matrix[target_idx, source_idx]), + ) + ) + candidates.sort(key=lambda edge: edge.score, reverse=True) + return candidates[:top_k] diff --git a/app/analysis/suppression.py b/app/analysis/suppression.py new file mode 100644 index 0000000000000000000000000000000000000000..43c60e801dcfd41a829d16f2c32e8b4bd33e2545 --- /dev/null +++ b/app/analysis/suppression.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import time +from dataclasses import asdict +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks +from app.core.model_support import describe_model_support +from app.core.schemas import ModelCapability, RuntimeMetadata + + +@dataclass(slots=True) +class AttributionMatrixComputation: + matrix: np.ndarray + raw_matrix: np.ndarray + token_nll: np.ndarray + runtime_metadata: RuntimeMetadata + + +def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + if logits.ndim != 3 or input_ids.ndim != 2: + raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].") + if logits.shape[0] != 1 or input_ids.shape[0] != 1: + raise ValueError("Only batch size 1 is supported for the prototype.") + if input_ids.shape[1] < 2: + raise ValueError("Need at least two tokens to compute next-token loss.") + + shifted_logits = logits[:, :-1, :] + shifted_targets = input_ids[:, 1:] + log_probs = torch.log_softmax(shifted_logits, dim=-1) + gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1) + return -gathered[0] + + +def _current_memory_mb(device: torch.device) -> float | None: + if device.type != "cuda": + return None + return float(torch.cuda.memory_allocated(device) / (1024 * 1024)) + + +def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray: + if not take_log: + return raw_matrix.copy() + presentation = np.zeros_like(raw_matrix) + positive = raw_matrix > 0 + presentation[positive] = np.log(raw_matrix[positive] + 1e-9) + return presentation + + +def compute_attribution_matrix( + input_ids: torch.Tensor, + token_ranges: list[tuple[int, int]], + model: Any, + take_log: bool = True, + max_trace_tokens: int = 0, + max_sentences: int = 0, +) -> AttributionMatrixComputation: + device = input_ids.device + handles = register_hooks(model) + model.zero_grad(set_to_none=True) + forward_start = time.perf_counter() + memory_before_mb = _current_memory_mb(device) + + try: + with torch.enable_grad(): + outputs = model( + input_ids=input_ids, + output_attentions=True, + return_dict=True, + ) + forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0 + + logits = outputs.logits + token_nll = compute_self_token_nll(logits, input_ids) + loss = token_nll.sum() + + backward_start = time.perf_counter() + loss.backward() + backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0 + + attentions = get_stored_attentions() + if not attentions: + raise RuntimeError("No attention tensors were captured. Check eager attention mode.") + + matrix_start = time.perf_counter() + sentence_count = len(token_ranges) + raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32) + + ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)] + first_attention = ordered_layers[0] + num_layers = len(ordered_layers) + num_heads = int(first_attention.shape[1]) + + for source_idx, (source_start, source_end) in enumerate(token_ranges): + for target_idx, (target_start, target_end) in enumerate(token_ranges): + if target_idx <= source_idx: + continue + + total = 0.0 + for attention in ordered_layers: + grad = attention.grad + if grad is None: + raise RuntimeError("Attention gradient was not retained for one or more layers.") + total += -( + grad[0, :, target_start:target_end, source_start:source_end] + * attention[0, :, target_start:target_end, source_start:source_end] + ).sum().item() + + denominator = max(1, target_end - target_start) + raw_matrix[target_idx, source_idx] = total / denominator + + matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0 + total_analysis_ms = ( + forward_pass_ms + backward_pass_ms + matrix_computation_ms + ) + presentation_matrix = _build_presentation_matrix(raw_matrix, take_log) + + attention_impl = getattr(model.config, "_attn_implementation", "unknown") + capability = describe_model_support(model) + runtime_metadata = RuntimeMetadata( + forward_pass_ms=forward_pass_ms, + backward_pass_ms=backward_pass_ms, + matrix_computation_ms=matrix_computation_ms, + total_analysis_ms=total_analysis_ms, + num_layers=num_layers, + num_heads=num_heads, + sequence_length_tokens=int(input_ids.shape[1]), + sentence_count=sentence_count, + device=str(device), + dtype=str(first_attention.dtype), + attention_impl=str(attention_impl), + max_trace_tokens=max_trace_tokens, + max_sentences=max_sentences, + capability=ModelCapability.model_validate(asdict(capability)), + ) + + memory_after_mb = _current_memory_mb(device) + if memory_before_mb is not None and memory_after_mb is not None: + runtime_metadata = runtime_metadata.model_copy( + update={ + "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)" + } + ) + + return AttributionMatrixComputation( + matrix=presentation_matrix, + raw_matrix=raw_matrix, + token_nll=token_nll.detach().cpu().numpy(), + runtime_metadata=runtime_metadata, + ) + finally: + for attention in get_stored_attentions().values(): + attention.grad = None + remove_hooks(handles) + model.zero_grad(set_to_none=True) + if device.type == "cuda": + torch.cuda.empty_cache() diff --git a/app/analysis/token_boundaries.py b/app/analysis/token_boundaries.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc31a75675744f4dfce0b1b311e335a0ecd5d50 --- /dev/null +++ b/app/analysis/token_boundaries.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from app.analysis.sentence_split import SentenceSpan + + +@dataclass(slots=True) +class TokenizedSentenceMapping: + input_ids: torch.Tensor + token_ranges: list[tuple[int, int]] + offsets: list[tuple[int, int]] + text: str + + +def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str: + if max_tokens <= 0: + raise ValueError("max_tokens must be positive.") + encoded = tokenizer( + text, + add_special_tokens=False, + return_offsets_mapping=True, + ) + offsets = encoded["offset_mapping"] + if len(offsets) <= max_tokens: + return text + end_char = offsets[max_tokens - 1][1] + return text[:end_char] + + +def tokenize_with_sentence_ranges( + text: str, + sentence_spans: list[SentenceSpan], + tokenizer: Any, +) -> TokenizedSentenceMapping: + encoded = tokenizer( + text, + add_special_tokens=False, + return_offsets_mapping=True, + return_tensors="pt", + ) + + input_ids = encoded["input_ids"] + raw_offsets = encoded["offset_mapping"][0].tolist() + offsets = [(int(start), int(end)) for start, end in raw_offsets] + token_ranges: list[tuple[int, int]] = [] + + for span in sentence_spans: + overlapping = [ + token_index + for token_index, (token_start, token_end) in enumerate(offsets) + if token_end > span.start_char and token_start < span.end_char + ] + if not overlapping: + raise ValueError( + f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens." + ) + token_ranges.append((overlapping[0], overlapping[-1] + 1)) + + if token_ranges: + if token_ranges[0][0] != 0 or token_ranges[-1][1] != len(offsets): + raise ValueError("Sentence token ranges do not cover the full analyzed sequence.") + for previous, current in zip(token_ranges, token_ranges[1:]): + if previous[1] != current[0]: + raise ValueError("Sentence token ranges are not contiguous.") + + return TokenizedSentenceMapping( + input_ids=input_ids, + token_ranges=token_ranges, + offsets=offsets, + text=text, + ) diff --git a/app/analysis/validation.py b/app/analysis/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..b47ac29965bb8f4a07e9c31d4bde42a70a10516b --- /dev/null +++ b/app/analysis/validation.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from scipy.stats import pearsonr, spearmanr + +from app.analysis.suppression import compute_self_token_nll +from app.core.schemas import TopEdge, ValidationMetadata + + +def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice: + start, end = token_range + return slice(max(0, start - 1), max(0, end - 1)) + + +def build_exact_suppression_mask( + *, + sequence_length: int, + source_range: tuple[int, int], + target_range: tuple[int, int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + fill_value = torch.finfo(dtype).min + mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype) + future_positions = torch.triu( + torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool), + diagonal=1, + ) + mask = mask.masked_fill(future_positions, fill_value) + source_start, source_end = source_range + target_start, target_end = target_range + mask[target_start:target_end, source_start:source_end] = fill_value + return mask.unsqueeze(0).unsqueeze(0) + + +def compute_exact_edge_score( + *, + model: Any, + input_ids: torch.Tensor, + source_range: tuple[int, int], + target_range: tuple[int, int], + baseline_token_nll: np.ndarray, +) -> float: + model_dtype = next(model.parameters()).dtype + attention_mask = build_exact_suppression_mask( + sequence_length=int(input_ids.shape[1]), + source_range=source_range, + target_range=target_range, + device=input_ids.device, + dtype=model_dtype, + ) + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=False, + return_dict=True, + ) + suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy() + nll_slice = _nll_slice_for_token_range(target_range) + return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum()) + + +def validate_top_edges( + *, + model: Any, + input_ids: torch.Tensor, + token_ranges: list[tuple[int, int]], + top_edges: list[TopEdge], + baseline_token_nll: np.ndarray, + top_k: int, +) -> ValidationMetadata: + if top_k <= 0 or not top_edges: + return ValidationMetadata(enabled=False, top_k=0) + + selected_edges = top_edges[:top_k] + exact_scores: list[float] = [] + attributed_scores: list[float] = [] + compared_edges: list[TopEdge] = [] + + try: + for edge in selected_edges: + exact_score = compute_exact_edge_score( + model=model, + input_ids=input_ids, + source_range=token_ranges[edge.source_sentence_idx], + target_range=token_ranges[edge.target_sentence_idx], + baseline_token_nll=baseline_token_nll, + ) + exact_scores.append(exact_score) + attributed_scores.append(edge.score) + compared_edges.append( + TopEdge( + source_sentence_idx=edge.source_sentence_idx, + target_sentence_idx=edge.target_sentence_idx, + score=exact_score, + ) + ) + except Exception as exc: # pragma: no cover - environment/model dependent + return ValidationMetadata( + enabled=True, + top_k=top_k, + compared_edges=[], + notes=f"Exact suppression validation failed: {exc}", + ) + + pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None + spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None + + return ValidationMetadata( + enabled=True, + top_k=top_k, + pearson=pearson, + spearman=spearman, + compared_edges=compared_edges, + notes="Exact suppression compares sentence-level NLL deltas for selected edges.", + ) diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1abc8addac74151fc1a80fea35ceb9836e9145ce --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1 @@ +"""API package for GPU-hosted analysis service.""" diff --git a/app/api/auth.py b/app/api/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..717879f41d752e903626c810d17d334a311dacd6 --- /dev/null +++ b/app/api/auth.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from fastapi import HTTPException, Request +from huggingface_hub import parse_huggingface_oauth + +from app.core.config import Settings + + +@dataclass(frozen=True, slots=True) +class UserContext: + id: str + username: str + display_name: str | None + avatar_url: str | None + authenticated: bool + + +def get_optional_user(request: Request) -> UserContext | None: + if "session" not in request.scope: + return None + try: + oauth_info = parse_huggingface_oauth(request) + except AssertionError: + return None + if oauth_info is None: + return None + + user_info = oauth_info.user_info + username = user_info.preferred_username or user_info.sub or "hf-user" + display_name = user_info.name or username + return UserContext( + id=user_info.sub or username, + username=username, + display_name=display_name, + avatar_url=user_info.picture, + authenticated=True, + ) + + +def require_user(request: Request, settings: Settings) -> UserContext: + user = get_optional_user(request) + if user is not None: + return user + if settings.require_auth: + raise HTTPException(status_code=401, detail="Sign in with Hugging Face to use this service.") + return UserContext( + id="anonymous", + username="anonymous", + display_name="Anonymous", + avatar_url=None, + authenticated=False, + ) diff --git a/app/api/main.py b/app/api/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e94e4946f3b41d44976e5267dbfb4f8737bb3540 --- /dev/null +++ b/app/api/main.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from io import StringIO +from pathlib import Path +from dataclasses import asdict +import os + +import torch +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import FileResponse +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from huggingface_hub import attach_huggingface_oauth + +from app.api.auth import get_optional_user, require_user +from app.core.config import get_settings +from app.core.runtime import load_model_bundle +from app.core.runtime_pipeline import compute_attribution_analysis +from app.core.schemas import ( + AnalysisRequest, + AnalysisResult, + CurrentUserResponse, + HealthResponse, + SessionCreateRequest, + SessionResponse, + SessionResultResponse, + WarmupResponse, +) +from app.services.sessions import SessionAccessError, SessionLimitError, SessionService +from app.storage.repository import SessionRepository +from app.workers.jobs import build_job_runner + +logger = logging.getLogger(__name__) +FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend" + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + settings = get_settings() + repository = SessionRepository(settings.database_path) + jobs = build_job_runner(settings.job_workers) + _app.state.repository = repository + _app.state.jobs = jobs + _app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs) + if settings.preload_model: + logger.info( + "Preloading model '%s' on device preference '%s'.", + settings.model_name, + settings.device_preference, + ) + load_model_bundle( + settings.model_name, + device_preference=settings.device_preference, + dtype_preference=settings.dtype_preference, + attn_implementation=settings.attn_implementation, + trust_remote_code=settings.trust_remote_code, + low_cpu_mem_usage=settings.low_cpu_mem_usage, + ) + yield + jobs.shutdown() + + +app = FastAPI( + title="CoT Attribution Analysis API", + version="0.1.0", + lifespan=lifespan, +) +if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"): + attach_huggingface_oauth(app) +app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static") + + +def get_session_service() -> SessionService: + return app.state.session_service + + +def _to_session_response(payload: dict) -> SessionResponse: + return SessionResponse( + id=payload["id"], + status=payload["status"], + question=payload["question"], + model_name=payload["model_name"], + error=payload.get("error"), + created_at=payload["created_at"], + updated_at=payload["updated_at"], + answer=payload.get("answer"), + raw_trace_text=payload.get("raw_trace_text"), + normalized_trace_text=payload.get("normalized_trace_text"), + sentences=payload.get("sentences"), + generation_metadata=payload.get("generation_metadata"), + ) + + +@app.get("/", include_in_schema=False) +def index() -> FileResponse: + return FileResponse(FRONTEND_DIR / "index.html") + + +@app.get("/healthz", response_model=HealthResponse) +def healthz() -> HealthResponse: + settings = get_settings() + return HealthResponse( + status="ok", + model_name=settings.model_name, + device_preference=settings.device_preference, + dtype_preference=settings.dtype_preference, + preload_model=settings.preload_model, + cuda_available=torch.cuda.is_available(), + mps_available=torch.backends.mps.is_available(), + require_auth=settings.require_auth, + public_api_enabled=settings.public_api_enabled, + max_queued_jobs=settings.max_queued_jobs, + max_active_jobs_per_user=settings.max_active_jobs_per_user, + ) + + +@app.get("/api/me", response_model=CurrentUserResponse) +def me(request: Request) -> CurrentUserResponse: + settings = get_settings() + user = get_optional_user(request) + return CurrentUserResponse( + authenticated=user is not None, + auth_required=settings.require_auth, + username=user.username if user else None, + full_name=user.display_name if user else None, + avatar_url=user.avatar_url if user else None, + ) + + +@app.post("/api/warmup", response_model=WarmupResponse) +def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse: + settings = get_settings() + bundle = load_model_bundle( + model_name or settings.model_name, + device_preference=device_preference or settings.device_preference, + dtype_preference=settings.dtype_preference, + attn_implementation=settings.attn_implementation, + trust_remote_code=settings.trust_remote_code, + low_cpu_mem_usage=settings.low_cpu_mem_usage, + ) + return WarmupResponse( + status="ready", + model_name=bundle.model_name, + device=str(bundle.device), + dtype=str(bundle.dtype), + capability=asdict(bundle.capability), + ) + + +@app.post("/api/analyze", response_model=AnalysisResult) +def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult: + settings = get_settings() + require_user(http_request, settings) + try: + return compute_attribution_analysis( + question=request.question, + model_name=request.model_name, + take_log=request.take_log, + max_sentences=request.max_sentences, + max_trace_tokens=request.max_trace_tokens, + validate_top_k=request.validate_top_k, + max_new_tokens=request.max_new_tokens, + temperature=request.temperature, + top_p=request.top_p, + device_preference=request.device_preference, + dtype_preference=request.dtype_preference, + attn_implementation=request.attn_implementation, + trust_remote_code=request.trust_remote_code, + low_cpu_mem_usage=request.low_cpu_mem_usage, + ) + except Exception as exc: # pragma: no cover - runtime path + logger.exception("Analysis request failed") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +@app.get("/api/sessions", response_model=list[SessionResponse]) +def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + payloads = service.list_sessions(user.id, limit=limit) + return [_to_session_response(payload) for payload in payloads] + + +@app.post("/api/sessions", response_model=SessionResponse) +def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse: + settings = get_settings() + user = require_user(http_request, settings) + service = get_session_service() + try: + session = service.create_session( + AnalysisRequest( + question=request.question, + model_name=request.model_name, + take_log=request.take_log, + max_sentences=request.max_sentences, + max_trace_tokens=request.max_trace_tokens, + validate_top_k=request.validate_top_k, + max_new_tokens=request.max_new_tokens, + temperature=request.temperature, + top_p=request.top_p, + device_preference=request.device_preference, + dtype_preference=request.dtype_preference, + attn_implementation=request.attn_implementation, + trust_remote_code=request.trust_remote_code, + low_cpu_mem_usage=request.low_cpu_mem_usage, + ), + owner_id=user.id, + owner_name=user.display_name, + ) + except SessionLimitError as exc: + raise HTTPException(status_code=429, detail=str(exc)) from exc + payload = service.get_session_payload(session.id, owner_id=user.id) + return _to_session_response(payload) + + +@app.get("/api/sessions/{session_id}", response_model=SessionResponse) +def get_session(session_id: str, request: Request) -> SessionResponse: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + try: + payload = service.get_session_payload(session_id, owner_id=user.id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Session not found") from exc + except SessionAccessError as exc: + raise HTTPException(status_code=403, detail=str(exc)) from exc + return _to_session_response(payload) + + +@app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse) +def analyze_session(session_id: str, request: Request) -> SessionResponse: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + try: + session = service.start_analysis(session_id, owner_id=user.id) + payload = service.get_session_payload(session.id, owner_id=user.id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Session not found") from exc + except SessionAccessError as exc: + raise HTTPException(status_code=403, detail=str(exc)) from exc + return _to_session_response(payload) + + +@app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse) +def get_session_result(session_id: str, request: Request) -> SessionResultResponse: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + try: + payload = service.get_session_payload(session_id, owner_id=user.id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Session not found") from exc + except SessionAccessError as exc: + raise HTTPException(status_code=403, detail=str(exc)) from exc + + session_response = _to_session_response(payload) + analysis_payload = payload.get("analysis") + return SessionResultResponse( + session=session_response, + analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None, + ) + + +@app.get("/api/sessions/{session_id}/export.json") +def export_session_json(session_id: str, request: Request) -> StreamingResponse: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + try: + payload = service.get_session_payload(session_id, owner_id=user.id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Session not found") from exc + except SessionAccessError as exc: + raise HTTPException(status_code=403, detail=str(exc)) from exc + result = SessionResultResponse( + session=_to_session_response(payload), + analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None, + ) + return StreamingResponse( + iter([result.model_dump_json(indent=2)]), + media_type="application/json", + headers={"content-disposition": f'attachment; filename="{session_id}.json"'}, + ) + + +@app.get("/api/sessions/{session_id}/export.csv") +def export_session_csv(session_id: str, request: Request) -> StreamingResponse: + settings = get_settings() + user = require_user(request, settings) + service = get_session_service() + try: + result = service.get_analysis_result(session_id, owner_id=user.id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Analysis result not found") from exc + except SessionAccessError as exc: + raise HTTPException(status_code=403, detail=str(exc)) from exc + + buffer = StringIO() + buffer.write("source_sentence_idx,target_sentence_idx,score\n") + for edge in result.top_edges: + buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n") + return StreamingResponse( + iter([buffer.getvalue()]), + media_type="text/csv", + headers={"content-disposition": f'attachment; filename="{session_id}.csv"'}, + ) diff --git a/app/cli/__init__.py b/app/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67651742c4d69c795deddd211606da6f12a020a3 --- /dev/null +++ b/app/cli/__init__.py @@ -0,0 +1 @@ +"""CLI entrypoints.""" diff --git a/app/cli/run_api.py b/app/cli/run_api.py new file mode 100644 index 0000000000000000000000000000000000000000..92ae1d032c56ec9a51ebe767070e647d23a11293 --- /dev/null +++ b/app/cli/run_api.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import uvicorn + +from app.core.config import get_settings + + +def main() -> None: + settings = get_settings() + uvicorn.run( + "app.api.main:app", + host=settings.api_host, + port=settings.api_port, + reload=False, + ) + + +if __name__ == "__main__": + main() diff --git a/app/cli/run_prototype.py b/app/cli/run_prototype.py new file mode 100644 index 0000000000000000000000000000000000000000..1536f50ca0f8086ebd11f748e261196b26bd1933 --- /dev/null +++ b/app/cli/run_prototype.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from pathlib import Path + +import typer + +from app.core.runtime_pipeline import compute_attribution_analysis + +cli = typer.Typer(add_completion=False) + + +def _write_heatmap(path: Path, matrix: list[list[float]]) -> None: + try: + import matplotlib.pyplot as plt + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "Heatmap output requires the optional viz dependency: pip install .[viz]" + ) from exc + + figure, axis = plt.subplots(figsize=(8, 6)) + image = axis.imshow(matrix, aspect="auto", cmap="viridis") + axis.set_xlabel("Source sentence") + axis.set_ylabel("Target sentence") + axis.set_title("Sentence influence matrix") + figure.colorbar(image, ax=axis) + figure.tight_layout() + figure.savefig(path) + plt.close(figure) + + +@cli.command() +def main( + question: str, + model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + output_json: Path | None = None, + output_heatmap: Path | None = None, + max_new_tokens: int = 512, + max_trace_tokens: int = 1024, + max_sentences: int = 40, + take_log: bool = True, + validate_top_k: int = 3, + temperature: float = 0.6, + top_p: float = 0.95, + device_preference: str = "auto", + dtype_preference: str = "auto", + attn_implementation: str = "eager", + trust_remote_code: bool = True, + low_cpu_mem_usage: bool = True, +) -> None: + result = compute_attribution_analysis( + question=question, + model_name=model_name, + take_log=take_log, + max_sentences=max_sentences, + max_trace_tokens=max_trace_tokens, + validate_top_k=validate_top_k, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + device_preference=device_preference, + dtype_preference=dtype_preference, + attn_implementation=attn_implementation, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + output_dir = Path("outputs") + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + + json_path = output_json or output_dir / f"analysis_{timestamp}.json" + json_path.write_text(json.dumps(result.model_dump(), indent=2), encoding="utf-8") + + if output_heatmap is not None: + _write_heatmap(output_heatmap, result.suppression_matrix) + + typer.echo(f"Wrote analysis JSON to {json_path}") + typer.echo(f"Sentences: {len(result.sentences)}") + typer.echo(f"Top edge count: {len(result.top_edges)}") + if result.validation_metadata and result.validation_metadata.enabled: + typer.echo( + "Validation:" + f" pearson={result.validation_metadata.pearson}" + f" spearman={result.validation_metadata.spearman}" + ) + + +if __name__ == "__main__": + cli() diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb889f7aec0bdce7f283cdd62173ac724201e81 --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1 @@ +"""Core runtime, config, and schema definitions.""" diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a95c2e534fa165a8c4f0d91d134756c52165f201 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import os +from functools import lru_cache +from dataclasses import dataclass +from typing import Literal + + +@dataclass(frozen=True, slots=True) +class Settings: + model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + max_trace_tokens: int = 1024 + max_sentences: int = 40 + take_log: bool = True + device_preference: Literal["auto", "cuda", "cpu", "mps"] = "auto" + dtype_preference: Literal["auto", "float32", "float16", "bfloat16"] = "auto" + attn_implementation: str = "eager" + trust_remote_code: bool = True + low_cpu_mem_usage: bool = True + preload_model: bool = False + api_host: str = "0.0.0.0" + api_port: int = 7860 + database_path: str = "data/app.db" + job_workers: int = 1 + max_queued_jobs: int = 8 + max_active_jobs_per_user: int = 2 + require_auth: bool = True + public_api_enabled: bool = True + + +DEFAULT_SETTINGS = Settings() + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + take_log = os.getenv("TAKE_LOG", "true").strip().lower() in {"1", "true", "yes", "on"} + trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "true").strip().lower() in {"1", "true", "yes", "on"} + low_cpu_mem_usage = os.getenv("LOW_CPU_MEM_USAGE", "true").strip().lower() in {"1", "true", "yes", "on"} + require_auth = os.getenv("REQUIRE_AUTH", "true").strip().lower() in {"1", "true", "yes", "on"} + public_api_enabled = os.getenv("PUBLIC_API_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"} + return Settings( + model_name=os.getenv("MODEL_NAME", DEFAULT_SETTINGS.model_name), + max_trace_tokens=int(os.getenv("MAX_TRACE_TOKENS", DEFAULT_SETTINGS.max_trace_tokens)), + max_sentences=int(os.getenv("MAX_SENTENCES", DEFAULT_SETTINGS.max_sentences)), + take_log=take_log, + device_preference=os.getenv("DEVICE_PREFERENCE", DEFAULT_SETTINGS.device_preference), # type: ignore[arg-type] + dtype_preference=os.getenv("DTYPE_PREFERENCE", DEFAULT_SETTINGS.dtype_preference), # type: ignore[arg-type] + attn_implementation=os.getenv("ATTN_IMPLEMENTATION", DEFAULT_SETTINGS.attn_implementation), + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=low_cpu_mem_usage, + preload_model=os.getenv("PRELOAD_MODEL", "false").strip().lower() in {"1", "true", "yes", "on"}, + api_host=os.getenv("API_HOST", DEFAULT_SETTINGS.api_host), + api_port=int(os.getenv("API_PORT", DEFAULT_SETTINGS.api_port)), + database_path=os.getenv("DATABASE_PATH", DEFAULT_SETTINGS.database_path), + job_workers=int(os.getenv("JOB_WORKERS", DEFAULT_SETTINGS.job_workers)), + max_queued_jobs=int(os.getenv("MAX_QUEUED_JOBS", DEFAULT_SETTINGS.max_queued_jobs)), + max_active_jobs_per_user=int( + os.getenv("MAX_ACTIVE_JOBS_PER_USER", DEFAULT_SETTINGS.max_active_jobs_per_user) + ), + require_auth=require_auth, + public_api_enabled=public_api_enabled, + ) diff --git a/app/core/model_support.py b/app/core/model_support.py new file mode 100644 index 0000000000000000000000000000000000000000..df0769cb13f23a1af10c4891c224b3f2a0e6c268 --- /dev/null +++ b/app/core/model_support.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class ModelSupport: + supports_attribution: bool + reason: str | None + layer_path: str | None + attention_attr: str | None + layer_count: int + attention_impl: str | None + + +_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = ( + ("model", "layers"), + ("model", "model", "layers"), + ("transformer", "h"), + ("gpt_neox", "layers"), +) +_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention") + + +def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None: + current = obj + for segment in path: + current = getattr(current, segment, None) + if current is None: + return None + return current + + +def _get_attention_impl(model: Any) -> str | None: + config = getattr(model, "config", None) + if config is None: + return None + return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None) + + +def describe_model_support(model: Any) -> ModelSupport: + attn_impl = _get_attention_impl(model) + layers = None + layer_path = None + for candidate in _LAYER_PATH_CANDIDATES: + maybe_layers = _resolve_attr_chain(model, candidate) + if maybe_layers is not None: + layers = list(maybe_layers) + layer_path = ".".join(candidate) + break + + if not layers: + return ModelSupport( + supports_attribution=False, + reason="Unsupported model structure: unable to locate decoder layers.", + layer_path=layer_path, + attention_attr=None, + layer_count=0, + attention_impl=attn_impl, + ) + + for attention_attr in _ATTENTION_ATTR_CANDIDATES: + if all(getattr(layer, attention_attr, None) is not None for layer in layers): + if attn_impl != "eager": + return ModelSupport( + supports_attribution=False, + reason="Attention gradients require attn_implementation='eager'.", + layer_path=layer_path, + attention_attr=attention_attr, + layer_count=len(layers), + attention_impl=attn_impl, + ) + return ModelSupport( + supports_attribution=True, + reason=None, + layer_path=layer_path, + attention_attr=attention_attr, + layer_count=len(layers), + attention_impl=attn_impl, + ) + + return ModelSupport( + supports_attribution=False, + reason="Unsupported attention module layout: no known attention attribute found on decoder layers.", + layer_path=layer_path, + attention_attr=None, + layer_count=len(layers), + attention_impl=attn_impl, + ) + + +def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]: + support = describe_model_support(model) + if not support.supports_attribution or support.layer_path is None or support.attention_attr is None: + reason = support.reason or "Model does not support attribution analysis." + raise RuntimeError(reason) + + layers = _resolve_attr_chain(model, tuple(support.layer_path.split("."))) + if layers is None: + raise RuntimeError("Model support metadata became inconsistent while resolving layers.") + return list(layers), support.layer_path, support.attention_attr diff --git a/app/core/runtime.py b/app/core/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..2b511211f0204881f17cd13558a5305497a0089d --- /dev/null +++ b/app/core/runtime.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase + +from app.core.model_support import ModelSupport, describe_model_support + + +@dataclass(slots=True) +class ModelBundle: + model_name: str + model: PreTrainedModel + tokenizer: PreTrainedTokenizerBase + device: torch.device + dtype: torch.dtype + capability: ModelSupport + + +def resolve_dtype(preference: str, device: torch.device) -> torch.dtype: + if preference == "float32": + return torch.float32 + if preference == "float16": + return torch.float16 + if preference == "bfloat16": + return torch.bfloat16 + if device.type == "cuda": + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if device.type == "mps": + return torch.float16 + return torch.float32 + + +def resolve_device(preference: str = "auto") -> torch.device: + if preference == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available.") + return torch.device("cuda") + if preference == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError("MPS requested but not available.") + return torch.device("mps") + if preference == "cpu": + return torch.device("cpu") + if torch.cuda.is_available(): + return torch.device("cuda") + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +@lru_cache(maxsize=2) +def load_model_bundle( + model_name: str, + device_preference: str = "auto", + dtype_preference: str = "auto", + attn_implementation: str = "eager", + trust_remote_code: bool = True, + low_cpu_mem_usage: bool = True, +) -> ModelBundle: + device = resolve_device(device_preference) + dtype = resolve_dtype(dtype_preference, device) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + torch_dtype=dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + model.to(device) + model.eval() + capability = describe_model_support(model) + + return ModelBundle( + model_name=model_name, + model=model, + tokenizer=tokenizer, + device=device, + dtype=dtype, + capability=capability, + ) + + +def compute_attribution_analysis(**kwargs): + from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis + + return _compute_attribution_analysis(**kwargs) diff --git a/app/core/runtime_pipeline.py b/app/core/runtime_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..700e3c20059144b292668bc4de6229c652f28f7b --- /dev/null +++ b/app/core/runtime_pipeline.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from app.analysis.sentence_split import split_sentences +from app.analysis.summaries import ( + compute_incoming_importance, + compute_outgoing_importance, + compute_top_edges, +) +from app.analysis.suppression import compute_attribution_matrix +from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit +from app.analysis.validation import validate_top_edges +from app.core.config import get_settings +from app.core.runtime import load_model_bundle +from app.core.schemas import AnalysisResult, GenerationResult +from app.generation.service import generate_answer_and_trace + + +def compute_attribution_analysis( + *, + question: str, + model_name: str | None = None, + take_log: bool | None = None, + max_sentences: int | None = None, + max_trace_tokens: int | None = None, + validate_top_k: int = 0, + max_new_tokens: int = 512, + temperature: float = 0.6, + top_p: float = 0.95, + device_preference: str | None = None, + dtype_preference: str | None = None, + attn_implementation: str | None = None, + trust_remote_code: bool | None = None, + low_cpu_mem_usage: bool | None = None, +) -> AnalysisResult: + generation = None + return analyze_generation_result( + question=question, + generation=generation, + model_name=model_name, + take_log=take_log, + max_sentences=max_sentences, + max_trace_tokens=max_trace_tokens, + validate_top_k=validate_top_k, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + device_preference=device_preference, + dtype_preference=dtype_preference, + attn_implementation=attn_implementation, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + +def analyze_generation_result( + *, + question: str, + generation: GenerationResult | None = None, + model_name: str | None = None, + take_log: bool | None = None, + max_sentences: int | None = None, + max_trace_tokens: int | None = None, + validate_top_k: int = 0, + max_new_tokens: int = 512, + temperature: float = 0.6, + top_p: float = 0.95, + device_preference: str | None = None, + dtype_preference: str | None = None, + attn_implementation: str | None = None, + trust_remote_code: bool | None = None, + low_cpu_mem_usage: bool | None = None, +) -> AnalysisResult: + settings = get_settings() + resolved_model_name = model_name or settings.model_name + resolved_take_log = settings.take_log if take_log is None else take_log + resolved_max_sentences = max_sentences or settings.max_sentences + resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens + resolved_device = device_preference or settings.device_preference + resolved_dtype = dtype_preference or settings.dtype_preference + resolved_attn_implementation = attn_implementation or settings.attn_implementation + resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code + resolved_low_cpu_mem_usage = ( + settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage + ) + + bundle = load_model_bundle( + resolved_model_name, + device_preference=resolved_device, + dtype_preference=resolved_dtype, + attn_implementation=resolved_attn_implementation, + trust_remote_code=resolved_trust_remote_code, + low_cpu_mem_usage=resolved_low_cpu_mem_usage, + ) + if not bundle.capability.supports_attribution: + reason = bundle.capability.reason or "Model does not support attribution analysis." + raise RuntimeError(reason) + if generation is None: + generation = generate_answer_and_trace( + question=question, + model_name=resolved_model_name, + model=bundle.model, + tokenizer=bundle.tokenizer, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + truncated_text = truncate_text_to_token_limit( + generation.normalized_trace_text, + bundle.tokenizer, + resolved_max_trace_tokens, + ) + sentence_spans = split_sentences(truncated_text) + if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences: + sentence_spans = sentence_spans[:resolved_max_sentences] + truncated_text = truncated_text[: sentence_spans[-1].end_char] + sentence_spans = split_sentences(truncated_text) + + if not sentence_spans: + raise RuntimeError("Trace normalization produced no analyzable sentences.") + + mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer) + input_ids = mapping.input_ids.to(bundle.device) + computation = compute_attribution_matrix( + input_ids=input_ids, + token_ranges=mapping.token_ranges, + model=bundle.model, + take_log=resolved_take_log, + max_trace_tokens=resolved_max_trace_tokens, + max_sentences=resolved_max_sentences, + ) + + outgoing = compute_outgoing_importance(computation.raw_matrix) + incoming = compute_incoming_importance(computation.raw_matrix) + top_edges = compute_top_edges(computation.raw_matrix, top_k=10) + validation = validate_top_edges( + model=bundle.model, + input_ids=input_ids, + token_ranges=mapping.token_ranges, + top_edges=top_edges, + baseline_token_nll=computation.token_nll, + top_k=validate_top_k, + ) + + return AnalysisResult( + question=question, + model_name=resolved_model_name, + answer=generation.answer, + raw_trace_text=generation.raw_trace_text, + normalized_trace_text=truncated_text, + sentences=[span.text for span in sentence_spans], + sentence_token_ranges=mapping.token_ranges, + suppression_matrix=computation.matrix.tolist(), + raw_suppression_matrix=computation.raw_matrix.tolist(), + outgoing_importance=outgoing, + incoming_importance=incoming, + top_edges=top_edges, + runtime_metadata=computation.runtime_metadata, + validation_metadata=validation, + extra_metadata={ + "raw_generation_text": generation.raw_generation_text, + "generation_metadata": generation.generation_metadata.model_dump(), + "effective_runtime": { + "device_preference": resolved_device, + "dtype_preference": resolved_dtype, + "attn_implementation": resolved_attn_implementation, + "trust_remote_code": resolved_trust_remote_code, + "low_cpu_mem_usage": resolved_low_cpu_mem_usage, + }, + }, + ) diff --git a/app/core/schemas.py b/app/core/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..a900d77eb63eeaba691c6583cdb6cea47025e1ba --- /dev/null +++ b/app/core/schemas.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + + +class TopEdge(BaseModel): + source_sentence_idx: int + target_sentence_idx: int + score: float + + +class ModelCapability(BaseModel): + supports_attribution: bool + reason: str | None = None + layer_path: str | None = None + attention_attr: str | None = None + layer_count: int = 0 + attention_impl: str | None = None + + +class RuntimeMetadata(BaseModel): + forward_pass_ms: float = 0.0 + backward_pass_ms: float = 0.0 + matrix_computation_ms: float = 0.0 + total_analysis_ms: float = 0.0 + num_layers: int = 0 + num_heads: int = 0 + sequence_length_tokens: int = 0 + sentence_count: int = 0 + device: str = "unknown" + dtype: str = "unknown" + attention_impl: str = "unknown" + max_trace_tokens: int = 0 + max_sentences: int = 0 + capability: ModelCapability = Field(default_factory=lambda: ModelCapability(supports_attribution=False)) + + +class ValidationMetadata(BaseModel): + enabled: bool = False + top_k: int = 0 + pearson: float | None = None + spearman: float | None = None + compared_edges: list[TopEdge] = Field(default_factory=list) + notes: str | None = None + + +class GenerationMetadata(BaseModel): + max_new_tokens: int + temperature: float + top_p: float + do_sample: bool + + +class GenerationResult(BaseModel): + question: str + model_name: str + answer: str + raw_generation_text: str + raw_trace_text: str + normalized_trace_text: str + generation_metadata: GenerationMetadata + + +class AnalysisResult(BaseModel): + question: str + model_name: str + answer: str + raw_trace_text: str + normalized_trace_text: str + sentences: list[str] + sentence_token_ranges: list[tuple[int, int]] + suppression_matrix: list[list[float]] + raw_suppression_matrix: list[list[float]] | None = None + outgoing_importance: list[float] + incoming_importance: list[float] + top_edges: list[TopEdge] + runtime_metadata: RuntimeMetadata + validation_metadata: ValidationMetadata | None = None + extra_metadata: dict[str, Any] = Field(default_factory=dict) + + +class AnalysisRequest(BaseModel): + question: str + model_name: str | None = None + take_log: bool | None = None + max_sentences: int | None = None + max_trace_tokens: int | None = None + validate_top_k: int = 0 + max_new_tokens: int = 256 + temperature: float = 0.6 + top_p: float = 0.95 + device_preference: str | None = None + dtype_preference: str | None = None + attn_implementation: str | None = None + trust_remote_code: bool | None = None + low_cpu_mem_usage: bool | None = None + + +class HealthResponse(BaseModel): + status: str + model_name: str + device_preference: str + dtype_preference: str + preload_model: bool + cuda_available: bool + mps_available: bool + require_auth: bool + public_api_enabled: bool + max_queued_jobs: int + max_active_jobs_per_user: int + + +class WarmupResponse(BaseModel): + status: str + model_name: str + device: str + dtype: str + capability: ModelCapability + + +class CurrentUserResponse(BaseModel): + authenticated: bool + auth_required: bool + username: str | None = None + full_name: str | None = None + avatar_url: str | None = None + login_url: str = "/oauth/huggingface/login" + logout_url: str = "/oauth/huggingface/logout" + + +class SessionCreateRequest(BaseModel): + question: str + model_name: str | None = None + take_log: bool | None = None + max_sentences: int | None = None + max_trace_tokens: int | None = None + validate_top_k: int = 0 + max_new_tokens: int = 256 + temperature: float = 0.6 + top_p: float = 0.95 + device_preference: str | None = None + dtype_preference: str | None = None + attn_implementation: str | None = None + trust_remote_code: bool | None = None + low_cpu_mem_usage: bool | None = None + + +class SessionResponse(BaseModel): + id: str + status: str + question: str + model_name: str + error: str | None = None + created_at: str + updated_at: str + answer: str | None = None + raw_trace_text: str | None = None + normalized_trace_text: str | None = None + sentences: list[str] | None = None + generation_metadata: dict[str, Any] | None = None + + +class SessionResultResponse(BaseModel): + session: SessionResponse + analysis: AnalysisResult | None = None diff --git a/app/frontend/app.js b/app/frontend/app.js new file mode 100644 index 0000000000000000000000000000000000000000..473ef1254a1d56a84a9caeb970281c9b4206cf06 --- /dev/null +++ b/app/frontend/app.js @@ -0,0 +1,340 @@ +const form = document.querySelector("#question-form"); +const submitButton = document.querySelector("#submit-button"); +const statusChip = document.querySelector("#status-chip"); +const statusDetail = document.querySelector("#status-detail"); +const authSummary = document.querySelector("#auth-summary"); +const loginLink = document.querySelector("#login-link"); +const logoutLink = document.querySelector("#logout-link"); +const recentSessions = document.querySelector("#recent-sessions"); +const answerOutput = document.querySelector("#answer-output"); +const sessionIdOutput = document.querySelector("#session-id"); +const sentenceList = document.querySelector("#sentence-list"); +const traceCount = document.querySelector("#trace-count"); +const heatmap = document.querySelector("#heatmap"); +const selectionDetails = document.querySelector("#selection-details"); +const metrics = document.querySelector("#metrics"); +const topEdges = document.querySelector("#top-edges"); +const exportJson = document.querySelector("#export-json"); +const exportCsv = document.querySelector("#export-csv"); + +let activeSessionId = null; +let pollHandle = null; +let activePayload = null; +let selectedSentenceIdx = null; +let selectedCell = null; +let currentUser = null; + +function currentSentences() { + return activePayload?.analysis?.sentences || activePayload?.session?.sentences || []; +} + +function setStatus(label, detail) { + statusChip.textContent = label; + statusDetail.textContent = detail; +} + +function clearSelection() { + selectedSentenceIdx = null; + selectedCell = null; + selectionDetails.textContent = "Choose a sentence or heatmap cell."; + metrics.innerHTML = ""; +} + +function setExportLinks(sessionId, enabled) { + exportJson.href = enabled ? `/api/sessions/${sessionId}/export.json` : "#"; + exportCsv.href = enabled ? `/api/sessions/${sessionId}/export.csv` : "#"; + exportJson.classList.toggle("is-disabled", !enabled); + exportCsv.classList.toggle("is-disabled", !enabled); + exportJson.setAttribute("aria-disabled", String(!enabled)); + exportCsv.setAttribute("aria-disabled", String(!enabled)); +} + +function renderRecentSessions(sessions) { + recentSessions.innerHTML = ""; + if (!currentUser?.authenticated) { + recentSessions.textContent = currentUser?.auth_required + ? "Sign in to view your jobs." + : "Anonymous mode is enabled."; + return; + } + if (!sessions.length) { + recentSessions.textContent = "No jobs yet on this running Space instance."; + return; + } + sessions.forEach((session) => { + const item = document.createElement("button"); + item.type = "button"; + item.className = "session-card"; + if (session.id === activeSessionId) { + item.classList.add("is-active"); + } + item.innerHTML = `${session.status}${session.question}`; + item.addEventListener("click", () => { + startPolling(session.id); + }); + recentSessions.appendChild(item); + }); +} + +async function loadRecentSessions() { + if (!currentUser?.authenticated && currentUser?.auth_required) { + renderRecentSessions([]); + return; + } + const response = await fetch("/api/sessions?limit=8"); + if (!response.ok) { + renderRecentSessions([]); + return; + } + renderRecentSessions(await response.json()); +} + +function renderAuth() { + if (!currentUser) { + authSummary.textContent = "Checking sign-in status."; + return; + } + loginLink.hidden = currentUser.authenticated; + logoutLink.hidden = !currentUser.authenticated; + if (currentUser.authenticated) { + authSummary.textContent = `Signed in as ${currentUser.full_name || currentUser.username}.`; + return; + } + authSummary.textContent = currentUser.auth_required + ? "Sign in with Hugging Face to run analysis jobs." + : "Anonymous access is enabled for this Space."; +} + +function renderSession(session) { + sessionIdOutput.textContent = session.id ? `Session ${session.id.slice(0, 8)}` : ""; + answerOutput.textContent = session.answer || "Waiting for generated answer."; + const sentences = currentSentences(); + traceCount.textContent = sentences.length ? `${sentences.length} sentences` : ""; + + sentenceList.innerHTML = ""; + sentences.forEach((sentence, index) => { + const item = document.createElement("li"); + item.className = "sentence-item"; + item.innerHTML = `${index} ${sentence}`; + item.addEventListener("click", () => { + selectedSentenceIdx = index; + selectedCell = null; + renderSelection(); + }); + sentenceList.appendChild(item); + }); +} + +function colorForValue(value, maxAbs) { + if (!maxAbs) { + return "rgba(224, 223, 218, 0.8)"; + } + const normalized = Math.max(-1, Math.min(1, value / maxAbs)); + if (normalized >= 0) { + const alpha = 0.12 + normalized * 0.88; + return `rgba(14, 90, 138, ${alpha.toFixed(3)})`; + } + const alpha = 0.12 + Math.abs(normalized) * 0.88; + return `rgba(215, 106, 52, ${alpha.toFixed(3)})`; +} + +function renderHeatmap(result) { + const matrix = result?.suppression_matrix; + if (!matrix || !matrix.length) { + heatmap.className = "heatmap placeholder-box"; + heatmap.textContent = "Analysis pending."; + return; + } + + const flatValues = matrix.flat(); + const maxAbs = Math.max(...flatValues.map((value) => Math.abs(value))); + heatmap.className = "heatmap"; + heatmap.innerHTML = ""; + const grid = document.createElement("div"); + grid.className = "heatmap-grid"; + grid.style.gridTemplateColumns = `repeat(${matrix.length}, 32px)`; + + matrix.forEach((row, rowIndex) => { + row.forEach((value, colIndex) => { + const cell = document.createElement("button"); + cell.type = "button"; + cell.className = "heatmap-cell"; + cell.style.background = colorForValue(value, maxAbs); + cell.title = `target ${rowIndex} ← source ${colIndex}: ${value.toFixed(4)}`; + if (selectedCell && selectedCell.row === rowIndex && selectedCell.col === colIndex) { + cell.classList.add("is-selected"); + } + cell.addEventListener("click", () => { + selectedSentenceIdx = null; + selectedCell = { row: rowIndex, col: colIndex, value }; + renderSelection(); + }); + grid.appendChild(cell); + }); + }); + heatmap.appendChild(grid); +} + +function renderTopEdges(result) { + topEdges.innerHTML = ""; + const edges = result?.top_edges || []; + if (!edges.length) { + return; + } + edges.slice(0, 5).forEach((edge) => { + const item = document.createElement("div"); + item.className = "edge-card"; + item.innerHTML = `${edge.source_sentence_idx} → ${edge.target_sentence_idx}${edge.score.toFixed(4)}`; + topEdges.appendChild(item); + }); +} + +function renderSelection() { + Array.from(sentenceList.children).forEach((item, index) => { + item.classList.toggle("is-active", selectedSentenceIdx === index); + }); + + const result = activePayload?.analysis; + const session = activePayload?.session; + if (!session) { + clearSelection(); + return; + } + + metrics.innerHTML = ""; + if (selectedSentenceIdx != null && result) { + const outgoing = result.outgoing_importance[selectedSentenceIdx] ?? 0; + const incoming = result.incoming_importance[selectedSentenceIdx] ?? 0; + selectionDetails.innerHTML = `Sentence ${selectedSentenceIdx}
${currentSentences()[selectedSentenceIdx] || ""}`; + metrics.innerHTML = ` +
Outgoing impact${outgoing.toFixed(4)}
+
Incoming dependence${incoming.toFixed(4)}
+ `; + return; + } + + if (selectedCell) { + selectionDetails.innerHTML = `Edge ${selectedCell.col} → ${selectedCell.row}
Influence score ${selectedCell.value.toFixed(4)}`; + metrics.innerHTML = ` +
Source sentence${selectedCell.col}
+
Target sentence${selectedCell.row}
+ `; + return; + } + + selectionDetails.textContent = "Choose a sentence or heatmap cell."; +} + +async function fetchSession(sessionId) { + const response = await fetch(`/api/sessions/${sessionId}/result`); + if (!response.ok) { + throw new Error(`Failed to fetch session ${sessionId}`); + } + return response.json(); +} + +function updateFromPayload(payload) { + activePayload = payload; + activeSessionId = payload.session.id; + renderSession(payload.session); + renderHeatmap(payload.analysis); + renderTopEdges(payload.analysis); + renderSelection(); + setExportLinks(payload.session.id, payload.session.status === "completed"); + + const { status, error } = payload.session; + if (status === "queued") setStatus("Queued", "Waiting for a worker slot."); + if (status === "generating") setStatus("Generating", "The model is producing an answer and visible trace."); + if (status === "answer_ready") setStatus("Analysis pending", "Answer is ready. Attribution analysis is starting."); + if (status === "analyzing") setStatus("Analyzing", "Running forward and backward passes for sentence influence."); + if (status === "completed") setStatus("Completed", "Analysis finished."); + if (status === "failed") setStatus("Failed", error || "The session failed."); + + if (["completed", "failed"].includes(status) && pollHandle) { + window.clearInterval(pollHandle); + pollHandle = null; + } + + loadRecentSessions().catch(() => {}); +} + +async function startPolling(sessionId) { + activeSessionId = sessionId; + if (pollHandle) { + window.clearInterval(pollHandle); + } + const tick = async () => { + try { + const payload = await fetchSession(sessionId); + updateFromPayload(payload); + } catch (error) { + setStatus("Error", String(error)); + window.clearInterval(pollHandle); + pollHandle = null; + } + }; + await tick(); + pollHandle = window.setInterval(tick, 2500); +} + +form.addEventListener("submit", async (event) => { + event.preventDefault(); + if (!currentUser?.authenticated && currentUser?.auth_required) { + window.location.href = loginLink.href; + return; + } + submitButton.disabled = true; + clearSelection(); + sentenceList.innerHTML = ""; + heatmap.className = "heatmap placeholder-box"; + heatmap.textContent = "Analysis pending."; + topEdges.innerHTML = ""; + answerOutput.textContent = "Waiting for generated answer."; + sessionIdOutput.textContent = ""; + traceCount.textContent = ""; + setExportLinks("", false); + setStatus("Submitting", "Creating session."); + + const payload = { + question: document.querySelector("#question").value, + max_new_tokens: Number(document.querySelector("#max-new-tokens").value), + max_trace_tokens: Number(document.querySelector("#max-trace-tokens").value), + max_sentences: Number(document.querySelector("#max-sentences").value), + }; + + try { + const response = await fetch("/api/sessions", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify(payload), + }); + if (!response.ok) { + throw new Error(await response.text()); + } + const session = await response.json(); + await startPolling(session.id); + } catch (error) { + setStatus("Error", String(error)); + } finally { + submitButton.disabled = false; + } +}); + +async function initialize() { + setExportLinks("", false); + try { + const response = await fetch("/api/me"); + currentUser = await response.json(); + } catch (_error) { + currentUser = { + authenticated: false, + auth_required: true, + login_url: "/oauth/huggingface/login", + logout_url: "/oauth/huggingface/logout", + }; + } + renderAuth(); + await loadRecentSessions(); +} + +initialize(); diff --git a/app/frontend/index.html b/app/frontend/index.html new file mode 100644 index 0000000000000000000000000000000000000000..ac4c978cdedd7c8a0c4be401c81bcd1e4294a435 --- /dev/null +++ b/app/frontend/index.html @@ -0,0 +1,108 @@ + + + + + + Thought Anchors + + + +
+
+
+

Public Hugging Face Space

+

Checking sign-in status.

+
+ +
+ +
+

Online Chain-of-Thought Analysis

+

Reasoning traces with white-box sentence influence.

+

+ Submit a question, watch the answer appear first, then inspect the sentence-to-sentence + attribution matrix once analysis completes. Results can be exported as JSON or CSV. +

+
+ +
+
+ + +
+ + + +
+ +
+
+ +
+
Idle
+
Submit a question to create a session.
+
+ +
+
+
+

Your Sessions

+ Ephemeral instance history +
+
Sign in to view your jobs.
+
+ +
+
+

Answer

+ +
+

No answer yet.

+
+ +
+
+

Reasoning Trace

+ +
+
    +
    + +
    +
    +

    Sentence Influence Matrix

    + Rows: targets, columns: sources +
    +
    Analysis pending.
    +
    + + +
    +
    + + + + diff --git a/app/frontend/styles.css b/app/frontend/styles.css new file mode 100644 index 0000000000000000000000000000000000000000..c4f8f94e7dcf9e3d1cc2a6d5f34072b5c5a98db5 --- /dev/null +++ b/app/frontend/styles.css @@ -0,0 +1,335 @@ +:root { + --bg: #f3efe6; + --panel: rgba(255, 252, 247, 0.9); + --ink: #1e1c19; + --muted: #6f655c; + --accent: #0e5a8a; + --accent-soft: #d2ecff; + --line: rgba(30, 28, 25, 0.12); + --warm: #d76a34; + --shadow: 0 20px 60px rgba(34, 25, 16, 0.08); +} + +* { + box-sizing: border-box; +} + +body { + margin: 0; + font-family: "Iowan Old Style", "Palatino Linotype", serif; + color: var(--ink); + background: + radial-gradient(circle at top left, rgba(14, 90, 138, 0.12), transparent 32%), + radial-gradient(circle at top right, rgba(215, 106, 52, 0.18), transparent 26%), + linear-gradient(180deg, #f8f4ec 0%, var(--bg) 100%); +} + +.page { + max-width: 1440px; + margin: 0 auto; + padding: 40px 24px 48px; +} + +.hero { + max-width: 760px; + margin-bottom: 28px; +} + +.compact { + margin: 0; +} + +.auth-strip { + display: flex; + align-items: center; + justify-content: space-between; + gap: 16px; + margin-bottom: 18px; +} + +.auth-actions, +.export-actions { + display: flex; + flex-wrap: wrap; + gap: 10px; +} + +.eyebrow { + margin: 0 0 8px; + text-transform: uppercase; + letter-spacing: 0.12em; + font-size: 0.78rem; + color: var(--accent); +} + +h1, h2 { + margin: 0; + font-weight: 600; +} + +h1 { + font-size: clamp(2.4rem, 5vw, 4.4rem); + line-height: 0.95; +} + +.lede { + margin: 18px 0 0; + font-size: 1.08rem; + color: var(--muted); + max-width: 60ch; +} + +.panel { + background: var(--panel); + border: 1px solid var(--line); + border-radius: 24px; + padding: 20px; + box-shadow: var(--shadow); + backdrop-filter: blur(18px); +} + +.composer { + margin-bottom: 16px; +} + +.label, +.controls span { + display: block; + margin-bottom: 8px; + font-size: 0.92rem; + color: var(--muted); +} + +textarea, +input { + width: 100%; + border: 1px solid rgba(30, 28, 25, 0.16); + border-radius: 14px; + padding: 14px 16px; + font: inherit; + background: rgba(255, 255, 255, 0.72); +} + +textarea { + resize: vertical; + min-height: 132px; +} + +.controls { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 12px; + margin: 16px 0; +} + +button { + appearance: none; + border: none; + border-radius: 999px; + padding: 14px 20px; + font: inherit; + font-weight: 600; + color: white; + background: linear-gradient(135deg, var(--accent), #123954); + cursor: pointer; +} + +.pill-link { + display: inline-flex; + align-items: center; + justify-content: center; + border-radius: 999px; + padding: 12px 18px; + text-decoration: none; + font-weight: 600; + color: white; + background: linear-gradient(135deg, var(--accent), #123954); +} + +.pill-link.secondary { + color: var(--ink); + background: rgba(255, 255, 255, 0.78); + border: 1px solid var(--line); +} + +.pill-link.is-disabled { + pointer-events: none; + opacity: 0.45; +} + +button:disabled { + opacity: 0.5; + cursor: wait; +} + +.status-row { + display: flex; + align-items: center; + gap: 12px; + margin-bottom: 16px; +} + +.status-chip { + display: inline-flex; + align-items: center; + border-radius: 999px; + padding: 8px 14px; + background: var(--accent-soft); + color: var(--accent); + font-size: 0.9rem; +} + +.status-detail, +.muted, +.placeholder { + color: var(--muted); +} + +.grid { + display: grid; + grid-template-columns: 1.2fr 1fr; + gap: 16px; +} + +.jobs-panel { + grid-column: 1 / -1; +} + +.answer-panel, +.trace-panel, +.heatmap-panel, +.details-panel { + min-height: 260px; +} + +.answer-panel, +.trace-panel { + grid-column: span 1; +} + +.heatmap-panel, +.details-panel { + grid-column: span 1; +} + +.panel-heading { + display: flex; + justify-content: space-between; + gap: 12px; + align-items: baseline; + margin-bottom: 16px; +} + +.sentence-list { + list-style: none; + margin: 0; + padding: 0; + max-height: 420px; + overflow: auto; +} + +.sentence-item { + padding: 10px 12px; + border-radius: 14px; + border: 1px solid transparent; + margin-bottom: 8px; + background: rgba(255,255,255,0.45); + cursor: pointer; +} + +.sentence-item:hover, +.sentence-item.is-active { + border-color: rgba(14, 90, 138, 0.25); + background: rgba(210, 236, 255, 0.5); +} + +.sentence-item strong { + display: inline-block; + min-width: 2.5rem; + color: var(--accent); +} + +.heatmap { + display: grid; + gap: 4px; + overflow: auto; + min-height: 300px; +} + +.heatmap-grid { + display: grid; + gap: 4px; +} + +.heatmap-cell { + width: 32px; + height: 32px; + border-radius: 8px; + border: 1px solid rgba(255,255,255,0.4); + cursor: pointer; + position: relative; +} + +.heatmap-cell.is-selected { + outline: 2px solid var(--ink); +} + +.placeholder-box { + display: grid; + place-items: center; + border: 1px dashed var(--line); + border-radius: 18px; +} + +.metric-card, +.edge-card { + padding: 12px 14px; + border-radius: 16px; + background: rgba(255,255,255,0.45); + border: 1px solid var(--line); + margin-top: 10px; +} + +.recent-sessions { + display: grid; + gap: 10px; +} + +.session-card { + color: var(--ink); + border: 1px solid var(--line); + border-radius: 16px; + background: rgba(255, 255, 255, 0.48); + padding: 14px; + cursor: pointer; + text-align: left; +} + +.session-card.is-active { + border-color: rgba(14, 90, 138, 0.25); + background: rgba(210, 236, 255, 0.42); +} + +.session-card strong { + display: block; + margin-bottom: 4px; + color: var(--accent); +} + +.metric-card strong, +.edge-card strong { + display: block; + color: var(--accent); +} + +@media (max-width: 960px) { + .auth-strip, + .grid, + .controls { + grid-template-columns: 1fr; + } + + .auth-strip { + align-items: flex-start; + } +} diff --git a/app/generation/__init__.py b/app/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e50e4a879de95703f25268152e15662957f5050b --- /dev/null +++ b/app/generation/__init__.py @@ -0,0 +1 @@ +"""Generation utilities for trace-producing model calls.""" diff --git a/app/generation/prompting.py b/app/generation/prompting.py new file mode 100644 index 0000000000000000000000000000000000000000..29e12964c333740fddad3b0ced214f01961a028e --- /dev/null +++ b/app/generation/prompting.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any + + +SYSTEM_PROMPT = ( + "You are a careful reasoning assistant. Respond with your full reasoning inside " + "... and then provide a concise final answer after the closing tag." +) + + +def build_messages(question: str) -> list[dict[str, str]]: + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ] + + +def render_prompt(tokenizer: Any, question: str) -> str: + messages = build_messages(question) + if hasattr(tokenizer, "apply_chat_template"): + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n" diff --git a/app/generation/service.py b/app/generation/service.py new file mode 100644 index 0000000000000000000000000000000000000000..923e48e7a1cb59e82348e6b28b8a61d9b2eda509 --- /dev/null +++ b/app/generation/service.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import re +from typing import Any + +import torch + +from app.analysis.sentence_split import normalize_trace_text +from app.core.schemas import GenerationMetadata, GenerationResult +from app.generation.prompting import render_prompt + +THINK_BLOCK_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) +ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE) + + +def _extract_trace_and_answer(text: str) -> tuple[str, str]: + match = THINK_BLOCK_RE.search(text) + if match: + raw_trace = match.group(0) + answer = text[match.end() :].strip() + if not answer: + answer = match.group(1).strip() + return raw_trace, answer + + raw_trace = text.strip() + answer_match = ANSWER_MARKER_RE.search(text) + if answer_match: + answer = text[answer_match.end() :].strip() + else: + paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()] + answer = paragraphs[-1] if paragraphs else raw_trace + return raw_trace, answer + + +def generate_answer_and_trace( + *, + question: str, + model_name: str, + model: Any, + tokenizer: Any, + max_new_tokens: int = 512, + temperature: float = 0.6, + top_p: float = 0.95, +) -> GenerationResult: + prompt_text = render_prompt(tokenizer, question) + encoded = tokenizer(prompt_text, return_tensors="pt") + model_device = next(model.parameters()).device + encoded = {key: value.to(model_device) for key, value in encoded.items()} + input_length = int(encoded["input_ids"].shape[-1]) + do_sample = temperature > 0.0 + + generation_kwargs: dict[str, Any] = { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "top_p": top_p, + "pad_token_id": tokenizer.pad_token_id, + "eos_token_id": tokenizer.eos_token_id, + } + if do_sample: + generation_kwargs["temperature"] = temperature + + with torch.no_grad(): + output_ids = model.generate(**encoded, **generation_kwargs) + + generated_ids = output_ids[0, input_length:] + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) + raw_trace_text, answer = _extract_trace_and_answer(generated_text) + normalized_trace_text = normalize_trace_text(raw_trace_text) + + return GenerationResult( + question=question, + model_name=model_name, + answer=answer, + raw_generation_text=generated_text, + raw_trace_text=raw_trace_text, + normalized_trace_text=normalized_trace_text, + generation_metadata=GenerationMetadata( + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + ), + ) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df6f31315514e5aac192dc34aaeb4e48a44f8c4d --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1 @@ +"""Application services for session orchestration.""" diff --git a/app/services/sessions.py b/app/services/sessions.py new file mode 100644 index 0000000000000000000000000000000000000000..1e06155db33880dbdb5a568016f5695c95949679 --- /dev/null +++ b/app/services/sessions.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from app.analysis.sentence_split import split_sentences +from app.core.config import Settings +from app.core.runtime import load_model_bundle +from app.core.runtime_pipeline import analyze_generation_result +from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult +from app.generation.service import generate_answer_and_trace +from app.storage.repository import SessionRecord, SessionRepository +from app.workers.jobs import JobRunner + + +class SessionLimitError(RuntimeError): + pass + + +class SessionAccessError(PermissionError): + pass + + +@dataclass(slots=True) +class SessionService: + settings: Settings + repository: SessionRepository + jobs: JobRunner + + def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord: + if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs: + raise SessionLimitError("The service queue is full. Try again after a few minutes.") + if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user: + raise SessionLimitError("You already have the maximum number of active analysis jobs.") + model_name = request.model_name or self.settings.model_name + session = self.repository.create_session( + question=request.question, + model_name=model_name, + owner_id=owner_id, + owner_name=owner_name, + ) + self.jobs.submit(self._run_session_pipeline, session.id, request) + return session + + def start_analysis( + self, + session_id: str, + *, + owner_id: str, + request: AnalysisRequest | None = None, + ) -> SessionRecord: + session = self.repository.get_session(session_id) + self._assert_owner(session, owner_id) + effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name) + self.jobs.submit(self._run_analysis_only, session_id, effective_request) + return session + + def get_session_payload(self, session_id: str, *, owner_id: str) -> dict: + session = self.repository.get_session(session_id) + self._assert_owner(session, owner_id) + return self.repository.list_session_payload(session_id) + + def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]: + return self.repository.list_sessions_for_owner(owner_id, limit=limit) + + def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult: + payload = self.get_session_payload(session_id, owner_id=owner_id) + analysis = payload.get("analysis") + if analysis is None: + raise KeyError(session_id) + return AnalysisResult.model_validate(analysis) + + @staticmethod + def _assert_owner(session: SessionRecord, owner_id: str) -> None: + if session.owner_id != owner_id: + raise SessionAccessError("Session belongs to a different user.") + + def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None: + try: + self.repository.update_status(session_id, status="generating") + bundle = load_model_bundle( + request.model_name or self.settings.model_name, + device_preference=request.device_preference or self.settings.device_preference, + dtype_preference=request.dtype_preference or self.settings.dtype_preference, + attn_implementation=request.attn_implementation or self.settings.attn_implementation, + trust_remote_code=( + self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code + ), + low_cpu_mem_usage=( + self.settings.low_cpu_mem_usage + if request.low_cpu_mem_usage is None + else request.low_cpu_mem_usage + ), + ) + generation = generate_answer_and_trace( + question=request.question, + model_name=bundle.model_name, + model=bundle.model, + tokenizer=bundle.tokenizer, + max_new_tokens=request.max_new_tokens, + temperature=request.temperature, + top_p=request.top_p, + ) + sentences = [span.text for span in split_sentences(generation.normalized_trace_text)] + self.repository.save_generation_result(session_id, generation, sentences) + self.repository.update_status(session_id, status="answer_ready") + self._run_analysis_only(session_id, request, generation=generation) + except Exception as exc: + self.repository.update_status(session_id, status="failed", error=str(exc)) + + def _run_analysis_only( + self, + session_id: str, + request: AnalysisRequest, + *, + generation=None, + ) -> None: + try: + self.repository.update_status(session_id, status="analyzing") + if generation is None: + payload = self.repository.list_session_payload(session_id) + if payload.get("generation_metadata") is not None: + generation = GenerationResult( + question=payload["question"], + model_name=payload["model_name"], + answer=payload["answer"], + raw_generation_text=payload.get("raw_generation_text", ""), + raw_trace_text=payload["raw_trace_text"], + normalized_trace_text=payload["normalized_trace_text"], + generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]), + ) + result = analyze_generation_result( + question=request.question, + generation=generation, + model_name=request.model_name or self.settings.model_name, + take_log=request.take_log, + max_sentences=request.max_sentences, + max_trace_tokens=request.max_trace_tokens, + validate_top_k=request.validate_top_k, + device_preference=request.device_preference or self.settings.device_preference, + dtype_preference=request.dtype_preference or self.settings.dtype_preference, + attn_implementation=request.attn_implementation or self.settings.attn_implementation, + trust_remote_code=( + self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code + ), + low_cpu_mem_usage=( + self.settings.low_cpu_mem_usage + if request.low_cpu_mem_usage is None + else request.low_cpu_mem_usage + ), + max_new_tokens=request.max_new_tokens, + temperature=request.temperature, + top_p=request.top_p, + ) + self.repository.save_analysis_result(session_id, result) + self.repository.update_status(session_id, status="completed") + except Exception as exc: + self.repository.update_status(session_id, status="failed", error=str(exc)) diff --git a/app/storage/__init__.py b/app/storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc5cba5510dc993738cdcd0e74c464a43755fbb --- /dev/null +++ b/app/storage/__init__.py @@ -0,0 +1 @@ +"""Persistence layer for sessions and analysis results.""" diff --git a/app/storage/db.py b/app/storage/db.py new file mode 100644 index 0000000000000000000000000000000000000000..312ca13d854196e94d0d72f6ac30dc3868111dea --- /dev/null +++ b/app/storage/db.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import sqlite3 +from pathlib import Path + + +def connect(database_path: str) -> sqlite3.Connection: + path = Path(database_path) + path.parent.mkdir(parents=True, exist_ok=True) + connection = sqlite3.connect(path, check_same_thread=False) + connection.row_factory = sqlite3.Row + connection.execute("PRAGMA journal_mode=WAL") + connection.execute("PRAGMA foreign_keys=ON") + return connection + + +def initialize_schema(connection: sqlite3.Connection) -> None: + connection.executescript( + """ + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + status TEXT NOT NULL, + question TEXT NOT NULL, + model_name TEXT NOT NULL, + owner_id TEXT NOT NULL, + owner_name TEXT, + error TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS generation_results ( + session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE, + answer TEXT NOT NULL, + raw_generation_text TEXT NOT NULL, + raw_trace_text TEXT NOT NULL, + normalized_trace_text TEXT NOT NULL, + sentences_json TEXT NOT NULL, + generation_metadata_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS analysis_results ( + session_id TEXT PRIMARY KEY REFERENCES sessions(id) ON DELETE CASCADE, + result_json TEXT NOT NULL + ); + """ + ) + columns = { + row["name"] for row in connection.execute("PRAGMA table_info(generation_results)").fetchall() + } + if "raw_generation_text" not in columns: + connection.execute( + "ALTER TABLE generation_results ADD COLUMN raw_generation_text TEXT NOT NULL DEFAULT ''" + ) + session_columns = {row["name"] for row in connection.execute("PRAGMA table_info(sessions)").fetchall()} + if "owner_id" not in session_columns: + connection.execute("ALTER TABLE sessions ADD COLUMN owner_id TEXT NOT NULL DEFAULT 'legacy-user'") + if "owner_name" not in session_columns: + connection.execute("ALTER TABLE sessions ADD COLUMN owner_name TEXT") + connection.commit() diff --git a/app/storage/repository.py b/app/storage/repository.py new file mode 100644 index 0000000000000000000000000000000000000000..53e20d809e4a937fef38f0d3c50257b180069cf0 --- /dev/null +++ b/app/storage/repository.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import json +import threading +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from app.core.schemas import AnalysisResult, GenerationResult +from app.storage.db import connect, initialize_schema + + +def _utc_now() -> str: + return datetime.now(UTC).isoformat() + + +@dataclass(slots=True) +class SessionRecord: + id: str + status: str + question: str + model_name: str + owner_id: str + owner_name: str | None + error: str | None + created_at: str + updated_at: str + + +class SessionRepository: + def __init__(self, database_path: str) -> None: + self.connection = connect(database_path) + initialize_schema(self.connection) + self.lock = threading.Lock() + + def create_session(self, *, question: str, model_name: str, owner_id: str, owner_name: str | None) -> SessionRecord: + session_id = str(uuid.uuid4()) + now = _utc_now() + with self.lock: + self.connection.execute( + """ + INSERT INTO sessions (id, status, question, model_name, owner_id, owner_name, error, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (session_id, "queued", question, model_name, owner_id, owner_name, None, now, now), + ) + self.connection.commit() + return self.get_session(session_id) + + def get_session(self, session_id: str) -> SessionRecord: + row = self.connection.execute( + "SELECT * FROM sessions WHERE id = ?", + (session_id,), + ).fetchone() + if row is None: + raise KeyError(session_id) + return SessionRecord(**dict(row)) + + def list_session_payload(self, session_id: str) -> dict[str, Any]: + session = self.get_session(session_id) + payload: dict[str, Any] = { + "id": session.id, + "status": session.status, + "question": session.question, + "model_name": session.model_name, + "owner_id": session.owner_id, + "owner_name": session.owner_name, + "error": session.error, + "created_at": session.created_at, + "updated_at": session.updated_at, + } + generation_row = self.connection.execute( + "SELECT * FROM generation_results WHERE session_id = ?", + (session_id,), + ).fetchone() + if generation_row is not None: + payload["answer"] = generation_row["answer"] + payload["raw_generation_text"] = generation_row["raw_generation_text"] + payload["raw_trace_text"] = generation_row["raw_trace_text"] + payload["normalized_trace_text"] = generation_row["normalized_trace_text"] + payload["sentences"] = json.loads(generation_row["sentences_json"]) + payload["generation_metadata"] = json.loads(generation_row["generation_metadata_json"]) + analysis_row = self.connection.execute( + "SELECT result_json FROM analysis_results WHERE session_id = ?", + (session_id,), + ).fetchone() + if analysis_row is not None: + payload["analysis"] = json.loads(analysis_row["result_json"]) + return payload + + def list_sessions_for_owner(self, owner_id: str, *, limit: int = 20) -> list[dict[str, Any]]: + rows = self.connection.execute( + """ + SELECT id + FROM sessions + WHERE owner_id = ? + ORDER BY updated_at DESC + LIMIT ? + """, + (owner_id, limit), + ).fetchall() + return [self.list_session_payload(row["id"]) for row in rows] + + def update_status(self, session_id: str, *, status: str, error: str | None = None) -> None: + now = _utc_now() + with self.lock: + self.connection.execute( + "UPDATE sessions SET status = ?, error = ?, updated_at = ? WHERE id = ?", + (status, error, now, session_id), + ) + self.connection.commit() + + def count_incomplete_sessions(self) -> int: + row = self.connection.execute( + "SELECT COUNT(*) AS count FROM sessions WHERE status NOT IN ('completed', 'failed')" + ).fetchone() + return int(row["count"]) + + def count_incomplete_sessions_for_owner(self, owner_id: str) -> int: + row = self.connection.execute( + """ + SELECT COUNT(*) AS count + FROM sessions + WHERE owner_id = ? AND status NOT IN ('completed', 'failed') + """, + (owner_id,), + ).fetchone() + return int(row["count"]) + + def save_generation_result(self, session_id: str, generation: GenerationResult, sentences: list[str]) -> None: + with self.lock: + self.connection.execute( + """ + INSERT INTO generation_results ( + session_id, + answer, + raw_generation_text, + raw_trace_text, + normalized_trace_text, + sentences_json, + generation_metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(session_id) DO UPDATE SET + answer = excluded.answer, + raw_generation_text = excluded.raw_generation_text, + raw_trace_text = excluded.raw_trace_text, + normalized_trace_text = excluded.normalized_trace_text, + sentences_json = excluded.sentences_json, + generation_metadata_json = excluded.generation_metadata_json + """, + ( + session_id, + generation.answer, + generation.raw_generation_text, + generation.raw_trace_text, + generation.normalized_trace_text, + json.dumps(sentences), + json.dumps(generation.generation_metadata.model_dump()), + ), + ) + self.connection.execute( + "UPDATE sessions SET updated_at = ? WHERE id = ?", + (_utc_now(), session_id), + ) + self.connection.commit() + + def save_analysis_result(self, session_id: str, result: AnalysisResult) -> None: + with self.lock: + self.connection.execute( + """ + INSERT INTO analysis_results (session_id, result_json) + VALUES (?, ?) + ON CONFLICT(session_id) DO UPDATE SET result_json = excluded.result_json + """, + (session_id, result.model_dump_json()), + ) + self.connection.execute( + "UPDATE sessions SET updated_at = ? WHERE id = ?", + (_utc_now(), session_id), + ) + self.connection.commit() diff --git a/app/workers/__init__.py b/app/workers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c667ea692b3eb86f2af714fd0215f23805d3fd50 --- /dev/null +++ b/app/workers/__init__.py @@ -0,0 +1 @@ +"""Background job runner.""" diff --git a/app/workers/jobs.py b/app/workers/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..df0fb3cc64026cedcee7481b0af6bd116e0cf24a --- /dev/null +++ b/app/workers/jobs.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Callable + + +@dataclass(slots=True) +class JobRunner: + executor: ThreadPoolExecutor + + def submit(self, fn: Callable[..., None], *args, **kwargs) -> Future: + return self.executor.submit(fn, *args, **kwargs) + + def shutdown(self) -> None: + self.executor.shutdown(wait=False, cancel_futures=True) + + +def build_job_runner(max_workers: int) -> JobRunner: + return JobRunner(executor=ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="cot-anc")) diff --git a/cot_anc.egg-info/PKG-INFO b/cot_anc.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..0495e7caeedf7cd9f6d462fa0628ccc46b0d2ce8 --- /dev/null +++ b/cot_anc.egg-info/PKG-INFO @@ -0,0 +1,147 @@ +Metadata-Version: 2.4 +Name: cot-anc +Version: 0.1.0 +Summary: Online chain-of-thought analysis with attribution patching +Requires-Python: >=3.11 +Description-Content-Type: text/markdown +Requires-Dist: fastapi>=0.115.0 +Requires-Dist: huggingface_hub[oauth]>=0.33.0 +Requires-Dist: numpy>=2.0.0 +Requires-Dist: pydantic>=2.7.0 +Requires-Dist: scipy>=1.13.0 +Requires-Dist: torch>=2.2.0 +Requires-Dist: transformers>=4.44.0 +Requires-Dist: typer>=0.12.3 +Requires-Dist: uvicorn>=0.30.0 +Provides-Extra: dev +Requires-Dist: pytest>=8.2.0; extra == "dev" +Provides-Extra: viz +Requires-Dist: matplotlib>=3.8.0; extra == "viz" + +--- +title: Thought Anchors +emoji: 🧠 +colorFrom: blue +colorTo: orange +sdk: docker +app_port: 7860 +hf_oauth: true +pinned: false +models: + - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +--- + +# Thought Anchors + +Public-facing FastAPI service for generating a visible reasoning trace and computing +sentence-to-sentence attribution on open-weight reasoning models. + +The app is now shaped for deployment as a Hugging Face `Docker Space`: + +- Hugging Face OAuth sign-in for end users +- browser UI plus programmatic API +- per-user ephemeral sessions on the running instance +- JSON and CSV export for completed analyses +- adaptive device and dtype loading for CPU, MPS, CUDA, Colab, Kaggle, and cloud GPUs + +## What It Can Do + +- generate an answer plus visible `...` trace from a supported causal LM +- normalize and split the trace into sentences +- compute a sentence influence matrix with gradient x attention attribution +- summarize incoming / outgoing importance and top edges +- expose the workflow through: + - CLI + - FastAPI + - web UI + - async session queue + +## Current Deployment Target + +The primary deployment target is a Hugging Face `Docker Space` running on upgraded GPU +hardware. The same app can also be run locally or on other cloud GPU hosts. + +Important runtime constraints: + +- attribution requires `attn_implementation="eager"` +- the model must expose usable attention tensors and a supported decoder-layer layout +- long traces are intentionally capped because the analysis path uses a full backward pass + +## Local Development + +Install dependencies: + +```bash +uv sync +``` + +Run the API: + +```bash +uv run python -m app.cli.run_api +``` + +Run the CLI: + +```bash +uv run python -m app.cli.run_prototype "Explain why the derivative of x^2 is 2x" +``` + +## API + +Main endpoints: + +- `GET /healthz` +- `GET /api/me` +- `POST /api/warmup` +- `POST /api/analyze` +- `GET /api/sessions` +- `POST /api/sessions` +- `GET /api/sessions/{id}` +- `GET /api/sessions/{id}/result` +- `GET /api/sessions/{id}/export.json` +- `GET /api/sessions/{id}/export.csv` + +Example: + +```bash +curl -X POST http://localhost:7860/api/analyze \ + -H 'content-type: application/json' \ + -d '{ + "question": "Explain why the derivative of x^2 is 2x", + "max_new_tokens": 128, + "max_trace_tokens": 256, + "max_sentences": 16, + "validate_top_k": 0 + }' +``` + +## Hugging Face Space Setup + +Recommended environment variables: + +- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` +- `DEVICE_PREFERENCE=auto` +- `DTYPE_PREFERENCE=auto` +- `ATTN_IMPLEMENTATION=eager` +- `LOW_CPU_MEM_USAGE=true` +- `TRUST_REMOTE_CODE=true` +- `PRELOAD_MODEL=true` +- `MAX_TRACE_TOKENS=256` +- `MAX_SENTENCES=16` +- `JOB_WORKERS=1` +- `MAX_QUEUED_JOBS=8` +- `MAX_ACTIVE_JOBS_PER_USER=2` +- `REQUIRE_AUTH=true` + +Notes: + +- local disk is ephemeral; users should export results they want to keep +- use upgraded GPU hardware for real attribution runs +- keep trace limits conservative for public traffic + +## Notebook + +Use [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb) +for Colab or Kaggle smoke testing. It installs dependencies, warms the model, runs one +short attribution analysis, prints the top edges, and renders a simple heatmap. diff --git a/cot_anc.egg-info/SOURCES.txt b/cot_anc.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..df4de4b428753501cbfab92d1cd5500a55564e64 --- /dev/null +++ b/cot_anc.egg-info/SOURCES.txt @@ -0,0 +1,44 @@ +README.md +pyproject.toml +app/__init__.py +app/analysis/__init__.py +app/analysis/hooks.py +app/analysis/sentence_split.py +app/analysis/summaries.py +app/analysis/suppression.py +app/analysis/token_boundaries.py +app/analysis/validation.py +app/api/__init__.py +app/api/auth.py +app/api/main.py +app/cli/__init__.py +app/cli/run_api.py +app/cli/run_prototype.py +app/core/__init__.py +app/core/config.py +app/core/model_support.py +app/core/runtime.py +app/core/runtime_pipeline.py +app/core/schemas.py +app/generation/__init__.py +app/generation/prompting.py +app/generation/service.py +app/services/__init__.py +app/services/sessions.py +app/storage/__init__.py +app/storage/db.py +app/storage/repository.py +app/workers/__init__.py +app/workers/jobs.py +cot_anc.egg-info/PKG-INFO +cot_anc.egg-info/SOURCES.txt +cot_anc.egg-info/dependency_links.txt +cot_anc.egg-info/requires.txt +cot_anc.egg-info/top_level.txt +tests/test_api.py +tests/test_model_support.py +tests/test_runtime_pipeline.py +tests/test_sentence_split.py +tests/test_summaries.py +tests/test_suppression_shape.py +tests/test_token_boundaries.py \ No newline at end of file diff --git a/cot_anc.egg-info/dependency_links.txt b/cot_anc.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/cot_anc.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cot_anc.egg-info/requires.txt b/cot_anc.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d7d47c9cfcdbea0a30307aae3d9c1fb1d6b8f49 --- /dev/null +++ b/cot_anc.egg-info/requires.txt @@ -0,0 +1,15 @@ +fastapi>=0.115.0 +huggingface_hub[oauth]>=0.33.0 +numpy>=2.0.0 +pydantic>=2.7.0 +scipy>=1.13.0 +torch>=2.2.0 +transformers>=4.44.0 +typer>=0.12.3 +uvicorn>=0.30.0 + +[dev] +pytest>=8.2.0 + +[viz] +matplotlib>=3.8.0 diff --git a/cot_anc.egg-info/top_level.txt b/cot_anc.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..b80f0bd60822d4fa4893de455958ef32f6c521bf --- /dev/null +++ b/cot_anc.egg-info/top_level.txt @@ -0,0 +1 @@ +app diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..387813ee7d50df1ada46c496aa861d0d92398283 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,83 @@ +# API And Product Behavior + +## Auth Model + +- If `REQUIRE_AUTH=true`, analysis endpoints require Hugging Face sign-in. +- Sessions are scoped to user identity. +- One user cannot fetch another user’s session or export. + +## Session Model + +Statuses: + +- `queued` +- `generating` +- `answer_ready` +- `analyzing` +- `completed` +- `failed` + +Sessions are ephemeral for current running instance. + +## Endpoints + +### `GET /healthz` + +Returns: + +- model name +- device preference +- dtype preference +- auth requirement +- queue limits +- CUDA / MPS availability + +### `GET /api/me` + +Returns current auth state plus login/logout URLs. + +### `POST /api/warmup` + +Loads model with current runtime policy and returns: + +- resolved device +- resolved dtype +- model attribution capability + +### `POST /api/analyze` + +Direct synchronous analysis call. Good for trusted programmatic use, not best path for UI. + +### `GET /api/sessions` + +List current user’s recent sessions on current instance. + +### `POST /api/sessions` + +Create async session job. + +Queue protections: + +- global queue cap +- per-user active-job cap + +### `GET /api/sessions/{id}` + +Return session summary. + +### `GET /api/sessions/{id}/result` + +Return session summary + full analysis if ready. + +### Export + +- `GET /api/sessions/{id}/export.json` +- `GET /api/sessions/{id}/export.csv` + +JSON contains session + analysis payload. + +CSV contains top edges: + +- `source_sentence_idx` +- `target_sentence_idx` +- `score` diff --git a/docs/deploy-huggingface.md b/docs/deploy-huggingface.md new file mode 100644 index 0000000000000000000000000000000000000000..ad342b5483eb281849f96816dae6de6cc3ee606e --- /dev/null +++ b/docs/deploy-huggingface.md @@ -0,0 +1,69 @@ +# Hugging Face Deployment + +Primary target: Hugging Face `Docker Space` on upgraded GPU hardware. + +## What Gets Deployed + +- FastAPI backend +- static web frontend +- Hugging Face OAuth routes +- ephemeral SQLite-backed session queue + +## Required Space Settings + +- SDK: `Docker` +- Port: `7860` +- OAuth: enabled via README metadata +- Hardware: upgraded GPU recommended + +## Recommended Runtime Variables + +Core: + +- `MODEL_NAME=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` +- `DEVICE_PREFERENCE=auto` +- `DTYPE_PREFERENCE=auto` +- `ATTN_IMPLEMENTATION=eager` +- `LOW_CPU_MEM_USAGE=true` +- `TRUST_REMOTE_CODE=true` +- `PRELOAD_MODEL=true` + +Traffic limits: + +- `MAX_TRACE_TOKENS=256` +- `MAX_SENTENCES=16` +- `JOB_WORKERS=1` +- `MAX_QUEUED_JOBS=8` +- `MAX_ACTIVE_JOBS_PER_USER=2` +- `REQUIRE_AUTH=true` + +## Deploy Flow + +1. Create new Hugging Face Space with `Docker` SDK. +2. Push repo contents. +3. Set runtime variables in Space settings. +4. Upgrade hardware. +5. Wait for build. +6. Verify: + - `GET /healthz` + - sign-in works + - one short analysis completes + - JSON / CSV export works + +## Operational Notes + +- Local disk is ephemeral. Session history disappears on restart. +- OAuth helper is mocked locally but real inside Space. +- Keep public defaults conservative. Long traces can OOM small GPUs. +- If queue pressure grows, lower token caps before increasing worker count. + +## Common Failure Modes + +- `attn_implementation` not eager: + - attribution disabled for model +- unsupported model layout: + - generation may work, attribution fails early with clear error +- OOM: + - reduce `MAX_TRACE_TOKENS`, `MAX_SENTENCES`, or choose larger GPU +- cold start slow: + - keep `PRELOAD_MODEL=true` diff --git a/docs/notebook.md b/docs/notebook.md new file mode 100644 index 0000000000000000000000000000000000000000..c9df71bf6c96f78f1b3f1c08fbaa593b652c7244 --- /dev/null +++ b/docs/notebook.md @@ -0,0 +1,38 @@ +# Notebook Usage + +Notebook path: + +- [notebooks/hf_space_demo.ipynb](/Users/vibhorkumar/Desktop/codes/cot-anc/notebooks/hf_space_demo.ipynb) + +## Purpose + +Short smoke test for: + +- Colab GPU +- Kaggle GPU +- quick local validation + +## What Notebook Does + +1. Install runtime deps. +2. Set conservative env vars. +3. Import project pipeline. +4. Run one short attribution job. +5. Print answer, runtime metadata, top edges. +6. Render heatmap. + +## Best Use + +- validate model availability +- validate driver / torch / transformers stack +- sanity-check latency before public deploy + +## If Notebook Fails + +Check: + +- GPU available +- model access permissions +- enough VRAM +- eager attention enabled +- trace limits not too high diff --git a/docs/runtime.md b/docs/runtime.md new file mode 100644 index 0000000000000000000000000000000000000000..b55208ebc23888f8fd266d62e89869ed9cca9c00 --- /dev/null +++ b/docs/runtime.md @@ -0,0 +1,77 @@ +# Runtime And Model Support + +## Execution Model + +Pipeline: + +1. Generate visible reasoning trace. +2. Normalize and split trace into sentences. +3. Map sentence spans to token spans. +4. Run forward + backward pass. +5. Build sentence influence matrix from gradient x attention. +6. Summarize top edges and importance scores. + +## Device And Dtype Policy + +Default policy: + +- CUDA: + - `bfloat16` if supported + - else `float16` +- MPS: + - `float16` +- CPU: + - `float32` + +Override with: + +- `DTYPE_PREFERENCE` +- request `dtype_preference` + +## Model Requirements + +Model must support all of: + +- causal LM generation +- `output_attentions=True` +- eager attention +- supported decoder layer layout +- supported attention module attribute + +Supported layer paths: + +- `model.layers` +- `model.model.layers` +- `transformer.h` +- `gpt_neox.layers` + +Supported attention attrs: + +- `self_attn` +- `attn` +- `attention` + +## Why Trace Limits Exist + +Attribution path uses full backward pass over attention tensors. Cost grows with: + +- sequence length +- layer count +- head count +- sentence count + +Public defaults stay small to protect uptime. + +## Good First Runtime Settings + +For public demo: + +- `max_new_tokens=128` +- `max_trace_tokens=256` +- `max_sentences=16` +- `validate_top_k=0` + +For deeper analysis on bigger GPU: + +- raise trace tokens slowly +- watch latency and memory first diff --git a/notebooks/hf_space_demo.ipynb b/notebooks/hf_space_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..104b532bf3b4df279509ef862509efe2066e8845 --- /dev/null +++ b/notebooks/hf_space_demo.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Thought Anchors Demo\n", + "\n", + "Colab/Kaggle notebook for a short end-to-end attribution smoke test." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q uv\n", + "!uv pip install --system fastapi huggingface_hub matplotlib numpy pydantic scipy torch transformers typer uvicorn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "if not Path('app').exists():\n", + " raise RuntimeError('Upload or clone the repository before running the notebook.')\n", + "\n", + "os.environ.setdefault('MODEL_NAME', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')\n", + "os.environ.setdefault('DEVICE_PREFERENCE', 'auto')\n", + "os.environ.setdefault('DTYPE_PREFERENCE', 'auto')\n", + "os.environ.setdefault('ATTN_IMPLEMENTATION', 'eager')\n", + "os.environ.setdefault('LOW_CPU_MEM_USAGE', 'true')\n", + "os.environ.setdefault('TRUST_REMOTE_CODE', 'true')\n", + "os.environ.setdefault('MAX_TRACE_TOKENS', '256')\n", + "os.environ.setdefault('MAX_SENTENCES', '16')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from app.core.runtime_pipeline import compute_attribution_analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "question = 'Explain why the derivative of x^2 is 2x.'\n", + "\n", + "result = compute_attribution_analysis(\n", + " question=question,\n", + " max_new_tokens=128,\n", + " max_trace_tokens=256,\n", + " max_sentences=16,\n", + " validate_top_k=0,\n", + " temperature=0.0,\n", + " top_p=1.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Answer:\\n', result.answer)\n", + "print('\\nSentences:', len(result.sentences))\n", + "print('Runtime:', result.runtime_metadata.model_dump())\n", + "print('\\nTop edges:')\n", + "for edge in result.top_edges[:10]:\n", + " print(edge)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "matrix = np.array(result.suppression_matrix)\n", + "figure, axis = plt.subplots(figsize=(7, 5))\n", + "image = axis.imshow(matrix, aspect='auto', cmap='viridis')\n", + "axis.set_title('Sentence Influence Matrix')\n", + "axis.set_xlabel('Source sentence')\n", + "axis.set_ylabel('Target sentence')\n", + "figure.colorbar(image, ax=axis)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..86da7f3af083f7169c8bb511a1a420f487f1d937 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "cot-anc" +version = "0.1.0" +description = "Online chain-of-thought analysis with attribution patching" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115.0", + "huggingface_hub[oauth]>=0.33.0", + "numpy>=2.0.0", + "pydantic>=2.7.0", + "scipy>=1.13.0", + "torch>=2.2.0", + "transformers>=4.44.0", + "typer>=0.12.3", + "uvicorn>=0.30.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.2.0", +] +viz = [ + "matplotlib>=3.8.0", +] + +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["app*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] diff --git a/runpod_start.sh b/runpod_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ee8e47a5ac27f90fbfcb1cd24f577f2a118c8a --- /dev/null +++ b/runpod_start.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +export API_HOST="${API_HOST:-0.0.0.0}" +export API_PORT="${API_PORT:-7860}" +export DEVICE_PREFERENCE="${DEVICE_PREFERENCE:-auto}" +export DTYPE_PREFERENCE="${DTYPE_PREFERENCE:-auto}" +export ATTN_IMPLEMENTATION="${ATTN_IMPLEMENTATION:-eager}" +export PRELOAD_MODEL="${PRELOAD_MODEL:-true}" + +exec uv run python -m app.cli.run_api diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f73d112e061885503c9912f69061645e398801 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import importlib.util +import re +from types import SimpleNamespace + +import pytest + +HAS_TORCH = importlib.util.find_spec("torch") is not None +if HAS_TORCH: + import torch +else: # pragma: no cover - environment dependent + torch = None + + +class MockTokenizer: + def __call__( + self, + text: str, + *, + add_special_tokens: bool = False, + return_offsets_mapping: bool = False, + return_tensors: str | None = None, + ): + if torch is None: # pragma: no cover - environment dependent + raise RuntimeError("MockTokenizer requires torch.") + matches = list(re.finditer(r"\S+", text)) + input_ids = [index + 1 for index, _match in enumerate(matches)] + result = { + "input_ids": torch.tensor([input_ids], dtype=torch.long) + if return_tensors == "pt" + else [input_ids] + } + if return_offsets_mapping: + offsets = [[match.start(), match.end()] for match in matches] + result["offset_mapping"] = ( + torch.tensor([offsets], dtype=torch.long) + if return_tensors == "pt" + else offsets + ) + return result + + +if torch is not None: + class FakeSelfAttention(torch.nn.Module): + def __init__(self, hidden_size: int, heads: int) -> None: + super().__init__() + self.heads = heads + self.scale = hidden_size ** -0.5 + + def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs): + scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale + scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1) + causal_mask = torch.triu( + torch.ones( + scores.shape[-2:], + dtype=torch.bool, + device=hidden_states.device, + ), + diagonal=1, + ) + scores = scores.masked_fill( + causal_mask.unsqueeze(0).unsqueeze(0), + torch.finfo(scores.dtype).min, + ) + if attention_mask is not None and attention_mask.dim() == 4: + scores = scores + attention_mask + attention = torch.softmax(scores, dim=-1) + context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1) + return hidden_states + context, attention + + + class FakeLayer(torch.nn.Module): + def __init__(self, hidden_size: int, heads: int) -> None: + super().__init__() + self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads) + + + class FakeCausalLM(torch.nn.Module): + def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="eager") + self.embed = torch.nn.Embedding(vocab_size, hidden_size) + self.model = SimpleNamespace( + layers=torch.nn.ModuleList( + [FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)] + ) + ) + self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) + + def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs): + hidden_states = self.embed(input_ids) + attentions = [] + for layer in self.model.layers: + hidden_states, attention = layer.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + attentions.append(attention) + logits = self.lm_head(hidden_states) + if return_dict: + return SimpleNamespace(logits=logits, attentions=tuple(attentions)) + return logits, tuple(attentions) + + +@pytest.fixture() +def mock_tokenizer() -> MockTokenizer: + if torch is None: # pragma: no cover - environment dependent + pytest.skip("torch not installed") + return MockTokenizer() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..44af1161c76d6d393f2a4c876ec3741f4c11aa48 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from fastapi.testclient import TestClient +import pytest + +from app.api.main import app +from app.core.schemas import ( + AnalysisResult, + CurrentUserResponse, + ModelCapability, + RuntimeMetadata, + SessionResponse, + ValidationMetadata, +) + + +@pytest.fixture(autouse=True) +def auth_override(monkeypatch): + monkeypatch.setattr( + "app.api.main.require_user", + lambda _request, _settings: type( + "User", + (), + {"id": "user-123", "display_name": "Test User", "authenticated": True}, + )(), + ) + monkeypatch.setattr( + "app.api.main.get_optional_user", + lambda _request: type( + "User", + (), + { + "id": "user-123", + "username": "tester", + "display_name": "Test User", + "avatar_url": "https://example.com/avatar.png", + "authenticated": True, + }, + )(), + ) + + +def test_healthz_returns_runtime_flags() -> None: + with TestClient(app) as client: + response = client.get("/healthz") + + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ok" + assert "cuda_available" in payload + assert "dtype_preference" in payload + + +def test_analyze_delegates_to_runtime(monkeypatch) -> None: + def fake_compute_attribution_analysis(**_kwargs): + return AnalysisResult( + question="Why?", + model_name="fake-model", + answer="Because.", + raw_trace_text="Alpha.", + normalized_trace_text="Alpha.", + sentences=["Alpha."], + sentence_token_ranges=[(0, 1)], + suppression_matrix=[[0.0]], + raw_suppression_matrix=[[0.0]], + outgoing_importance=[0.0], + incoming_importance=[0.0], + top_edges=[], + runtime_metadata=RuntimeMetadata( + device="cpu", + capability=ModelCapability(supports_attribution=True, layer_count=2, attention_impl="eager"), + ), + validation_metadata=ValidationMetadata(enabled=False, top_k=0), + ) + + monkeypatch.setattr("app.api.main.compute_attribution_analysis", fake_compute_attribution_analysis) + + with TestClient(app) as client: + response = client.post( + "/api/analyze", + json={ + "question": "Why?", + "max_new_tokens": 8, + "validate_top_k": 0, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["answer"] == "Because." + assert payload["model_name"] == "fake-model" + + +def test_me_reports_current_user() -> None: + with TestClient(app) as client: + response = client.get("/api/me") + + assert response.status_code == 200 + payload = CurrentUserResponse.model_validate(response.json()) + assert payload.authenticated is True + assert payload.username == "tester" + + +def test_root_serves_frontend() -> None: + with TestClient(app) as client: + response = client.get("/") + + assert response.status_code == 200 + assert "Thought Anchors" in response.text + + +def test_session_routes_use_service(monkeypatch) -> None: + class FakeSessionService: + def __init__(self) -> None: + self.payload = { + "id": "session-123", + "status": "completed", + "question": "Why?", + "model_name": "fake-model", + "error": None, + "created_at": "2026-04-06T00:00:00+00:00", + "updated_at": "2026-04-06T00:00:05+00:00", + "answer": "Because.", + "raw_trace_text": "Alpha.", + "normalized_trace_text": "Alpha.", + "sentences": ["Alpha."], + "generation_metadata": {"max_new_tokens": 8}, + "analysis": AnalysisResult( + question="Why?", + model_name="fake-model", + answer="Because.", + raw_trace_text="Alpha.", + normalized_trace_text="Alpha.", + sentences=["Alpha."], + sentence_token_ranges=[(0, 1)], + suppression_matrix=[[0.0]], + raw_suppression_matrix=[[0.0]], + outgoing_importance=[0.0], + incoming_importance=[0.0], + top_edges=[], + runtime_metadata=RuntimeMetadata( + device="cpu", + capability=ModelCapability( + supports_attribution=True, + layer_count=2, + attention_impl="eager", + ), + ), + validation_metadata=ValidationMetadata(enabled=False, top_k=0), + ).model_dump(), + } + + def create_session(self, _request, **_kwargs): + return SessionResponse.model_validate(self.payload) + + def get_session_payload(self, _session_id: str, **_kwargs): + return self.payload + + def start_analysis(self, _session_id: str, **_kwargs): + return SessionResponse.model_validate(self.payload) + + def list_sessions(self, _owner_id: str, **_kwargs): + return [self.payload] + + def get_analysis_result(self, _session_id: str, **_kwargs): + return AnalysisResult.model_validate(self.payload["analysis"]) + + monkeypatch.setattr("app.api.main.get_session_service", lambda: FakeSessionService()) + + with TestClient(app) as client: + listing = client.get("/api/sessions") + created = client.post("/api/sessions", json={"question": "Why?"}) + session = client.get("/api/sessions/session-123") + result = client.get("/api/sessions/session-123/result") + exported_json = client.get("/api/sessions/session-123/export.json") + exported_csv = client.get("/api/sessions/session-123/export.csv") + + assert listing.status_code == 200 + assert created.status_code == 200 + assert session.status_code == 200 + assert result.status_code == 200 + assert exported_json.status_code == 200 + assert exported_csv.status_code == 200 + assert created.json()["id"] == "session-123" + assert session.json()["answer"] == "Because." + assert result.json()["analysis"]["model_name"] == "fake-model" diff --git a/tests/test_model_support.py b/tests/test_model_support.py new file mode 100644 index 0000000000000000000000000000000000000000..b08e85992d90a7f5eeba3457b5942e75c50411ce --- /dev/null +++ b/tests/test_model_support.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytest.importorskip("torch") + +from app.core.model_support import describe_model_support + + +def test_model_support_accepts_transformer_h_layout() -> None: + layer = SimpleNamespace(attn=object()) + model = SimpleNamespace( + transformer=SimpleNamespace(h=[layer, layer]), + config=SimpleNamespace(_attn_implementation="eager"), + ) + + support = describe_model_support(model) + + assert support.supports_attribution is True + assert support.layer_path == "transformer.h" + assert support.attention_attr == "attn" + + +def test_model_support_rejects_non_eager_attention() -> None: + layer = SimpleNamespace(self_attn=object()) + model = SimpleNamespace( + model=SimpleNamespace(layers=[layer]), + config=SimpleNamespace(_attn_implementation="sdpa"), + ) + + support = describe_model_support(model) + + assert support.supports_attribution is False + assert "eager" in (support.reason or "") diff --git a/tests/test_runtime_pipeline.py b/tests/test_runtime_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c492ee6c9c246012af7af7b3f37f8c2f02c5d689 --- /dev/null +++ b/tests/test_runtime_pipeline.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytest.importorskip("torch") + +from app.core.schemas import GenerationMetadata, GenerationResult +from app.core.runtime_pipeline import compute_attribution_analysis + + +def test_runtime_pipeline_with_mocked_generation(monkeypatch, mock_tokenizer) -> None: + from tests.conftest import FakeCausalLM + + fake_model = FakeCausalLM() + fake_bundle = SimpleNamespace( + model=fake_model, + tokenizer=mock_tokenizer, + device=next(fake_model.parameters()).device, + dtype=next(fake_model.parameters()).dtype, + capability=SimpleNamespace( + supports_attribution=True, + reason=None, + layer_path="model.layers", + attention_attr="self_attn", + layer_count=len(fake_model.model.layers), + attention_impl="eager", + ), + ) + + def fake_load_model_bundle(_model_name: str, **_kwargs): + return fake_bundle + + def fake_generate_answer_and_trace(**_kwargs): + return GenerationResult( + question="Why?", + model_name="fake-model", + answer="Because.", + raw_generation_text="Alpha beta. Gamma delta. Epsilon zeta. Because.", + raw_trace_text="Alpha beta. Gamma delta. Epsilon zeta.", + normalized_trace_text="Alpha beta. Gamma delta. Epsilon zeta.", + generation_metadata=GenerationMetadata( + max_new_tokens=32, + temperature=0.0, + top_p=1.0, + do_sample=False, + ), + ) + + monkeypatch.setattr("app.core.runtime_pipeline.load_model_bundle", fake_load_model_bundle) + monkeypatch.setattr("app.core.runtime_pipeline.generate_answer_and_trace", fake_generate_answer_and_trace) + + result = compute_attribution_analysis( + question="Why?", + model_name="fake-model", + validate_top_k=0, + max_trace_tokens=32, + max_sentences=5, + take_log=False, + ) + + assert result.answer == "Because." + assert len(result.sentences) == 3 + assert len(result.suppression_matrix) == 3 + assert result.validation_metadata is not None + assert result.validation_metadata.enabled is False diff --git a/tests/test_sentence_split.py b/tests/test_sentence_split.py new file mode 100644 index 0000000000000000000000000000000000000000..f23e563bcff4c3b326b2a9344ade5c92ead76b2d --- /dev/null +++ b/tests/test_sentence_split.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from app.analysis.sentence_split import normalize_trace_text, split_sentences + + +def test_sentence_split_preserves_spacing_and_newlines() -> None: + raw = "First sentence. Second sentence?\n\nThird sentence!" + normalized = normalize_trace_text(raw) + spans = split_sentences(normalized) + + assert normalized == "First sentence. Second sentence?\n\nThird sentence!" + assert [span.text for span in spans] == [ + "First sentence. ", + "Second sentence?\n\n", + "Third sentence!", + ] + assert spans[1].start_char == len("First sentence. ") + diff --git a/tests/test_summaries.py b/tests/test_summaries.py new file mode 100644 index 0000000000000000000000000000000000000000..cf76bb61cb746138403e22111c885ee4b2049d9d --- /dev/null +++ b/tests/test_summaries.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import numpy as np + +from app.analysis.summaries import ( + compute_incoming_importance, + compute_outgoing_importance, + compute_top_edges, +) + + +def test_summary_metrics_follow_upper_triangle() -> None: + matrix = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + + outgoing = compute_outgoing_importance(matrix) + incoming = compute_incoming_importance(matrix) + top_edges = compute_top_edges(matrix, top_k=2) + + assert outgoing == [1.5, 3.0, 0.0] + assert incoming == [0.0, 1.0, 2.5] + assert [(edge.source_sentence_idx, edge.target_sentence_idx) for edge in top_edges] == [ + (1, 2), + (0, 2), + ] + diff --git a/tests/test_suppression_shape.py b/tests/test_suppression_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..34b829bb053fd1e98872d20fb1821dc66df6d03b --- /dev/null +++ b/tests/test_suppression_shape.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from app.analysis.hooks import get_stored_attentions +from app.analysis.suppression import compute_attribution_matrix + + +def test_attribution_matrix_shape_and_cleanup() -> None: + from tests.conftest import FakeCausalLM + + model = FakeCausalLM() + input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + token_ranges = [(0, 2), (2, 4), (4, 5)] + + result = compute_attribution_matrix( + input_ids=input_ids, + token_ranges=token_ranges, + model=model, + take_log=False, + ) + + assert result.matrix.shape == (3, 3) + assert np.allclose(np.diag(result.raw_matrix), 0.0) + assert np.allclose(np.triu(result.raw_matrix), 0.0) + assert np.isfinite(result.matrix).all() + assert get_stored_attentions() == {} diff --git a/tests/test_token_boundaries.py b/tests/test_token_boundaries.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d4a4f227a084c4c3ebf5cb04bc6dc40d3e1427 --- /dev/null +++ b/tests/test_token_boundaries.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import pytest + +pytest.importorskip("torch") + +from app.analysis.sentence_split import split_sentences +from app.analysis.token_boundaries import tokenize_with_sentence_ranges + + +def test_token_boundaries_cover_full_sequence(mock_tokenizer) -> None: + text = "Alpha beta. Gamma delta." + spans = split_sentences(text) + mapping = tokenize_with_sentence_ranges(text, spans, mock_tokenizer) + + assert mapping.token_ranges == [(0, 2), (2, 4)] + assert mapping.input_ids.shape == (1, 4)