File size: 7,618 Bytes
58af620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3648b5
 
 
 
 
 
 
 
 
 
 
 
58af620
 
 
c3648b5
 
 
 
 
 
 
 
 
 
 
 
58af620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""LLM-backed policy for the Incident Command Center environment.

Wraps any Hugging Face causal-LM (a base model OR a fine-tuned checkpoint)
into a callable that takes an ``IncidentObservation`` and returns a typed
``IncidentAction``. This is what turns a raw language model into an agent
that can act inside the environment.

Usage::

    from llm_policy import LLMPolicy
    policy = LLMPolicy("Qwen/Qwen2.5-0.5B-Instruct")
    action = policy.select_action(observation)

If the model emits invalid JSON, the policy degrades gracefully to a safe
default action (inspect the first log target) so one bad generation never
crashes a whole rollout.
"""

from __future__ import annotations

import json
import logging
import re
from typing import Any, Dict, Optional

from models import IncidentAction, IncidentObservation

_LOG = logging.getLogger("icc.llm_policy")

# Regex for the first balanced-ish JSON object in the model output.
# (Greedy `.*` inside `{...}` keeps nested braces intact for our tiny JSON.)
_JSON_RE = re.compile(r"\{[\s\S]*\}")


class LLMPolicy:
    """Policy that calls a HF causal-LM and parses its JSON action."""

    def __init__(
        self,
        model_name_or_path: str,
        *,
        device: Optional[str] = None,
        max_new_tokens: int = 160,
        temperature: float = 0.0,
        dtype: Optional[str] = None,
        label: Optional[str] = None,
    ) -> None:
        try:
            import torch
            from transformers import AutoModelForCausalLM, AutoTokenizer
        except ImportError as exc:  # pragma: no cover - runtime dep
            raise RuntimeError(
                "LLMPolicy requires `transformers` and `torch` installed. "
                "Run: pip install transformers torch"
            ) from exc

        self._torch = torch
        self.label = label or model_name_or_path
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

        resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if dtype is None:
            torch_dtype = torch.float16 if resolved_device == "cuda" else torch.float32
        else:
            torch_dtype = getattr(torch, dtype)

        _LOG.info(
            "Loading LLM policy %s on %s (dtype=%s)",
            model_name_or_path,
            resolved_device,
            torch_dtype,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # transformers renamed torch_dtype -> dtype; try new kwarg first and
        # fall back for older versions. Works silently on both.
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                dtype=torch_dtype,
            ).to(resolved_device)
        except TypeError:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                torch_dtype=torch_dtype,
            ).to(resolved_device)
        self.model.eval()
        self.device = resolved_device

        # Strip sampling-only fields from the shipped generation_config so
        # transformers doesn't warn "these flags will be ignored" when we
        # decode greedily (do_sample=False).
        gen_config = getattr(self.model, "generation_config", None)
        if gen_config is not None:
            for attr in ("temperature", "top_p", "top_k"):
                if hasattr(gen_config, attr):
                    try:
                        setattr(gen_config, attr, None)
                    except Exception:
                        pass

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def select_action(self, observation: IncidentObservation) -> IncidentAction:
        prompt_text = self._build_prompt_text(observation)
        response_text = self._generate(prompt_text)
        return self._parse_action(response_text, observation)

    # ------------------------------------------------------------------
    # Internals
    # ------------------------------------------------------------------

    def _build_prompt_text(self, observation: IncidentObservation) -> str:
        # Keep this import here to avoid importing the trainer stack when the
        # module is used for inference only.
        from train_trl import obs_to_prompt

        user_prompt = obs_to_prompt(observation)
        if getattr(self.tokenizer, "chat_template", None):
            messages = [{"role": "user", "content": user_prompt}]
            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        return f"User: {user_prompt}\n\nAssistant:"

    def _generate(self, prompt_text: str) -> str:
        torch = self._torch
        inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
        gen_kwargs: Dict[str, Any] = {
            "max_new_tokens": self.max_new_tokens,
            "pad_token_id": self.tokenizer.pad_token_id,
        }
        if self.temperature > 0:
            gen_kwargs.update(
                do_sample=True,
                temperature=self.temperature,
                top_p=0.9,
            )
        else:
            gen_kwargs["do_sample"] = False

        with torch.no_grad():
            output = self.model.generate(**inputs, **gen_kwargs)
        generated_ids = output[0][inputs["input_ids"].shape[1]:]
        return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    def _parse_action(
        self,
        response_text: str,
        observation: IncidentObservation,
    ) -> IncidentAction:
        json_match = _JSON_RE.search(response_text)
        if json_match:
            raw = json_match.group(0)
            # Qwen / Llama sometimes add trailing commentary; strip past the
            # last closing brace to give JSON parser a clean slice.
            last_close = raw.rfind("}")
            if last_close != -1:
                raw = raw[: last_close + 1]
            try:
                data = json.loads(raw)
                return IncidentAction.model_validate(data)
            except Exception as exc:
                _LOG.debug(
                    "LLM JSON parse failed: %s :: raw=%s",
                    exc,
                    raw[:200],
                )

        return self._safe_fallback(observation)

    def _safe_fallback(self, observation: IncidentObservation) -> IncidentAction:
        logs = (observation.investigation_targets or {}).get("logs", []) or []
        target = logs[0] if logs else "payments-api"
        return IncidentAction(
            actor="triage_agent",
            action_type="inspect_logs",
            target=target,
            reason="LLM output invalid; using safe fallback action.",
        )

    # ------------------------------------------------------------------
    # Resource cleanup
    # ------------------------------------------------------------------

    def release(self) -> None:
        """Free GPU memory so a second model can be loaded after this one."""
        try:
            import gc
            self.model = None  # type: ignore[assignment]
            self.tokenizer = None  # type: ignore[assignment]
            gc.collect()
            if self._torch.cuda.is_available():
                self._torch.cuda.empty_cache()
        except Exception:
            pass