case0 / tests /test_runtime_gen.py
HusseinEid's picture
feat: multi-crime cases, scene+exhibit pixel art, background AI generation
80cd1f2 verified
raw
history blame
2.27 kB
"""Background-generation safety: a player interrupt aborts an in-flight generation call
between tokens and frees the single-flight lock immediately.
"""
from __future__ import annotations
import threading
import pytest
from case_zero.api.runtime import _SharedLockBackend
from case_zero.llm.backend import GenParams, LLMError
class _SlowStreamBackend:
"""Yields tokens one by one; records how many were consumed."""
def __init__(self, tokens: list[str]) -> None:
self.tokens = tokens
self.consumed = 0
def generate(self, prompt: str, params: GenParams) -> str:
return "".join(self.tokens)
def stream(self, prompt: str, params: GenParams):
for t in self.tokens:
self.consumed += 1
yield t
def test_uninterrupted_generate_joins_stream() -> None:
inner = _SlowStreamBackend(["a", "b", "c"])
lock = threading.Lock()
wrapped = _SharedLockBackend(inner, lock, threading.Event())
assert wrapped.generate("p", GenParams()) == "abc"
assert not lock.locked(), "lock must be released after the call"
def test_interrupt_aborts_between_tokens_and_frees_lock() -> None:
inner = _SlowStreamBackend(["a", "b", "c", "d", "e"])
lock = threading.Lock()
interrupt = threading.Event()
class _TripWire(_SlowStreamBackend):
def stream(self, prompt: str, params: GenParams):
for i, t in enumerate(self.tokens):
if i == 2:
interrupt.set() # the player shows up mid-stream
self.consumed += 1
yield t
tripwire = _TripWire(["a", "b", "c", "d", "e"])
wrapped = _SharedLockBackend(tripwire, lock, interrupt)
with pytest.raises(LLMError, match="interrupted"):
wrapped.generate("p", GenParams())
assert tripwire.consumed <= 3, "must abort within ~a token of the interrupt"
assert not lock.locked(), "lock must be freed for the player immediately"
assert inner.consumed == 0
def test_no_interrupt_event_uses_plain_generate() -> None:
inner = _SlowStreamBackend(["x", "y"])
wrapped = _SharedLockBackend(inner, threading.Lock(), None)
assert wrapped.generate("p", GenParams()) == "xy"
assert inner.consumed == 0, "plain path must not stream"