File size: 7,618 Bytes
58af620 c3648b5 58af620 c3648b5 58af620 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """LLM-backed policy for the Incident Command Center environment.
Wraps any Hugging Face causal-LM (a base model OR a fine-tuned checkpoint)
into a callable that takes an ``IncidentObservation`` and returns a typed
``IncidentAction``. This is what turns a raw language model into an agent
that can act inside the environment.
Usage::
from llm_policy import LLMPolicy
policy = LLMPolicy("Qwen/Qwen2.5-0.5B-Instruct")
action = policy.select_action(observation)
If the model emits invalid JSON, the policy degrades gracefully to a safe
default action (inspect the first log target) so one bad generation never
crashes a whole rollout.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any, Dict, Optional
from models import IncidentAction, IncidentObservation
_LOG = logging.getLogger("icc.llm_policy")
# Regex for the first balanced-ish JSON object in the model output.
# (Greedy `.*` inside `{...}` keeps nested braces intact for our tiny JSON.)
_JSON_RE = re.compile(r"\{[\s\S]*\}")
class LLMPolicy:
"""Policy that calls a HF causal-LM and parses its JSON action."""
def __init__(
self,
model_name_or_path: str,
*,
device: Optional[str] = None,
max_new_tokens: int = 160,
temperature: float = 0.0,
dtype: Optional[str] = None,
label: Optional[str] = None,
) -> None:
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as exc: # pragma: no cover - runtime dep
raise RuntimeError(
"LLMPolicy requires `transformers` and `torch` installed. "
"Run: pip install transformers torch"
) from exc
self._torch = torch
self.label = label or model_name_or_path
self.max_new_tokens = max_new_tokens
self.temperature = temperature
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if dtype is None:
torch_dtype = torch.float16 if resolved_device == "cuda" else torch.float32
else:
torch_dtype = getattr(torch, dtype)
_LOG.info(
"Loading LLM policy %s on %s (dtype=%s)",
model_name_or_path,
resolved_device,
torch_dtype,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# transformers renamed torch_dtype -> dtype; try new kwarg first and
# fall back for older versions. Works silently on both.
try:
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
dtype=torch_dtype,
).to(resolved_device)
except TypeError:
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
).to(resolved_device)
self.model.eval()
self.device = resolved_device
# Strip sampling-only fields from the shipped generation_config so
# transformers doesn't warn "these flags will be ignored" when we
# decode greedily (do_sample=False).
gen_config = getattr(self.model, "generation_config", None)
if gen_config is not None:
for attr in ("temperature", "top_p", "top_k"):
if hasattr(gen_config, attr):
try:
setattr(gen_config, attr, None)
except Exception:
pass
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def select_action(self, observation: IncidentObservation) -> IncidentAction:
prompt_text = self._build_prompt_text(observation)
response_text = self._generate(prompt_text)
return self._parse_action(response_text, observation)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _build_prompt_text(self, observation: IncidentObservation) -> str:
# Keep this import here to avoid importing the trainer stack when the
# module is used for inference only.
from train_trl import obs_to_prompt
user_prompt = obs_to_prompt(observation)
if getattr(self.tokenizer, "chat_template", None):
messages = [{"role": "user", "content": user_prompt}]
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return f"User: {user_prompt}\n\nAssistant:"
def _generate(self, prompt_text: str) -> str:
torch = self._torch
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": self.max_new_tokens,
"pad_token_id": self.tokenizer.pad_token_id,
}
if self.temperature > 0:
gen_kwargs.update(
do_sample=True,
temperature=self.temperature,
top_p=0.9,
)
else:
gen_kwargs["do_sample"] = False
with torch.no_grad():
output = self.model.generate(**inputs, **gen_kwargs)
generated_ids = output[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
def _parse_action(
self,
response_text: str,
observation: IncidentObservation,
) -> IncidentAction:
json_match = _JSON_RE.search(response_text)
if json_match:
raw = json_match.group(0)
# Qwen / Llama sometimes add trailing commentary; strip past the
# last closing brace to give JSON parser a clean slice.
last_close = raw.rfind("}")
if last_close != -1:
raw = raw[: last_close + 1]
try:
data = json.loads(raw)
return IncidentAction.model_validate(data)
except Exception as exc:
_LOG.debug(
"LLM JSON parse failed: %s :: raw=%s",
exc,
raw[:200],
)
return self._safe_fallback(observation)
def _safe_fallback(self, observation: IncidentObservation) -> IncidentAction:
logs = (observation.investigation_targets or {}).get("logs", []) or []
target = logs[0] if logs else "payments-api"
return IncidentAction(
actor="triage_agent",
action_type="inspect_logs",
target=target,
reason="LLM output invalid; using safe fallback action.",
)
# ------------------------------------------------------------------
# Resource cleanup
# ------------------------------------------------------------------
def release(self) -> None:
"""Free GPU memory so a second model can be loaded after this one."""
try:
import gc
self.model = None # type: ignore[assignment]
self.tokenizer = None # type: ignore[assignment]
gc.collect()
if self._torch.cuda.is_available():
self._torch.cuda.empty_cache()
except Exception:
pass
|