File size: 13,810 Bytes
0b89610
 
 
 
 
 
 
f7594d7
0b89610
222f8ce
0b89610
 
 
 
6cad4bb
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2599a77
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7594d7
99fe20f
0b89610
 
 
f7594d7
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cad4bb
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b308a54
f7594d7
 
0b89610
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
Baseline inference runner for the incident operations HTTP environment.
"""

from __future__ import annotations

import asyncio
import json
import os
import re
import sys
from typing import Any

import httpx
from openai import OpenAI

from rag_optimizer_env.models import RagAction

ENV_NAME = "incident-ops-env"
ALLOW_BASELINE_FALLBACK = os.getenv("ALLOW_BASELINE_FALLBACK", "").strip().lower() in {"1", "true", "yes"}
RAG_ENV_TASK = os.getenv("RAG_ENV_TASK", "refund_triage_easy")
RAG_ENV_URL = os.getenv("RAG_ENV_URL", "http://localhost:7860")
TASK_SEQUENCE = [
    "refund_triage_easy",
    "cross_function_brief_medium",
    "executive_escalation_hard",
]

SYSTEM_PROMPT = """You are a baseline incident operations agent.
You are handling an enterprise case through staged workflow actions.
Your job is to inspect the right artifacts, prioritize the evidence that belongs in the working set,
draft a short operational plan, summarize heavy artifacts when needed, and finally submit a grounded report.
Return only valid JSON matching one of these forms:
{"action_type":"inspect_artifact","artifact_id":"support_003"}
{"action_type":"prioritize_artifact","artifact_id":"support_003"}
{"action_type":"summarize_artifact","artifact_id":"support_003","compression_ratio":0.55}
{"action_type":"set_resolution_plan","plan":"Verify outage evidence, confirm the billing ledger, and route manual exceptions to finance review."}
{"action_type":"submit_report","answer":"Proceed to refund review only after outage and billing evidence are confirmed. [support_001] [support_003]"}
Legacy aliases like select_chunk, compress_chunk, and submit_answer are also accepted, but prefer the new workflow actions."""

DEFAULT_LEGACY_BASE_URL = "https://router.huggingface.co/v1"
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"


def _model_name() -> str:
    return os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)


def _resolve_llm_credentials() -> tuple[str | None, str | None, str | None]:
    api_base_url = os.getenv("API_BASE_URL", DEFAULT_LEGACY_BASE_URL)
    api_key = os.getenv("API_KEY")
    legacy_token = os.getenv("HF_TOKEN")
    if api_key:
        return api_base_url, api_key, "proxy"
    if legacy_token:
        return api_base_url, legacy_token, "legacy"
    return None, None, None


def _format_bool(value: bool) -> str:
    return "true" if value else "false"


def _format_reward(value: float | None) -> str:
    return "0.00" if value is None else f"{value:.2f}"


def _format_error(error: str | None) -> str:
    return "null" if not error else error.replace("\n", " ").strip()


def _clamp_score(value: float) -> float:
    return max(0.0, min(1.0, value))


def _format_rewards(rewards: list[float]) -> str:
    return ",".join(f"{reward:.2f}" for reward in rewards)


def _format_action(action: dict[str, Any]) -> str:
    return json.dumps(action, ensure_ascii=True, separators=(",", ":"))


def _extract_json_object(text: str) -> dict[str, Any]:
    payload = text.strip()
    try:
        return json.loads(payload)
    except json.JSONDecodeError:
        match = re.search(r"\{.*\}", payload, re.DOTALL)
        if not match:
            raise
        return json.loads(match.group(0))


def _tokenize(text: str) -> set[str]:
    return set(re.findall(r"[a-z0-9]+", text.lower()))


def _keyword_overlap(query: str, chunk: dict[str, Any]) -> float:
    query_terms = _tokenize(query)
    keyword_terms = _tokenize(" ".join(chunk.get("keywords", [])))
    if not query_terms or not keyword_terms:
        return 0.0
    union = query_terms | keyword_terms
    return (len(query_terms & keyword_terms) / len(union)) if union else 0.0


def _fallback_report(observation: dict[str, Any]) -> str:
    prioritized = set(observation.get("prioritized_artifacts") or observation.get("selected_chunks", []))
    snippets: list[str] = []
    for chunk in observation.get("available_artifacts") or observation.get("available_chunks", []):
        if chunk.get("chunk_id") in prioritized:
            keywords = ", ".join(chunk.get("keywords", [])[:3])
            snippets.append(f"[{chunk['chunk_id']}] covers {keywords}")
    if not snippets:
        return "The case needs a defensible operational recommendation grounded in reviewed incident artifacts."
    return "; ".join(snippets[:3]) + "."


def _fallback_plan(observation: dict[str, Any]) -> str:
    task_name = observation.get("task_name", "")
    if task_name == "refund_triage_easy":
        return "Verify outage evidence, confirm the billing ledger, and route manual exceptions to finance review."
    if task_name == "cross_function_brief_medium":
        return "Align the incident timeline, customer communications, and rollback guardrails before publishing the brief."
    return "Revoke active risk, protect customers, preserve evidence, and keep change freeze safeguards in place."


def _fallback_action(observation: dict[str, Any]) -> dict[str, Any]:
    reviewed = set(observation.get("reviewed_artifacts", []))
    prioritized = set(observation.get("prioritized_artifacts") or observation.get("selected_chunks", []))
    available = list(observation.get("available_artifacts") or observation.get("available_chunks", []))
    token_budget = observation["token_budget"]
    total_tokens_used = observation["total_tokens_used"]
    remaining_budget = token_budget - total_tokens_used

    ranked = sorted(
        available,
        key=lambda chunk: (-_keyword_overlap(observation["query"], chunk), chunk["tokens"], chunk["chunk_id"]),
    )

    unprioritized_reviewed = [chunk for chunk in ranked if chunk["chunk_id"] in reviewed and chunk["chunk_id"] not in prioritized]
    for chunk in unprioritized_reviewed:
        if chunk["tokens"] <= remaining_budget:
            return {"action_type": "prioritize_artifact", "artifact_id": chunk["chunk_id"]}

    unseen = [chunk for chunk in ranked if chunk["chunk_id"] not in reviewed]
    if unseen:
        if len(reviewed) >= 2:
            unseen = unseen[:1]
        return {"action_type": "inspect_artifact", "artifact_id": unseen[0]["chunk_id"]}

    if prioritized and not observation.get("plan_draft"):
        return {"action_type": "set_resolution_plan", "plan": _fallback_plan(observation)}

    heavy_prioritized = [chunk for chunk in ranked if chunk["chunk_id"] in prioritized and chunk["tokens"] >= max(120, token_budget // 4)]
    if heavy_prioritized and total_tokens_used >= int(token_budget * 0.7):
        return {"action_type": "summarize_artifact", "artifact_id": heavy_prioritized[0]["chunk_id"], "compression_ratio": 0.55}

    return {"action_type": "submit_report", "answer": _fallback_report(observation)}


def _build_user_prompt(observation: dict[str, Any]) -> str:
    payload = {
        "case_id": observation.get("case_id"),
        "case_summary": observation.get("case_summary"),
        "objective": observation.get("objective") or observation.get("query"),
        "workflow_stage": observation.get("workflow_stage"),
        "customer_tier": observation.get("customer_tier"),
        "incident_severity": observation.get("incident_severity"),
        "reviewed_artifacts": observation.get("reviewed_artifacts", []),
        "prioritized_artifacts": observation.get("prioritized_artifacts") or observation.get("selected_chunks", []),
        "plan_draft": observation.get("plan_draft"),
        "report_requirements": observation.get("report_requirements", []),
        "progress_signals": observation.get("progress_signals", {}),
        "total_tokens_used": observation["total_tokens_used"],
        "token_budget": observation["token_budget"],
        "step_number": observation["step_number"],
        "task_name": observation["task_name"],
        "last_action_feedback": observation.get("last_action_feedback"),
        "available_artifacts": [
            {
                "chunk_id": chunk["chunk_id"],
                "domain": chunk["domain"],
                "tokens": chunk["tokens"],
                "keywords": chunk["keywords"],
            }
            for chunk in (observation.get("available_artifacts") or observation.get("available_chunks", []))
        ],
    }
    return json.dumps(payload, ensure_ascii=True)


async def _llm_action(client: OpenAI, observation: dict[str, Any]) -> dict[str, Any]:
    prompt = _build_user_prompt(observation)
    model_name = _model_name()

    def _call() -> Any:
        return client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            response_format={"type": "json_object"},
            temperature=0,
        )

    response = await asyncio.to_thread(_call)
    content = response.choices[0].message.content or "{}"
    return _extract_json_object(content)


async def _post_json(http_client: httpx.AsyncClient, path: str, payload: dict[str, Any]) -> dict[str, Any]:
    response = await http_client.post(f"{RAG_ENV_URL}{path}", json=payload)
    if response.status_code >= 400:
        raise RuntimeError(f"{path} -> {response.status_code}: {response.text}")
    return response.json()


async def _run_task_http(task_name: str) -> tuple[float, list[float], int, bool]:
    rewards: list[float] = []
    steps = 0
    success = False
    score = 0.0
    terminal_error: str | None = None
    fallback_reason: str | None = None
    model_name = _model_name()

    print(f"[START] task={task_name} env={ENV_NAME} model={model_name}")

    api_base_url, client_api_key, auth_mode = _resolve_llm_credentials()
    llm_required = auth_mode in {"proxy", "legacy"}
    openai_client: Any | None = None

    if llm_required:
        openai_client = OpenAI(base_url=api_base_url, api_key=client_api_key)
    elif ALLOW_BASELINE_FALLBACK:
        fallback_reason = "missing_llm_credentials"
        print(
            f"[warn] Missing API_BASE_URL/API_KEY credentials; using deterministic fallback policy for {task_name}.",
            file=sys.stderr,
            flush=True,
        )
    else:
        print(
            "[warn] Missing API_BASE_URL/API_KEY credentials; aborting model-backed run. "
            "Set ALLOW_BASELINE_FALLBACK=1 only for offline smoke testing.",
            file=sys.stderr,
            flush=True,
        )
        print("[END] success=false steps=0 score=0.000 rewards=")
        return 0.0, [], 0, False

    try:
        async with httpx.AsyncClient(timeout=30.0) as http_client:
            reset_payload = await _post_json(http_client, "/reset", {"task_name": task_name})
            observation = reset_payload["observation"]

            while True:
                step_error: str | None = None
                try:
                    if openai_client is None:
                        raise RuntimeError("llm_unavailable")
                    llm_payload = await _llm_action(openai_client, observation)
                    action_payload = RagAction.model_validate(llm_payload).model_dump(exclude_none=True)
                except Exception as exc:
                    fallback_reason = fallback_reason or type(exc).__name__
                    if llm_required or not ALLOW_BASELINE_FALLBACK:
                        terminal_error = f"model_unavailable:{fallback_reason}"
                        print(f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}")
                        return score, rewards, steps, False
                    action_payload = RagAction.model_validate(_fallback_action(observation)).model_dump(exclude_none=True)

                try:
                    step_response = await _post_json(http_client, "/step", action_payload)
                except Exception as exc:
                    steps += 1
                    rewards.append(0.0)
                    terminal_error = str(exc)
                    print(f"[STEP] step={steps} action={_format_action(action_payload)} reward=0.00 done=true error={_format_error(terminal_error)}")
                    break

                steps += 1
                reward_value = step_response.get("reward")
                reward_float = float(reward_value) if reward_value is not None else 0.0
                rewards.append(reward_float)
                done = bool(step_response["done"])
                print(
                    f"[STEP] step={steps} action={_format_action(action_payload)} "
                    f"reward={_format_reward(reward_float)} done={_format_bool(done)} error={_format_error(step_error)}"
                )
                observation = step_response["observation"]
                if done:
                    score = _clamp_score(reward_float)
                    success = terminal_error is None
                    break

            print(f"[END] success={_format_bool(success)} steps={steps} score={score:.3f} rewards={_format_rewards(rewards)}")
            return score, rewards, steps, success
    except Exception:
        print(f"[END] success=false steps={steps} score={_clamp_score(score):.3f} rewards={_format_rewards(rewards)}")
        return score, rewards, steps, False


def run_task(task_name: str) -> tuple[float, list[float], int, bool]:
    return asyncio.run(_run_task_http(task_name))


def main() -> int:
    if RAG_ENV_TASK in TASK_SEQUENCE:
        tasks = [RAG_ENV_TASK] + [task for task in TASK_SEQUENCE if task != RAG_ENV_TASK]
    else:
        tasks = list(TASK_SEQUENCE)
    for task_name in tasks:
        try:
            run_task(task_name)
        except Exception:
            print(f"[START] task={task_name} env={ENV_NAME} model={_model_name()}")
            print("[END] success=false steps=0 score=0.000 rewards=")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())