File size: 4,113 Bytes
414dc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cd1f2
414dc55
 
80cd1f2
414dc55
 
 
80cd1f2
 
 
 
 
 
414dc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cd1f2
414dc55
 
80cd1f2
414dc55
 
 
 
 
 
 
80cd1f2
414dc55
80cd1f2
414dc55
 
 
 
 
 
 
 
 
 
 
80cd1f2
 
 
414dc55
80cd1f2
414dc55
 
 
 
80cd1f2
414dc55
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Case generation pipeline: two creative LLM calls + structural scaffold + solver gate.

On a solvability failure the pipeline bumps the seed and regenerates, up to a small
cap. Generation is decomposed so each model call stays small and reliable.
"""

from __future__ import annotations

from dataclasses import dataclass

from ..llm.backend import LLMBackend, LLMError
from ..llm.decoding import generate_model
from ..schemas.case import CaseFile, GenerationKnobs
from ..schemas.timeline import TimeWindow
from ..solver.checker import CheckReport, check
from .assemble import assemble_case
from .crime_profiles import kind_for_seed, profile_for
from .stages import MysteryOut, WorldCastOut, mystery_prompt, world_cast_prompt

# A fixed incident window keeps alibi reasoning simple; the model still invents the rest.
MURDER_WINDOW = TimeWindow(start_min=21 * 60, end_min=22 * 60)
TIME_OF_DEATH = TimeWindow(start_min=21 * 60 + 20, end_min=21 * 60 + 50)

# Output caps sized to healthy generations (~600-1100 / ~200-400 tokens). The caps only
# bite on runaway outputs, where on 2 vCPU an uncapped 4096-token decode wastes minutes
# before the retry can even start.
_WORLD_MAX_TOKENS = 1800
_MYSTERY_MAX_TOKENS = 900


@dataclass(frozen=True)
class GenerationResult:
    case: CaseFile
    report: CheckReport
    attempts: int


def _clamp(index: int, count: int) -> int:
    return max(0, min(index, count - 1))


def _ensure_female(world: WorldCastOut, seed: int) -> WorldCastOut:
    """Guarantee every case has at least one woman in the cast (variety), flipping one
    suspect deterministically if the model produced an all-male cast."""
    suspects = list(world.suspects)
    if any((s.gender or "").lower().startswith("f") for s in suspects):
        return world
    idx = seed % len(suspects)
    suspects[idx] = suspects[idx].model_copy(update={"gender": "female"})
    return world.model_copy(update={"suspects": suspects})


def generate_case(
    backend: LLMBackend,
    *,
    seed: int,
    knobs: GenerationKnobs | None = None,
    max_attempts: int = 2,
) -> GenerationResult:
    knobs = knobs or GenerationKnobs()
    profile = profile_for(knobs.crime_kind) if knobs.crime_kind else profile_for(kind_for_seed(seed))
    case: CaseFile | None = None
    report: CheckReport | None = None

    for attempt in range(max_attempts):
        attempt_seed = seed + attempt
        world = generate_model(
            backend,
            world_cast_prompt(profile, knobs.setting_hint, knobs.era_hint, knobs.tone_hint,
                              knobs.n_suspects, MURDER_WINDOW.start_min, MURDER_WINDOW.end_min),
            WorldCastOut, temperature=0.85, max_tokens=_WORLD_MAX_TOKENS, seed=attempt_seed,
        )
        world = _ensure_female(world, attempt_seed)
        n = len(world.suspects)
        n_loc = len(world.locations)
        culprit_idx = attempt_seed % n
        crime_idx = _clamp(world.found_at_index, n_loc)
        claimed_idx = (crime_idx + 1) % n_loc
        culprit = world.suspects[culprit_idx]

        mystery = generate_model(
            backend,
            mystery_prompt(profile, culprit.name, culprit.role, world.victim_name,
                           world.weapon_name, world.locations[crime_idx],
                           world.locations[claimed_idx],
                           MURDER_WINDOW.start_min, MURDER_WINDOW.end_min),
            MysteryOut, temperature=0.6, max_tokens=_MYSTERY_MAX_TOKENS, seed=attempt_seed,
        )

        case = assemble_case(
            case_id=f"gen-{seed:06d}", seed=attempt_seed, knobs=knobs, world=world, mystery=mystery,
            profile=profile, window=MURDER_WINDOW, tod=TIME_OF_DEATH,
            culprit_idx=culprit_idx, crime_idx=crime_idx, claimed_idx=claimed_idx,
        )
        report = check(case)
        if report.ok:
            return GenerationResult(case=case, report=report, attempts=attempt + 1)

    if case is None or report is None:
        raise LLMError("generate_case produced no case after all attempts")
    return GenerationResult(case=case, report=report, attempts=max_attempts)