crisis-inbox / server /rewards.py
eptan's picture
Upload folder using huggingface_hub
fdb5700 verified
"""
CrisisInbox Reward Function — Single Source of Truth.
Both the environment server and training notebooks import from here.
No OpenEnv dependencies — only requires models.py (pydantic).
"""
import re
try:
from ..models import Message, Urgency
except ImportError:
from models import Message, Urgency
_FAMILY_SENDERS = {"mom", "sister", "neighbor dave", "emma"}
_FAMILY_TONE_WORDS = {"love", "safe", "worried", "sorry", "care", "okay", "miss", "hang in there"}
_FORMAL_TONE_WORDS = {"confirm", "attached", "documentation", "regarding", "request", "please", "submit"}
def tone_multiplier(sender: str, response: str) -> float:
"""
Small reward multiplier for tone-appropriate responses.
Family/personal senders reward empathetic language.
Professional/institutional senders reward formal language.
Returns 1.0-1.15 (bonus) or 1.0 (neutral). Never penalizes.
"""
resp_lower = response.lower()
sender_lower = sender.lower()
if any(f in sender_lower for f in _FAMILY_SENDERS):
matches = sum(1 for w in _FAMILY_TONE_WORDS if w in resp_lower)
if matches >= 2:
return 1.15
elif matches == 1:
return 1.07
else:
matches = sum(1 for w in _FORMAL_TONE_WORDS if w in resp_lower)
if matches >= 2:
return 1.1
elif matches == 1:
return 1.05
return 1.0
def calculate_reward(
msg: Message,
current_hour: float,
response: str,
superseded: dict[str, str],
visible_messages: list[Message] | None = None,
handled: dict[str, str] | None = None,
) -> float:
"""
Calculate reward for handling a message.
Reward signals:
- Base reward by urgency (critical=10, high=5, medium=3, low=1)
- Deadline timing bonus (up to +50% for early, -75% for late)
- Response quality (penalty for very short responses)
- Tone awareness (up to +15% for matching tone to sender type)
- Schema drift adaptation bonus (+50% for handling drift messages)
- Penalty for acting on superseded/stale information (-50%)
- Priority penalty (-70% for choosing low/medium when critical is pending)
"""
base_rewards = {
Urgency.CRITICAL: 10.0,
Urgency.HIGH: 5.0,
Urgency.MEDIUM: 3.0,
Urgency.LOW: 1.0,
}
reward = base_rewards.get(msg.urgency, 1.0)
# Deadline timing
if msg.deadline_hours is not None:
if current_hour <= msg.deadline_hours:
time_remaining_frac = (msg.deadline_hours - current_hour) / max(msg.deadline_hours, 1.0)
reward *= 1.0 + 0.5 * time_remaining_frac
else:
reward *= 0.25
# Response quality - penalty for very short/empty responses
if len(response.strip()) < 10:
reward *= 0.5
# Tone awareness: small bonus for matching tone to sender type
reward *= tone_multiplier(msg.sender, response)
# Drift adaptation bonus
if msg.drift_flag:
reward *= 1.5
# Penalty for responding to a superseded message (stale info)
if msg.id in superseded:
reward *= 0.5
# Conflict-resolution bonus: handling a message that forces a trade-off
if msg.conflicts_with:
reward *= 1.25
# Priority penalty: choosing low/medium when unhandled critical messages exist
if visible_messages and handled is not None:
has_unhandled_critical = any(
m.urgency == Urgency.CRITICAL and m.id not in handled
for m in visible_messages
if m.id != msg.id # exclude the message being handled now
)
if has_unhandled_critical and msg.urgency in (Urgency.LOW, Urgency.MEDIUM):
reward *= 0.3
return round(reward, 2)
def score_action(completion: str, prompt_data: dict) -> float:
"""Score a model completion against an inbox snapshot.
Parses the model output for respond_to_message(msg_id, response),
constructs a Message object, and delegates to calculate_reward().
Args:
completion: Raw model output string.
prompt_data: Dict with keys 'messages' (list of message dicts),
'hour' (float), and optionally 'superseded' (dict).
Returns:
Reward float. Negative for unparseable or invalid actions.
"""
messages = prompt_data["messages"]
hour = prompt_data["hour"]
superseded = prompt_data.get("superseded", {})
# Parse the model output for message_id and response text
msg_id = None
response_text = ""
match = re.search(
r'respond_to_message\s*\(\s*["\']?(msg_\d+)["\']?\s*,\s*["\'](.+?)["\']',
completion, re.DOTALL,
)
if match:
msg_id = match.group(1)
response_text = match.group(2)
else:
id_match = re.search(r'(msg_\d+)', completion)
if id_match:
msg_id = id_match.group(1)
response_text = completion
if not msg_id:
return -1.0
# Find the message dict in the inbox
target_dict = None
for msg in messages:
if msg["id"] == msg_id:
target_dict = msg
break
if target_dict is None:
return -0.5
# Construct a Message object for the shared reward function
target_msg = Message(
id=target_dict["id"],
sender=target_dict["sender"],
channel=target_dict.get("channel", "email"),
subject=target_dict.get("subject", ""),
content=target_dict.get("content", ""),
urgency=target_dict["urgency"],
timestamp_hours=target_dict.get("timestamp_hours", 0.0),
deadline_hours=target_dict.get("deadline_hours"),
dependencies=target_dict.get("dependencies", []),
drift_flag=target_dict.get("drift_flag", False),
supersedes=target_dict.get("supersedes"),
)
# Build visible_messages as Message objects for priority penalty
handled_ids = {m["id"]: "" for m in messages if m.get("handled", False)}
visible = [
Message(
id=m["id"],
sender=m["sender"],
channel=m.get("channel", "email"),
subject=m.get("subject", ""),
content=m.get("content", ""),
urgency=m["urgency"],
timestamp_hours=m.get("timestamp_hours", 0.0),
deadline_hours=m.get("deadline_hours"),
dependencies=m.get("dependencies", []),
drift_flag=m.get("drift_flag", False),
supersedes=m.get("supersedes"),
)
for m in messages
]
return calculate_reward(
msg=target_msg,
current_hour=hour,
response=response_text,
superseded=superseded,
visible_messages=visible,
handled=handled_ids,
)