NagaNithin-V
Deploy GraphForge OpenEnv β€” AST-parsed KG code-editing environment
7952f32
"""Policy interface and stub policies.
A *policy* is anything that, given a list of messages, returns a single
completion string. The rollout doesn't care whether that string came from
a 0.5B Qwen sample, a hand-written script, or random noise β€” only that
the contract is honored.
This file provides:
* :class:`Policy` β€” a runtime-checkable Protocol.
* :class:`ScriptedPolicy` β€” yields a fixed list of completions in order.
Useful for tests and for building oracle trajectories during rejection-
sampling SFT (PROPOSAL.md Β§7.4 plan B).
* :class:`HfPolicy` β€” wraps an HF causal LM + tokenizer; the real thing.
Defined here so consumers can swap it in once we hook up Qwen, but
deliberately not imported at module-load time.
"""
from __future__ import annotations
from typing import Iterator, Protocol, runtime_checkable
from graphforge.training.prompt import Message
@runtime_checkable
class Policy(Protocol):
def sample(self, messages: list[Message]) -> str: ...
# ---- scripted -------------------------------------------------------
class ScriptedPolicy:
"""Returns each item of ``completions`` in order.
If the rollout asks for more turns than there are scripted completions,
raises :class:`StopIteration` β€” that's a test bug, not an env bug.
"""
def __init__(self, completions: list[str]) -> None:
self._iter: Iterator[str] = iter(completions)
self._n = len(completions)
def sample(self, _messages: list[Message]) -> str:
return next(self._iter)
# ---- HF (lazy) ------------------------------------------------------
class HfPolicy:
"""A real LM-backed policy. Imports torch / transformers lazily.
Constructor args::
model β€” a HF AutoModelForCausalLM
tokenizer β€” the matching tokenizer
max_new_tokens β€” generation cap per turn (PROPOSAL.md Β§7.1: 384)
temperature, top_p β€” sampling knobs
"""
def __init__(
self,
model: object,
tokenizer: object,
*,
max_new_tokens: int = 384,
temperature: float = 0.7,
top_p: float = 0.95,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
def sample(self, messages: list[Message]) -> str:
# Defer heavy imports.
import torch # noqa: F401 β€” required for inputs / device
# Critical for trained-eval correctness: ensure the model is in
# eval mode (no dropout) and that KV-cache is enabled (post-SFT,
# gradient checkpointing may have set use_cache=False).
self.model.eval() # type: ignore[attr-defined]
if hasattr(self.model, "config"):
self.model.config.use_cache = True # type: ignore[attr-defined]
tok = self.tokenizer
# Render to text first, then tokenize. ``apply_chat_template`` 's
# return type drifted across transformers versions (sometimes a raw
# tensor, sometimes a BatchEncoding); going through ``tok(text)`` is
# the canonical pattern and works on all of them.
text = tok.apply_chat_template( # type: ignore[attr-defined]
messages, add_generation_prompt=True, tokenize=False
)
inputs = tok(text, return_tensors="pt") # type: ignore[operator]
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # type: ignore[attr-defined]
with torch.no_grad():
out_ids = self.model.generate( # type: ignore[attr-defined]
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=True,
temperature=self.temperature,
top_p=self.top_p,
pad_token_id=tok.eos_token_id, # type: ignore[attr-defined]
use_cache=True,
)
prompt_len = inputs["input_ids"].shape[-1]
gen = out_ids[0, prompt_len:]
return tok.decode(gen, skip_special_tokens=True) # type: ignore[attr-defined]