virustechhacks's picture
Upload folder using huggingface_hub
e18fa06 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Tests for the Adaptive Project Manager Environment.
"""
import pytest
import sys
import os
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import (
ProjectAction, ProjectObservation, Assignment,
TaskState, EmployeeState, RiskState, ProjectState,
)
from server.hustlers_env_environment import AdaptiveProjectManagerEnv
from graders import grade_easy, grade_medium, grade_hard, compute_final_score
from tasks import get_easy_task, get_medium_task, get_hard_task
class TestModels:
"""Test Pydantic models."""
def test_task_state(self):
task = TaskState(
id="task_1",
name="Test Task",
priority="high",
status="todo",
required_skill="backend",
remaining_effort=3.0,
original_effort=3.0,
dependencies=["task_0"],
is_critical_path=True,
)
assert task.id == "task_1"
assert task.priority == "high"
assert task.remaining_effort == 3.0
def test_employee_state(self):
emp = EmployeeState(
id="emp_1",
name="Alice",
skills=["frontend", "backend"],
available=True,
workload=0.5,
burnout=0.2,
)
assert emp.id == "emp_1"
assert "frontend" in emp.skills
assert emp.burnout == 0.2
def test_project_action(self):
action = ProjectAction(
assignments=[
Assignment(employee_id="emp_1", task_id="task_1"),
Assignment(employee_id="emp_2", task_id="task_2"),
],
reprioritized_tasks=["task_3"],
contingency_action="request_overtime",
)
assert len(action.assignments) == 2
assert action.contingency_action == "request_overtime"
def test_project_observation(self):
obs = ProjectObservation(
day=5,
days_remaining=10,
budget_remaining=50000.0,
project_completion=0.3,
blocked_tasks=1,
overdue_tasks=0,
average_burnout=0.15,
tasks=[],
employees=[],
risks=[],
message="Test message",
)
assert obs.day == 5
assert obs.project_completion == 0.3
class TestTaskConfigs:
"""Test task configuration generators."""
def test_easy_task_config(self):
config = get_easy_task()
assert config["task_id"] == "easy"
assert config["seed"] == 42
assert config["total_days"] == 12
assert len(config["employees"]) == 3
assert len(config["tasks"]) == 5
assert len(config["scheduled_events"]) == 0 # No events for easy
def test_medium_task_config(self):
config = get_medium_task()
assert config["task_id"] == "medium"
assert config["seed"] == 1337
assert config["total_days"] == 18
assert len(config["employees"]) == 4
assert len(config["tasks"]) == 9
assert len(config["scheduled_events"]) == 2 # Two scheduled events
def test_hard_task_config(self):
config = get_hard_task()
assert config["task_id"] == "hard"
assert config["seed"] == 9001
assert config["total_days"] == 25
assert len(config["employees"]) == 5
assert len(config["tasks"]) == 14
assert len(config["scheduled_events"]) == 4 # Four scheduled events
class TestEnvironment:
"""Test the main environment."""
def test_reset_easy(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
assert isinstance(obs, ProjectObservation)
assert obs.day == 1
assert obs.days_remaining == 11 # 12 - 1
assert len(obs.tasks) == 5
assert len(obs.employees) == 3
assert obs.project_completion == 0.0
assert obs.done is False
def test_reset_medium(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("medium")
assert obs.day == 1
assert obs.days_remaining == 17 # 18 - 1
assert len(obs.tasks) == 9
assert len(obs.employees) == 4
def test_reset_hard(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("hard")
assert obs.day == 1
assert obs.days_remaining == 24 # 25 - 1
assert len(obs.tasks) == 14
assert len(obs.employees) == 5
def test_step_with_assignment(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
# Find an available employee and a todo task
emp = next(e for e in obs.employees if e.available)
task = next(t for t in obs.tasks if t.status == "todo" and not t.dependencies)
action = ProjectAction(
assignments=[Assignment(employee_id=emp.id, task_id=task.id)]
)
obs = env.step(action)
assert obs.day == 2
assert obs.days_remaining == 10
# Task should be in progress
task_state = next(t for t in obs.tasks if t.id == task.id)
assert task_state.status == "in_progress"
def test_step_no_action(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
action = ProjectAction(assignments=[])
obs = env.step(action)
assert obs.day == 2
# No work should progress
def test_task_completion(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
# Get task_1 (Design Homepage - 2 effort) and emp_1 (Alice - has ui_design)
emp_1 = next(e for e in obs.employees if e.id == "emp_1")
task_1 = next(t for t in obs.tasks if t.id == "task_1")
assert task_1.remaining_effort == 2.0
# Assign Alice to Design Homepage (exact skill match)
action = ProjectAction(
assignments=[Assignment(employee_id="emp_1", task_id="task_1")]
)
# Step multiple times until task is done
for _ in range(5): # Should be enough for 2 effort with skill match
obs = env.step(action)
task_1 = next(t for t in obs.tasks if t.id == "task_1")
if task_1.status == "done":
break
assert task_1.status == "done"
def test_burnout_increases(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
initial_burnout = obs.employees[0].burnout
# Assign employee to work
action = ProjectAction(
assignments=[Assignment(employee_id="emp_1", task_id="task_1")]
)
# Step multiple times
for _ in range(3):
obs = env.step(action)
# Burnout should have increased
final_burnout = next(e for e in obs.employees if e.id == "emp_1").burnout
assert final_burnout > initial_burnout
def test_overtime_contingency(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
action = ProjectAction(
assignments=[Assignment(employee_id="emp_1", task_id="task_1")],
contingency_action="request_overtime",
)
obs = env.step(action)
# Message should mention overtime
assert "overtime" in obs.message.lower()
def test_hire_contractor(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
initial_emp_count = len(obs.employees)
action = ProjectAction(
assignments=[],
contingency_action="hire_contractor",
)
obs = env.step(action)
# Should have one more employee
assert len(obs.employees) == initial_emp_count + 1
assert any(e.id == "contractor_1" for e in obs.employees)
def test_defer_low_priority(self):
env = AdaptiveProjectManagerEnv()
obs = env.reset("medium") # Has low priority tasks
action = ProjectAction(
assignments=[],
contingency_action="defer_low_priority_work",
)
obs = env.step(action)
# Low priority tasks should be blocked
low_tasks = [t for t in obs.tasks if t.priority == "low"]
for task in low_tasks:
assert task.status == "blocked"
def test_reproducibility(self):
"""Test that same seed produces same results."""
env1 = AdaptiveProjectManagerEnv()
env2 = AdaptiveProjectManagerEnv()
obs1 = env1.reset("easy")
obs2 = env2.reset("easy")
# Should be identical
assert obs1.day == obs2.day
assert len(obs1.tasks) == len(obs2.tasks)
for t1, t2 in zip(obs1.tasks, obs2.tasks):
assert t1.id == t2.id
assert t1.remaining_effort == t2.remaining_effort
class TestGraders:
"""Test the grading functions."""
def test_grade_easy_perfect(self):
"""Test grading a perfectly completed easy project."""
state = ProjectState(
day=10,
total_days=12,
budget_total=50000.0,
budget_spent=30000.0,
tasks=[
TaskState(id=f"task_{i}", required_skill="test", status="done",
is_critical_path=True, priority="high")
for i in range(5)
],
employees=[
EmployeeState(id=f"emp_{i}", skills=["test"], burnout=0.1)
for i in range(3)
],
risks=[],
stakeholder_satisfaction=1.0,
task_id="easy",
)
score = grade_easy(state)
assert 0.0 <= score <= 1.0
assert score > 0.7 # Should be high for perfect completion
def test_grade_easy_failed(self):
"""Test grading a failed easy project."""
state = ProjectState(
day=12,
total_days=12,
budget_total=50000.0,
budget_spent=50000.0, # Budget exhausted
tasks=[
TaskState(id=f"task_{i}", required_skill="test", status="todo",
is_critical_path=True, priority="high")
for i in range(5)
],
employees=[
EmployeeState(id=f"emp_{i}", skills=["test"], burnout=0.9) # High burnout
for i in range(3)
],
risks=[],
stakeholder_satisfaction=0.3,
task_id="easy",
)
score = grade_easy(state)
assert 0.0 <= score <= 1.0
assert score < 0.5 # Should be low for failed project
def test_grader_bounds(self):
"""Test that graders always return values in [0, 1]."""
# Test various edge cases
for task_id, grader in [("easy", grade_easy), ("medium", grade_medium), ("hard", grade_hard)]:
# Empty state
state = ProjectState(task_id=task_id)
score = grader(state)
assert 0.0 <= score <= 1.0
# Fully completed
state = ProjectState(
day=5,
total_days=30,
budget_total=100000.0,
budget_spent=10000.0,
tasks=[TaskState(id="t1", required_skill="test", status="done", is_critical_path=True)],
employees=[EmployeeState(id="e1", skills=["test"], burnout=0.0)],
stakeholder_satisfaction=1.0,
task_id=task_id,
)
score = grader(state)
assert 0.0 <= score <= 1.0
class TestEndToEnd:
"""End-to-end tests running complete episodes."""
def test_easy_episode(self):
"""Run a complete easy episode with heuristic policy."""
env = AdaptiveProjectManagerEnv()
obs = env.reset("easy")
total_reward = 0.0
steps = 0
while not obs.done and steps < 20:
# Simple heuristic: assign available employees to available tasks
assignments = []
available_emps = [e for e in obs.employees if e.available and e.assigned_task_id is None]
available_tasks = [t for t in obs.tasks if t.status in ["todo", "in_progress"]]
for emp in available_emps:
for task in available_tasks:
if task.required_skill in emp.skills:
assignments.append(Assignment(employee_id=emp.id, task_id=task.id))
available_tasks.remove(task)
break
action = ProjectAction(assignments=assignments)
obs = env.step(action)
total_reward += obs.reward or 0.0
steps += 1
# Should complete within deadline
assert steps <= 15
# Get final score
state = env.get_project_state()
score = grade_easy(state)
assert 0.0 <= score <= 1.0
def test_difficulty_ordering(self):
"""Test that easy > medium > hard in baseline scores."""
scores = {}
for task_id in ["easy", "medium", "hard"]:
env = AdaptiveProjectManagerEnv()
obs = env.reset(task_id)
# Run simple heuristic
for _ in range(30):
if obs.done:
break
assignments = []
available_emps = [e for e in obs.employees if e.available and e.assigned_task_id is None]
available_tasks = [t for t in obs.tasks if t.status == "todo" and not any(
dep_task.status != "done"
for dep_id in t.dependencies
for dep_task in obs.tasks if dep_task.id == dep_id
)]
for emp in available_emps:
for task in available_tasks:
if task.required_skill in emp.skills:
assignments.append(Assignment(employee_id=emp.id, task_id=task.id))
available_tasks.remove(task)
break
action = ProjectAction(assignments=assignments)
obs = env.step(action)
state = env.get_project_state()
from graders import GRADER_REGISTRY
scores[task_id] = GRADER_REGISTRY[task_id](state)
# Easy should be easier than medium, medium easier than hard
# Note: Due to heuristic limitations, we check for reasonable values
assert scores["easy"] >= 0.0
assert scores["medium"] >= 0.0
assert scores["hard"] >= 0.0
class TestScheduledEvents:
"""Test scheduled events in medium and hard tasks."""
def test_medium_employee_unavailable(self):
"""Test that employee becomes unavailable on day 6."""
env = AdaptiveProjectManagerEnv()
obs = env.reset("medium")
# Progress to day 6
for i in range(5): # Days 1-5
action = ProjectAction(assignments=[])
obs = env.step(action)
# On day 6, Bob should become unavailable
bob = next(e for e in obs.employees if e.id == "emp_2")
assert bob.available is False
assert "sick" in obs.message.lower() or "unavailable" in obs.message.lower()
def test_hard_new_task_added(self):
"""Test that compliance task is added on day 9."""
env = AdaptiveProjectManagerEnv()
obs = env.reset("hard")
initial_task_count = len(obs.tasks)
# Progress to day 9
for i in range(8): # Days 1-8
action = ProjectAction(assignments=[])
obs = env.step(action)
# On day 9, compliance task should be added
assert len(obs.tasks) == initial_task_count + 1
compliance_task = next((t for t in obs.tasks if "compliance" in t.name.lower()), None)
assert compliance_task is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])