cot-anc / app /core /runtime_pipeline.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
from typing import Any
import torch
from app.analysis.sentence_split import split_sentences
from app.analysis.summaries import (
compute_incoming_importance,
compute_outgoing_importance,
compute_top_edges,
)
from app.analysis.suppression import compute_attribution_matrix
from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit
from app.analysis.validation import validate_top_edges
from app.core.config import get_settings
from app.core.runtime import load_model_bundle
from app.core.schemas import AnalysisResult, GenerationResult
from app.generation.service import generate_answer_and_trace
def compute_attribution_analysis(
*,
question: str,
model_name: str | None = None,
take_log: bool | None = None,
max_sentences: int | None = None,
max_trace_tokens: int | None = None,
validate_top_k: int = 0,
max_new_tokens: int = 512,
temperature: float = 0.6,
top_p: float = 0.95,
device_preference: str | None = None,
dtype_preference: str | None = None,
attn_implementation: str | None = None,
trust_remote_code: bool | None = None,
low_cpu_mem_usage: bool | None = None,
) -> AnalysisResult:
generation = None
return analyze_generation_result(
question=question,
generation=generation,
model_name=model_name,
take_log=take_log,
max_sentences=max_sentences,
max_trace_tokens=max_trace_tokens,
validate_top_k=validate_top_k,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
device_preference=device_preference,
dtype_preference=dtype_preference,
attn_implementation=attn_implementation,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def analyze_generation_result(
*,
question: str,
generation: GenerationResult | None = None,
model_name: str | None = None,
take_log: bool | None = None,
max_sentences: int | None = None,
max_trace_tokens: int | None = None,
validate_top_k: int = 0,
max_new_tokens: int = 512,
temperature: float = 0.6,
top_p: float = 0.95,
device_preference: str | None = None,
dtype_preference: str | None = None,
attn_implementation: str | None = None,
trust_remote_code: bool | None = None,
low_cpu_mem_usage: bool | None = None,
) -> AnalysisResult:
settings = get_settings()
resolved_model_name = model_name or settings.model_name
resolved_take_log = settings.take_log if take_log is None else take_log
resolved_max_sentences = max_sentences or settings.max_sentences
resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens
resolved_device = device_preference or settings.device_preference
resolved_dtype = dtype_preference or settings.dtype_preference
resolved_attn_implementation = attn_implementation or settings.attn_implementation
resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code
resolved_low_cpu_mem_usage = (
settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage
)
bundle = load_model_bundle(
resolved_model_name,
device_preference=resolved_device,
dtype_preference=resolved_dtype,
attn_implementation=resolved_attn_implementation,
trust_remote_code=resolved_trust_remote_code,
low_cpu_mem_usage=resolved_low_cpu_mem_usage,
)
if not bundle.capability.supports_attribution:
reason = bundle.capability.reason or "Model does not support attribution analysis."
raise RuntimeError(reason)
if generation is None:
generation = generate_answer_and_trace(
question=question,
model_name=resolved_model_name,
model=bundle.model,
tokenizer=bundle.tokenizer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
truncated_text = truncate_text_to_token_limit(
generation.normalized_trace_text,
bundle.tokenizer,
resolved_max_trace_tokens,
)
sentence_spans = split_sentences(truncated_text)
if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences:
sentence_spans = sentence_spans[:resolved_max_sentences]
truncated_text = truncated_text[: sentence_spans[-1].end_char]
sentence_spans = split_sentences(truncated_text)
if not sentence_spans:
raise RuntimeError("Trace normalization produced no analyzable sentences.")
mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer)
input_ids = mapping.input_ids.to(bundle.device)
computation = compute_attribution_matrix(
input_ids=input_ids,
token_ranges=mapping.token_ranges,
model=bundle.model,
take_log=resolved_take_log,
max_trace_tokens=resolved_max_trace_tokens,
max_sentences=resolved_max_sentences,
)
outgoing = compute_outgoing_importance(computation.raw_matrix)
incoming = compute_incoming_importance(computation.raw_matrix)
top_edges = compute_top_edges(computation.raw_matrix, top_k=10)
validation = validate_top_edges(
model=bundle.model,
input_ids=input_ids,
token_ranges=mapping.token_ranges,
top_edges=top_edges,
baseline_token_nll=computation.token_nll,
top_k=validate_top_k,
)
return AnalysisResult(
question=question,
model_name=resolved_model_name,
answer=generation.answer,
raw_trace_text=generation.raw_trace_text,
normalized_trace_text=truncated_text,
sentences=[span.text for span in sentence_spans],
sentence_token_ranges=mapping.token_ranges,
suppression_matrix=computation.matrix.tolist(),
raw_suppression_matrix=computation.raw_matrix.tolist(),
outgoing_importance=outgoing,
incoming_importance=incoming,
top_edges=top_edges,
runtime_metadata=computation.runtime_metadata,
validation_metadata=validation,
extra_metadata={
"raw_generation_text": generation.raw_generation_text,
"generation_metadata": generation.generation_metadata.model_dump(),
"effective_runtime": {
"device_preference": resolved_device,
"dtype_preference": resolved_dtype,
"attn_implementation": resolved_attn_implementation,
"trust_remote_code": resolved_trust_remote_code,
"low_cpu_mem_usage": resolved_low_cpu_mem_usage,
},
},
)