DevikaJ2005 commited on
Commit
45cf526
·
1 Parent(s): f056f9f

Align local policy prompts and canonical action targets

Browse files
Files changed (3) hide show
  1. environment.py +45 -27
  2. llm_agent_openai.py +4 -20
  3. train.py +2 -2
environment.py CHANGED
@@ -12,6 +12,14 @@ from models import ActionTypeEnum, FraudCheckAction, ResolutionEnum
12
  from reward import RewardBreakdown, build_reward_breakdown
13
  from utils import approximate_token_count, extract_json_object
14
 
 
 
 
 
 
 
 
 
15
  INVESTIGATION_ALIAS_TO_ACTION = {
16
  "merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
17
  "fetch_merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
@@ -27,6 +35,42 @@ INVESTIGATION_ALIAS_TO_ACTION = {
27
  "check_policy": ActionTypeEnum.CHECK_POLICY,
28
  }
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @dataclass
32
  class TextStepResult:
@@ -70,33 +114,7 @@ class FraudShieldTextEnvironment:
70
 
71
  def build_prompt(self, observation) -> str:
72
  """Build the prompt shown to an LLM policy."""
73
-
74
- payload = {
75
- "case_id": observation.case_id,
76
- "task_name": observation.task_name.value,
77
- "visible_panels": observation.visible_panels,
78
- "revealed_evidence": observation.revealed_evidence,
79
- "linked_case_ids": observation.linked_case_ids,
80
- "remaining_steps": observation.remaining_steps,
81
- "remaining_sla": observation.remaining_sla,
82
- "note_required": observation.note_required,
83
- "allowed_actions": [action.value for action in observation.allowed_actions],
84
- "case_summary": observation.case_summary.model_dump(mode="json"),
85
- "app_context": observation.app_context,
86
- }
87
- available = observation.app_context.get(
88
- "available_investigations",
89
- ["merchant_profile", "customer_profile", "network_graph", "payment_trace", "policy_review"],
90
- )
91
- return (
92
- "You are a fraud analyst in a multi-step training environment. "
93
- "Return JSON only. Use visible evidence, investigation budget, and prior evidence carefully.\n\n"
94
- f"Visible observation:\n{json.dumps(payload, sort_keys=True)}\n\n"
95
- f"Valid investigation aliases: {available}\n"
96
- "JSON schema: "
97
- '{"action_type":"investigate|decide","investigation_target":"alias_or_null",'
98
- '"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
99
- )
100
 
101
  def parse_response(self, response_text: str) -> tuple[FraudCheckAction, dict[str, Any], bool, bool]:
102
  """Convert model output into a typed environment action."""
 
12
  from reward import RewardBreakdown, build_reward_breakdown
13
  from utils import approximate_token_count, extract_json_object
14
 
15
+ CANONICAL_INVESTIGATION_ALIASES = [
16
+ "merchant_profile",
17
+ "customer_profile",
18
+ "network_graph",
19
+ "payment_trace",
20
+ "policy_review",
21
+ ]
22
+
23
  INVESTIGATION_ALIAS_TO_ACTION = {
24
  "merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
25
  "fetch_merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
 
35
  "check_policy": ActionTypeEnum.CHECK_POLICY,
36
  }
37
 
38
+ ACTION_TYPE_TO_CANONICAL_ALIAS = {
39
+ ActionTypeEnum.FETCH_MERCHANT_PROFILE: "merchant_profile",
40
+ ActionTypeEnum.FETCH_CUSTOMER_PROFILE: "customer_profile",
41
+ ActionTypeEnum.FETCH_NETWORK_GRAPH: "network_graph",
42
+ ActionTypeEnum.REVIEW_TRANSACTION: "payment_trace",
43
+ ActionTypeEnum.CHECK_POLICY: "policy_review",
44
+ }
45
+
46
+
47
+ def build_fraudshield_prompt(observation) -> str:
48
+ """Build the canonical prompt used for both training and inference."""
49
+
50
+ payload = {
51
+ "case_id": observation.case_id,
52
+ "task_name": observation.task_name.value,
53
+ "visible_panels": observation.visible_panels,
54
+ "revealed_evidence": observation.revealed_evidence,
55
+ "linked_case_ids": observation.linked_case_ids,
56
+ "remaining_steps": observation.remaining_steps,
57
+ "remaining_sla": observation.remaining_sla,
58
+ "note_required": observation.note_required,
59
+ "allowed_actions": [action.value for action in observation.allowed_actions],
60
+ "case_summary": observation.case_summary.model_dump(mode="json"),
61
+ "app_context": observation.app_context,
62
+ }
63
+ available = observation.app_context.get("available_investigations", CANONICAL_INVESTIGATION_ALIASES)
64
+ return (
65
+ "You are a fraud analyst in a multi-step training environment. "
66
+ "Return JSON only. Use visible evidence, investigation budget, and prior evidence carefully.\n\n"
67
+ f"Visible observation:\n{json.dumps(payload, sort_keys=True)}\n\n"
68
+ f"Valid investigation aliases: {available}\n"
69
+ "JSON schema: "
70
+ '{"action_type":"investigate|decide","investigation_target":"alias_or_null",'
71
+ '"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
72
+ )
73
+
74
 
75
  @dataclass
76
  class TextStepResult:
 
114
 
115
  def build_prompt(self, observation) -> str:
116
  """Build the prompt shown to an LLM policy."""
117
+ return build_fraudshield_prompt(observation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def parse_response(self, response_text: str) -> tuple[FraudCheckAction, dict[str, Any], bool, bool]:
120
  """Convert model output into a typed environment action."""
llm_agent_openai.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  from pathlib import Path
9
  from typing import Any, Dict, Optional
10
 
 
11
  from models import ActionTypeEnum, FraudCheckAction, ResolutionEnum, TaskDifficulty
12
 
13
  try: # pragma: no cover - optional in local smoke tests
@@ -95,21 +96,7 @@ class LLMFraudDetectionAgent:
95
 
96
  def _build_messages(self, observation) -> list[Dict[str, str]]:
97
  available_aliases = self._available_investigation_aliases(observation)
98
- observation_payload = {
99
- "case_id": observation.case_id,
100
- "task_name": observation.task_name.value,
101
- "current_screen": observation.current_screen.value,
102
- "visible_panels": observation.visible_panels,
103
- "case_summary": observation.case_summary.model_dump(mode="json"),
104
- "revealed_evidence": observation.revealed_evidence,
105
- "linked_case_ids": observation.linked_case_ids,
106
- "remaining_steps": observation.remaining_steps,
107
- "remaining_sla": observation.remaining_sla,
108
- "note_required": observation.note_required,
109
- "allowed_public_actions": [action.value for action in observation.allowed_actions],
110
- "available_investigation_aliases": available_aliases,
111
- "app_context": observation.app_context,
112
- }
113
  system_prompt = (
114
  "You are a fraud analyst operating inside a simulated investigation workflow. "
115
  "Only use the visible evidence shown to you. Choose either one investigation alias or one final "
@@ -121,7 +108,7 @@ class LLMFraudDetectionAgent:
121
  )
122
  return [
123
  {"role": "system", "content": system_prompt},
124
- {"role": "user", "content": json.dumps(observation_payload, separators=(",", ":"))},
125
  ]
126
 
127
  def _payload_to_action(self, payload: Dict[str, Any], observation) -> FraudCheckAction:
@@ -316,10 +303,7 @@ class LocalModelFraudDetectionAgent(LLMFraudDetectionAgent):
316
  return self._fallback(observation, exc)
317
 
318
  def _build_local_prompt(self, observation) -> str:
319
- messages = self._build_messages(observation)
320
- if hasattr(self.tokenizer, "apply_chat_template"):
321
- return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
322
- return "\n".join(f"{message['role'].upper()}: {message['content']}" for message in messages)
323
 
324
  def _load_model(self) -> None:
325
  try:
 
8
  from pathlib import Path
9
  from typing import Any, Dict, Optional
10
 
11
+ from environment import ACTION_TYPE_TO_CANONICAL_ALIAS, build_fraudshield_prompt
12
  from models import ActionTypeEnum, FraudCheckAction, ResolutionEnum, TaskDifficulty
13
 
14
  try: # pragma: no cover - optional in local smoke tests
 
96
 
97
  def _build_messages(self, observation) -> list[Dict[str, str]]:
98
  available_aliases = self._available_investigation_aliases(observation)
99
+ prompt = build_fraudshield_prompt(observation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  system_prompt = (
101
  "You are a fraud analyst operating inside a simulated investigation workflow. "
102
  "Only use the visible evidence shown to you. Choose either one investigation alias or one final "
 
108
  )
109
  return [
110
  {"role": "system", "content": system_prompt},
111
+ {"role": "user", "content": prompt},
112
  ]
113
 
114
  def _payload_to_action(self, payload: Dict[str, Any], observation) -> FraudCheckAction:
 
303
  return self._fallback(observation, exc)
304
 
305
  def _build_local_prompt(self, observation) -> str:
306
+ return build_fraudshield_prompt(observation) + "\n"
 
 
 
307
 
308
  def _load_model(self) -> None:
309
  try:
train.py CHANGED
@@ -13,7 +13,7 @@ import pandas as pd
13
  from datasets import Dataset, load_dataset
14
 
15
  from config import ExperimentConfig
16
- from environment import FraudShieldTextEnvironment
17
  from models import ActionTypeEnum, FraudCheckAction
18
  from utils import ensure_dir, save_json, seed_everything
19
 
@@ -171,7 +171,7 @@ def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
171
  action = agent.decide(text_env)
172
  payload = {
173
  "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
174
- "investigation_target": action.action_type.value,
175
  "decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
176
  "confidence": 0.8,
177
  "reasoning": action.reasoning or "Training rollout step.",
 
13
  from datasets import Dataset, load_dataset
14
 
15
  from config import ExperimentConfig
16
+ from environment import ACTION_TYPE_TO_CANONICAL_ALIAS, FraudShieldTextEnvironment
17
  from models import ActionTypeEnum, FraudCheckAction
18
  from utils import ensure_dir, save_json, seed_everything
19
 
 
171
  action = agent.decide(text_env)
172
  payload = {
173
  "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
174
+ "investigation_target": ACTION_TYPE_TO_CANONICAL_ALIAS.get(action.action_type),
175
  "decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
176
  "confidence": 0.8,
177
  "reasoning": action.reasoning or "Training rollout step.",