Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from types import SimpleNamespace | |
| import pytest | |
| pytest.importorskip("torch") | |
| from app.core.schemas import GenerationMetadata, GenerationResult | |
| from app.core.runtime_pipeline import compute_attribution_analysis | |
| def test_runtime_pipeline_with_mocked_generation(monkeypatch, mock_tokenizer) -> None: | |
| from tests.conftest import FakeCausalLM | |
| fake_model = FakeCausalLM() | |
| fake_bundle = SimpleNamespace( | |
| model=fake_model, | |
| tokenizer=mock_tokenizer, | |
| device=next(fake_model.parameters()).device, | |
| dtype=next(fake_model.parameters()).dtype, | |
| capability=SimpleNamespace( | |
| supports_attribution=True, | |
| reason=None, | |
| layer_path="model.layers", | |
| attention_attr="self_attn", | |
| layer_count=len(fake_model.model.layers), | |
| attention_impl="eager", | |
| ), | |
| ) | |
| def fake_load_model_bundle(_model_name: str, **_kwargs): | |
| return fake_bundle | |
| def fake_generate_answer_and_trace(**_kwargs): | |
| return GenerationResult( | |
| question="Why?", | |
| model_name="fake-model", | |
| answer="Because.", | |
| raw_generation_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think> Because.", | |
| raw_trace_text="<think>Alpha beta. Gamma delta. Epsilon zeta.</think>", | |
| normalized_trace_text="Alpha beta. Gamma delta. Epsilon zeta.", | |
| generation_metadata=GenerationMetadata( | |
| max_new_tokens=32, | |
| temperature=0.0, | |
| top_p=1.0, | |
| do_sample=False, | |
| ), | |
| ) | |
| monkeypatch.setattr("app.core.runtime_pipeline.load_model_bundle", fake_load_model_bundle) | |
| monkeypatch.setattr("app.core.runtime_pipeline.generate_answer_and_trace", fake_generate_answer_and_trace) | |
| result = compute_attribution_analysis( | |
| question="Why?", | |
| model_name="fake-model", | |
| validate_top_k=0, | |
| max_trace_tokens=32, | |
| max_sentences=5, | |
| take_log=False, | |
| ) | |
| assert result.answer == "Because." | |
| assert len(result.sentences) == 3 | |
| assert len(result.suppression_matrix) == 3 | |
| assert result.validation_metadata is not None | |
| assert result.validation_metadata.enabled is False | |