gov_env / grader.py
dharunkkk's picture
Upload grader.py
a08f890 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Deterministic grader for the Government Service Application Assistant Environment.
"""
from typing import Dict, Any, List, Tuple
try:
from .models import GovAction, GovObservation
from .models import GovServiceState
from .tasks import task_manager
except ImportError:
from models import GovAction, GovObservation
from models import GovServiceState
from tasks import task_manager
class Grader:
"""Deterministic grader for environment tasks."""
def __init__(self):
self.task_manager = task_manager
def grade_action(
self,
action: GovAction,
observation: GovObservation,
state: GovServiceState,
task_info: Dict[str, Any]
) -> float:
"""
Grade an action based on the current task and state.
Returns:
Score between 0.0 and 1.0
"""
# Base score
score = 0.0
# Grade based on action type and task difficulty
if task_info["expected_difficulty"] == "easy":
score = self._grade_easy_task(action, observation, state, task_info)
elif task_info["expected_difficulty"] == "medium":
score = self._grade_medium_task(action, observation, state, task_info)
elif task_info["expected_difficulty"] == "hard":
score = self._grade_hard_task(action, observation, state, task_info)
# Apply efficiency bonus/penalty
optimal_steps = task_info.get("optimal_steps", 5)
if state.step_count <= optimal_steps:
score = min(1.0, score + 0.1) # Efficiency bonus
elif state.step_count > optimal_steps * 2:
score = max(0.0, score - 0.1) # Inefficiency penalty
# Ensure score is in valid range
return max(0.0, min(1.0, score))
def _grade_easy_task(
self,
action: GovAction,
observation: GovObservation,
state: GovServiceState,
task_info: Dict[str, Any]
) -> float:
"""Grade easy task: Identify required documents."""
score = 0.0
if action.action_type == "select_service":
# Correctly selected service
if action.service_type == task_info["service_type"]:
score += 0.3
else:
score -= 0.2 # Penalty for wrong service
elif action.action_type == "list_required_documents":
# Check if we have required documents in observation
required_docs = observation.required_documents
expected_docs = self.task_manager.get_task_requirements(task_info["service_type"])
if len(required_docs) == len(expected_docs):
# Check if all expected documents are present
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
if expected_types == observed_types:
score += 0.6 # Perfect match
else:
# Partial credit for correct documents
correct_count = len(expected_types & observed_types)
score += 0.3 * (correct_count / len(expected_types)) if expected_types else 0
elif len(required_docs) > 0:
# Partial credit for having some documents
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
correct_count = len(expected_types & observed_types)
score += 0.3 * (correct_count / len(expected_types)) if expected_types else 0
return max(0.0, min(1.0, score))
def _grade_medium_task(
self,
action: GovAction,
observation: GovObservation,
state: GovServiceState,
task_info: Dict[str, Any]
) -> float:
"""Grade medium task: Validate application."""
score = 0.0
if action.action_type == "select_service":
# Correctly selected service
if action.service_type == task_info["service_type"]:
score += 0.2
else:
score -= 0.1
elif action.action_type == "list_required_documents":
# Check if we have required documents in observation
required_docs = observation.required_documents
expected_docs = self.task_manager.get_task_requirements(task_info["service_type"])
if len(required_docs) == len(expected_docs):
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
if expected_types == observed_types:
score += 0.3
else:
correct_count = len(expected_types & observed_types)
score += 0.15 * (correct_count / len(expected_types)) if expected_types else 0
elif len(required_docs) > 0:
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
correct_count = len(expected_types & observed_types)
score += 0.15 * (correct_count / len(expected_types)) if expected_types else 0
elif action.action_type == "validate_documents":
# Check validation results
validation_results = observation.validation_results
if validation_results:
expected_feedback = task_info.get("validation_feedback", {})
# Check missing documents
expected_missing = set(expected_feedback.get("missing_documents", []))
actual_missing = set(validation_results.get("missing_documents", []))
missing_score = 0.0
if expected_missing:
if actual_missing == expected_missing:
missing_score = 0.3
else:
# Partial credit
correct_missing = len(expected_missing & actual_missing)
missing_score = 0.15 * (correct_missing / len(expected_missing))
else:
missing_score = 0.3 # No missing expected is good
# Check invalid documents
expected_invalid = expected_feedback.get("invalid_documents", [])
actual_invalid = validation_results.get("invalid_documents", [])
invalid_score = 0.0
if expected_invalid:
# Simple check: same count and types
if len(expected_invalid) == len(actual_invalid):
expected_types = {item.get("type", "") for item in expected_invalid}
actual_types = {item.get("type", "") for item in actual_invalid}
if expected_types == actual_types:
invalid_score = 0.3
else:
correct_types = len(expected_types & actual_types)
invalid_score = 0.15 * (correct_types / len(expected_types)) if expected_types else 0
else:
# Partial credit based on overlap
expected_set = {(item.get("type", ""), item.get("reason", "")) for item in expected_invalid}
actual_set = {(item.get("type", ""), item.get("reason", "")) for item in actual_invalid}
correct_items = len(expected_set & actual_set)
invalid_score = 0.15 * (correct_items / len(expected_invalid)) if expected_invalid else 0
else:
invalid_score = 0.3 # No invalid expected is good
# Check valid documents
expected_valid = set(expected_feedback.get("valid_documents", []))
actual_valid = set(validation_results.get("valid_documents", []))
valid_score = 0.0
if expected_valid:
if actual_valid == expected_valid:
valid_score = 0.2
else:
correct_valid = len(expected_valid & actual_valid)
valid_score = 0.1 * (correct_valid / len(expected_valid))
else:
valid_score = 0.2 # No valid expected is good
score += missing_score + invalid_score + valid_score
else:
# No validation results yet - small penalty for premature validation
score -= 0.1
return max(0.0, min(1.0, score))
def _grade_hard_task(
self,
action: GovAction,
observation: GovObservation,
state: GovServiceState,
task_info: Dict[str, Any]
) -> float:
"""Grade hard task: Fix incorrect application."""
score = 0.0
if action.action_type == "select_service":
# Correctly selected service
if action.service_type == task_info["service_type"]:
score += 0.15
else:
score -= 0.1
elif action.action_type == "list_required_documents":
# Check if we have required documents in observation
required_docs = observation.required_documents
expected_docs = self.task_manager.get_task_requirements(task_info["service_type"])
if len(required_docs) == len(expected_docs):
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
if expected_types == observed_types:
score += 0.25
else:
correct_count = len(expected_types & observed_types)
score += 0.125 * (correct_count / len(expected_types)) if expected_types else 0
elif len(required_docs) > 0:
expected_types = {doc.type for doc in expected_docs}
observed_types = {doc.get("type", "") for doc in required_docs}
correct_count = len(expected_types & observed_types)
score += 0.125 * (correct_count / len(expected_types)) if expected_types else 0
elif action.action_type == "validate_documents":
# Check validation results
validation_results = observation.validation_results
if validation_results:
expected_feedback = task_info.get("validation_feedback", {})
# Similar to medium task but with different weighting
expected_missing = set(expected_feedback.get("missing_documents", []))
actual_missing = set(validation_results.get("missing_documents", []))
missing_score = 0.0
if expected_missing:
if actual_missing == expected_missing:
missing_score = 0.25
else:
correct_missing = len(expected_missing & actual_missing)
missing_score = 0.125 * (correct_missing / len(expected_missing))
else:
missing_score = 0.25
# Check invalid documents
expected_invalid = expected_feedback.get("invalid_documents", [])
actual_invalid = validation_results.get("invalid_documents", [])
invalid_score = 0.0
if expected_invalid:
if len(expected_invalid) == len(actual_invalid):
expected_types = {item.get("type", "") for item in expected_invalid}
actual_types = {item.get("type", "") for item in actual_invalid}
if expected_types == actual_types:
invalid_score = 0.25
else:
correct_types = len(expected_types & actual_types)
invalid_score = 0.125 * (correct_types / len(expected_types)) if expected_types else 0
else:
expected_set = {(item.get("type", ""), item.get("reason", "")) for item in expected_invalid}
actual_set = {(item.get("type", ""), item.get("reason", "")) for item in actual_invalid}
correct_items = len(expected_set & actual_set)
invalid_score = 0.125 * (correct_items / len(expected_invalid)) if expected_invalid else 0
else:
invalid_score = 0.25
# Check valid documents
expected_valid = set(expected_feedback.get("valid_documents", []))
actual_valid = set(validation_results.get("valid_documents", []))
valid_score = 0.0
if expected_valid:
if actual_valid == expected_valid:
valid_score = 0.15
else:
correct_valid = len(expected_valid & actual_valid)
valid_score = 0.075 * (correct_valid / len(expected_valid))
else:
valid_score = 0.15
score += missing_score + invalid_score + valid_score
else:
score -= 0.05
elif action.action_type == "suggest_corrections":
# Check correction suggestions
correction_suggestions = observation.correction_suggestions
expected_feedback = task_info.get("validation_feedback", {})
if correction_suggestions:
# Check if we addressed the invalid documents
expected_invalid = expected_feedback.get("invalid_documents", [])
expected_missing = expected_feedback.get("missing_documents", [])
# Simple heuristic: good suggestions address the issues
issue_count = len(expected_invalid) + len(expected_missing)
if issue_count > 0:
# Check if suggestions mention addressing these issues
suggestion_text = " ".join([
str(sugg.get("suggested_action", "")) + " " +
str(sugg.get("issue_type", ""))
for sugg in correction_suggestions
]).lower()
# Check for key terms from expected issues
issue_terms = []
for item in expected_invalid:
issue_terms.append(item.get("type", "").lower())
issue_terms.append(item.get("reason", "").lower())
for item in expected_missing:
issue_terms.append(item.lower())
# Simple matching score
matches = 0
for term in issue_terms:
if term and term in suggestion_text:
matches += 1
if issue_terms:
suggestion_score = 0.2 * (matches / len(issue_terms))
else:
suggestion_score = 0.2
score += suggestion_score
else:
score += 0.2 # No issues to correct
else:
# No suggestions when there are issues to correct
if expected_feedback.get("invalid_documents") or expected_feedback.get("missing_documents"):
score -= 0.1
else:
score += 0.1 # Correctly identified no issues
elif action.action_type == "submit_application":
# Check if application is ready to submit
is_complete = observation.is_complete
validation_results = observation.validation_results
if validation_results and validation_results.get("is_valid") and validation_results.get("is_complete"):
if is_complete:
score += 0.3 # Correctly submitted valid application
else:
score -= 0.1 # Submitted but observation says incomplete
elif not validation_results or not validation_results.get("is_valid"):
if not is_complete:
score += 0.1 # Correctly didn't submit invalid application
else:
score -= 0.2 # Incorrectly thinks incomplete application is complete
else:
score -= 0.1 # Unclear state
return max(0.0, min(1.0, score))
# Global grader instance
grader = Grader()