Spaces:
Running
Running
| """Background-generation safety: a player interrupt aborts an in-flight generation call | |
| between tokens and frees the single-flight lock immediately. | |
| """ | |
| from __future__ import annotations | |
| import threading | |
| import pytest | |
| from case_zero.api.runtime import _SharedLockBackend | |
| from case_zero.llm.backend import GenParams, LLMError | |
| class _SlowStreamBackend: | |
| """Yields tokens one by one; records how many were consumed.""" | |
| def __init__(self, tokens: list[str]) -> None: | |
| self.tokens = tokens | |
| self.consumed = 0 | |
| def generate(self, prompt: str, params: GenParams) -> str: | |
| return "".join(self.tokens) | |
| def stream(self, prompt: str, params: GenParams): | |
| for t in self.tokens: | |
| self.consumed += 1 | |
| yield t | |
| def test_uninterrupted_generate_joins_stream() -> None: | |
| inner = _SlowStreamBackend(["a", "b", "c"]) | |
| lock = threading.Lock() | |
| wrapped = _SharedLockBackend(inner, lock, threading.Event()) | |
| assert wrapped.generate("p", GenParams()) == "abc" | |
| assert not lock.locked(), "lock must be released after the call" | |
| def test_interrupt_aborts_between_tokens_and_frees_lock() -> None: | |
| inner = _SlowStreamBackend(["a", "b", "c", "d", "e"]) | |
| lock = threading.Lock() | |
| interrupt = threading.Event() | |
| class _TripWire(_SlowStreamBackend): | |
| def stream(self, prompt: str, params: GenParams): | |
| for i, t in enumerate(self.tokens): | |
| if i == 2: | |
| interrupt.set() # the player shows up mid-stream | |
| self.consumed += 1 | |
| yield t | |
| tripwire = _TripWire(["a", "b", "c", "d", "e"]) | |
| wrapped = _SharedLockBackend(tripwire, lock, interrupt) | |
| with pytest.raises(LLMError, match="interrupted"): | |
| wrapped.generate("p", GenParams()) | |
| assert tripwire.consumed <= 3, "must abort within ~a token of the interrupt" | |
| assert not lock.locked(), "lock must be freed for the player immediately" | |
| assert inner.consumed == 0 | |
| def test_no_interrupt_event_uses_plain_generate() -> None: | |
| inner = _SlowStreamBackend(["x", "y"]) | |
| wrapped = _SharedLockBackend(inner, threading.Lock(), None) | |
| assert wrapped.generate("p", GenParams()) == "xy" | |
| assert inner.consumed == 0, "plain path must not stream" | |