File size: 4,122 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Policy interface and stub policies.

A *policy* is anything that, given a list of messages, returns a single
completion string. The rollout doesn't care whether that string came from
a 0.5B Qwen sample, a hand-written script, or random noise — only that
the contract is honored.

This file provides:

  * :class:`Policy` — a runtime-checkable Protocol.
  * :class:`ScriptedPolicy` — yields a fixed list of completions in order.
    Useful for tests and for building oracle trajectories during rejection-
    sampling SFT (PROPOSAL.md §7.4 plan B).
  * :class:`HfPolicy` — wraps an HF causal LM + tokenizer; the real thing.
    Defined here so consumers can swap it in once we hook up Qwen, but
    deliberately not imported at module-load time.
"""

from __future__ import annotations

from typing import Iterator, Protocol, runtime_checkable

from graphforge.training.prompt import Message


@runtime_checkable
class Policy(Protocol):
    def sample(self, messages: list[Message]) -> str: ...


# ---- scripted -------------------------------------------------------


class ScriptedPolicy:
    """Returns each item of ``completions`` in order.

    If the rollout asks for more turns than there are scripted completions,
    raises :class:`StopIteration` — that's a test bug, not an env bug.
    """

    def __init__(self, completions: list[str]) -> None:
        self._iter: Iterator[str] = iter(completions)
        self._n = len(completions)

    def sample(self, _messages: list[Message]) -> str:
        return next(self._iter)


# ---- HF (lazy) ------------------------------------------------------


class HfPolicy:
    """A real LM-backed policy. Imports torch / transformers lazily.

    Constructor args::

        model           — a HF AutoModelForCausalLM
        tokenizer       — the matching tokenizer
        max_new_tokens  — generation cap per turn (PROPOSAL.md §7.1: 384)
        temperature, top_p — sampling knobs
    """

    def __init__(
        self,
        model: object,
        tokenizer: object,
        *,
        max_new_tokens: int = 384,
        temperature: float = 0.7,
        top_p: float = 0.95,
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p

    def sample(self, messages: list[Message]) -> str:
        # Defer heavy imports.
        import torch  # noqa: F401  — required for inputs / device

        # Critical for trained-eval correctness: ensure the model is in
        # eval mode (no dropout) and that KV-cache is enabled (post-SFT,
        # gradient checkpointing may have set use_cache=False).
        self.model.eval()  # type: ignore[attr-defined]
        if hasattr(self.model, "config"):
            self.model.config.use_cache = True  # type: ignore[attr-defined]

        tok = self.tokenizer
        # Render to text first, then tokenize. ``apply_chat_template`` 's
        # return type drifted across transformers versions (sometimes a raw
        # tensor, sometimes a BatchEncoding); going through ``tok(text)`` is
        # the canonical pattern and works on all of them.
        text = tok.apply_chat_template(  # type: ignore[attr-defined]
            messages, add_generation_prompt=True, tokenize=False
        )
        inputs = tok(text, return_tensors="pt")  # type: ignore[operator]
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}  # type: ignore[attr-defined]

        with torch.no_grad():
            out_ids = self.model.generate(  # type: ignore[attr-defined]
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                temperature=self.temperature,
                top_p=self.top_p,
                pad_token_id=tok.eos_token_id,  # type: ignore[attr-defined]
                use_cache=True,
            )
        prompt_len = inputs["input_ids"].shape[-1]
        gen = out_ids[0, prompt_len:]
        return tok.decode(gen, skip_special_tokens=True)  # type: ignore[attr-defined]