# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Deterministic grader for the Government Service Application Assistant Environment. """ from typing import Dict, Any, List, Tuple try: from .models import GovAction, GovObservation from .models import GovServiceState from .tasks import task_manager except ImportError: from models import GovAction, GovObservation from models import GovServiceState from tasks import task_manager class Grader: """Deterministic grader for environment tasks.""" def __init__(self): self.task_manager = task_manager def grade_action( self, action: GovAction, observation: GovObservation, state: GovServiceState, task_info: Dict[str, Any] ) -> float: """ Grade an action based on the current task and state. Returns: Score between 0.0 and 1.0 """ # Base score score = 0.0 # Grade based on action type and task difficulty if task_info["expected_difficulty"] == "easy": score = self._grade_easy_task(action, observation, state, task_info) elif task_info["expected_difficulty"] == "medium": score = self._grade_medium_task(action, observation, state, task_info) elif task_info["expected_difficulty"] == "hard": score = self._grade_hard_task(action, observation, state, task_info) # Apply efficiency bonus/penalty optimal_steps = task_info.get("optimal_steps", 5) if state.step_count <= optimal_steps: score = min(1.0, score + 0.1) # Efficiency bonus elif state.step_count > optimal_steps * 2: score = max(0.0, score - 0.1) # Inefficiency penalty # Ensure score is in valid range return max(0.0, min(1.0, score)) def _grade_easy_task( self, action: GovAction, observation: GovObservation, state: GovServiceState, task_info: Dict[str, Any] ) -> float: """Grade easy task: Identify required documents.""" score = 0.0 if action.action_type == "select_service": # Correctly selected service if action.service_type == task_info["service_type"]: score += 0.3 else: score -= 0.2 # Penalty for wrong service elif action.action_type == "list_required_documents": # Check if we have required documents in observation required_docs = observation.required_documents expected_docs = self.task_manager.get_task_requirements(task_info["service_type"]) if len(required_docs) == len(expected_docs): # Check if all expected documents are present expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} if expected_types == observed_types: score += 0.6 # Perfect match else: # Partial credit for correct documents correct_count = len(expected_types & observed_types) score += 0.3 * (correct_count / len(expected_types)) if expected_types else 0 elif len(required_docs) > 0: # Partial credit for having some documents expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} correct_count = len(expected_types & observed_types) score += 0.3 * (correct_count / len(expected_types)) if expected_types else 0 return max(0.0, min(1.0, score)) def _grade_medium_task( self, action: GovAction, observation: GovObservation, state: GovServiceState, task_info: Dict[str, Any] ) -> float: """Grade medium task: Validate application.""" score = 0.0 if action.action_type == "select_service": # Correctly selected service if action.service_type == task_info["service_type"]: score += 0.2 else: score -= 0.1 elif action.action_type == "list_required_documents": # Check if we have required documents in observation required_docs = observation.required_documents expected_docs = self.task_manager.get_task_requirements(task_info["service_type"]) if len(required_docs) == len(expected_docs): expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} if expected_types == observed_types: score += 0.3 else: correct_count = len(expected_types & observed_types) score += 0.15 * (correct_count / len(expected_types)) if expected_types else 0 elif len(required_docs) > 0: expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} correct_count = len(expected_types & observed_types) score += 0.15 * (correct_count / len(expected_types)) if expected_types else 0 elif action.action_type == "validate_documents": # Check validation results validation_results = observation.validation_results if validation_results: expected_feedback = task_info.get("validation_feedback", {}) # Check missing documents expected_missing = set(expected_feedback.get("missing_documents", [])) actual_missing = set(validation_results.get("missing_documents", [])) missing_score = 0.0 if expected_missing: if actual_missing == expected_missing: missing_score = 0.3 else: # Partial credit correct_missing = len(expected_missing & actual_missing) missing_score = 0.15 * (correct_missing / len(expected_missing)) else: missing_score = 0.3 # No missing expected is good # Check invalid documents expected_invalid = expected_feedback.get("invalid_documents", []) actual_invalid = validation_results.get("invalid_documents", []) invalid_score = 0.0 if expected_invalid: # Simple check: same count and types if len(expected_invalid) == len(actual_invalid): expected_types = {item.get("type", "") for item in expected_invalid} actual_types = {item.get("type", "") for item in actual_invalid} if expected_types == actual_types: invalid_score = 0.3 else: correct_types = len(expected_types & actual_types) invalid_score = 0.15 * (correct_types / len(expected_types)) if expected_types else 0 else: # Partial credit based on overlap expected_set = {(item.get("type", ""), item.get("reason", "")) for item in expected_invalid} actual_set = {(item.get("type", ""), item.get("reason", "")) for item in actual_invalid} correct_items = len(expected_set & actual_set) invalid_score = 0.15 * (correct_items / len(expected_invalid)) if expected_invalid else 0 else: invalid_score = 0.3 # No invalid expected is good # Check valid documents expected_valid = set(expected_feedback.get("valid_documents", [])) actual_valid = set(validation_results.get("valid_documents", [])) valid_score = 0.0 if expected_valid: if actual_valid == expected_valid: valid_score = 0.2 else: correct_valid = len(expected_valid & actual_valid) valid_score = 0.1 * (correct_valid / len(expected_valid)) else: valid_score = 0.2 # No valid expected is good score += missing_score + invalid_score + valid_score else: # No validation results yet - small penalty for premature validation score -= 0.1 return max(0.0, min(1.0, score)) def _grade_hard_task( self, action: GovAction, observation: GovObservation, state: GovServiceState, task_info: Dict[str, Any] ) -> float: """Grade hard task: Fix incorrect application.""" score = 0.0 if action.action_type == "select_service": # Correctly selected service if action.service_type == task_info["service_type"]: score += 0.15 else: score -= 0.1 elif action.action_type == "list_required_documents": # Check if we have required documents in observation required_docs = observation.required_documents expected_docs = self.task_manager.get_task_requirements(task_info["service_type"]) if len(required_docs) == len(expected_docs): expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} if expected_types == observed_types: score += 0.25 else: correct_count = len(expected_types & observed_types) score += 0.125 * (correct_count / len(expected_types)) if expected_types else 0 elif len(required_docs) > 0: expected_types = {doc.type for doc in expected_docs} observed_types = {doc.get("type", "") for doc in required_docs} correct_count = len(expected_types & observed_types) score += 0.125 * (correct_count / len(expected_types)) if expected_types else 0 elif action.action_type == "validate_documents": # Check validation results validation_results = observation.validation_results if validation_results: expected_feedback = task_info.get("validation_feedback", {}) # Similar to medium task but with different weighting expected_missing = set(expected_feedback.get("missing_documents", [])) actual_missing = set(validation_results.get("missing_documents", [])) missing_score = 0.0 if expected_missing: if actual_missing == expected_missing: missing_score = 0.25 else: correct_missing = len(expected_missing & actual_missing) missing_score = 0.125 * (correct_missing / len(expected_missing)) else: missing_score = 0.25 # Check invalid documents expected_invalid = expected_feedback.get("invalid_documents", []) actual_invalid = validation_results.get("invalid_documents", []) invalid_score = 0.0 if expected_invalid: if len(expected_invalid) == len(actual_invalid): expected_types = {item.get("type", "") for item in expected_invalid} actual_types = {item.get("type", "") for item in actual_invalid} if expected_types == actual_types: invalid_score = 0.25 else: correct_types = len(expected_types & actual_types) invalid_score = 0.125 * (correct_types / len(expected_types)) if expected_types else 0 else: expected_set = {(item.get("type", ""), item.get("reason", "")) for item in expected_invalid} actual_set = {(item.get("type", ""), item.get("reason", "")) for item in actual_invalid} correct_items = len(expected_set & actual_set) invalid_score = 0.125 * (correct_items / len(expected_invalid)) if expected_invalid else 0 else: invalid_score = 0.25 # Check valid documents expected_valid = set(expected_feedback.get("valid_documents", [])) actual_valid = set(validation_results.get("valid_documents", [])) valid_score = 0.0 if expected_valid: if actual_valid == expected_valid: valid_score = 0.15 else: correct_valid = len(expected_valid & actual_valid) valid_score = 0.075 * (correct_valid / len(expected_valid)) else: valid_score = 0.15 score += missing_score + invalid_score + valid_score else: score -= 0.05 elif action.action_type == "suggest_corrections": # Check correction suggestions correction_suggestions = observation.correction_suggestions expected_feedback = task_info.get("validation_feedback", {}) if correction_suggestions: # Check if we addressed the invalid documents expected_invalid = expected_feedback.get("invalid_documents", []) expected_missing = expected_feedback.get("missing_documents", []) # Simple heuristic: good suggestions address the issues issue_count = len(expected_invalid) + len(expected_missing) if issue_count > 0: # Check if suggestions mention addressing these issues suggestion_text = " ".join([ str(sugg.get("suggested_action", "")) + " " + str(sugg.get("issue_type", "")) for sugg in correction_suggestions ]).lower() # Check for key terms from expected issues issue_terms = [] for item in expected_invalid: issue_terms.append(item.get("type", "").lower()) issue_terms.append(item.get("reason", "").lower()) for item in expected_missing: issue_terms.append(item.lower()) # Simple matching score matches = 0 for term in issue_terms: if term and term in suggestion_text: matches += 1 if issue_terms: suggestion_score = 0.2 * (matches / len(issue_terms)) else: suggestion_score = 0.2 score += suggestion_score else: score += 0.2 # No issues to correct else: # No suggestions when there are issues to correct if expected_feedback.get("invalid_documents") or expected_feedback.get("missing_documents"): score -= 0.1 else: score += 0.1 # Correctly identified no issues elif action.action_type == "submit_application": # Check if application is ready to submit is_complete = observation.is_complete validation_results = observation.validation_results if validation_results and validation_results.get("is_valid") and validation_results.get("is_complete"): if is_complete: score += 0.3 # Correctly submitted valid application else: score -= 0.1 # Submitted but observation says incomplete elif not validation_results or not validation_results.get("is_valid"): if not is_complete: score += 0.1 # Correctly didn't submit invalid application else: score -= 0.2 # Incorrectly thinks incomplete application is complete else: score -= 0.1 # Unclear state return max(0.0, min(1.0, score)) # Global grader instance grader = Grader()