"""Background-generation safety: a player interrupt aborts an in-flight generation call between tokens and frees the single-flight lock immediately. """ from __future__ import annotations import threading import pytest from case_zero.api.runtime import _SharedLockBackend from case_zero.llm.backend import GenParams, LLMError class _SlowStreamBackend: """Yields tokens one by one; records how many were consumed.""" def __init__(self, tokens: list[str]) -> None: self.tokens = tokens self.consumed = 0 def generate(self, prompt: str, params: GenParams) -> str: return "".join(self.tokens) def stream(self, prompt: str, params: GenParams): for t in self.tokens: self.consumed += 1 yield t def test_uninterrupted_generate_joins_stream() -> None: inner = _SlowStreamBackend(["a", "b", "c"]) lock = threading.Lock() wrapped = _SharedLockBackend(inner, lock, threading.Event()) assert wrapped.generate("p", GenParams()) == "abc" assert not lock.locked(), "lock must be released after the call" def test_interrupt_aborts_between_tokens_and_frees_lock() -> None: inner = _SlowStreamBackend(["a", "b", "c", "d", "e"]) lock = threading.Lock() interrupt = threading.Event() class _TripWire(_SlowStreamBackend): def stream(self, prompt: str, params: GenParams): for i, t in enumerate(self.tokens): if i == 2: interrupt.set() # the player shows up mid-stream self.consumed += 1 yield t tripwire = _TripWire(["a", "b", "c", "d", "e"]) wrapped = _SharedLockBackend(tripwire, lock, interrupt) with pytest.raises(LLMError, match="interrupted"): wrapped.generate("p", GenParams()) assert tripwire.consumed <= 3, "must abort within ~a token of the interrupt" assert not lock.locked(), "lock must be freed for the player immediately" assert inner.consumed == 0 def test_no_interrupt_event_uses_plain_generate() -> None: inner = _SlowStreamBackend(["x", "y"]) wrapped = _SharedLockBackend(inner, threading.Lock(), None) assert wrapped.generate("p", GenParams()) == "xy" assert inner.consumed == 0, "plain path must not stream"