| from flashtrace import FlashTrace, TraceResult |
| from tests.helpers import make_tiny_qwen2_model_and_tokenizer |
|
|
|
|
| def test_flashtrace_trace_returns_public_result(): |
| model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2) |
| tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True) |
|
|
| result = tracer.trace( |
| prompt="t10 t20 t30 t40", |
| target="t60 t70 t80", |
| output_span=(1, 2), |
| reasoning_span=(0, 1), |
| hops=1, |
| ) |
|
|
| assert isinstance(result, TraceResult) |
| assert result.method == "flashtrace" |
| assert len(result.prompt_tokens) > 0 |
| assert len(result.scores) == len(result.prompt_tokens) |
| assert result.output_span == (1, 2) |
| assert result.reasoning_span == (0, 1) |
|
|
|
|
| def test_ifr_span_method_returns_public_result(): |
| model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2) |
| tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True) |
|
|
| result = tracer.trace( |
| prompt="t10 t20 t30 t40", |
| target="t60 t70", |
| output_span=(0, 1), |
| method="ifr-span", |
| ) |
|
|
| assert result.method == "ifr-span" |
| assert len(result.scores) == len(result.prompt_tokens) |
|
|
|
|
| def test_flashtrace_default_raw_prompt_does_not_call_chat_template(): |
| model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2) |
|
|
| def fail_apply_chat_template(*args, **kwargs): |
| raise AssertionError("apply_chat_template should be opt-in") |
|
|
| tokenizer.apply_chat_template = fail_apply_chat_template |
| tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True) |
|
|
| result = tracer.trace( |
| prompt="t3 t4 t5", |
| target="t6 t7", |
| output_span=(0, 1), |
| method="ifr-span", |
| ) |
|
|
| assert result.method == "ifr-span" |
| assert result.prompt_tokens == ["t3", "t4", "t5"] |
|
|
|
|
| def test_flashtrace_target_without_eos_token(): |
| model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2) |
| tokenizer.eos_token = None |
| tokenizer.eos_token_id = None |
| tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True) |
|
|
| result = tracer.trace( |
| prompt="t10 t20 t30 t40", |
| target="t60 t70", |
| output_span=(0, 1), |
| method="ifr-span", |
| ) |
|
|
| assert result.method == "ifr-span" |
| assert result.generation_tokens == ["t60", "t70"] |
|
|