from utils.embeddings import get_embedding, cos_sim import logging logger = logging.getLogger("nexus.reward_engine") def compute_reward(message: str, tool_calls: list, tool_results: list, episode_state, scenario: dict) -> tuple[float, dict]: breakdown = {} msg_lower = message.lower() # 1. HYPOTHESIS SPECIFICITY (0.0-0.25) specificity_indicators = ["shows", "value", "config", "log", "found", "confirmed", "set to", "equals", "returns", "indicates"] breakdown['specificity'] = min(0.25, sum(0.025 for word in specificity_indicators if word in msg_lower) ) # 2. PARTNER ENGAGEMENT (0.0-0.20) if episode_state.last_partner_message: sim = cos_sim( get_embedding(message), get_embedding(episode_state.last_partner_message) ) breakdown['partner_engagement'] = min(0.20, sim * 0.25) else: breakdown['partner_engagement'] = 0.10 # 3. PROGRESS TOWARD ROOT CAUSE (0.0-0.30) root_cause_desc = scenario.get('root_cause', {}).get('description', '') if root_cause_desc: root_cause_sim = cos_sim( get_embedding(message), get_embedding(root_cause_desc) ) breakdown['progress'] = min(0.30, root_cause_sim * 0.40) else: breakdown['progress'] = 0.15 # 4. TOOL USAGE (0.0-0.15) if tool_calls: new_tools = 0 for t in tool_calls: sig = f"{t.tool_name}:{str(t.params)}" if sig not in episode_state.previous_tool_calls: new_tools += 1 breakdown['tool_usage'] = min(0.15, new_tools * 0.08) else: breakdown['tool_usage'] = 0.0 # 5. NOVELTY (0.0-0.10) if episode_state.all_messages: max_sim_to_history = max( cos_sim(get_embedding(message), get_embedding(prev)) for prev in episode_state.all_messages[-4:] ) breakdown['novelty'] = max(0.0, 0.10 * (1 - max_sim_to_history)) else: breakdown['novelty'] = 0.10 # PENALTIES penalty = 0.0 if breakdown['novelty'] < 0.02: penalty += 0.15 # circular reasoning total = sum(breakdown.values()) - penalty final_score = round(max(0.0, min(1.0, total)), 4) # Store history episode_state.reward_history.append(final_score) episode_state.cumulative_reward += final_score return final_score, breakdown