cot-anc / tests /test_runtime_pipeline.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
from types import SimpleNamespace
import pytest
pytest.importorskip("torch")
from app.core.schemas import GenerationMetadata, GenerationResult
from app.core.runtime_pipeline import compute_attribution_analysis
def test_runtime_pipeline_with_mocked_generation(monkeypatch, mock_tokenizer) -> None:
from tests.conftest import FakeCausalLM
fake_model = FakeCausalLM()
fake_bundle = SimpleNamespace(
model=fake_model,
tokenizer=mock_tokenizer,
device=next(fake_model.parameters()).device,
dtype=next(fake_model.parameters()).dtype,
capability=SimpleNamespace(
supports_attribution=True,
reason=None,
layer_path="model.layers",
attention_attr="self_attn",
layer_count=len(fake_model.model.layers),
attention_impl="eager",
),
)
def fake_load_model_bundle(_model_name: str, **_kwargs):
return fake_bundle
def fake_generate_answer_and_trace(**_kwargs):
return GenerationResult(
question="Why?",
model_name="fake-model",
answer="Because.",
raw_generation_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think> Because.",
raw_trace_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think>",
normalized_trace_text="Alpha beta. Gamma delta. Epsilon zeta.",
generation_metadata=GenerationMetadata(
max_new_tokens=32,
temperature=0.0,
top_p=1.0,
do_sample=False,
),
)
monkeypatch.setattr("app.core.runtime_pipeline.load_model_bundle", fake_load_model_bundle)
monkeypatch.setattr("app.core.runtime_pipeline.generate_answer_and_trace", fake_generate_answer_and_trace)
result = compute_attribution_analysis(
question="Why?",
model_name="fake-model",
validate_top_k=0,
max_trace_tokens=32,
max_sentences=5,
take_log=False,
)
assert result.answer == "Because."
assert len(result.sentences) == 3
assert len(result.suppression_matrix) == 3
assert result.validation_metadata is not None
assert result.validation_metadata.enabled is False