flashtrace / tests /test_tracer.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
from flashtrace import FlashTrace, TraceResult
from tests.helpers import make_tiny_qwen2_model_and_tokenizer
def test_flashtrace_trace_returns_public_result():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70 t80",
output_span=(1, 2),
reasoning_span=(0, 1),
hops=1,
)
assert isinstance(result, TraceResult)
assert result.method == "flashtrace"
assert len(result.prompt_tokens) > 0
assert len(result.scores) == len(result.prompt_tokens)
assert result.output_span == (1, 2)
assert result.reasoning_span == (0, 1)
def test_ifr_span_method_returns_public_result():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert len(result.scores) == len(result.prompt_tokens)
def test_flashtrace_default_raw_prompt_does_not_call_chat_template():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
def fail_apply_chat_template(*args, **kwargs):
raise AssertionError("apply_chat_template should be opt-in")
tokenizer.apply_chat_template = fail_apply_chat_template
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t3 t4 t5",
target="t6 t7",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert result.prompt_tokens == ["t3", "t4", "t5"]
def test_flashtrace_target_without_eos_token():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tokenizer.eos_token = None
tokenizer.eos_token_id = None
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert result.generation_tokens == ["t60", "t70"]