multi-agent-email-env2 / reply_grader.py
Vansh04092003's picture
Upload folder using huggingface_hub
e2a36ed verified
"""
reply_grader.py β€” Grades the quality of a drafted email reply.
This is the key Round 2 addition. Instead of only classifying emails,
the agent must also DRAFT A REPLY. This is much harder for LLMs and
far more useful for training β€” it tests tone, resolution, professionalism,
and context understanding simultaneously.
Score breakdown (all 0.0–1.0, weighted to produce final reply_score):
- resolution_score (0.40): Does the reply actually address the issue?
- tone_score (0.25): Is the tone appropriate for the category/priority?
- completeness_score(0.20): Does it include required elements (greeting, sign-off, next steps)?
- length_score (0.10): Is it an appropriate length (not too short, not a wall of text)?
- safety_score (0.05): No promises the company can't keep, no admissions of liability?
"""
from __future__ import annotations
import re
from typing import Any, Dict, Optional, Tuple
# ── Tone profiles per category ────────────────────────────────────────────────
# Each maps to required signals and forbidden signals
TONE_PROFILES: Dict[str, Dict[str, Any]] = {
"spam_phishing": {
"required": [],
"forbidden": ["congratulations", "prize", "winner", "claim", "urgent action"],
"style": "none", # No reply needed β€” agent should skip/trash
},
"customer_complaint": {
"required": ["apologize", "sorry", "understand", "resolve", "look into"],
"forbidden": ["your fault", "you should have", "not our problem", "impossible"],
"style": "empathetic",
},
"billing_inquiry": {
"required": ["account", "payment", "invoice", "amount", "charge"],
"forbidden": ["ignore", "not our fault"],
"style": "professional",
},
"technical_support": {
"required": ["issue", "problem", "help", "team", "investigate"],
"forbidden": ["cannot help", "not possible"],
"style": "helpful",
},
"legal_compliance": {
"required": ["received", "review", "team", "matter", "respond"],
"forbidden": ["admit", "liable", "fault", "guilty", "settlement"],
"style": "formal",
},
"internal_hr": {
"required": ["team", "hr", "discuss", "meeting"],
"forbidden": [],
"style": "professional",
},
"sales_lead": {
"required": ["thank", "interest", "team", "discuss", "schedule"],
"forbidden": [],
"style": "enthusiastic",
},
"general_inquiry": {
"required": ["thank", "help", "answer", "question"],
"forbidden": [],
"style": "friendly",
},
}
# Structural elements every professional reply should have
REQUIRED_ELEMENTS = {
"greeting": re.compile(r"\b(hi|hello|dear|good morning|good afternoon|thank you for)\b", re.I),
"sign_off": re.compile(r"\b(regards|sincerely|best|thanks|thank you|warm regards)\b", re.I),
"next_steps": re.compile(r"\b(will|shall|team|we'll|contact|reach out|follow up|review|look into|investigate|schedule|call)\b", re.I),
}
# Safety: phrases that create legal/financial liability
UNSAFE_PHRASES = [
"we guarantee",
"we promise",
"you will receive",
"we admit",
"our fault",
"we are liable",
"full refund",
"immediately refund",
"we will pay",
"no questions asked",
]
# Ideal reply length ranges (words) per priority
LENGTH_RANGE: Dict[str, Tuple[int, int]] = {
"urgent": (50, 200),
"high": (40, 250),
"medium": (30, 300),
"low": (20, 200),
"spam": (0, 30), # ideally no reply or very short rejection
}
def grade_reply(
reply: str,
category: str,
priority: str,
email_body: str,
email_subject: str,
) -> Tuple[float, Dict[str, Any]]:
"""
Grade the quality of a drafted reply.
Args:
reply: The agent's drafted reply text.
category: Predicted/correct category (used for tone profile).
priority: Predicted/correct priority (used for length/urgency norms).
email_body: Original email body (for keyword overlap check).
email_subject: Original email subject.
Returns:
(reply_score, detail_dict) where reply_score ∈ [0.0, 1.0]
"""
# No reply needed for spam β€” reward is 0 but not penalised
if category == "spam_phishing" or priority == "spam":
skip_ok = len(reply.strip()) < 20 # Agent chose not to reply = good
score = 0.8 if skip_ok else 0.3 # Penalise if they replied to spam
return score, {
"resolution_score": score,
"tone_score": score,
"completeness_score": score,
"length_score": score,
"safety_score": 1.0,
"note": "spam β€” no reply expected",
}
reply_lower = reply.lower()
word_count = len(reply.split())
# ── 1. Resolution score (0.40) ────────────────────────────────────────────
# Does the reply address the actual content of the email?
body_keywords = set(re.findall(r'\b\w{5,}\b', email_body.lower()))
reply_keywords = set(re.findall(r'\b\w{5,}\b', reply_lower))
overlap = body_keywords & reply_keywords
overlap_ratio = len(overlap) / max(len(body_keywords), 1)
profile = TONE_PROFILES.get(category, TONE_PROFILES["general_inquiry"])
required_hits = sum(1 for kw in profile["required"] if kw in reply_lower)
required_ratio = required_hits / max(len(profile["required"]), 1)
resolution_score = min(1.0, (overlap_ratio * 0.5) + (required_ratio * 0.5))
# ── 2. Tone score (0.25) ──────────────────────────────────────────────────
forbidden_hits = sum(1 for kw in profile["forbidden"] if kw in reply_lower)
tone_score = max(0.0, 1.0 - (forbidden_hits * 0.3))
# Empathetic categories need softening words
if profile["style"] == "empathetic":
empathy_words = ["sorry", "apologize", "understand", "frustrating", "concern"]
has_empathy = any(w in reply_lower for w in empathy_words)
if not has_empathy:
tone_score *= 0.7
# Formal categories penalise casual language
if profile["style"] == "formal":
casual_words = ["hey", "gonna", "wanna", "asap", "fyi", "lol"]
casual_hits = sum(1 for w in casual_words if w in reply_lower)
tone_score = max(0.0, tone_score - casual_hits * 0.15)
# ── 3. Completeness score (0.20) ─────────────────────────────────────────
element_scores = {
elem: 1.0 if pattern.search(reply) else 0.0
for elem, pattern in REQUIRED_ELEMENTS.items()
}
completeness_score = sum(element_scores.values()) / len(element_scores)
# ── 4. Length score (0.10) ────────────────────────────────────────────────
lo, hi = LENGTH_RANGE.get(priority, (30, 300))
if lo <= word_count <= hi:
length_score = 1.0
elif word_count < lo:
length_score = max(0.0, word_count / lo) # Too short
else:
length_score = max(0.0, 1.0 - (word_count - hi) / hi) # Too long
# ── 5. Safety score (0.05) ────────────────────────────────────────────────
unsafe_hits = sum(1 for phrase in UNSAFE_PHRASES if phrase in reply_lower)
safety_score = max(0.0, 1.0 - unsafe_hits * 0.5)
# ── Final weighted score ──────────────────────────────────────────────────
reply_score = (
resolution_score * 0.40
+ tone_score * 0.25
+ completeness_score * 0.20
+ length_score * 0.10
+ safety_score * 0.05
)
reply_score = round(min(1.0, max(0.0, reply_score)), 4)
detail = {
"resolution_score": round(resolution_score, 3),
"tone_score": round(tone_score, 3),
"completeness_score": round(completeness_score, 3),
"length_score": round(length_score, 3),
"safety_score": round(safety_score, 3),
"word_count": word_count,
"keyword_overlap": round(overlap_ratio, 3),
"required_kw_hits": required_hits,
"forbidden_kw_hits": forbidden_hits,
"element_scores": element_scores,
"reply_score": reply_score,
}
return reply_score, detail