mathi3046 commited on
Commit
7acbefe
Β·
1 Parent(s): 562c0a4

fix: rewrite grader with reference scoring pattern and clamp all reward fields

Browse files

- Adopted additive scoring: category(+0.3), empathy(+0.1/0.2), helpfulness(+0.3), resolution(+0.2), efficiency bonus
- Added penalties: angry(-0.25), generic(-0.1), repetition(-0.2), escalation(-0.1)
- CRITICAL: clamped observation.reward to avg instead of raw cumulative (was exceeding 1.0)
- CRITICAL: clamped cumulative_reward in info dict
- Pass action_type, step_count, max_steps to grader for resolution/efficiency scoring

Files changed (2) hide show
  1. grader.py +114 -301
  2. server/environment.py +5 -2
grader.py CHANGED
@@ -1,10 +1,16 @@
1
  """
2
  Deterministic grading engine for the Customer Support Environment.
3
 
4
- Evaluates agent responses on three axes:
5
- - Correctness (keyword / concept matching)
6
- - Tone (positive vs. negative signal detection)
7
- - Completeness (checklist of required response elements)
 
 
 
 
 
 
8
 
9
  Returns a RewardBreakdown with a total score in (0.0, 1.0) β€” strict open interval.
10
 
@@ -27,328 +33,135 @@ def _normalise(text: str) -> str:
27
  return re.sub(r"\s+", " ", text.strip().lower())
28
 
29
 
30
- # ──────────────────────────────────────────────────────────────────
31
- # Correctness scorer
32
- # ──────────────────────────────────────────────────────────────────
33
-
34
- def _score_correctness(
35
- response: str,
36
- rubric: Dict[str, Any],
37
- ) -> float:
38
- """Score based on presence of expected keyword groups.
39
-
40
- Returns a value in (0, 1) β€” never 0.0 or 1.0.
41
- """
42
- norm = _normalise(response)
43
- criteria = rubric.get("criteria", [])
44
- if not criteria:
45
- return safe_score(0.1)
46
-
47
- total = 0.0
48
- for criterion in criteria:
49
- kw_group: List[str] = criterion.get("keyword_group", [])
50
- points: float = criterion.get("points", 0.0)
51
- if any(kw.lower() in norm for kw in kw_group):
52
- total += points
53
-
54
- return safe_score(total)
55
-
56
-
57
- # ──────────────────────────────────────────────────────────────────
58
- # Tone scorer
59
- # ──────────────────────────────────────────────────────────────────
60
-
61
- def _score_tone(
62
- response: str,
63
- rubric: Dict[str, Any],
64
- ) -> float:
65
- """
66
- Score tone based on positive and negative signal presence.
67
- Start at 0.5, boost for positive signals, penalize for negative signals.
68
-
69
- Returns a value in (0, 1) β€” never 0.0 or 1.0.
70
- """
71
- norm = _normalise(response)
72
- criteria = rubric.get("criteria", {})
73
-
74
- positive_signals: List[str] = criteria.get("positive_signals", [])
75
- negative_signals: List[str] = criteria.get("negative_signals", [])
76
-
77
- pos_count = sum(1 for sig in positive_signals if sig.lower() in norm)
78
- neg_count = sum(1 for sig in negative_signals if sig.lower() in norm)
79
-
80
- score = 0.5
81
-
82
- if positive_signals:
83
- pos_ratio = pos_count / len(positive_signals)
84
- score += pos_ratio * 0.4
85
-
86
- if neg_count > 0:
87
- score -= min(neg_count * 0.2, 0.4)
88
-
89
- word_count = len(norm.split())
90
- if word_count < 10:
91
- score -= 0.1
92
-
93
- upper_ratio = sum(1 for c in response if c.isupper()) / max(len(response), 1)
94
- if upper_ratio > 0.4 and len(response) > 20:
95
- score -= 0.05
96
-
97
- return safe_score(score)
98
-
99
-
100
- # ──────────────────────────────────────────────────────────────────
101
- # Completeness scorer
102
- # ──────────────────────────────────────────────────────────────────
103
-
104
- def _score_completeness(
105
- response: str,
106
- rubric: Dict[str, Any],
107
- ticket_info: Dict[str, Any],
108
- conversation_history: List[Dict[str, Any]],
109
- ) -> float:
110
- """Score based on completeness checklist.
111
-
112
- Returns a value in (0, 1) β€” never 0.0 or 1.0.
113
- """
114
- norm = _normalise(response)
115
- criteria = rubric.get("criteria", [])
116
- if not criteria:
117
- return safe_score(0.1)
118
-
119
- total = 0.0
120
- for criterion in criteria:
121
- check = criterion.get("check", "")
122
- points = criterion.get("points", 0.0)
123
-
124
- if check == "addresses_question" or check == "addresses_defect":
125
- subject = _normalise(ticket_info.get("subject", ""))
126
- subject_words = [w for w in subject.split() if len(w) > 3]
127
- if any(w in norm for w in subject_words) or len(norm.split()) > 20:
128
- total += points
129
-
130
- elif check == "provides_next_steps":
131
- step_indicators = [
132
- "will", "can", "please", "next step", "process",
133
- "we'll", "i'll", "going to", "let me", "i can",
134
- "here's what", "here is what", "follow up",
135
- ]
136
- if any(ind in norm for ind in step_indicators):
137
- total += points
138
-
139
- elif check == "references_order":
140
- order_id = ticket_info.get("order_id", "")
141
- if order_id and order_id.lower() in norm:
142
- total += points
143
- elif "order" in norm:
144
- total += points * 0.5
145
-
146
- elif check == "explains_policy":
147
- policy_terms = [
148
- "policy", "within", "days", "eligible", "qualify",
149
- "terms", "condition", "guideline",
150
- ]
151
- if sum(1 for t in policy_terms if t in norm) >= 2:
152
- total += points
153
-
154
- elif check == "provides_process":
155
- process_terms = [
156
- "step", "first", "then", "send", "ship", "return",
157
- "label", "process", "receive", "refund",
158
- ]
159
- if sum(1 for t in process_terms if t in norm) >= 3:
160
- total += points
161
-
162
- elif check == "offers_options":
163
- option_indicators = ["or", "option", "alternative", "either", "choose", "prefer"]
164
- if any(ind in norm for ind in option_indicators):
165
- total += points
166
-
167
- elif check == "acknowledges_all_issues":
168
- issues_to_address = ["wrong", "late", "delay", "rude", "staff", "agent"]
169
- addressed = sum(1 for iss in issues_to_address if iss in norm)
170
- if addressed >= 3:
171
- total += points
172
- elif addressed >= 2:
173
- total += points * 0.6
174
- elif addressed >= 1:
175
- total += points * 0.3
176
-
177
- elif check == "concrete_resolution":
178
- concrete_terms = [
179
- "refund", "replacement", "ship", "send", "credit",
180
- "discount", "expedite", "priority", "immediately",
181
- "right away", "today",
182
- ]
183
- if sum(1 for t in concrete_terms if t in norm) >= 2:
184
- total += points
185
-
186
- elif check == "timeline":
187
- time_patterns = [
188
- r"\d+\s*(hour|day|week|business day)",
189
- r"within\s+\d+",
190
- r"by\s+(end of|tomorrow|today)",
191
- r"immediately",
192
- r"right away",
193
- r"asap",
194
- r"as soon as",
195
- ]
196
- if any(re.search(pat, norm) for pat in time_patterns):
197
- total += points
198
-
199
- elif check == "empathy":
200
- empathy_terms = [
201
- "understand", "frustrat", "sorry", "apologize",
202
- "inconvenience", "disappoint", "concern",
203
- "appreciate your patience", "i hear you",
204
- ]
205
- if sum(1 for t in empathy_terms if t in norm) >= 2:
206
- total += points
207
-
208
- elif check == "follow_up_plan":
209
- follow_up_terms = [
210
- "follow up", "follow-up", "check back", "update you",
211
- "keep you informed", "contact you", "reach out",
212
- "email you", "confirmation",
213
- ]
214
- if any(t in norm for t in follow_up_terms):
215
- total += points
216
-
217
- return safe_score(total)
218
-
219
-
220
- # ──────────────────────────────────────────────────────────────────
221
- # Penalty computation
222
- # ──────────────────────────────────────────────────────────────────
223
-
224
- def _compute_penalties(
225
- response: str,
226
- conversation_history: List[Dict[str, Any]],
227
- ) -> float:
228
- """
229
- Compute penalties for bad behaviours.
230
- Returns a negative value in [-0.5, 0.0].
231
- """
232
- norm = _normalise(response)
233
- penalty = 0.0
234
-
235
- if len(norm.split()) < 5:
236
- penalty -= 0.2
237
-
238
- if conversation_history:
239
- prev_agent_msgs = [
240
- _normalise(m.get("content", ""))
241
- for m in conversation_history
242
- if m.get("role") == "agent"
243
- ]
244
- for prev in prev_agent_msgs:
245
- if prev and norm == prev:
246
- penalty -= 0.2
247
- break
248
- elif prev and len(prev) > 20 and prev in norm:
249
- penalty -= 0.1
250
- break
251
-
252
- harmful_patterns = [
253
- "kill", "die", "hate you", "shut up", "idiot",
254
- "moron", "loser", "go away",
255
- ]
256
- if any(pat in norm for pat in harmful_patterns):
257
- penalty -= 0.3
258
-
259
- irrelevant_signals = [
260
- "weather", "recipe", "joke", "game score",
261
- "political", "stock market",
262
- ]
263
- if sum(1 for s in irrelevant_signals if s in norm) >= 2:
264
- penalty -= 0.3
265
-
266
- return max(-0.5, penalty)
267
-
268
-
269
- # ────────────────────────────────────────────────────────��─────────
270
- # Main grading function
271
- # ──────────────────────────────────────────────────────────────────
272
-
273
  def grade_response(
274
  response: str,
275
  grading_rubric: Dict[str, Any],
276
  ticket_info: Dict[str, Any],
277
  conversation_history: List[Dict[str, Any]],
 
 
 
278
  ) -> RewardBreakdown:
279
  """
280
- Grade an agent response and return a detailed RewardBreakdown.
281
 
282
  Args:
283
  response: The agent's response text
284
  grading_rubric: Task-specific grading criteria
285
  ticket_info: Ticket metadata
286
  conversation_history: Previous messages
 
 
 
287
 
288
  Returns:
289
  RewardBreakdown with ALL scores in strict (0.0, 1.0) open interval.
290
- The RewardBreakdown model auto-clamps all score fields via validators.
291
  """
292
- # Score each axis β€” safe_score guarantees (0, 1)
293
- correctness = safe_score(_score_correctness(
294
- response,
295
- grading_rubric.get("correctness", {}),
296
- ))
297
- tone = safe_score(_score_tone(
298
- response,
299
- grading_rubric.get("tone", {}),
300
- ))
301
- completeness = safe_score(_score_completeness(
302
- response,
303
- grading_rubric.get("completeness", {}),
304
- ticket_info,
305
- conversation_history,
306
- ))
307
 
308
- # Get weights
309
- w_correctness = grading_rubric.get("correctness", {}).get("weight", 0.33)
310
- w_tone = grading_rubric.get("tone", {}).get("weight", 0.33)
311
- w_completeness = grading_rubric.get("completeness", {}).get("weight", 0.34)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
- # Compute penalties (capped at -0.5)
314
- penalties = _compute_penalties(response, conversation_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- # Weighted total (before penalties)
317
- weighted = (
318
- correctness * w_correctness
319
- + tone * w_tone
320
- + completeness * w_completeness
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
-
323
- # Apply penalties β€” safe_score guarantees strict (0, 1)
324
- total = safe_score(weighted + penalties)
325
-
326
- # The efficiency field re-uses the weighted pre-penalty score
327
- efficiency = safe_score(weighted)
328
-
329
- # Debug logging
330
- logger.info(
331
- f"[GRADER] correctness={correctness:.4f} tone={tone:.4f} "
332
- f"completeness={completeness:.4f} weighted={weighted:.4f} "
333
- f"penalties={penalties:.4f} total={total:.4f}"
334
  )
 
 
 
 
335
 
336
  # Build explanation
337
- parts = []
338
- parts.append(f"Correctness: {correctness:.4f} (weight={w_correctness:.2f})")
339
- parts.append(f"Tone: {tone:.4f} (weight={w_tone:.2f})")
340
- parts.append(f"Completeness: {completeness:.4f} (weight={w_completeness:.2f})")
341
- if penalties < 0:
342
- parts.append(f"Penalties: {penalties:.4f}")
343
- parts.append(f"Total: {total:.4f}")
344
 
345
- # RewardBreakdown auto-clamps all score fields via field_validator
346
  return RewardBreakdown(
347
- correctness=correctness,
348
- tone=tone,
349
- completeness=completeness,
350
- efficiency=efficiency,
351
- penalties=round(penalties, 4),
352
- total=total,
353
  explanation=" | ".join(parts),
354
  )
 
1
  """
2
  Deterministic grading engine for the Customer Support Environment.
3
 
4
+ Follows the reference additive scoring pattern:
5
+ - Category/keyword correctness (+0.3)
6
+ - Empathy detection (+0.1 / +0.2)
7
+ - Angry customer strict rule (-0.25)
8
+ - Anti-generic response penalty (-0.1)
9
+ - Helpfulness detection (+0.3)
10
+ - Repetition penalty (-0.2)
11
+ - Escalation penalty (-0.1)
12
+ - Resolution bonus (+0.2)
13
+ - Efficiency bonus (+0.1 * remaining steps)
14
 
15
  Returns a RewardBreakdown with a total score in (0.0, 1.0) β€” strict open interval.
16
 
 
33
  return re.sub(r"\s+", " ", text.strip().lower())
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def grade_response(
37
  response: str,
38
  grading_rubric: Dict[str, Any],
39
  ticket_info: Dict[str, Any],
40
  conversation_history: List[Dict[str, Any]],
41
+ action_type: str = "respond",
42
+ step_count: int = 0,
43
+ max_steps: int = 5,
44
  ) -> RewardBreakdown:
45
  """
46
+ Grade an agent response using the reference additive scoring pattern.
47
 
48
  Args:
49
  response: The agent's response text
50
  grading_rubric: Task-specific grading criteria
51
  ticket_info: Ticket metadata
52
  conversation_history: Previous messages
53
+ action_type: 'respond', 'escalate', or 'resolve'
54
+ step_count: Current step number (1-indexed, already incremented)
55
+ max_steps: Maximum allowed steps for this task
56
 
57
  Returns:
58
  RewardBreakdown with ALL scores in strict (0.0, 1.0) open interval.
 
59
  """
60
+ score = 0.0
61
+ metrics: Dict[str, float] = {}
62
+ norm = _normalise(response)
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # ── 1. Correct category / keyword extraction (+0.3) ──
65
+ correctness_criteria = grading_rubric.get("correctness", {}).get("criteria", [])
66
+ correctness_hit = False
67
+ for criterion in correctness_criteria:
68
+ kw_group: List[str] = criterion.get("keyword_group", [])
69
+ if any(kw.lower() in norm for kw in kw_group):
70
+ correctness_hit = True
71
+ break
72
+ if correctness_hit:
73
+ score += 0.3
74
+ metrics["category_correct"] = 0.3
75
+
76
+ # ── 2. Empathy check (+0.1 neutral, +0.2 angry/frustrated) ──
77
+ sentiment = ticket_info.get("customer_sentiment", "neutral")
78
+ empathy_words = ["sorry", "apologize", "understand", "help"]
79
+ if any(word in norm for word in empathy_words):
80
+ empathy_score = 0.2 if sentiment in ["angry", "frustrated"] else 0.1
81
+ score += empathy_score
82
+ metrics["empathy"] = empathy_score
83
+
84
+ # ── 3. Angry customer strict rule (-0.25) ──
85
+ if sentiment == "angry" and not any(
86
+ w in norm for w in ["sorry", "apologize", "understand"]
87
+ ):
88
+ score -= 0.25
89
+ metrics["angry_penalty"] = -0.25
90
+
91
+ # ── 4. Anti-generic response penalty (-0.1) ──
92
+ generic_phrases = ["i will help you", "let me help", "i understand your issue"]
93
+ if any(phrase in norm for phrase in generic_phrases) and len(response) < 60:
94
+ score -= 0.1
95
+ metrics["generic_penalty"] = -0.1
96
 
97
+ # ── 5. Helpfulness check (+0.3) ──
98
+ helpful_words = [
99
+ "step", "fix", "update", "here is", "resolved",
100
+ "refund", "replacement", "process", "ship", "send",
101
+ "return", "credit", "track", "label",
102
+ ]
103
+ if any(word in norm for word in helpful_words):
104
+ score += 0.3
105
+ metrics["helpfulness"] = 0.3
106
+
107
+ # ── 6. Repetition penalty (-0.2) ──
108
+ past_responses = [
109
+ msg.get("content", "").lower()
110
+ for msg in conversation_history
111
+ if msg.get("role") == "agent"
112
+ ]
113
+ if norm in past_responses:
114
+ score -= 0.2
115
+ metrics["repetition_penalty"] = -0.2
116
 
117
+ # ── 7. Escalation penalty (-0.1) ──
118
+ if action_type == "escalate":
119
+ score -= 0.1
120
+ metrics["escalation_penalty"] = -0.1
121
+
122
+ # ── 8. Resolution bonus (+0.2) & Efficiency bonus ──
123
+ if action_type == "resolve":
124
+ score += 0.2
125
+ metrics["resolution_bonus"] = 0.2
126
+
127
+ # Efficiency bonus: reward resolving in fewer steps
128
+ if step_count < max_steps:
129
+ efficiency_bonus = round(0.1 * (max_steps - step_count), 4)
130
+ score += efficiency_bonus
131
+ metrics["efficiency_bonus"] = efficiency_bonus
132
+
133
+ # ── Final score β€” STRICT (0, 1) via safe_score ──
134
+ final_score = safe_score(score)
135
+
136
+ # Map metrics to RewardBreakdown fields
137
+ correctness_val = safe_score(metrics.get("category_correct", 0.0))
138
+ tone_val = safe_score(
139
+ metrics.get("empathy", 0.0)
140
+ + metrics.get("angry_penalty", 0.0)
141
+ + metrics.get("generic_penalty", 0.0)
142
+ + 0.3 # base tone
143
  )
144
+ completeness_val = safe_score(
145
+ metrics.get("helpfulness", 0.0)
146
+ + metrics.get("resolution_bonus", 0.0)
 
 
 
 
 
 
 
 
 
147
  )
148
+ efficiency_val = safe_score(
149
+ metrics.get("efficiency_bonus", 0.0) + 0.2
150
+ )
151
+ penalties_total = sum(v for v in metrics.values() if v < 0)
152
 
153
  # Build explanation
154
+ parts = [f"{k}: {v:.4f}" for k, v in sorted(metrics.items())]
155
+ parts.append(f"Total: {final_score:.4f}")
156
+
157
+ logger.info(f"[GRADER] score={final_score:.4f} metrics={metrics}")
 
 
 
158
 
 
159
  return RewardBreakdown(
160
+ correctness=correctness_val,
161
+ tone=tone_val,
162
+ completeness=completeness_val,
163
+ efficiency=efficiency_val,
164
+ penalties=round(max(-1.0, min(0.0, penalties_total)), 4),
165
+ total=final_score,
166
  explanation=" | ".join(parts),
167
  )
server/environment.py CHANGED
@@ -160,6 +160,9 @@ class CustomerSupportEnvironment:
160
  grading_rubric=self._task["grading_rubric"],
161
  ticket_info=self._task["ticket"],
162
  conversation_history=[m.model_dump() for m in self._conversation],
 
 
 
163
  )
164
 
165
  # Clamp step reward to strict (0, 1) β€” safe_score guarantees this
@@ -217,7 +220,7 @@ class CustomerSupportEnvironment:
217
  info = {
218
  "reward_breakdown": rb_dict,
219
  "step_reward": step_reward,
220
- "cumulative_reward": self._cumulative_reward,
221
  "average_reward": avg_reward,
222
  "steps_taken": self._state.step_count,
223
  "task_id": self._state.task_id,
@@ -264,7 +267,7 @@ class CustomerSupportEnvironment:
264
  max_steps=self._state.max_steps,
265
  steps_remaining=self._state.max_steps - self._state.step_count,
266
  done=self._state.done,
267
- reward=self._cumulative_reward,
268
  )
269
 
270
  def _generate_contextual_reply(self, action: SupportAction) -> str:
 
160
  grading_rubric=self._task["grading_rubric"],
161
  ticket_info=self._task["ticket"],
162
  conversation_history=[m.model_dump() for m in self._conversation],
163
+ action_type=action.action_type,
164
+ step_count=self._state.step_count,
165
+ max_steps=self._state.max_steps,
166
  )
167
 
168
  # Clamp step reward to strict (0, 1) β€” safe_score guarantees this
 
220
  info = {
221
  "reward_breakdown": rb_dict,
222
  "step_reward": step_reward,
223
+ "cumulative_reward": safe_score(self._cumulative_reward / self._state.step_count),
224
  "average_reward": avg_reward,
225
  "steps_taken": self._state.step_count,
226
  "task_id": self._state.task_id,
 
267
  max_steps=self._state.max_steps,
268
  steps_remaining=self._state.max_steps - self._state.step_count,
269
  done=self._state.done,
270
+ reward=safe_score(self._cumulative_reward / max(self._state.step_count, 1)),
271
  )
272
 
273
  def _generate_contextual_reply(self, action: SupportAction) -> str: