File size: 2,283 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import annotations

from types import SimpleNamespace

import pytest

pytest.importorskip("torch")

from app.core.schemas import GenerationMetadata, GenerationResult
from app.core.runtime_pipeline import compute_attribution_analysis


def test_runtime_pipeline_with_mocked_generation(monkeypatch, mock_tokenizer) -> None:
    from tests.conftest import FakeCausalLM

    fake_model = FakeCausalLM()
    fake_bundle = SimpleNamespace(
        model=fake_model,
        tokenizer=mock_tokenizer,
        device=next(fake_model.parameters()).device,
        dtype=next(fake_model.parameters()).dtype,
        capability=SimpleNamespace(
            supports_attribution=True,
            reason=None,
            layer_path="model.layers",
            attention_attr="self_attn",
            layer_count=len(fake_model.model.layers),
            attention_impl="eager",
        ),
    )

    def fake_load_model_bundle(_model_name: str, **_kwargs):
        return fake_bundle

    def fake_generate_answer_and_trace(**_kwargs):
        return GenerationResult(
            question="Why?",
            model_name="fake-model",
            answer="Because.",
            raw_generation_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think> Because.",
            raw_trace_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think>",
            normalized_trace_text="Alpha beta. Gamma delta. Epsilon zeta.",
            generation_metadata=GenerationMetadata(
                max_new_tokens=32,
                temperature=0.0,
                top_p=1.0,
                do_sample=False,
            ),
        )

    monkeypatch.setattr("app.core.runtime_pipeline.load_model_bundle", fake_load_model_bundle)
    monkeypatch.setattr("app.core.runtime_pipeline.generate_answer_and_trace", fake_generate_answer_and_trace)

    result = compute_attribution_analysis(
        question="Why?",
        model_name="fake-model",
        validate_top_k=0,
        max_trace_tokens=32,
        max_sentences=5,
        take_log=False,
    )

    assert result.answer == "Because."
    assert len(result.sentences) == 3
    assert len(result.suppression_matrix) == 3
    assert result.validation_metadata is not None
    assert result.validation_metadata.enabled is False