File size: 2,606 Bytes
55b60a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | from flashtrace import FlashTrace, TraceResult
from tests.helpers import make_tiny_qwen2_model_and_tokenizer
def test_flashtrace_trace_returns_public_result():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70 t80",
output_span=(1, 2),
reasoning_span=(0, 1),
hops=1,
)
assert isinstance(result, TraceResult)
assert result.method == "flashtrace"
assert len(result.prompt_tokens) > 0
assert len(result.scores) == len(result.prompt_tokens)
assert result.output_span == (1, 2)
assert result.reasoning_span == (0, 1)
def test_ifr_span_method_returns_public_result():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert len(result.scores) == len(result.prompt_tokens)
def test_flashtrace_default_raw_prompt_does_not_call_chat_template():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
def fail_apply_chat_template(*args, **kwargs):
raise AssertionError("apply_chat_template should be opt-in")
tokenizer.apply_chat_template = fail_apply_chat_template
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t3 t4 t5",
target="t6 t7",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert result.prompt_tokens == ["t3", "t4", "t5"]
def test_flashtrace_target_without_eos_token():
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
tokenizer.eos_token = None
tokenizer.eos_token_id = None
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
result = tracer.trace(
prompt="t10 t20 t30 t40",
target="t60 t70",
output_span=(0, 1),
method="ifr-span",
)
assert result.method == "ifr-span"
assert result.generation_tokens == ["t60", "t70"]
|