File size: 4,631 Bytes
ec17c6d
e181764
ec17c6d
e181764
ec17c6d
 
 
 
 
 
e181764
ec17c6d
 
e181764
 
 
 
 
 
 
 
 
e1e46c2
e181764
e1e46c2
 
 
 
ec17c6d
 
e181764
ec17c6d
e181764
 
 
 
 
ec17c6d
 
e181764
 
 
 
 
 
 
 
 
ec17c6d
e1e46c2
 
e181764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126c21b
e181764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1e46c2
 
ec17c6d
e181764
ec17c6d
 
 
 
 
6c8a204
ec17c6d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
FastAPI application for the HR Onboarding/Offboarding Environment.

Serves both the OpenEnv API endpoints and an interactive web playground UI.
"""

try:
    from openenv.core.env_server.http_server import create_app
except Exception as e:  # pragma: no cover
    raise ImportError(
        "openenv is required. Install with: uv sync"
    ) from e

from models import HROnboardingAction, HROnboardingObservation
from .hr_onboarding_environment import HROnboardingEnvironment
from .tools import TOOL_DEFINITIONS
from .rubrics import RubricEvaluator

from fastapi import Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import os
import json


# Required for OpenEnv to mount the HF-style web UI at /web.
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")


# Create the OpenEnv app
app = create_app(
    HROnboardingEnvironment,
    HROnboardingAction,
    HROnboardingObservation,
    env_name="hr_onboarding_env",
    max_concurrent_envs=4,
)

# Mount static files
STATIC_DIR = Path(__file__).parent / "static"
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")

# Shared environment instance for the playground
_playground_env = HROnboardingEnvironment(seed=42, max_steps=15)


# --- Playground API endpoints ---

@app.get("/", include_in_schema=False)
def root_redirect():
    """Serve the interactive playground UI."""
    return RedirectResponse(url="/playground")


@app.get("/playground", include_in_schema=False)
def playground():
    """Serve the interactive playground HTML."""
    html_path = STATIC_DIR / "index.html"
    return HTMLResponse(html_path.read_text())


@app.get("/api/tasks")
def get_tasks():
    """Return all tasks with metadata for the task picker."""
    env = _playground_env
    tasks = []
    # Save current state
    orig_idx = env._task_idx

    for i in range(len(env._tasks)):
        task = env._tasks[i]
        tasks.append({
            "index": i,
            "task_id": task.task_id,
            "instruction": task.instruction,
            "difficulty": task.difficulty,
            "category": task.category,
            "expected_tools": task.expected_tools,
            "rubric_criteria": task.rubric_criteria,
            "num_criteria": len(task.rubric_criteria),
        })

    env._task_idx = orig_idx
    return tasks


@app.get("/api/tool_definitions")
def get_tool_definitions():
    """Return all tool definitions with descriptions and parameters."""
    return TOOL_DEFINITIONS


@app.post("/api/reset")
async def playground_reset(request: Request):
    """Reset the environment to a specific task."""
    body = await request.json()
    task_idx = body.get("task_idx", 0)

    env = _playground_env
    # Set task index so reset() picks the right task
    env._task_idx = task_idx
    obs = env.reset()

    return {
        "task_id": obs.task_id,
        "instruction": obs.instruction,
        "step": obs.step,
        "max_steps": obs.max_steps,
        "available_tools": obs.available_tools,
        "done": obs.done,
        "reward": obs.reward,
        "metadata": obs.metadata,
    }


@app.post("/api/step")
async def playground_step(request: Request):
    """Execute one tool call step."""
    body = await request.json()
    tool_name = body.get("tool_name", "")
    arguments = body.get("arguments", {})

    env = _playground_env
    action = HROnboardingAction(tool_name=tool_name, arguments=arguments)
    obs = env.step(action)

    return {
        "task_id": obs.task_id,
        "instruction": obs.instruction,
        "tool_name": obs.tool_name,
        "tool_result": obs.tool_result,
        "step": obs.step,
        "max_steps": obs.max_steps,
        "done": obs.done,
        "reward": obs.reward,
        "metadata": obs.metadata,
    }


@app.post("/api/evaluate")
async def playground_evaluate():
    """Force evaluation of current episode."""
    env = _playground_env
    evaluator = RubricEvaluator()

    if env._current_task:
        result = evaluator.evaluate(env._current_task, env.world.action_log)
        return result

    return {"score": 0, "passed": False, "criteria_results": [], "passed_count": 0, "total_criteria": 0}


def main():
    """Entry point for direct execution."""
    import argparse
    import uvicorn

    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=7860)
    args = parser.parse_args()

    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()