Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +118 -5
- __init__.py +16 -0
- baseline_llm.py +202 -0
- baseline_random.py +94 -0
- client.py +58 -0
- demo/index.html +724 -0
- eval_trained.py +209 -0
- models.py +67 -0
- openenv.yaml +7 -0
- openenv_recruitopenenv.egg-info/PKG-INFO +9 -0
- openenv_recruitopenenv.egg-info/SOURCES.txt +17 -0
- openenv_recruitopenenv.egg-info/dependency_links.txt +1 -0
- openenv_recruitopenenv.egg-info/entry_points.txt +2 -0
- openenv_recruitopenenv.egg-info/requires.txt +5 -0
- openenv_recruitopenenv.egg-info/top_level.txt +1 -0
- play.py +172 -0
- pyproject.toml +45 -0
- server/__init__.py +11 -0
- server/app.py +102 -0
- server/recruitopenenv_environment.py +1422 -0
- server/requirements.txt +6 -0
- train_colab.ipynb +558 -0
- train_grpo.py +431 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=recruitopenenv
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,123 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Driver Recruit Environment
|
| 3 |
+
emoji: 🚛
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- reinforcement-learning
|
| 13 |
+
- recruiting
|
| 14 |
+
- multi-turn
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# 🚛 Driver Recruit Environment
|
| 18 |
+
|
| 19 |
+
A **multi-turn, tool-based RL environment** for training LLMs to recruit truck drivers through a CRM system. Built on [OpenEnv 0.2.1](https://github.com/meta-pytorch/OpenEnv).
|
| 20 |
+
|
| 21 |
+
The agent must discover driver qualifications through conversation, record info in the CRM, get management approval, and hire — all using structured tool calls across 15-40+ step episodes.
|
| 22 |
+
|
| 23 |
+
## Pipeline
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
lead → contacted → interested → approval_pending → offer_sent → hired
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Tools
|
| 30 |
+
|
| 31 |
+
| Tool | Actions | Purpose |
|
| 32 |
+
|------|---------|---------|
|
| 33 |
+
| **crm** | `read_candidate`, `update_stage`, `update_field`, `add_note` | Manage pipeline & record info |
|
| 34 |
+
| **messaging** | `send_message`, `read_reply` | Screen driver (18 topics) |
|
| 35 |
+
| **approval** | `request_approval`, `check_approval` | Get management sign-off |
|
| 36 |
+
| **workflow** | `wait` | Advance time for approval processing |
|
| 37 |
+
|
| 38 |
+
## Reward Signal
|
| 39 |
+
|
| 40 |
+
- **Successful hire** (good job fit): **+10** to **+15** (base + CRM bonus)
|
| 41 |
+
- **Bad hire** (poor match): **-5**
|
| 42 |
+
- **Ghosted** (trust runs out): **-4**
|
| 43 |
+
- **Per-step**: Small rewards/penalties for correct/incorrect actions
|
| 44 |
+
|
| 45 |
+
## What Makes This Hard
|
| 46 |
+
|
| 47 |
+
- **Long horizon**: 15-40+ tool calls per episode
|
| 48 |
+
- **Information gathering**: Must ask the right screening questions to match driver to the right job
|
| 49 |
+
- **Trust dynamics**: Each message costs trust — ask too many questions and the driver ghosts
|
| 50 |
+
- **Job matching**: 6 jobs per episode (1-2 good, 1-2 traps with deal-breakers, 2-3 partial)
|
| 51 |
+
- **Procedural correctness**: Must follow stage order, read replies before messaging, get approval before offering
|
| 52 |
+
|
| 53 |
+
## Quick Start
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
|
| 57 |
+
|
| 58 |
+
env = RecruitopenenvEnv(base_url="YOUR_SPACE_URL")
|
| 59 |
+
|
| 60 |
+
result = env.reset(seed=42)
|
| 61 |
+
obs = result.observation
|
| 62 |
+
print(f"Driver: {obs.driver_name}, Stage: {obs.stage}")
|
| 63 |
+
|
| 64 |
+
# Read CRM
|
| 65 |
+
result = env.step(RecruitopenenvAction(tool="crm", action="read_candidate"))
|
| 66 |
+
print(result.observation.jobs_summary)
|
| 67 |
+
|
| 68 |
+
# Greet driver
|
| 69 |
+
result = env.step(RecruitopenenvAction(tool="messaging", action="send_message", topic="greeting"))
|
| 70 |
+
print(f"Reward: {result.reward}")
|
| 71 |
+
|
| 72 |
+
# Read reply
|
| 73 |
+
result = env.step(RecruitopenenvAction(tool="messaging", action="read_reply"))
|
| 74 |
+
print(result.observation.discovered_info)
|
| 75 |
+
|
| 76 |
+
env.close()
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Training
|
| 80 |
+
|
| 81 |
+
We train using GRPO/REINFORCE with the model choosing screening topics. See `train_grpo.py` for the full training script.
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python train_grpo.py --model Qwen/Qwen2.5-3B-Instruct
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## Deploying
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
# From the recruitopenenv/ directory
|
| 91 |
+
openenv push
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Action Format
|
| 95 |
+
|
| 96 |
+
```json
|
| 97 |
+
{"tool": "crm", "action": "read_candidate"}
|
| 98 |
+
{"tool": "messaging", "action": "send_message", "topic": "experience"}
|
| 99 |
+
{"tool": "messaging", "action": "read_reply"}
|
| 100 |
+
{"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
|
| 101 |
+
{"tool": "crm", "action": "update_stage", "stage": "contacted"}
|
| 102 |
+
{"tool": "approval", "action": "request_approval", "job_id": 2}
|
| 103 |
+
{"tool": "workflow", "action": "wait"}
|
| 104 |
+
{"tool": "approval", "action": "check_approval"}
|
| 105 |
+
{"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
|
| 106 |
+
{"tool": "crm", "action": "update_stage", "stage": "hired"}
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Observation Fields
|
| 110 |
+
|
| 111 |
+
| Field | Description |
|
| 112 |
+
|-------|-------------|
|
| 113 |
+
| `driver_name` | Driver's name |
|
| 114 |
+
| `crm_summary` | Full CRM record (empty until `read_candidate`) |
|
| 115 |
+
| `jobs_summary` | 6 available job listings |
|
| 116 |
+
| `discovered_info` | Info from screening conversations |
|
| 117 |
+
| `stage` | Current pipeline stage |
|
| 118 |
+
| `feedback` | API response from last action |
|
| 119 |
+
| `pending_reply` | Whether driver has unread message |
|
| 120 |
+
|
| 121 |
+
## Screening Topics
|
| 122 |
+
|
| 123 |
+
`greeting`, `call`, `experience`, `home_time`, `pay`, `equipment`, `route`, `deal_breakers`, `availability`, `violations`, `medical_card`, `references`, `pitch`, `offer`, `negotiate_pay`, `negotiate_home_time`, `signing_bonus`, `address_concern`
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Recruitopenenv Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import RecruitopenenvEnv
|
| 10 |
+
from .models import RecruitopenenvAction, RecruitopenenvObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"RecruitopenenvAction",
|
| 14 |
+
"RecruitopenenvObservation",
|
| 15 |
+
"RecruitopenenvEnv",
|
| 16 |
+
]
|
baseline_llm.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM agent baseline — test how well a base model performs without RL training."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import requests
|
| 6 |
+
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
|
| 7 |
+
|
| 8 |
+
SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
|
| 9 |
+
|
| 10 |
+
You have 4 tools:
|
| 11 |
+
|
| 12 |
+
## crm
|
| 13 |
+
- read_candidate: Read the current CRM record
|
| 14 |
+
- update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
|
| 15 |
+
- update_field: Record info (field + value)
|
| 16 |
+
- add_note: Add a free-text note
|
| 17 |
+
|
| 18 |
+
## messaging
|
| 19 |
+
- send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
|
| 20 |
+
- read_reply: Read the driver's response
|
| 21 |
+
|
| 22 |
+
## approval
|
| 23 |
+
- request_approval: Request approval for a job (needs job_id)
|
| 24 |
+
- check_approval: Check approval status
|
| 25 |
+
|
| 26 |
+
## workflow
|
| 27 |
+
- wait: Advance time (needed for approval processing)
|
| 28 |
+
|
| 29 |
+
## Rules
|
| 30 |
+
- Must read CRM before messaging
|
| 31 |
+
- Must read_reply before sending another message
|
| 32 |
+
- Must request_approval and wait before sending offer
|
| 33 |
+
- Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
|
| 34 |
+
- Record important info in CRM with update_field
|
| 35 |
+
- Too many messages hurt trust
|
| 36 |
+
|
| 37 |
+
## Strategy
|
| 38 |
+
1. crm.read_candidate → see the lead
|
| 39 |
+
2. messaging.send_message(greeting or call) → messaging.read_reply → crm.update_stage(contacted)
|
| 40 |
+
3. Screen: send_message(experience) → read_reply → update_field(cdl_class, value) ... repeat for key questions
|
| 41 |
+
4. crm.update_stage(interested)
|
| 42 |
+
5. approval.request_approval(job_id) → workflow.wait → approval.check_approval
|
| 43 |
+
6. crm.update_stage(approval_pending)
|
| 44 |
+
7. messaging.send_message(offer) → messaging.read_reply
|
| 45 |
+
8. crm.update_stage(offer_sent) → crm.update_stage(hired)
|
| 46 |
+
|
| 47 |
+
Tips:
|
| 48 |
+
- ask_experience is critical (CDL class filters jobs)
|
| 49 |
+
- ask_deal_breakers helps avoid trap jobs
|
| 50 |
+
- ask_violations and ask_medical_card reveal fatal blockers
|
| 51 |
+
- If driver has concerns about offer, use negotiate_pay/negotiate_home_time/address_concern
|
| 52 |
+
- If no good match exists, update_stage to lost
|
| 53 |
+
|
| 54 |
+
Respond with ONLY JSON:
|
| 55 |
+
{"tool": "crm", "action": "read_candidate"}
|
| 56 |
+
{"tool": "messaging", "action": "send_message", "topic": "experience"}
|
| 57 |
+
{"tool": "messaging", "action": "read_reply"}
|
| 58 |
+
{"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
|
| 59 |
+
{"tool": "approval", "action": "request_approval", "job_id": 2}
|
| 60 |
+
{"tool": "crm", "action": "update_stage", "stage": "hired"}"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def format_observation(obs):
|
| 64 |
+
parts = [f"Driver: {obs.driver_name}"]
|
| 65 |
+
if obs.crm_summary:
|
| 66 |
+
parts.append(f"CRM:\n{obs.crm_summary}")
|
| 67 |
+
if obs.jobs_summary:
|
| 68 |
+
parts.append(f"Jobs:\n{obs.jobs_summary}")
|
| 69 |
+
if obs.discovered_info:
|
| 70 |
+
parts.append(f"Discovered:\n{obs.discovered_info}")
|
| 71 |
+
status = f"Stage: {obs.stage}"
|
| 72 |
+
if obs.pending_reply:
|
| 73 |
+
status += " | PENDING REPLY"
|
| 74 |
+
parts.append(status)
|
| 75 |
+
if obs.feedback:
|
| 76 |
+
parts.append(f"Result: {obs.feedback}")
|
| 77 |
+
return "\n".join(parts)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def ask_llm(messages, llm_url, model):
|
| 81 |
+
resp = requests.post(llm_url, json={
|
| 82 |
+
"model": model,
|
| 83 |
+
"messages": messages,
|
| 84 |
+
"temperature": 0.1,
|
| 85 |
+
"max_tokens": 150,
|
| 86 |
+
})
|
| 87 |
+
content = resp.json()["choices"][0]["message"]["content"]
|
| 88 |
+
return content
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def parse_action(text):
|
| 92 |
+
"""Try to extract action from LLM response."""
|
| 93 |
+
text = text.strip()
|
| 94 |
+
|
| 95 |
+
# Remove markdown code fences
|
| 96 |
+
if "```" in text:
|
| 97 |
+
parts = text.split("```")
|
| 98 |
+
for part in parts:
|
| 99 |
+
part = part.strip()
|
| 100 |
+
if part.startswith("json"):
|
| 101 |
+
part = part[4:].strip()
|
| 102 |
+
if part.startswith("{"):
|
| 103 |
+
text = part
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
# Try JSON parse
|
| 107 |
+
try:
|
| 108 |
+
data = json.loads(text)
|
| 109 |
+
if isinstance(data, dict) and "tool" in data and "action" in data:
|
| 110 |
+
return RecruitopenenvAction(
|
| 111 |
+
tool=data["tool"],
|
| 112 |
+
action=data["action"],
|
| 113 |
+
topic=data.get("topic", ""),
|
| 114 |
+
job_id=data.get("job_id", -1),
|
| 115 |
+
stage=data.get("stage", ""),
|
| 116 |
+
field=data.get("field", ""),
|
| 117 |
+
value=data.get("value", ""),
|
| 118 |
+
)
|
| 119 |
+
except (json.JSONDecodeError, KeyError):
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
# Fallback
|
| 123 |
+
text_lower = text.lower()
|
| 124 |
+
if "read_candidate" in text_lower:
|
| 125 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 126 |
+
if "read_reply" in text_lower:
|
| 127 |
+
return RecruitopenenvAction(tool="messaging", action="read_reply")
|
| 128 |
+
if "check_approval" in text_lower:
|
| 129 |
+
return RecruitopenenvAction(tool="approval", action="check_approval")
|
| 130 |
+
if "wait" in text_lower:
|
| 131 |
+
return RecruitopenenvAction(tool="workflow", action="wait")
|
| 132 |
+
|
| 133 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def run_baseline(env_url, llm_url, model, num_episodes):
|
| 137 |
+
rewards = []
|
| 138 |
+
successes = 0
|
| 139 |
+
total_steps = 0
|
| 140 |
+
|
| 141 |
+
env = RecruitopenenvEnv(base_url=env_url)
|
| 142 |
+
|
| 143 |
+
for ep in range(num_episodes):
|
| 144 |
+
result = env.reset()
|
| 145 |
+
obs = result.observation
|
| 146 |
+
ep_reward = 0.0
|
| 147 |
+
steps = 0
|
| 148 |
+
|
| 149 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 150 |
+
|
| 151 |
+
while not result.done and steps < 100:
|
| 152 |
+
obs_text = format_observation(obs)
|
| 153 |
+
messages.append({"role": "user", "content": obs_text})
|
| 154 |
+
|
| 155 |
+
llm_response = ask_llm(messages, llm_url, model)
|
| 156 |
+
messages.append({"role": "assistant", "content": llm_response})
|
| 157 |
+
|
| 158 |
+
action = parse_action(llm_response)
|
| 159 |
+
result = env.step(action)
|
| 160 |
+
obs = result.observation
|
| 161 |
+
ep_reward += result.reward
|
| 162 |
+
steps += 1
|
| 163 |
+
|
| 164 |
+
print(f" Step {steps}: {action.tool}.{action.action}"
|
| 165 |
+
f"{'(' + action.topic + ')' if action.topic else ''}"
|
| 166 |
+
f"{'[job=' + str(action.job_id) + ']' if action.job_id >= 0 else ''}"
|
| 167 |
+
f" -> reward={result.reward:.1f}")
|
| 168 |
+
|
| 169 |
+
rewards.append(ep_reward)
|
| 170 |
+
total_steps += steps
|
| 171 |
+
if obs.stage == "hired":
|
| 172 |
+
successes += 1
|
| 173 |
+
|
| 174 |
+
print(f"Episode {ep+1}: total_reward={ep_reward:.1f}, steps={steps}, "
|
| 175 |
+
f"{'HIRED' if obs.stage == 'hired' else 'FAIL (' + obs.stage + ')'}")
|
| 176 |
+
print()
|
| 177 |
+
|
| 178 |
+
env.close()
|
| 179 |
+
|
| 180 |
+
avg_reward = sum(rewards) / len(rewards)
|
| 181 |
+
avg_steps = total_steps / num_episodes
|
| 182 |
+
|
| 183 |
+
print("\n========== LLM BASELINE (no RL) ==========")
|
| 184 |
+
print(f"Model: {model}")
|
| 185 |
+
print(f"Episodes: {num_episodes}")
|
| 186 |
+
print(f"Avg reward: {avg_reward:.2f}")
|
| 187 |
+
print(f"Min reward: {min(rewards):.2f}")
|
| 188 |
+
print(f"Max reward: {max(rewards):.2f}")
|
| 189 |
+
print(f"Hire rate: {successes}/{num_episodes} ({100*successes/num_episodes:.1f}%)")
|
| 190 |
+
print(f"Avg steps/episode: {avg_steps:.1f}")
|
| 191 |
+
print("==========================================")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
parser = argparse.ArgumentParser(description="LLM baseline for Driver Recruit Environment")
|
| 196 |
+
parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL")
|
| 197 |
+
parser.add_argument("--llm-url", default="http://localhost:8033/v1/chat/completions", help="LLM API URL")
|
| 198 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct", help="Model name")
|
| 199 |
+
parser.add_argument("--episodes", type=int, default=20, help="Number of episodes")
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
run_baseline(args.env_url, args.llm_url, args.model, args.episodes)
|
baseline_random.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Random agent baseline — establishes the floor for reward."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
|
| 5 |
+
|
| 6 |
+
TOOLS_ACTIONS = {
|
| 7 |
+
"crm": ["read_candidate", "update_stage", "update_field", "add_note"],
|
| 8 |
+
"messaging": ["send_message", "read_reply"],
|
| 9 |
+
"approval": ["request_approval", "check_approval"],
|
| 10 |
+
"workflow": ["wait"],
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
TOPICS = [
|
| 14 |
+
"greeting", "call", "experience", "home_time", "pay", "equipment",
|
| 15 |
+
"route", "deal_breakers", "availability", "violations", "medical_card",
|
| 16 |
+
"references", "pitch", "offer", "negotiate_pay", "negotiate_home_time",
|
| 17 |
+
"signing_bonus", "address_concern",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
STAGES = ["contacted", "interested", "approval_pending", "offer_sent", "hired", "lost"]
|
| 21 |
+
|
| 22 |
+
NUM_EPISODES = 100
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def random_action():
|
| 26 |
+
tool = random.choice(list(TOOLS_ACTIONS.keys()))
|
| 27 |
+
action = random.choice(TOOLS_ACTIONS[tool])
|
| 28 |
+
|
| 29 |
+
topic = ""
|
| 30 |
+
job_id = -1
|
| 31 |
+
stage = ""
|
| 32 |
+
field = ""
|
| 33 |
+
value = ""
|
| 34 |
+
|
| 35 |
+
if tool == "messaging" and action == "send_message":
|
| 36 |
+
topic = random.choice(TOPICS)
|
| 37 |
+
if topic in ("pitch", "offer"):
|
| 38 |
+
job_id = random.randint(0, 5)
|
| 39 |
+
elif tool == "crm" and action == "update_stage":
|
| 40 |
+
stage = random.choice(STAGES)
|
| 41 |
+
elif tool == "crm" and action == "update_field":
|
| 42 |
+
field = random.choice(["cdl_class", "years_exp", "home_time_pref"])
|
| 43 |
+
value = "A"
|
| 44 |
+
elif tool == "approval" and action == "request_approval":
|
| 45 |
+
job_id = random.randint(0, 5)
|
| 46 |
+
|
| 47 |
+
return RecruitopenenvAction(
|
| 48 |
+
tool=tool, action=action, topic=topic,
|
| 49 |
+
job_id=job_id, stage=stage, field=field, value=value,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_baseline():
|
| 54 |
+
rewards = []
|
| 55 |
+
successes = 0
|
| 56 |
+
total_steps = 0
|
| 57 |
+
|
| 58 |
+
with RecruitopenenvEnv(base_url="http://localhost:8000").sync() as env:
|
| 59 |
+
for ep in range(NUM_EPISODES):
|
| 60 |
+
result = env.reset()
|
| 61 |
+
ep_reward = 0.0
|
| 62 |
+
steps = 0
|
| 63 |
+
|
| 64 |
+
while not result.done and steps < 100:
|
| 65 |
+
action = random_action()
|
| 66 |
+
result = env.step(action)
|
| 67 |
+
ep_reward += result.reward
|
| 68 |
+
steps += 1
|
| 69 |
+
|
| 70 |
+
rewards.append(ep_reward)
|
| 71 |
+
total_steps += steps
|
| 72 |
+
|
| 73 |
+
if result.observation.stage == "hired":
|
| 74 |
+
successes += 1
|
| 75 |
+
|
| 76 |
+
if (ep + 1) % 10 == 0:
|
| 77 |
+
avg_so_far = sum(rewards) / len(rewards)
|
| 78 |
+
print(f" Episode {ep+1}: reward={ep_reward:.1f}, running avg={avg_so_far:.2f}")
|
| 79 |
+
|
| 80 |
+
avg_reward = sum(rewards) / len(rewards)
|
| 81 |
+
avg_steps = total_steps / NUM_EPISODES
|
| 82 |
+
|
| 83 |
+
print("\n========== RANDOM BASELINE ==========")
|
| 84 |
+
print(f"Episodes: {NUM_EPISODES}")
|
| 85 |
+
print(f"Avg reward: {avg_reward:.2f}")
|
| 86 |
+
print(f"Min reward: {min(rewards):.2f}")
|
| 87 |
+
print(f"Max reward: {max(rewards):.2f}")
|
| 88 |
+
print(f"Hire rate: {successes}/{NUM_EPISODES} ({100*successes/NUM_EPISODES:.1f}%)")
|
| 89 |
+
print(f"Avg steps/episode: {avg_steps:.1f}")
|
| 90 |
+
print("======================================")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
run_baseline()
|
client.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Recruitopenenv Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core.client_types import StepResult
|
| 6 |
+
from openenv.core.env_server.types import State
|
| 7 |
+
from openenv.core import EnvClient
|
| 8 |
+
|
| 9 |
+
from .models import RecruitopenenvAction, RecruitopenenvObservation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RecruitopenenvEnv(
|
| 13 |
+
EnvClient[RecruitopenenvAction, RecruitopenenvObservation, State]
|
| 14 |
+
):
|
| 15 |
+
"""Client for the Driver Recruit Environment."""
|
| 16 |
+
|
| 17 |
+
def _step_payload(self, action: RecruitopenenvAction) -> Dict:
|
| 18 |
+
payload = {
|
| 19 |
+
"tool": action.tool,
|
| 20 |
+
"action": action.action,
|
| 21 |
+
}
|
| 22 |
+
if action.topic:
|
| 23 |
+
payload["topic"] = action.topic
|
| 24 |
+
if action.job_id >= 0:
|
| 25 |
+
payload["job_id"] = action.job_id
|
| 26 |
+
if action.stage:
|
| 27 |
+
payload["stage"] = action.stage
|
| 28 |
+
if action.field:
|
| 29 |
+
payload["field"] = action.field
|
| 30 |
+
if action.value:
|
| 31 |
+
payload["value"] = action.value
|
| 32 |
+
return payload
|
| 33 |
+
|
| 34 |
+
def _parse_result(self, payload: Dict) -> StepResult[RecruitopenenvObservation]:
|
| 35 |
+
obs_data = payload.get("observation", {})
|
| 36 |
+
observation = RecruitopenenvObservation(
|
| 37 |
+
driver_name=obs_data.get("driver_name", ""),
|
| 38 |
+
crm_summary=obs_data.get("crm_summary", ""),
|
| 39 |
+
jobs_summary=obs_data.get("jobs_summary", ""),
|
| 40 |
+
discovered_info=obs_data.get("discovered_info", ""),
|
| 41 |
+
stage=obs_data.get("stage", "lead"),
|
| 42 |
+
feedback=obs_data.get("feedback", ""),
|
| 43 |
+
pending_reply=obs_data.get("pending_reply", False),
|
| 44 |
+
done=payload.get("done", False),
|
| 45 |
+
reward=payload.get("reward", 0.0),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return StepResult(
|
| 49 |
+
observation=observation,
|
| 50 |
+
reward=payload.get("reward", 0.0),
|
| 51 |
+
done=payload.get("done", False),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 55 |
+
return State(
|
| 56 |
+
episode_id=payload.get("episode_id"),
|
| 57 |
+
step_count=payload.get("step_count", 0),
|
| 58 |
+
)
|
demo/index.html
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Driver Recruit Environment</title>
|
| 7 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 8 |
+
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500;600&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 9 |
+
<style>
|
| 10 |
+
*{margin:0;padding:0;box-sizing:border-box}
|
| 11 |
+
:root{
|
| 12 |
+
--bg:#09090b;--s1:#111113;--s2:#18181b;--s3:#27272a;
|
| 13 |
+
--b1:#27272a;--b2:#3f3f46;
|
| 14 |
+
--t1:#fafafa;--t2:#a1a1aa;--t3:#71717a;
|
| 15 |
+
--green:#22c55e;--red:#ef4444;--amber:#f59e0b;--blue:#3b82f6;--violet:#8b5cf6;--rose:#f43f5e;--cyan:#06b6d4;--orange:#f97316;
|
| 16 |
+
}
|
| 17 |
+
body{font-family:'Inter',system-ui,sans-serif;background:var(--bg);color:var(--t1);-webkit-font-smoothing:antialiased}
|
| 18 |
+
.mono{font-family:'IBM Plex Mono',monospace}
|
| 19 |
+
|
| 20 |
+
/* ─── HERO ─── */
|
| 21 |
+
.hero{min-height:100vh;display:flex;flex-direction:column;align-items:center;justify-content:center;padding:40px 24px;position:relative;overflow:hidden}
|
| 22 |
+
.hero::before{content:'';position:absolute;top:-200px;left:50%;transform:translateX(-50%);width:800px;height:800px;background:radial-gradient(circle,rgba(139,92,246,0.06) 0%,transparent 70%);pointer-events:none}
|
| 23 |
+
.hero-eyebrow{font-size:13px;font-weight:500;color:var(--violet);letter-spacing:0.08em;text-transform:uppercase;margin-bottom:20px}
|
| 24 |
+
.hero h1{font-size:clamp(2rem,5vw,3.5rem);font-weight:700;letter-spacing:-0.03em;line-height:1.1;text-align:center;max-width:700px;margin-bottom:16px}
|
| 25 |
+
.hero-sub{color:var(--t2);font-size:17px;line-height:1.6;text-align:center;max-width:560px;margin-bottom:48px}
|
| 26 |
+
|
| 27 |
+
.cards{display:grid;grid-template-columns:repeat(4,1fr);gap:1px;background:var(--b1);border:1px solid var(--b1);border-radius:12px;overflow:hidden;max-width:900px;width:100%;margin-bottom:48px}
|
| 28 |
+
.card{background:var(--s1);padding:24px 20px}
|
| 29 |
+
.card-num{font-family:'IBM Plex Mono',monospace;font-size:12px;color:var(--t3);margin-bottom:10px}
|
| 30 |
+
.card h3{font-size:14px;font-weight:600;margin-bottom:6px}
|
| 31 |
+
.card p{font-size:13px;color:var(--t2);line-height:1.55}
|
| 32 |
+
@media(max-width:800px){.cards{grid-template-columns:1fr 1fr}}
|
| 33 |
+
@media(max-width:500px){.cards{grid-template-columns:1fr}}
|
| 34 |
+
|
| 35 |
+
.btn{display:inline-flex;align-items:center;gap:8px;padding:12px 28px;border-radius:8px;font-size:14px;font-weight:600;cursor:pointer;border:none;transition:all .15s}
|
| 36 |
+
.btn-white{background:var(--t1);color:var(--bg)}
|
| 37 |
+
.btn-white:hover{opacity:.9}
|
| 38 |
+
|
| 39 |
+
.env-input{background:var(--s2);border:1px solid var(--b1);color:var(--t2);padding:8px 12px;border-radius:6px;font-size:13px;font-family:'IBM Plex Mono',monospace;width:240px;margin-top:16px;text-align:center}
|
| 40 |
+
.env-input:focus{outline:none;border-color:var(--b2)}
|
| 41 |
+
|
| 42 |
+
/* ─── GAME ─── */
|
| 43 |
+
.game{display:none;max-width:1280px;margin:0 auto;padding:16px 20px 40px}
|
| 44 |
+
.game.on{display:block}
|
| 45 |
+
|
| 46 |
+
/* Top bar */
|
| 47 |
+
.topbar{display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:1px solid var(--b1);margin-bottom:16px;flex-wrap:wrap;gap:12px}
|
| 48 |
+
.topbar-left{display:flex;align-items:center;gap:16px}
|
| 49 |
+
.avatar{width:36px;height:36px;border-radius:50%;background:var(--s3);display:flex;align-items:center;justify-content:center;font-weight:600;font-size:14px;flex-shrink:0}
|
| 50 |
+
.driver-meta h2{font-size:15px;font-weight:600;line-height:1}
|
| 51 |
+
.pill{display:inline-block;font-size:11px;font-weight:500;padding:2px 8px;border-radius:4px;margin-top:4px}
|
| 52 |
+
.pill-chatty{background:rgba(34,197,94,.12);color:var(--green)}
|
| 53 |
+
.pill-professional{background:rgba(59,130,246,.12);color:var(--blue)}
|
| 54 |
+
.pill-impatient{background:rgba(239,68,68,.12);color:var(--red)}
|
| 55 |
+
.pill-suspicious{background:rgba(244,63,94,.12);color:var(--rose)}
|
| 56 |
+
|
| 57 |
+
.topbar-stats{display:flex;gap:24px;align-items:center}
|
| 58 |
+
.ts{text-align:right}
|
| 59 |
+
.ts-label{font-size:11px;color:var(--t3);text-transform:uppercase;letter-spacing:.05em}
|
| 60 |
+
.ts-val{font-family:'IBM Plex Mono',monospace;font-size:16px;font-weight:600}
|
| 61 |
+
|
| 62 |
+
.trust-wrap{width:140px}
|
| 63 |
+
.trust-track{width:100%;height:4px;background:var(--s3);border-radius:2px;margin-top:4px;overflow:hidden}
|
| 64 |
+
.trust-fill{height:100%;border-radius:2px;transition:width .4s ease}
|
| 65 |
+
.trust-num{display:flex;justify-content:space-between;align-items:center;margin-top:3px}
|
| 66 |
+
.trust-num span{font-family:'IBM Plex Mono',monospace;font-size:11px;color:var(--t3)}
|
| 67 |
+
.trust-num .delta{font-weight:600}
|
| 68 |
+
.delta-up{color:var(--green)}
|
| 69 |
+
.delta-down{color:var(--red)}
|
| 70 |
+
|
| 71 |
+
/* Layout */
|
| 72 |
+
.layout{display:grid;grid-template-columns:260px 1fr 260px;gap:16px;min-height:calc(100vh - 120px)}
|
| 73 |
+
@media(max-width:1000px){.layout{grid-template-columns:1fr}}
|
| 74 |
+
.right-col{display:flex;flex-direction:column;gap:12px}
|
| 75 |
+
|
| 76 |
+
/* Sidebar */
|
| 77 |
+
.sidebar{display:flex;flex-direction:column;gap:12px}
|
| 78 |
+
.pane{background:var(--s1);border:1px solid var(--b1);border-radius:10px;overflow:hidden}
|
| 79 |
+
.pane-head{font-size:11px;font-weight:600;color:var(--t3);text-transform:uppercase;letter-spacing:.06em;padding:12px 14px 8px;display:flex;align-items:center;justify-content:space-between}
|
| 80 |
+
|
| 81 |
+
.job{padding:10px 14px;border-bottom:1px solid var(--b1);font-size:13px;cursor:default;transition:background .1s}
|
| 82 |
+
.job:last-child{border-bottom:none}
|
| 83 |
+
.job:hover{background:var(--s2)}
|
| 84 |
+
.job-id{font-family:'IBM Plex Mono',monospace;color:var(--t3);font-size:11px;font-weight:500}
|
| 85 |
+
.job-co{font-weight:500;margin-left:6px}
|
| 86 |
+
.job-det{color:var(--t2);font-size:12px;margin-top:3px;line-height:1.5}
|
| 87 |
+
.job-warn{color:var(--amber);font-size:11px;margin-top:2px}
|
| 88 |
+
|
| 89 |
+
.info-item{padding:8px 14px;border-bottom:1px solid var(--b1);font-size:13px;line-height:1.5}
|
| 90 |
+
.info-item:last-child{border-bottom:none}
|
| 91 |
+
.info-cat{font-family:'IBM Plex Mono',monospace;font-size:10px;font-weight:600;color:var(--violet);text-transform:uppercase;letter-spacing:.04em}
|
| 92 |
+
.info-empty{padding:14px;color:var(--t3);font-size:13px;font-style:italic}
|
| 93 |
+
|
| 94 |
+
.crm-field{padding:5px 14px;border-bottom:1px solid var(--b1);font-size:12px;display:flex;justify-content:space-between}
|
| 95 |
+
.crm-field:last-child{border-bottom:none}
|
| 96 |
+
.crm-key{color:var(--t3);font-family:'IBM Plex Mono',monospace;font-size:11px}
|
| 97 |
+
.crm-val{color:var(--t1)}
|
| 98 |
+
.crm-empty{padding:14px;color:var(--t3);font-size:13px;font-style:italic}
|
| 99 |
+
|
| 100 |
+
/* Main area */
|
| 101 |
+
.main{display:flex;flex-direction:column;gap:12px}
|
| 102 |
+
|
| 103 |
+
/* Timeline */
|
| 104 |
+
.timeline{flex:1;background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:16px;overflow-y:auto;max-height:calc(100vh - 320px);min-height:320px}
|
| 105 |
+
.tl-entry{display:flex;gap:12px;margin-bottom:16px;animation:slideIn .25s ease}
|
| 106 |
+
@keyframes slideIn{from{opacity:0;transform:translateY(6px)}to{opacity:1;transform:none}}
|
| 107 |
+
.tl-dot-col{display:flex;flex-direction:column;align-items:center;padding-top:4px}
|
| 108 |
+
.tl-dot{width:8px;height:8px;border-radius:50%;flex-shrink:0}
|
| 109 |
+
.tl-line{width:1px;flex:1;background:var(--b1);margin-top:4px}
|
| 110 |
+
.tl-content{flex:1;min-width:0}
|
| 111 |
+
.tl-head{display:flex;align-items:center;gap:8px;margin-bottom:4px;flex-wrap:wrap}
|
| 112 |
+
.tl-action{font-size:13px;font-weight:600}
|
| 113 |
+
.tl-reward{font-family:'IBM Plex Mono',monospace;font-size:12px;font-weight:500}
|
| 114 |
+
.tl-reward.pos{color:var(--green)}.tl-reward.neg{color:var(--red)}.tl-reward.zero{color:var(--t3)}
|
| 115 |
+
.tl-tool-badge{font-size:10px;font-weight:600;padding:1px 6px;border-radius:3px;text-transform:uppercase;letter-spacing:.04em}
|
| 116 |
+
.badge-crm{background:rgba(6,182,212,.12);color:var(--cyan)}
|
| 117 |
+
.badge-messaging{background:rgba(139,92,246,.12);color:var(--violet)}
|
| 118 |
+
.badge-approval{background:rgba(249,115,22,.12);color:var(--orange)}
|
| 119 |
+
.badge-workflow{background:rgba(113,113,122,.15);color:var(--t3)}
|
| 120 |
+
.tl-body{font-size:13px;color:var(--t2);line-height:1.55;padding:8px 12px;background:var(--s2);border-radius:6px;border-left:2px solid var(--b1)}
|
| 121 |
+
.tl-body.good{border-left-color:var(--green)}
|
| 122 |
+
.tl-body.bad{border-left-color:var(--red)}
|
| 123 |
+
.tl-step{font-family:'IBM Plex Mono',monospace;font-size:10px;color:var(--t3)}
|
| 124 |
+
|
| 125 |
+
/* Tool sections */
|
| 126 |
+
.tool-section{background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:10px 14px}
|
| 127 |
+
.tool-section-head{font-size:11px;font-weight:600;text-transform:uppercase;letter-spacing:.05em;margin-bottom:8px;display:flex;align-items:center;gap:6px}
|
| 128 |
+
.tool-section-head .dot{width:6px;height:6px;border-radius:50%;display:inline-block}
|
| 129 |
+
|
| 130 |
+
.act-grid{display:flex;flex-wrap:wrap;gap:6px}
|
| 131 |
+
.act{padding:6px 12px;border-radius:6px;font-size:12px;font-weight:500;cursor:pointer;border:1px solid var(--b1);background:var(--s2);color:var(--t2);transition:all .12s;white-space:nowrap}
|
| 132 |
+
.act:hover{background:var(--s3);color:var(--t1);border-color:var(--b2)}
|
| 133 |
+
.act:disabled{opacity:.3;cursor:not-allowed}
|
| 134 |
+
.act-go{border-color:rgba(34,197,94,.3);color:var(--green)}
|
| 135 |
+
.act-go:hover{background:rgba(34,197,94,.1);border-color:var(--green)}
|
| 136 |
+
.act-no{border-color:rgba(239,68,68,.2);color:var(--red)}
|
| 137 |
+
.act-no:hover{background:rgba(239,68,68,.08);border-color:var(--red)}
|
| 138 |
+
.act-warn{border-color:rgba(245,158,11,.2);color:var(--amber)}
|
| 139 |
+
.act-warn:hover{background:rgba(245,158,11,.08);border-color:var(--amber)}
|
| 140 |
+
|
| 141 |
+
/* Pipeline */
|
| 142 |
+
.pipeline{background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:10px 14px}
|
| 143 |
+
.pipe-stages{display:flex;gap:2px;align-items:center}
|
| 144 |
+
.pipe-stage{font-size:10px;font-weight:600;padding:4px 8px;border-radius:4px;text-transform:uppercase;letter-spacing:.04em;background:var(--s2);color:var(--t3);border:1px solid var(--b1)}
|
| 145 |
+
.pipe-stage.active{background:rgba(139,92,246,.15);color:var(--violet);border-color:var(--violet)}
|
| 146 |
+
.pipe-stage.done{background:rgba(34,197,94,.1);color:var(--green);border-color:rgba(34,197,94,.3)}
|
| 147 |
+
.pipe-stage.fail{background:rgba(239,68,68,.1);color:var(--red);border-color:rgba(239,68,68,.3)}
|
| 148 |
+
.pipe-arrow{color:var(--t3);font-size:10px}
|
| 149 |
+
|
| 150 |
+
/* Pending reply indicator */
|
| 151 |
+
.pending-badge{font-size:11px;font-weight:500;padding:3px 10px;border-radius:4px;background:rgba(139,92,246,.12);color:var(--violet);animation:pulse 1.5s ease infinite}
|
| 152 |
+
@keyframes pulse{0%,100%{opacity:1}50%{opacity:.5}}
|
| 153 |
+
|
| 154 |
+
/* Modal */
|
| 155 |
+
.modal-bg{display:none;position:fixed;inset:0;background:rgba(0,0,0,.6);z-index:100;align-items:center;justify-content:center;backdrop-filter:blur(4px)}
|
| 156 |
+
.modal-bg.on{display:flex}
|
| 157 |
+
.modal{background:var(--s1);border:1px solid var(--b1);border-radius:12px;padding:20px;width:420px;max-width:90vw}
|
| 158 |
+
.modal h3{font-size:14px;font-weight:600;margin-bottom:14px}
|
| 159 |
+
.modal-job{display:block;width:100%;text-align:left;background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:10px 12px;border-radius:6px;margin-bottom:6px;cursor:pointer;font-size:13px;transition:border-color .12s}
|
| 160 |
+
.modal-job:hover{border-color:var(--violet)}
|
| 161 |
+
.modal-cancel{display:block;width:100%;text-align:center;background:transparent;border:1px solid var(--b1);color:var(--t3);padding:8px;border-radius:6px;cursor:pointer;font-size:13px;margin-top:8px}
|
| 162 |
+
.modal-cancel:hover{color:var(--t2)}
|
| 163 |
+
|
| 164 |
+
/* Input modal */
|
| 165 |
+
.modal input[type="text"]{width:100%;background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:8px 12px;border-radius:6px;font-size:13px;font-family:'IBM Plex Mono',monospace;margin-bottom:8px}
|
| 166 |
+
.modal input[type="text"]:focus{outline:none;border-color:var(--violet)}
|
| 167 |
+
.modal-row{display:flex;gap:8px;margin-top:8px}
|
| 168 |
+
.modal-row .btn{flex:1;justify-content:center;padding:8px;font-size:13px}
|
| 169 |
+
.modal-btn-go{background:var(--violet);color:white;border:none;padding:8px 16px;border-radius:6px;font-size:13px;font-weight:500;cursor:pointer}
|
| 170 |
+
.modal-btn-go:hover{opacity:.9}
|
| 171 |
+
|
| 172 |
+
/* End screen */
|
| 173 |
+
.endscreen{display:none;position:fixed;inset:0;background:rgba(0,0,0,.75);z-index:200;align-items:center;justify-content:center;backdrop-filter:blur(8px)}
|
| 174 |
+
.endscreen.on{display:flex}
|
| 175 |
+
.end-card{background:var(--s1);border:1px solid var(--b1);border-radius:14px;padding:40px 36px;text-align:center;width:440px;max-width:90vw}
|
| 176 |
+
.end-label{font-size:11px;font-weight:600;text-transform:uppercase;letter-spacing:.1em;margin-bottom:8px}
|
| 177 |
+
.end-title{font-size:28px;font-weight:700;letter-spacing:-0.02em;margin-bottom:6px}
|
| 178 |
+
.end-sub{font-size:13px;color:var(--t2);margin-bottom:28px;line-height:1.5}
|
| 179 |
+
.end-grid{display:grid;grid-template-columns:1fr 1fr 1fr;gap:1px;background:var(--b1);border:1px solid var(--b1);border-radius:8px;overflow:hidden;margin-bottom:24px}
|
| 180 |
+
.end-stat{background:var(--s1);padding:14px 8px}
|
| 181 |
+
.end-stat-val{font-family:'IBM Plex Mono',monospace;font-size:18px;font-weight:600}
|
| 182 |
+
.end-stat-lbl{font-size:11px;color:var(--t3);margin-top:2px}
|
| 183 |
+
|
| 184 |
+
/* Hidden info toggle */
|
| 185 |
+
.hidden-info.off{display:none}
|
| 186 |
+
.toggle-hidden{background:var(--s2);border:1px solid var(--b1);color:var(--t3);padding:4px 10px;border-radius:4px;font-size:11px;cursor:pointer;margin-left:12px}
|
| 187 |
+
.toggle-hidden:hover{color:var(--t2);border-color:var(--b2)}
|
| 188 |
+
|
| 189 |
+
/* Stage select */
|
| 190 |
+
.stage-select{background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:4px 8px;border-radius:4px;font-size:12px;font-family:'IBM Plex Mono',monospace}
|
| 191 |
+
.stage-select:focus{outline:none;border-color:var(--violet)}
|
| 192 |
+
</style>
|
| 193 |
+
</head>
|
| 194 |
+
<body>
|
| 195 |
+
|
| 196 |
+
<div class="hero" id="hero">
|
| 197 |
+
<div class="hero-eyebrow">OpenEnv Hackathon — Long-Horizon RL</div>
|
| 198 |
+
<h1>Train an AI to recruit truck drivers through tool calls</h1>
|
| 199 |
+
<p class="hero-sub">A multi-turn RL environment where agents use CRM, messaging, approval, and workflow tools across 40-70 step episodes to screen candidates, avoid trap jobs, and close hires.</p>
|
| 200 |
+
|
| 201 |
+
<div class="cards">
|
| 202 |
+
<div class="card">
|
| 203 |
+
<div class="card-num mono">01</div>
|
| 204 |
+
<h3>Tool calling</h3>
|
| 205 |
+
<p>4 tools — CRM, messaging, approval, workflow. The agent must call the right tool with the right action at each step.</p>
|
| 206 |
+
</div>
|
| 207 |
+
<div class="card">
|
| 208 |
+
<div class="card-num mono">02</div>
|
| 209 |
+
<h3>Long horizon</h3>
|
| 210 |
+
<p>Episodes span 40-70 steps through a full recruiting pipeline: lead → contacted → interested → approval → offer → hired.</p>
|
| 211 |
+
</div>
|
| 212 |
+
<div class="card">
|
| 213 |
+
<div class="card-num mono">03</div>
|
| 214 |
+
<h3>Hidden information</h3>
|
| 215 |
+
<p>Driver preferences, deal breakers, and personality are hidden. Must be discovered through screening messages.</p>
|
| 216 |
+
</div>
|
| 217 |
+
<div class="card">
|
| 218 |
+
<div class="card-num mono">04</div>
|
| 219 |
+
<h3>Trap jobs</h3>
|
| 220 |
+
<p>Jobs that look perfect but violate deal breakers. Skip screening and you'll hire for the wrong one — big negative reward.</p>
|
| 221 |
+
</div>
|
| 222 |
+
</div>
|
| 223 |
+
|
| 224 |
+
<button class="btn btn-white" onclick="startGame()">Play the environment</button>
|
| 225 |
+
<input class="env-input" id="envUrl" value="http://localhost:8000" spellcheck="false">
|
| 226 |
+
</div>
|
| 227 |
+
|
| 228 |
+
<div class="game" id="game">
|
| 229 |
+
<div class="topbar">
|
| 230 |
+
<div class="topbar-left">
|
| 231 |
+
<div class="avatar" id="av">?</div>
|
| 232 |
+
<div class="driver-meta">
|
| 233 |
+
<h2 id="dName">---</h2>
|
| 234 |
+
<span class="pill" id="dPers"></span>
|
| 235 |
+
</div>
|
| 236 |
+
</div>
|
| 237 |
+
<div class="topbar-stats">
|
| 238 |
+
<div class="ts">
|
| 239 |
+
<div class="ts-label">Stage</div>
|
| 240 |
+
<div class="ts-val" id="uiStage">lead</div>
|
| 241 |
+
</div>
|
| 242 |
+
<div class="ts">
|
| 243 |
+
<div class="ts-label">Step</div>
|
| 244 |
+
<div class="ts-val"><span id="uiStep">0</span><span style="color:var(--t3)"> / 100</span></div>
|
| 245 |
+
</div>
|
| 246 |
+
<div class="ts">
|
| 247 |
+
<div class="ts-label">Reward</div>
|
| 248 |
+
<div class="ts-val" id="uiRew">0.0</div>
|
| 249 |
+
</div>
|
| 250 |
+
<div id="pendingBadge" style="display:none" class="pending-badge">Unread reply</div>
|
| 251 |
+
<button class="toggle-hidden" onclick="toggleHidden()">Show hidden</button>
|
| 252 |
+
</div>
|
| 253 |
+
</div>
|
| 254 |
+
|
| 255 |
+
<!-- Pipeline -->
|
| 256 |
+
<div class="pipeline">
|
| 257 |
+
<div class="pipe-stages" id="pipeStages"></div>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<div class="layout" style="margin-top:12px">
|
| 261 |
+
<div class="sidebar">
|
| 262 |
+
<div class="pane" id="jobsPane">
|
| 263 |
+
<div class="pane-head">Jobs</div>
|
| 264 |
+
<div id="jobsList"></div>
|
| 265 |
+
</div>
|
| 266 |
+
<div class="pane">
|
| 267 |
+
<div class="pane-head">CRM Record</div>
|
| 268 |
+
<div id="crmList"><div class="crm-empty">Not loaded — use crm.read_candidate</div></div>
|
| 269 |
+
</div>
|
| 270 |
+
</div>
|
| 271 |
+
<div class="main">
|
| 272 |
+
<div class="timeline" id="tl"></div>
|
| 273 |
+
|
| 274 |
+
<!-- Tool: CRM -->
|
| 275 |
+
<div class="tool-section">
|
| 276 |
+
<div class="tool-section-head"><span class="dot" style="background:var(--cyan)"></span><span style="color:var(--cyan)">CRM</span></div>
|
| 277 |
+
<div class="act-grid">
|
| 278 |
+
<button class="act" onclick="doTool('crm','read_candidate')">read_candidate</button>
|
| 279 |
+
<button class="act" onclick="showStageModal()">update_stage</button>
|
| 280 |
+
<button class="act" onclick="showFieldModal()">update_field</button>
|
| 281 |
+
<button class="act" onclick="showNoteModal()">add_note</button>
|
| 282 |
+
</div>
|
| 283 |
+
</div>
|
| 284 |
+
|
| 285 |
+
<!-- Tool: Messaging -->
|
| 286 |
+
<div class="tool-section">
|
| 287 |
+
<div class="tool-section-head"><span class="dot" style="background:var(--violet)"></span><span style="color:var(--violet)">Messaging</span></div>
|
| 288 |
+
<div class="act-grid" id="msgGrid">
|
| 289 |
+
<button class="act" onclick="doMsg('greeting')">greeting</button>
|
| 290 |
+
<button class="act" onclick="doMsg('call')">call</button>
|
| 291 |
+
<span style="width:1px;height:24px;background:var(--b1)"></span>
|
| 292 |
+
<button class="act" onclick="doMsg('experience')">experience</button>
|
| 293 |
+
<button class="act" onclick="doMsg('home_time')">home time</button>
|
| 294 |
+
<button class="act" onclick="doMsg('pay')">pay</button>
|
| 295 |
+
<button class="act" onclick="doMsg('equipment')">equipment</button>
|
| 296 |
+
<button class="act" onclick="doMsg('route')">route</button>
|
| 297 |
+
<button class="act" onclick="doMsg('deal_breakers')">deal breakers</button>
|
| 298 |
+
<button class="act" onclick="doMsg('availability')">availability</button>
|
| 299 |
+
<button class="act" onclick="doMsg('violations')">violations</button>
|
| 300 |
+
<button class="act" onclick="doMsg('medical_card')">medical card</button>
|
| 301 |
+
<button class="act" onclick="doMsg('references')">references</button>
|
| 302 |
+
<span style="width:1px;height:24px;background:var(--b1)"></span>
|
| 303 |
+
<button class="act act-warn" onclick="showJobModal('pitch')">pitch job</button>
|
| 304 |
+
<button class="act act-warn" onclick="showJobModal('offer')">send offer</button>
|
| 305 |
+
<span style="width:1px;height:24px;background:var(--b1)"></span>
|
| 306 |
+
<button class="act" onclick="doMsg('negotiate_pay')">negotiate pay</button>
|
| 307 |
+
<button class="act" onclick="doMsg('negotiate_home_time')">negotiate home</button>
|
| 308 |
+
<button class="act" onclick="doMsg('signing_bonus')">signing bonus</button>
|
| 309 |
+
<button class="act" onclick="doMsg('address_concern')">address concern</button>
|
| 310 |
+
<span style="width:1px;height:24px;background:var(--b1)"></span>
|
| 311 |
+
<button class="act act-go" onclick="doTool('messaging','read_reply')">read_reply</button>
|
| 312 |
+
</div>
|
| 313 |
+
</div>
|
| 314 |
+
|
| 315 |
+
<!-- Tool: Approval + Workflow -->
|
| 316 |
+
<div style="display:flex;gap:12px">
|
| 317 |
+
<div class="tool-section" style="flex:1">
|
| 318 |
+
<div class="tool-section-head"><span class="dot" style="background:var(--orange)"></span><span style="color:var(--orange)">Approval</span></div>
|
| 319 |
+
<div class="act-grid">
|
| 320 |
+
<button class="act" onclick="showJobModal('request_approval')">request_approval</button>
|
| 321 |
+
<button class="act" onclick="doTool('approval','check_approval')">check_approval</button>
|
| 322 |
+
</div>
|
| 323 |
+
</div>
|
| 324 |
+
<div class="tool-section" style="flex:1">
|
| 325 |
+
<div class="tool-section-head"><span class="dot" style="background:var(--t3)"></span><span style="color:var(--t3)">Workflow</span></div>
|
| 326 |
+
<div class="act-grid">
|
| 327 |
+
<button class="act" onclick="doTool('workflow','wait')">wait</button>
|
| 328 |
+
<button class="act act-go" onclick="showStageModal('hired')">hire (finish)</button>
|
| 329 |
+
<button class="act act-no" onclick="doStage('lost')">reject (lost)</button>
|
| 330 |
+
</div>
|
| 331 |
+
</div>
|
| 332 |
+
</div>
|
| 333 |
+
</div>
|
| 334 |
+
<div class="right-col">
|
| 335 |
+
<div class="pane">
|
| 336 |
+
<div class="pane-head">Discovered Info</div>
|
| 337 |
+
<div id="infoList"><div class="info-empty">No info yet — send messages and read replies</div></div>
|
| 338 |
+
</div>
|
| 339 |
+
</div>
|
| 340 |
+
</div>
|
| 341 |
+
</div>
|
| 342 |
+
|
| 343 |
+
<!-- Job picker modal -->
|
| 344 |
+
<div class="modal-bg" id="modalBg">
|
| 345 |
+
<div class="modal">
|
| 346 |
+
<h3 id="modalTitle">Select job</h3>
|
| 347 |
+
<div id="modalJobs"></div>
|
| 348 |
+
<button class="modal-cancel" onclick="closeModal()">Cancel</button>
|
| 349 |
+
</div>
|
| 350 |
+
</div>
|
| 351 |
+
|
| 352 |
+
<!-- Stage modal -->
|
| 353 |
+
<div class="modal-bg" id="stageModalBg">
|
| 354 |
+
<div class="modal">
|
| 355 |
+
<h3>Update Pipeline Stage</h3>
|
| 356 |
+
<div id="stageModalBtns"></div>
|
| 357 |
+
<button class="modal-cancel" onclick="closeStageModal()">Cancel</button>
|
| 358 |
+
</div>
|
| 359 |
+
</div>
|
| 360 |
+
|
| 361 |
+
<!-- Field modal -->
|
| 362 |
+
<div class="modal-bg" id="fieldModalBg">
|
| 363 |
+
<div class="modal">
|
| 364 |
+
<h3>Update CRM Field</h3>
|
| 365 |
+
<select id="fieldSelect" class="stage-select" style="width:100%;margin-bottom:8px;padding:8px">
|
| 366 |
+
<option value="cdl_class">cdl_class</option>
|
| 367 |
+
<option value="years_experience">years_experience</option>
|
| 368 |
+
<option value="endorsements">endorsements</option>
|
| 369 |
+
<option value="location">location</option>
|
| 370 |
+
<option value="home_time_pref">home_time_pref</option>
|
| 371 |
+
<option value="pay_expectation">pay_expectation</option>
|
| 372 |
+
<option value="equipment_pref">equipment_pref</option>
|
| 373 |
+
<option value="route_pref">route_pref</option>
|
| 374 |
+
<option value="deal_breakers">deal_breakers</option>
|
| 375 |
+
<option value="availability">availability</option>
|
| 376 |
+
<option value="violations">violations</option>
|
| 377 |
+
<option value="medical_card">medical_card</option>
|
| 378 |
+
<option value="references">references</option>
|
| 379 |
+
<option value="matched_job">matched_job</option>
|
| 380 |
+
</select>
|
| 381 |
+
<input type="text" id="fieldValue" placeholder="Value..." />
|
| 382 |
+
<div class="modal-row">
|
| 383 |
+
<button class="modal-btn-go" onclick="submitField()">Save</button>
|
| 384 |
+
<button class="modal-cancel" onclick="closeFieldModal()" style="margin-top:0">Cancel</button>
|
| 385 |
+
</div>
|
| 386 |
+
</div>
|
| 387 |
+
</div>
|
| 388 |
+
|
| 389 |
+
<!-- Note modal -->
|
| 390 |
+
<div class="modal-bg" id="noteModalBg">
|
| 391 |
+
<div class="modal">
|
| 392 |
+
<h3>Add CRM Note</h3>
|
| 393 |
+
<input type="text" id="noteValue" placeholder="Note text..." />
|
| 394 |
+
<div class="modal-row">
|
| 395 |
+
<button class="modal-btn-go" onclick="submitNote()">Add</button>
|
| 396 |
+
<button class="modal-cancel" onclick="closeNoteModal()" style="margin-top:0">Cancel</button>
|
| 397 |
+
</div>
|
| 398 |
+
</div>
|
| 399 |
+
</div>
|
| 400 |
+
|
| 401 |
+
<!-- End screen -->
|
| 402 |
+
<div class="endscreen" id="endscreen">
|
| 403 |
+
<div class="end-card">
|
| 404 |
+
<div class="end-label" id="endLabel"></div>
|
| 405 |
+
<div class="end-title" id="endTitle"></div>
|
| 406 |
+
<div class="end-sub" id="endSub"></div>
|
| 407 |
+
<div class="end-grid">
|
| 408 |
+
<div class="end-stat"><div class="end-stat-val" id="erRew"></div><div class="end-stat-lbl">Reward</div></div>
|
| 409 |
+
<div class="end-stat"><div class="end-stat-val" id="erStep"></div><div class="end-stat-lbl">Steps</div></div>
|
| 410 |
+
<div class="end-stat"><div class="end-stat-val" id="erStage"></div><div class="end-stat-lbl">Final Stage</div></div>
|
| 411 |
+
</div>
|
| 412 |
+
<button class="btn btn-white" onclick="startGame()">Play again</button>
|
| 413 |
+
</div>
|
| 414 |
+
</div>
|
| 415 |
+
|
| 416 |
+
<script>
|
| 417 |
+
let ENV='',WS=null;
|
| 418 |
+
let S={obs:null,rew:0,done:false,jobs:[],stepCount:0};
|
| 419 |
+
let showHidden=false;
|
| 420 |
+
|
| 421 |
+
const STAGES=['lead','contacted','interested','approval_pending','offer_sent','hired'];
|
| 422 |
+
const FAIL_STAGES=['lost','ghosted'];
|
| 423 |
+
|
| 424 |
+
function toggleHidden(){
|
| 425 |
+
showHidden=!showHidden;
|
| 426 |
+
document.querySelectorAll('.hidden-info').forEach(el=>el.classList.toggle('off',!showHidden));
|
| 427 |
+
document.querySelector('.toggle-hidden').textContent=showHidden?'Hide hidden':'Show hidden';
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
function wsUrl(){
|
| 431 |
+
const base=document.getElementById('envUrl').value.replace(/\/$/,'');
|
| 432 |
+
return base.replace(/^http/,'ws')+'/ws';
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
function connectWS(){
|
| 436 |
+
return new Promise((resolve,reject)=>{
|
| 437 |
+
if(WS&&WS.readyState===WebSocket.OPEN){resolve();return}
|
| 438 |
+
if(WS)WS.close();
|
| 439 |
+
WS=new WebSocket(wsUrl());
|
| 440 |
+
WS.onopen=()=>resolve();
|
| 441 |
+
WS.onerror=()=>reject(new Error('WebSocket connection failed'));
|
| 442 |
+
WS.onmessage=(ev)=>{
|
| 443 |
+
const msg=JSON.parse(ev.data);
|
| 444 |
+
if(msg.type==='error'){
|
| 445 |
+
console.error('WS error:',msg.data);
|
| 446 |
+
if(pendingResolve){pendingResolve=null;}
|
| 447 |
+
return;
|
| 448 |
+
}
|
| 449 |
+
if(msg.type==='observation'&&pendingResolve){
|
| 450 |
+
const cb=pendingResolve;pendingResolve=null;
|
| 451 |
+
cb(msg.data);
|
| 452 |
+
}
|
| 453 |
+
};
|
| 454 |
+
WS.onclose=()=>{WS=null};
|
| 455 |
+
});
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
let pendingResolve=null;
|
| 459 |
+
function wsSend(msg){
|
| 460 |
+
return new Promise(resolve=>{
|
| 461 |
+
pendingResolve=resolve;
|
| 462 |
+
WS.send(JSON.stringify(msg));
|
| 463 |
+
});
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
async function startGame(){
|
| 467 |
+
document.getElementById('hero').style.display='none';
|
| 468 |
+
document.getElementById('game').classList.add('on');
|
| 469 |
+
document.getElementById('endscreen').classList.remove('on');
|
| 470 |
+
document.getElementById('tl').innerHTML='';
|
| 471 |
+
S={obs:null,rew:0,done:false,jobs:[],stepCount:0};
|
| 472 |
+
try{
|
| 473 |
+
await connectWS();
|
| 474 |
+
const d=await wsSend({type:'reset'});
|
| 475 |
+
handle(d,null);
|
| 476 |
+
}catch(e){
|
| 477 |
+
alert('Cannot reach server: '+e.message);
|
| 478 |
+
document.getElementById('hero').style.display='';
|
| 479 |
+
document.getElementById('game').classList.remove('on');
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// --- Tool actions ---
|
| 484 |
+
async function doTool(tool,action,extra){
|
| 485 |
+
if(S.done||!WS)return;
|
| 486 |
+
const data={tool,action,...(extra||{})};
|
| 487 |
+
const d=await wsSend({type:'step',data});
|
| 488 |
+
handle(d,tool+'.'+action,data);
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
async function doMsg(topic,jobId){
|
| 492 |
+
const extra={topic};
|
| 493 |
+
if(jobId!==undefined)extra.job_id=jobId;
|
| 494 |
+
await doTool('messaging','send_message',extra);
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
async function doStage(stage){
|
| 498 |
+
await doTool('crm','update_stage',{stage});
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
// --- Handle response ---
|
| 502 |
+
function handle(d,label,actionData){
|
| 503 |
+
const o=d.observation,rw=d.reward||0;
|
| 504 |
+
S.obs=o; S.rew+=rw; S.done=d.done;
|
| 505 |
+
if(o.steps_taken!==undefined)S.stepCount=o.steps_taken;
|
| 506 |
+
else if(label)S.stepCount++;
|
| 507 |
+
render(o,rw,label,actionData);
|
| 508 |
+
if(d.done)setTimeout(()=>showEnd(o),500);
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
function render(o,rw,label,actionData){
|
| 512 |
+
// Driver info
|
| 513 |
+
document.getElementById('dName').textContent=o.driver_name;
|
| 514 |
+
document.getElementById('av').textContent=o.driver_name?o.driver_name[0]:'?';
|
| 515 |
+
|
| 516 |
+
// Stage
|
| 517 |
+
document.getElementById('uiStage').textContent=o.stage;
|
| 518 |
+
document.getElementById('uiStep').textContent=S.stepCount;
|
| 519 |
+
|
| 520 |
+
// Reward
|
| 521 |
+
const re=document.getElementById('uiRew');
|
| 522 |
+
re.textContent=(S.rew>=0?'+':'')+S.rew.toFixed(1);
|
| 523 |
+
re.style.color=S.rew>=0?'var(--green)':'var(--red)';
|
| 524 |
+
|
| 525 |
+
// Pending reply
|
| 526 |
+
document.getElementById('pendingBadge').style.display=o.pending_reply?'':'none';
|
| 527 |
+
|
| 528 |
+
// Pipeline
|
| 529 |
+
renderPipeline(o.stage);
|
| 530 |
+
|
| 531 |
+
// Jobs
|
| 532 |
+
if(o.jobs_summary){
|
| 533 |
+
const lines=o.jobs_summary.split('\n');
|
| 534 |
+
document.getElementById('jobsList').innerHTML=lines.map(l=>{
|
| 535 |
+
const fm=l.match(/\[(.+?)\]/);
|
| 536 |
+
const warn=fm?'<div class="job-warn">'+fm[1]+'</div>':'';
|
| 537 |
+
const parts=l.split(' \u2014 ');
|
| 538 |
+
const hd=parts[0]||'';
|
| 539 |
+
const det=parts[1]||'';
|
| 540 |
+
const im=hd.match(/^Job (\d+): (.+)/);
|
| 541 |
+
return '<div class="job"><span class="job-id">#'+((im&&im[1])||'?')+'</span><span class="job-co">'+((im&&im[2])||hd)+'</span><div class="job-det">'+det+'</div>'+warn+'</div>';
|
| 542 |
+
}).join('');
|
| 543 |
+
S.jobs=lines.map(l=>{const m=l.match(/^Job (\d+): (.+?) \u2014/);return m?{id:+m[1],label:'#'+m[1]+' '+m[2]}:null}).filter(Boolean);
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
// CRM
|
| 547 |
+
if(o.crm_summary){
|
| 548 |
+
const lines=o.crm_summary.split('\n');
|
| 549 |
+
let html='';
|
| 550 |
+
lines.forEach(l=>{
|
| 551 |
+
const fieldMatch=l.match(/^\s{2}(\w+):\s*(.+)/);
|
| 552 |
+
if(fieldMatch){
|
| 553 |
+
html+='<div class="crm-field"><span class="crm-key">'+fieldMatch[1]+'</span><span class="crm-val">'+fieldMatch[2]+'</span></div>';
|
| 554 |
+
} else if(l.startsWith('Name:')||l.startsWith('Stage:')){
|
| 555 |
+
html+='<div class="crm-field"><span class="crm-key">'+l.split(':')[0]+'</span><span class="crm-val">'+l.split(':').slice(1).join(':').trim()+'</span></div>';
|
| 556 |
+
} else if(l.trim()==='Fields: (none recorded)'){
|
| 557 |
+
html+='<div class="crm-field"><span class="crm-key" style="color:var(--t3)">no fields recorded</span></div>';
|
| 558 |
+
} else if(l.match(/^\s{2}-\s(.+)/)){
|
| 559 |
+
html+='<div class="crm-field"><span class="crm-key">note</span><span class="crm-val" style="font-style:italic">'+l.match(/^\s{2}-\s(.+)/)[1]+'</span></div>';
|
| 560 |
+
}
|
| 561 |
+
});
|
| 562 |
+
document.getElementById('crmList').innerHTML=html||'<div class="crm-empty">Empty CRM</div>';
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
// Discovered info
|
| 566 |
+
if(o.discovered_info){
|
| 567 |
+
const items=o.discovered_info.split('\n').filter(l=>l.trim());
|
| 568 |
+
document.getElementById('infoList').innerHTML=items.map(l=>{
|
| 569 |
+
const m=l.match(/^\[(.+?)\]\s*(.*)/);
|
| 570 |
+
if(m)return '<div class="info-item"><span class="info-cat">'+m[1]+'</span><br>'+m[2]+'</div>';
|
| 571 |
+
return '<div class="info-item">'+l+'</div>';
|
| 572 |
+
}).join('');
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
// Timeline
|
| 576 |
+
if(label){
|
| 577 |
+
const tl=document.getElementById('tl');
|
| 578 |
+
const rwClass=rw>0?'pos':rw<0?'neg':'zero';
|
| 579 |
+
const rwStr=rw>=0?'+'+rw.toFixed(1):rw.toFixed(1);
|
| 580 |
+
const dotColor=rw>0?'var(--green)':rw<0?'var(--red)':'var(--b2)';
|
| 581 |
+
const bodyClass=rw>0?'good':rw<0?'bad':'';
|
| 582 |
+
|
| 583 |
+
// Tool badge
|
| 584 |
+
let toolName='';
|
| 585 |
+
if(actionData&&actionData.tool)toolName=actionData.tool;
|
| 586 |
+
else if(label.includes('.'))toolName=label.split('.')[0];
|
| 587 |
+
const badgeClass={'crm':'badge-crm','messaging':'badge-messaging','approval':'badge-approval','workflow':'badge-workflow'}[toolName]||'badge-workflow';
|
| 588 |
+
const badge=toolName?'<span class="tl-tool-badge '+badgeClass+'">'+toolName+'</span>':'';
|
| 589 |
+
|
| 590 |
+
// Parse feedback for display
|
| 591 |
+
let feedbackText='';
|
| 592 |
+
if(o.feedback){
|
| 593 |
+
try{
|
| 594 |
+
const fb=JSON.parse(o.feedback);
|
| 595 |
+
if(fb.reply)feedbackText=fb.reply;
|
| 596 |
+
else if(fb.message)feedbackText=fb.message;
|
| 597 |
+
else if(fb.error)feedbackText='Error: '+fb.error;
|
| 598 |
+
else if(fb.result)feedbackText='Result: '+fb.result+(fb.reason?' ('+fb.reason+')':'');
|
| 599 |
+
else if(fb.approval_status)feedbackText='Approval: '+fb.approval_status;
|
| 600 |
+
else if(fb.stage)feedbackText='Stage updated: '+fb.stage;
|
| 601 |
+
else if(fb.field)feedbackText=fb.field+' = '+fb.value;
|
| 602 |
+
else if(fb.elapsed)feedbackText='Time elapsed: '+fb.elapsed;
|
| 603 |
+
else feedbackText=o.feedback;
|
| 604 |
+
}catch(e){feedbackText=o.feedback}
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
let html='<div class="tl-entry"><div class="tl-dot-col"><div class="tl-dot" style="background:'+dotColor+'"></div><div class="tl-line"></div></div><div class="tl-content"><div class="tl-head"><span class="tl-step mono">'+S.stepCount+'</span>'+badge+'<span class="tl-action">'+label+'</span><span class="tl-reward '+rwClass+'">'+rwStr+'</span></div>';
|
| 608 |
+
if(feedbackText)html+='<div class="tl-body '+bodyClass+'">'+feedbackText+'</div>';
|
| 609 |
+
html+='</div></div>';
|
| 610 |
+
tl.innerHTML+=html;
|
| 611 |
+
tl.scrollTop=tl.scrollHeight;
|
| 612 |
+
} else if(o.feedback){
|
| 613 |
+
const tl=document.getElementById('tl');
|
| 614 |
+
let feedbackText='';
|
| 615 |
+
try{
|
| 616 |
+
const fb=JSON.parse(o.feedback);
|
| 617 |
+
feedbackText='New episode: '+o.driver_name+' — '+fb.jobs+' jobs available';
|
| 618 |
+
}catch(e){feedbackText=o.feedback}
|
| 619 |
+
tl.innerHTML+='<div class="tl-entry"><div class="tl-dot-col"><div class="tl-dot" style="background:var(--violet)"></div><div class="tl-line"></div></div><div class="tl-content"><div class="tl-head"><span class="tl-step mono">0</span><span class="tl-action">Episode start</span></div><div class="tl-body">'+feedbackText+'</div></div></div>';
|
| 620 |
+
}
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
function renderPipeline(currentStage){
|
| 624 |
+
const el=document.getElementById('pipeStages');
|
| 625 |
+
const failStage=FAIL_STAGES.includes(currentStage)?currentStage:null;
|
| 626 |
+
const curIdx=STAGES.indexOf(currentStage);
|
| 627 |
+
let html='';
|
| 628 |
+
STAGES.forEach((s,i)=>{
|
| 629 |
+
let cls='pipe-stage';
|
| 630 |
+
if(failStage){cls+=' fail'}
|
| 631 |
+
else if(i<curIdx)cls+=' done';
|
| 632 |
+
else if(i===curIdx)cls+=' active';
|
| 633 |
+
html+='<span class="'+cls+'">'+s.replace('_',' ')+'</span>';
|
| 634 |
+
if(i<STAGES.length-1)html+='<span class="pipe-arrow">→</span>';
|
| 635 |
+
});
|
| 636 |
+
if(failStage){
|
| 637 |
+
html+='<span class="pipe-arrow">→</span><span class="pipe-stage fail">'+failStage+'</span>';
|
| 638 |
+
}
|
| 639 |
+
el.innerHTML=html;
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
// --- Modals ---
|
| 643 |
+
let pendingModalAction='';
|
| 644 |
+
|
| 645 |
+
function showJobModal(action){
|
| 646 |
+
pendingModalAction=action;
|
| 647 |
+
const titles={'pitch':'Pitch which job?','offer':'Send offer for which job?','request_approval':'Request approval for which job?'};
|
| 648 |
+
document.getElementById('modalTitle').textContent=titles[action]||'Select job';
|
| 649 |
+
document.getElementById('modalJobs').innerHTML=S.jobs.map(j=>'<button class="modal-job" onclick="selJob('+j.id+')">'+j.label+'</button>').join('');
|
| 650 |
+
document.getElementById('modalBg').classList.add('on');
|
| 651 |
+
}
|
| 652 |
+
function selJob(id){
|
| 653 |
+
closeModal();
|
| 654 |
+
if(pendingModalAction==='pitch'||pendingModalAction==='offer'){
|
| 655 |
+
doMsg(pendingModalAction,id);
|
| 656 |
+
} else if(pendingModalAction==='request_approval'){
|
| 657 |
+
doTool('approval','request_approval',{job_id:id});
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
function closeModal(){document.getElementById('modalBg').classList.remove('on')}
|
| 661 |
+
|
| 662 |
+
function showStageModal(preselect){
|
| 663 |
+
const stages=['contacted','interested','approval_pending','offer_sent','hired','lost'];
|
| 664 |
+
document.getElementById('stageModalBtns').innerHTML=stages.map(s=>{
|
| 665 |
+
const cls=s==='hired'?'modal-job" style="border-color:var(--green);color:var(--green)':s==='lost'?'modal-job" style="border-color:var(--red);color:var(--red)':'modal-job';
|
| 666 |
+
return '<button class="'+cls+'" onclick="doStage(\''+s+'\');closeStageModal()">'+s.replace('_',' ')+'</button>';
|
| 667 |
+
}).join('');
|
| 668 |
+
document.getElementById('stageModalBg').classList.add('on');
|
| 669 |
+
}
|
| 670 |
+
function closeStageModal(){document.getElementById('stageModalBg').classList.remove('on')}
|
| 671 |
+
|
| 672 |
+
function showFieldModal(){document.getElementById('fieldModalBg').classList.add('on');document.getElementById('fieldValue').value='';document.getElementById('fieldValue').focus()}
|
| 673 |
+
function closeFieldModal(){document.getElementById('fieldModalBg').classList.remove('on')}
|
| 674 |
+
function submitField(){
|
| 675 |
+
const f=document.getElementById('fieldSelect').value;
|
| 676 |
+
const v=document.getElementById('fieldValue').value;
|
| 677 |
+
if(!v)return;
|
| 678 |
+
closeFieldModal();
|
| 679 |
+
doTool('crm','update_field',{field:f,value:v});
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
function showNoteModal(){document.getElementById('noteModalBg').classList.add('on');document.getElementById('noteValue').value='';document.getElementById('noteValue').focus()}
|
| 683 |
+
function closeNoteModal(){document.getElementById('noteModalBg').classList.remove('on')}
|
| 684 |
+
function submitNote(){
|
| 685 |
+
const v=document.getElementById('noteValue').value;
|
| 686 |
+
if(!v)return;
|
| 687 |
+
closeNoteModal();
|
| 688 |
+
doTool('crm','add_note',{value:v});
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
// --- End screen ---
|
| 692 |
+
function showEnd(o){
|
| 693 |
+
const e=document.getElementById('endscreen');e.classList.add('on');
|
| 694 |
+
const win=o.stage==='hired';
|
| 695 |
+
document.getElementById('endLabel').textContent=win?'DRIVER HIRED':'EPISODE ENDED';
|
| 696 |
+
document.getElementById('endLabel').style.color=win?'var(--green)':'var(--red)';
|
| 697 |
+
document.getElementById('endTitle').textContent=win?'Placement complete':o.stage==='ghosted'?'Driver ghosted':'Failed';
|
| 698 |
+
document.getElementById('endTitle').style.color=win?'var(--green)':'var(--t1)';
|
| 699 |
+
|
| 700 |
+
let subText='';
|
| 701 |
+
if(o.feedback){
|
| 702 |
+
try{
|
| 703 |
+
const fb=JSON.parse(o.feedback);
|
| 704 |
+
if(fb.reason)subText=fb.reason.replace(/_/g,' ');
|
| 705 |
+
if(fb.result)subText=fb.result.replace(/_/g,' ')+(subText?' — '+subText:'');
|
| 706 |
+
if(fb.score)subText+=' (fit score: '+fb.score+')';
|
| 707 |
+
if(fb.crm_bonus)subText+=' CRM bonus: +'+fb.crm_bonus;
|
| 708 |
+
}catch(e){subText=o.feedback}
|
| 709 |
+
}
|
| 710 |
+
document.getElementById('endSub').innerHTML=subText;
|
| 711 |
+
|
| 712 |
+
const rv=document.getElementById('erRew');
|
| 713 |
+
rv.textContent=(S.rew>=0?'+':'')+S.rew.toFixed(1);
|
| 714 |
+
rv.style.color=S.rew>=0?'var(--green)':'var(--red)';
|
| 715 |
+
document.getElementById('erStep').textContent=S.stepCount;
|
| 716 |
+
document.getElementById('erStage').textContent=o.stage;
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
// Enter to submit in modals
|
| 720 |
+
document.getElementById('fieldValue').addEventListener('keydown',e=>{if(e.key==='Enter')submitField()});
|
| 721 |
+
document.getElementById('noteValue').addEventListener('keydown',e=>{if(e.key==='Enter')submitNote()});
|
| 722 |
+
</script>
|
| 723 |
+
</body>
|
| 724 |
+
</html>
|
eval_trained.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate a trained model against the recruiting environment."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
import torch
|
| 7 |
+
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
|
| 8 |
+
|
| 9 |
+
SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
|
| 10 |
+
|
| 11 |
+
You have 4 tools:
|
| 12 |
+
|
| 13 |
+
## crm
|
| 14 |
+
- read_candidate: Read the current CRM record
|
| 15 |
+
- update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
|
| 16 |
+
- update_field: Record info (field + value)
|
| 17 |
+
- add_note: Add a free-text note
|
| 18 |
+
|
| 19 |
+
## messaging
|
| 20 |
+
- send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
|
| 21 |
+
- read_reply: Read the driver's response
|
| 22 |
+
|
| 23 |
+
## approval
|
| 24 |
+
- request_approval: Request approval for a job (needs job_id)
|
| 25 |
+
- check_approval: Check approval status
|
| 26 |
+
|
| 27 |
+
## workflow
|
| 28 |
+
- wait: Advance time (needed for approval processing)
|
| 29 |
+
|
| 30 |
+
## Rules
|
| 31 |
+
- Must read CRM before messaging
|
| 32 |
+
- Must read_reply before sending another message
|
| 33 |
+
- Must request_approval and wait before sending offer
|
| 34 |
+
- Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
|
| 35 |
+
- Record important info in CRM with update_field
|
| 36 |
+
|
| 37 |
+
Respond with ONLY JSON:
|
| 38 |
+
{"tool": "crm", "action": "read_candidate"}
|
| 39 |
+
{"tool": "messaging", "action": "send_message", "topic": "experience"}
|
| 40 |
+
{"tool": "messaging", "action": "read_reply"}
|
| 41 |
+
{"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
|
| 42 |
+
{"tool": "crm", "action": "update_stage", "stage": "contacted"}
|
| 43 |
+
{"tool": "approval", "action": "request_approval", "job_id": 2}
|
| 44 |
+
{"tool": "workflow", "action": "wait"}
|
| 45 |
+
{"tool": "approval", "action": "check_approval"}
|
| 46 |
+
{"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
|
| 47 |
+
{"tool": "crm", "action": "update_stage", "stage": "hired"}"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def format_observation(obs):
|
| 51 |
+
parts = [f"Driver: {obs.driver_name}"]
|
| 52 |
+
if obs.crm_summary:
|
| 53 |
+
parts.append(f"CRM:\n{obs.crm_summary}")
|
| 54 |
+
if obs.jobs_summary:
|
| 55 |
+
parts.append(f"Jobs:\n{obs.jobs_summary}")
|
| 56 |
+
if obs.discovered_info:
|
| 57 |
+
parts.append(f"Discovered:\n{obs.discovered_info}")
|
| 58 |
+
status = f"Stage: {obs.stage}"
|
| 59 |
+
if obs.pending_reply:
|
| 60 |
+
status += " | PENDING REPLY"
|
| 61 |
+
parts.append(status)
|
| 62 |
+
if obs.feedback:
|
| 63 |
+
parts.append(f"Result: {obs.feedback}")
|
| 64 |
+
return "\n".join(parts)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_action(text):
|
| 68 |
+
text = text.strip()
|
| 69 |
+
if "```" in text:
|
| 70 |
+
for part in text.split("```"):
|
| 71 |
+
part = part.strip()
|
| 72 |
+
if part.startswith("json"):
|
| 73 |
+
part = part[4:].strip()
|
| 74 |
+
if part.startswith("{"):
|
| 75 |
+
text = part
|
| 76 |
+
break
|
| 77 |
+
try:
|
| 78 |
+
data = json.loads(text)
|
| 79 |
+
if isinstance(data, list):
|
| 80 |
+
data = data[0] if data else {}
|
| 81 |
+
if isinstance(data, dict) and "tool" in data and "action" in data:
|
| 82 |
+
return RecruitopenenvAction(
|
| 83 |
+
tool=data["tool"],
|
| 84 |
+
action=data["action"],
|
| 85 |
+
topic=data.get("topic", ""),
|
| 86 |
+
job_id=data.get("job_id", -1),
|
| 87 |
+
stage=data.get("stage", ""),
|
| 88 |
+
field=data.get("field", ""),
|
| 89 |
+
value=data.get("value", ""),
|
| 90 |
+
)
|
| 91 |
+
except (json.JSONDecodeError, KeyError, IndexError):
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
text_lower = text.lower()
|
| 95 |
+
if "read_candidate" in text_lower:
|
| 96 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 97 |
+
if "read_reply" in text_lower:
|
| 98 |
+
return RecruitopenenvAction(tool="messaging", action="read_reply")
|
| 99 |
+
if "check_approval" in text_lower:
|
| 100 |
+
return RecruitopenenvAction(tool="approval", action="check_approval")
|
| 101 |
+
if "wait" in text_lower:
|
| 102 |
+
return RecruitopenenvAction(tool="workflow", action="wait")
|
| 103 |
+
|
| 104 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def generate(model, tokenizer, messages, device):
|
| 108 |
+
prompt = tokenizer.apply_chat_template(
|
| 109 |
+
messages, add_generation_prompt=True, tokenize=False
|
| 110 |
+
)
|
| 111 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
outputs = model.generate(
|
| 114 |
+
**inputs,
|
| 115 |
+
max_new_tokens=128,
|
| 116 |
+
temperature=0.1,
|
| 117 |
+
do_sample=True,
|
| 118 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 119 |
+
)
|
| 120 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 121 |
+
return tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main():
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
parser.add_argument("--model", default="./recruit-grpo-output", help="Path to trained model")
|
| 127 |
+
parser.add_argument("--base-model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Base model for comparison")
|
| 128 |
+
parser.add_argument("--env-url", default="http://localhost:8001")
|
| 129 |
+
parser.add_argument("--num-episodes", type=int, default=20)
|
| 130 |
+
parser.add_argument("--compare", action="store_true", help="Also run base model for comparison")
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
|
| 133 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 134 |
+
|
| 135 |
+
models_to_eval = [("TRAINED", args.model)]
|
| 136 |
+
if args.compare:
|
| 137 |
+
models_to_eval.append(("BASE", args.base_model))
|
| 138 |
+
|
| 139 |
+
for label, model_path in models_to_eval:
|
| 140 |
+
print(f"\n{'='*50}")
|
| 141 |
+
print(f"Evaluating: {label} ({model_path})")
|
| 142 |
+
print(f"{'='*50}")
|
| 143 |
+
|
| 144 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 145 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 146 |
+
model_path, torch_dtype=torch.float16, device_map="auto"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
rewards = []
|
| 150 |
+
successes = 0
|
| 151 |
+
total_steps = 0
|
| 152 |
+
|
| 153 |
+
with RecruitopenenvEnv(base_url=args.env_url) as env:
|
| 154 |
+
for ep in range(args.num_episodes):
|
| 155 |
+
result = env.reset()
|
| 156 |
+
obs = result.observation
|
| 157 |
+
ep_reward = 0.0
|
| 158 |
+
steps = 0
|
| 159 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 160 |
+
|
| 161 |
+
while not result.done and steps < 100:
|
| 162 |
+
obs_text = format_observation(obs)
|
| 163 |
+
messages.append({"role": "user", "content": obs_text})
|
| 164 |
+
|
| 165 |
+
response = generate(model, tokenizer, messages, device)
|
| 166 |
+
messages.append({"role": "assistant", "content": response})
|
| 167 |
+
|
| 168 |
+
action = parse_action(response)
|
| 169 |
+
result = env.step(action)
|
| 170 |
+
obs = result.observation
|
| 171 |
+
ep_reward += result.reward
|
| 172 |
+
steps += 1
|
| 173 |
+
|
| 174 |
+
print(f" Step {steps}: {action.tool}.{action.action}"
|
| 175 |
+
f"{'(' + action.topic + ')' if action.topic else ''}"
|
| 176 |
+
f"{'[job=' + str(action.job_id) + ']' if action.job_id >= 0 else ''}"
|
| 177 |
+
f" -> reward={result.reward:.1f}")
|
| 178 |
+
|
| 179 |
+
rewards.append(ep_reward)
|
| 180 |
+
total_steps += steps
|
| 181 |
+
hired = obs.stage == "hired"
|
| 182 |
+
if hired:
|
| 183 |
+
successes += 1
|
| 184 |
+
|
| 185 |
+
print(f"Episode {ep+1}: reward={ep_reward:.1f}, steps={steps}, "
|
| 186 |
+
f"{'HIRED' if hired else 'FAIL (' + obs.stage + ')'}")
|
| 187 |
+
print()
|
| 188 |
+
|
| 189 |
+
avg_reward = sum(rewards) / len(rewards)
|
| 190 |
+
avg_steps = total_steps / args.num_episodes
|
| 191 |
+
|
| 192 |
+
print(f"\n{'='*40}")
|
| 193 |
+
print(f" {label} RESULTS")
|
| 194 |
+
print(f"{'='*40}")
|
| 195 |
+
print(f"Model: {model_path}")
|
| 196 |
+
print(f"Episodes: {args.num_episodes}")
|
| 197 |
+
print(f"Avg reward: {avg_reward:.2f}")
|
| 198 |
+
print(f"Min reward: {min(rewards):.2f}")
|
| 199 |
+
print(f"Max reward: {max(rewards):.2f}")
|
| 200 |
+
print(f"Hire rate: {successes}/{args.num_episodes} ({100*successes/args.num_episodes:.1f}%)")
|
| 201 |
+
print(f"Avg steps/episode: {avg_steps:.1f}")
|
| 202 |
+
print(f"{'='*40}")
|
| 203 |
+
|
| 204 |
+
del model
|
| 205 |
+
torch.cuda.empty_cache()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the Driver Recruit Environment.
|
| 3 |
+
|
| 4 |
+
Tool-based action interface for long-horizon recruiting pipeline.
|
| 5 |
+
Agent uses CRM, messaging, approval, and workflow tools.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pydantic import Field
|
| 9 |
+
|
| 10 |
+
from openenv.core.env_server.types import Action, Observation
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RecruitopenenvAction(Action):
|
| 14 |
+
"""Tool-based action the agent takes."""
|
| 15 |
+
|
| 16 |
+
tool: str = Field(
|
| 17 |
+
...,
|
| 18 |
+
description="Tool: crm, messaging, approval, workflow",
|
| 19 |
+
)
|
| 20 |
+
action: str = Field(
|
| 21 |
+
...,
|
| 22 |
+
description=(
|
| 23 |
+
"Action within tool. "
|
| 24 |
+
"crm: read_candidate, update_stage, update_field, add_note. "
|
| 25 |
+
"messaging: send_message, read_reply. "
|
| 26 |
+
"approval: request_approval, check_approval. "
|
| 27 |
+
"workflow: wait."
|
| 28 |
+
),
|
| 29 |
+
)
|
| 30 |
+
topic: str = Field(
|
| 31 |
+
default="",
|
| 32 |
+
description=(
|
| 33 |
+
"Message topic for messaging.send_message: "
|
| 34 |
+
"greeting, call, experience, home_time, pay, equipment, route, "
|
| 35 |
+
"deal_breakers, availability, violations, medical_card, references, "
|
| 36 |
+
"pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern"
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
job_id: int = Field(
|
| 40 |
+
default=-1,
|
| 41 |
+
description="Job index (0-5). Used with pitch, offer, request_approval.",
|
| 42 |
+
)
|
| 43 |
+
stage: str = Field(
|
| 44 |
+
default="",
|
| 45 |
+
description="Target stage for crm.update_stage: contacted, interested, approval_pending, offer_sent, hired, lost",
|
| 46 |
+
)
|
| 47 |
+
field: str = Field(
|
| 48 |
+
default="",
|
| 49 |
+
description="CRM field for crm.update_field",
|
| 50 |
+
)
|
| 51 |
+
value: str = Field(
|
| 52 |
+
default="",
|
| 53 |
+
description="Value for crm.update_field or text for crm.add_note",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RecruitopenenvObservation(Observation):
|
| 58 |
+
"""What the agent sees after each action."""
|
| 59 |
+
|
| 60 |
+
driver_name: str = Field(default="", description="Driver's name")
|
| 61 |
+
crm_summary: str = Field(default="", description="CRM record (empty until read_candidate)")
|
| 62 |
+
jobs_summary: str = Field(default="", description="Available job listings")
|
| 63 |
+
discovered_info: str = Field(default="", description="Info discovered through conversation")
|
| 64 |
+
|
| 65 |
+
stage: str = Field(default="lead", description="Current pipeline stage")
|
| 66 |
+
feedback: str = Field(default="", description="API response from last action")
|
| 67 |
+
pending_reply: bool = Field(default=False, description="Whether an unread message is waiting")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: recruitopenenv
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
openenv_recruitopenenv.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-recruitopenenv
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Recruitopenenv environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.0
|
| 7 |
+
Provides-Extra: dev
|
| 8 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 9 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_recruitopenenv.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
__init__.py
|
| 3 |
+
client.py
|
| 4 |
+
models.py
|
| 5 |
+
pyproject.toml
|
| 6 |
+
./__init__.py
|
| 7 |
+
./client.py
|
| 8 |
+
./models.py
|
| 9 |
+
openenv_recruitopenenv.egg-info/PKG-INFO
|
| 10 |
+
openenv_recruitopenenv.egg-info/SOURCES.txt
|
| 11 |
+
openenv_recruitopenenv.egg-info/dependency_links.txt
|
| 12 |
+
openenv_recruitopenenv.egg-info/entry_points.txt
|
| 13 |
+
openenv_recruitopenenv.egg-info/requires.txt
|
| 14 |
+
openenv_recruitopenenv.egg-info/top_level.txt
|
| 15 |
+
server/__init__.py
|
| 16 |
+
server/app.py
|
| 17 |
+
server/recruitopenenv_environment.py
|
openenv_recruitopenenv.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_recruitopenenv.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = recruitopenenv.server.app:main
|
openenv_recruitopenenv.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
|
| 3 |
+
[dev]
|
| 4 |
+
pytest>=8.0.0
|
| 5 |
+
pytest-cov>=4.0.0
|
openenv_recruitopenenv.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
recruitopenenv
|
play.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interactive CLI to play the recruiting environment manually."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
BASE_URL = "http://localhost:8000"
|
| 7 |
+
|
| 8 |
+
SHORTCUTS = {
|
| 9 |
+
"r": '{"tool":"crm","action":"read_candidate"}',
|
| 10 |
+
"rr": '{"tool":"messaging","action":"read_reply"}',
|
| 11 |
+
"w": '{"tool":"workflow","action":"wait"}',
|
| 12 |
+
"ca": '{"tool":"approval","action":"check_approval"}',
|
| 13 |
+
"hi": '{"tool":"crm","action":"update_stage","stage":"hired"}',
|
| 14 |
+
"lost": '{"tool":"crm","action":"update_stage","stage":"lost"}',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
TOPIC_SHORTCUTS = {
|
| 18 |
+
"g": "greeting", "c": "call", "exp": "experience", "ht": "home_time",
|
| 19 |
+
"pay": "pay", "eq": "equipment", "rt": "route", "db": "deal_breakers",
|
| 20 |
+
"av": "availability", "vio": "violations", "med": "medical_card",
|
| 21 |
+
"ref": "references", "pitch": "pitch", "offer": "offer",
|
| 22 |
+
"np": "negotiate_pay", "nht": "negotiate_home_time",
|
| 23 |
+
"sb": "signing_bonus", "ac": "address_concern",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def print_obs(obs, reward):
|
| 27 |
+
print(f"\n{'='*60}")
|
| 28 |
+
print(f"Driver: {obs['driver_name']}")
|
| 29 |
+
if obs.get('crm_summary'):
|
| 30 |
+
print(f"\nCRM:\n{obs['crm_summary']}")
|
| 31 |
+
if obs.get('jobs_summary'):
|
| 32 |
+
print(f"\nJobs:\n{obs['jobs_summary']}")
|
| 33 |
+
if obs.get('discovered_info'):
|
| 34 |
+
print(f"\nDiscovered:\n{obs['discovered_info']}")
|
| 35 |
+
status = f"Stage: {obs['stage']}"
|
| 36 |
+
if obs.get('pending_reply'):
|
| 37 |
+
status += " | PENDING REPLY"
|
| 38 |
+
print(f"\n{status}")
|
| 39 |
+
print(f"Reward this step: {reward}")
|
| 40 |
+
if obs.get('feedback'):
|
| 41 |
+
try:
|
| 42 |
+
fb = json.loads(obs['feedback'])
|
| 43 |
+
print(f"Response: {json.dumps(fb, indent=2)}")
|
| 44 |
+
except (json.JSONDecodeError, TypeError):
|
| 45 |
+
print(f"Response: {obs['feedback']}")
|
| 46 |
+
|
| 47 |
+
def print_help():
|
| 48 |
+
print("\nShortcuts:")
|
| 49 |
+
print(" r = read CRM")
|
| 50 |
+
print(" rr = read reply")
|
| 51 |
+
print(" w = wait")
|
| 52 |
+
print(" ca = check approval")
|
| 53 |
+
print(" hi = update stage to hired")
|
| 54 |
+
print(" lost = update stage to lost")
|
| 55 |
+
print("\nSend message: s <topic> e.g. s g, s exp, s offer")
|
| 56 |
+
print(" Topics: g=greeting c=call exp=experience ht=home_time pay eq=equipment")
|
| 57 |
+
print(" rt=route db=deal_breakers av=availability vio=violations med=medical_card")
|
| 58 |
+
print(" ref=references pitch offer np=negotiate_pay nht=negotiate_home_time")
|
| 59 |
+
print(" sb=signing_bonus ac=address_concern")
|
| 60 |
+
print("\nWith job_id: s pitch 2 s offer 3")
|
| 61 |
+
print("\nUpdate stage: st <stage> e.g. st contacted")
|
| 62 |
+
print("Update field: f <field> <value> e.g. f cdl_class A")
|
| 63 |
+
print("Add note: n <text> e.g. n Driver prefers OTR")
|
| 64 |
+
print("Request approval: ra <job_id> e.g. ra 2")
|
| 65 |
+
print("\nOr paste raw JSON: {\"tool\":\"crm\",\"action\":\"read_candidate\"}")
|
| 66 |
+
print(" q = quit, h = help, reset = new episode")
|
| 67 |
+
|
| 68 |
+
def parse_input(user_input):
|
| 69 |
+
user_input = user_input.strip()
|
| 70 |
+
if not user_input:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
# Shortcuts
|
| 74 |
+
if user_input in SHORTCUTS:
|
| 75 |
+
return json.loads(SHORTCUTS[user_input])
|
| 76 |
+
|
| 77 |
+
# Raw JSON
|
| 78 |
+
if user_input.startswith("{"):
|
| 79 |
+
return json.loads(user_input)
|
| 80 |
+
|
| 81 |
+
parts = user_input.split(None, 2)
|
| 82 |
+
cmd = parts[0]
|
| 83 |
+
|
| 84 |
+
# Send message: s <topic> [job_id]
|
| 85 |
+
if cmd == "s" and len(parts) >= 2:
|
| 86 |
+
topic = TOPIC_SHORTCUTS.get(parts[1], parts[1])
|
| 87 |
+
action = {"tool": "messaging", "action": "send_message", "topic": topic}
|
| 88 |
+
if len(parts) >= 3:
|
| 89 |
+
action["job_id"] = int(parts[2])
|
| 90 |
+
return action
|
| 91 |
+
|
| 92 |
+
# Update stage: st <stage>
|
| 93 |
+
if cmd == "st" and len(parts) >= 2:
|
| 94 |
+
return {"tool": "crm", "action": "update_stage", "stage": parts[1]}
|
| 95 |
+
|
| 96 |
+
# Update field: f <field> <value>
|
| 97 |
+
if cmd == "f" and len(parts) >= 3:
|
| 98 |
+
return {"tool": "crm", "action": "update_field", "field": parts[1], "value": parts[2]}
|
| 99 |
+
|
| 100 |
+
# Add note: n <text>
|
| 101 |
+
if cmd == "n" and len(parts) >= 2:
|
| 102 |
+
return {"tool": "crm", "action": "add_note", "value": " ".join(parts[1:])}
|
| 103 |
+
|
| 104 |
+
# Request approval: ra <job_id>
|
| 105 |
+
if cmd == "ra" and len(parts) >= 2:
|
| 106 |
+
return {"tool": "approval", "action": "request_approval", "job_id": int(parts[1])}
|
| 107 |
+
|
| 108 |
+
print(f"Unknown command: {user_input}. Type 'h' for help.")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
session = requests.Session()
|
| 113 |
+
total_reward = 0.0
|
| 114 |
+
|
| 115 |
+
print("\n🚛 DRIVER RECRUITING ENVIRONMENT — INTERACTIVE MODE")
|
| 116 |
+
print_help()
|
| 117 |
+
|
| 118 |
+
# Reset
|
| 119 |
+
resp = session.post(f"{BASE_URL}/reset", json={})
|
| 120 |
+
data = resp.json()
|
| 121 |
+
obs = data["observation"]
|
| 122 |
+
print_obs(obs, 0)
|
| 123 |
+
|
| 124 |
+
while True:
|
| 125 |
+
try:
|
| 126 |
+
user_input = input("\n> ").strip()
|
| 127 |
+
except (EOFError, KeyboardInterrupt):
|
| 128 |
+
print("\nBye!")
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
if user_input == "q":
|
| 132 |
+
break
|
| 133 |
+
if user_input == "h":
|
| 134 |
+
print_help()
|
| 135 |
+
continue
|
| 136 |
+
if user_input == "reset":
|
| 137 |
+
resp = session.post(f"{BASE_URL}/reset", json={})
|
| 138 |
+
data = resp.json()
|
| 139 |
+
obs = data["observation"]
|
| 140 |
+
total_reward = 0.0
|
| 141 |
+
print_obs(obs, 0)
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
action = parse_input(user_input)
|
| 145 |
+
if action is None:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
print(f"→ {action['tool']}.{action['action']}"
|
| 149 |
+
+ (f"({action.get('topic', '')})" if action.get('topic') else "")
|
| 150 |
+
+ (f"[job={action['job_id']}]" if action.get('job_id', -1) >= 0 else "")
|
| 151 |
+
+ (f"({action.get('stage', '')})" if action.get('stage') else "")
|
| 152 |
+
+ (f"({action.get('field', '')}={action.get('value', '')})" if action.get('field') else ""))
|
| 153 |
+
|
| 154 |
+
resp = session.post(f"{BASE_URL}/step", json=action)
|
| 155 |
+
data = resp.json()
|
| 156 |
+
obs = data["observation"]
|
| 157 |
+
reward = data["reward"]
|
| 158 |
+
done = data["done"]
|
| 159 |
+
total_reward += reward
|
| 160 |
+
|
| 161 |
+
print_obs(obs, reward)
|
| 162 |
+
print(f"Total reward: {total_reward:.1f}")
|
| 163 |
+
|
| 164 |
+
if done:
|
| 165 |
+
print(f"\n{'='*60}")
|
| 166 |
+
print(f"EPISODE OVER — Final stage: {obs['stage']} | Total reward: {total_reward:.1f}")
|
| 167 |
+
print(f"{'='*60}")
|
| 168 |
+
print("Type 'reset' for a new episode or 'q' to quit.")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-recruitopenenv"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Recruitopenenv environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]==0.2.1",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m recruitopenenv.server.app
|
| 40 |
+
server = "recruitopenenv.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["recruitopenenv", "recruitopenenv.server"]
|
| 45 |
+
package-dir = { "recruitopenenv" = ".", "recruitopenenv.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Recruitopenenv environment server components."""
|
| 8 |
+
|
| 9 |
+
from .recruitopenenv_environment import RecruitopenenvEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["RecruitopenenvEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Recruitopenenv Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the RecruitopenenvEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 34 |
+
from fastapi.responses import FileResponse
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from openenv.core.env_server.http_server import create_app
|
| 38 |
+
except Exception as e: # pragma: no cover
|
| 39 |
+
raise ImportError(
|
| 40 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 41 |
+
) from e
|
| 42 |
+
|
| 43 |
+
# Import from local models.py (PYTHONPATH includes /app/env in Docker)
|
| 44 |
+
from models import RecruitopenenvAction, RecruitopenenvObservation
|
| 45 |
+
from .recruitopenenv_environment import RecruitopenenvEnvironment
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Create the app with web interface and README integration
|
| 49 |
+
app = create_app(
|
| 50 |
+
RecruitopenenvEnvironment,
|
| 51 |
+
RecruitopenenvAction,
|
| 52 |
+
RecruitopenenvObservation,
|
| 53 |
+
env_name="recruitopenenv",
|
| 54 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# CORS for demo page
|
| 58 |
+
app.add_middleware(
|
| 59 |
+
CORSMiddleware,
|
| 60 |
+
allow_origins=["*"],
|
| 61 |
+
allow_methods=["*"],
|
| 62 |
+
allow_headers=["*"],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Serve the demo page
|
| 66 |
+
_DEMO_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "demo")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@app.get("/demo", include_in_schema=False)
|
| 70 |
+
async def demo_page():
|
| 71 |
+
return FileResponse(os.path.join(_DEMO_DIR, "index.html"))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 75 |
+
"""
|
| 76 |
+
Entry point for direct execution via uv run or python -m.
|
| 77 |
+
|
| 78 |
+
This function enables running the server without Docker:
|
| 79 |
+
uv run --project . server
|
| 80 |
+
uv run --project . server --port 8001
|
| 81 |
+
python -m recruitopenenv.server.app
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 85 |
+
port: Port number to listen on (default: 8000)
|
| 86 |
+
|
| 87 |
+
For production deployments, consider using uvicorn directly with
|
| 88 |
+
multiple workers:
|
| 89 |
+
uvicorn recruitopenenv.server.app:app --workers 4
|
| 90 |
+
"""
|
| 91 |
+
import uvicorn
|
| 92 |
+
|
| 93 |
+
uvicorn.run(app, host=host, port=port)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
import argparse
|
| 98 |
+
|
| 99 |
+
parser = argparse.ArgumentParser()
|
| 100 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
main(port=args.port)
|
server/recruitopenenv_environment.py
ADDED
|
@@ -0,0 +1,1422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Driver Recruit Environment — Tool-based Long-Horizon.
|
| 3 |
+
|
| 4 |
+
Agent interacts through 4 tools: CRM, messaging, approval, workflow.
|
| 5 |
+
Each recruiting interaction requires multiple tool calls, creating
|
| 6 |
+
naturally long episodes (40-70 steps).
|
| 7 |
+
|
| 8 |
+
Pipeline: lead → contacted → interested → approval_pending → offer_sent → hired
|
| 9 |
+
Terminal failures: lost, ghosted
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
from uuid import uuid4
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server.interfaces import Environment
|
| 17 |
+
from openenv.core.env_server.types import State
|
| 18 |
+
|
| 19 |
+
from models import RecruitopenenvAction, RecruitopenenvObservation
|
| 20 |
+
|
| 21 |
+
# --- Constants ---
|
| 22 |
+
|
| 23 |
+
FIRST_NAMES = [
|
| 24 |
+
"Mike", "James", "Robert", "John", "David", "Carlos", "Marcus",
|
| 25 |
+
"Sarah", "Maria", "Linda", "Patricia", "Jessica", "Angela", "Rosa",
|
| 26 |
+
"Travis", "Derek", "Kevin", "Brandon", "Tyler", "Dustin", "Ray",
|
| 27 |
+
]
|
| 28 |
+
LAST_NAMES = [
|
| 29 |
+
"Johnson", "Smith", "Williams", "Garcia", "Martinez", "Brown",
|
| 30 |
+
"Davis", "Rodriguez", "Wilson", "Taylor", "Thomas", "Moore",
|
| 31 |
+
"Jackson", "White", "Harris", "Clark", "Lewis", "Young",
|
| 32 |
+
]
|
| 33 |
+
LOCATIONS = [
|
| 34 |
+
"Dallas TX", "Atlanta GA", "Chicago IL", "Denver CO", "Phoenix AZ",
|
| 35 |
+
"Memphis TN", "Louisville KY", "Nashville TN", "Indianapolis IN",
|
| 36 |
+
"Columbus OH", "Jacksonville FL", "Charlotte NC", "Kansas City MO",
|
| 37 |
+
]
|
| 38 |
+
COMPANIES = [
|
| 39 |
+
"Werner Enterprises", "Swift Transport", "Schneider National",
|
| 40 |
+
"J.B. Hunt", "KLLM Transport", "Heartland Express",
|
| 41 |
+
"Covenant Logistics", "USA Truck", "Marten Transport",
|
| 42 |
+
"Prime Inc", "CR England", "Western Express",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
CDL_CLASSES = ["A", "B"]
|
| 46 |
+
ENDORSEMENTS_ALL = ["H", "N", "T", "TWIC"]
|
| 47 |
+
HOME_TIMES = ["daily", "weekends", "weekly", "biweekly"]
|
| 48 |
+
ROUTE_TYPES = ["OTR", "regional", "local", "dedicated"]
|
| 49 |
+
EQUIPMENT_TYPES = ["dry_van", "flatbed", "reefer", "tanker"]
|
| 50 |
+
CONTACT_METHODS = ["text", "call"]
|
| 51 |
+
DEAL_BREAKERS_ALL = [
|
| 52 |
+
"touch_freight", "forced_dispatch", "team_driving",
|
| 53 |
+
"northeast", "hazmat_no_premium", "no_benefits",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
PERSONALITY_PARAMS = {
|
| 57 |
+
"chatty": {"initial_trust": 0.80, "decay": 0.02, "reveal_breakers": "all"},
|
| 58 |
+
"professional": {"initial_trust": 0.70, "decay": 0.025, "reveal_breakers": "all"},
|
| 59 |
+
"impatient": {"initial_trust": 0.60, "decay": 0.04, "reveal_breakers": "partial"},
|
| 60 |
+
"suspicious": {"initial_trust": 0.55, "decay": 0.03, "reveal_breakers": "all_if_trusted"},
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
AVAILABILITIES = ["immediately", "2_weeks", "1_month", "negotiable"]
|
| 64 |
+
VIOLATION_LEVELS = ["clean", "minor", "major"]
|
| 65 |
+
MEDICAL_CARD_STATUS = ["valid", "expiring_soon", "expired"]
|
| 66 |
+
REFERENCE_QUALITY = ["strong", "mixed", "none"]
|
| 67 |
+
|
| 68 |
+
MAX_STEPS = 100
|
| 69 |
+
|
| 70 |
+
VALID_TOOL_ACTIONS = {
|
| 71 |
+
"crm": {"read_candidate", "update_stage", "update_field", "add_note"},
|
| 72 |
+
"messaging": {"send_message", "read_reply"},
|
| 73 |
+
"approval": {"request_approval", "check_approval"},
|
| 74 |
+
"workflow": {"wait"},
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
VALID_TOPICS = {
|
| 78 |
+
"greeting", "call",
|
| 79 |
+
"experience", "home_time", "pay", "equipment", "route", "deal_breakers",
|
| 80 |
+
"availability", "violations", "medical_card", "references",
|
| 81 |
+
"pitch", "offer",
|
| 82 |
+
"negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
STAGE_ORDER = ["lead", "contacted", "interested", "approval_pending", "offer_sent", "hired"]
|
| 86 |
+
ALL_STAGES = set(STAGE_ORDER) | {"lost", "ghosted"}
|
| 87 |
+
|
| 88 |
+
SCREENING_TOPICS = {
|
| 89 |
+
"experience", "home_time", "pay", "equipment", "route", "deal_breakers",
|
| 90 |
+
"availability", "violations", "medical_card", "references",
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
VALID_CRM_FIELDS = {
|
| 94 |
+
"cdl_class", "years_experience", "endorsements", "location",
|
| 95 |
+
"home_time_pref", "pay_expectation", "equipment_pref", "route_pref",
|
| 96 |
+
"deal_breakers", "availability", "violations", "medical_card", "references",
|
| 97 |
+
"matched_job",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# --- Data generation ---
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def generate_driver():
|
| 105 |
+
personality = random.choices(
|
| 106 |
+
["chatty", "professional", "impatient", "suspicious"],
|
| 107 |
+
weights=[25, 35, 20, 20],
|
| 108 |
+
)[0]
|
| 109 |
+
params = PERSONALITY_PARAMS[personality]
|
| 110 |
+
|
| 111 |
+
cdl = random.choices(CDL_CLASSES, weights=[75, 25])[0]
|
| 112 |
+
exp = random.randint(1, 20)
|
| 113 |
+
|
| 114 |
+
endorsements = [e for e in ENDORSEMENTS_ALL if random.random() < 0.10 + exp * 0.02]
|
| 115 |
+
|
| 116 |
+
equip_opts = ["dry_van", "flatbed", "reefer"]
|
| 117 |
+
if "N" in endorsements:
|
| 118 |
+
equip_opts.append("tanker")
|
| 119 |
+
equipment_pref = random.choice(equip_opts)
|
| 120 |
+
|
| 121 |
+
n_breakers = random.choices([1, 2, 3], weights=[30, 50, 20])[0]
|
| 122 |
+
deal_breakers = random.sample(DEAL_BREAKERS_ALL, n_breakers)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"name": f"{random.choice(FIRST_NAMES)} {random.choice(LAST_NAMES)}",
|
| 126 |
+
"cdl_class": cdl,
|
| 127 |
+
"endorsements": endorsements,
|
| 128 |
+
"experience_years": exp,
|
| 129 |
+
"location": random.choice(LOCATIONS),
|
| 130 |
+
"preferred_contact": random.choice(CONTACT_METHODS),
|
| 131 |
+
"personality": personality,
|
| 132 |
+
"trust": params["initial_trust"],
|
| 133 |
+
"decay": params["decay"],
|
| 134 |
+
"home_time_pref": random.choices(HOME_TIMES, weights=[15, 30, 30, 25])[0],
|
| 135 |
+
"min_cpm": round(random.uniform(0.48, 0.78), 2),
|
| 136 |
+
"equipment_pref": equipment_pref,
|
| 137 |
+
"route_pref": random.choices(ROUTE_TYPES, weights=[20, 30, 30, 20])[0],
|
| 138 |
+
"deal_breakers": deal_breakers,
|
| 139 |
+
"availability": random.choices(AVAILABILITIES, weights=[30, 35, 25, 10])[0],
|
| 140 |
+
"violations": random.choices(VIOLATION_LEVELS, weights=[60, 30, 10])[0],
|
| 141 |
+
"medical_card": random.choices(MEDICAL_CARD_STATUS, weights=[70, 20, 10])[0],
|
| 142 |
+
"references": random.choices(REFERENCE_QUALITY, weights=[40, 40, 20])[0],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def generate_jobs(driver):
|
| 147 |
+
"""Generate 6 jobs: 1-2 good, 1-2 traps, 2-3 bad."""
|
| 148 |
+
jobs = []
|
| 149 |
+
if random.random() > 0.2:
|
| 150 |
+
jobs.append(_make_good_job(driver, 0))
|
| 151 |
+
else:
|
| 152 |
+
jobs.append(_make_trap_job(driver, 0))
|
| 153 |
+
jobs.append(_make_trap_job(driver, 1))
|
| 154 |
+
jobs.append(_make_partial_job(driver, 2))
|
| 155 |
+
|
| 156 |
+
bad_cdl = "B" if driver["cdl_class"] == "A" else "A"
|
| 157 |
+
jobs.append({
|
| 158 |
+
"job_id": 3, "company": random.choice(COMPANIES),
|
| 159 |
+
"required_cdl": bad_cdl, "required_endorsements": [],
|
| 160 |
+
"min_experience": random.randint(1, 5),
|
| 161 |
+
"route_type": random.choice(ROUTE_TYPES),
|
| 162 |
+
"home_time": random.choice(HOME_TIMES),
|
| 163 |
+
"pay_cpm": round(random.uniform(0.50, 0.85), 2),
|
| 164 |
+
"equipment": random.choice(EQUIPMENT_TYPES),
|
| 165 |
+
"has_touch_freight": random.random() < 0.3,
|
| 166 |
+
"forced_dispatch": random.random() < 0.3,
|
| 167 |
+
"team_driving": False, "northeast_routes": False,
|
| 168 |
+
"hazmat_premium": False,
|
| 169 |
+
"benefits": random.choice(["none", "basic", "good"]),
|
| 170 |
+
"location": random.choice(LOCATIONS),
|
| 171 |
+
"start_urgency": random.choice(["immediate", "flexible"]),
|
| 172 |
+
"requires_clean_record": random.random() < 0.3,
|
| 173 |
+
"requires_medical": True,
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
jobs.append({
|
| 177 |
+
"job_id": 4, "company": random.choice(COMPANIES),
|
| 178 |
+
"required_cdl": driver["cdl_class"],
|
| 179 |
+
"required_endorsements": ["H", "T"],
|
| 180 |
+
"min_experience": driver["experience_years"] + random.randint(5, 10),
|
| 181 |
+
"route_type": random.choice(ROUTE_TYPES),
|
| 182 |
+
"home_time": random.choice(HOME_TIMES),
|
| 183 |
+
"pay_cpm": round(random.uniform(0.70, 0.90), 2),
|
| 184 |
+
"equipment": random.choice(EQUIPMENT_TYPES),
|
| 185 |
+
"has_touch_freight": False, "forced_dispatch": False,
|
| 186 |
+
"team_driving": False, "northeast_routes": False,
|
| 187 |
+
"hazmat_premium": True, "benefits": "excellent",
|
| 188 |
+
"location": random.choice(LOCATIONS),
|
| 189 |
+
"start_urgency": "flexible",
|
| 190 |
+
"requires_clean_record": True,
|
| 191 |
+
"requires_medical": True,
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
if random.random() < 0.5:
|
| 195 |
+
jobs.append(_make_trap_job(driver, 5))
|
| 196 |
+
else:
|
| 197 |
+
jobs.append({
|
| 198 |
+
"job_id": 5, "company": random.choice(COMPANIES),
|
| 199 |
+
"required_cdl": bad_cdl, "required_endorsements": [],
|
| 200 |
+
"min_experience": random.randint(1, 8),
|
| 201 |
+
"route_type": random.choice(ROUTE_TYPES),
|
| 202 |
+
"home_time": driver["home_time_pref"],
|
| 203 |
+
"pay_cpm": round(driver["min_cpm"] + random.uniform(0.05, 0.15), 2),
|
| 204 |
+
"equipment": driver["equipment_pref"],
|
| 205 |
+
"has_touch_freight": False, "forced_dispatch": False,
|
| 206 |
+
"team_driving": False, "northeast_routes": False,
|
| 207 |
+
"hazmat_premium": False, "benefits": "good",
|
| 208 |
+
"location": random.choice(LOCATIONS),
|
| 209 |
+
"start_urgency": random.choice(["immediate", "flexible"]),
|
| 210 |
+
"requires_clean_record": random.random() < 0.3,
|
| 211 |
+
"requires_medical": True,
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
random.shuffle(jobs)
|
| 215 |
+
for i, j in enumerate(jobs):
|
| 216 |
+
j["job_id"] = i
|
| 217 |
+
return jobs
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _make_good_job(driver, job_id):
|
| 221 |
+
return {
|
| 222 |
+
"job_id": job_id, "company": random.choice(COMPANIES),
|
| 223 |
+
"required_cdl": driver["cdl_class"],
|
| 224 |
+
"required_endorsements": [e for e in driver["endorsements"] if random.random() < 0.3],
|
| 225 |
+
"min_experience": max(1, driver["experience_years"] - random.randint(1, 3)),
|
| 226 |
+
"route_type": driver["route_pref"],
|
| 227 |
+
"home_time": driver["home_time_pref"],
|
| 228 |
+
"pay_cpm": round(driver["min_cpm"] + random.uniform(0.03, 0.12), 2),
|
| 229 |
+
"equipment": driver["equipment_pref"],
|
| 230 |
+
"has_touch_freight": False, "forced_dispatch": False,
|
| 231 |
+
"team_driving": False, "northeast_routes": False,
|
| 232 |
+
"hazmat_premium": "H" in driver.get("endorsements", []),
|
| 233 |
+
"benefits": random.choice(["good", "excellent"]),
|
| 234 |
+
"location": random.choice(LOCATIONS),
|
| 235 |
+
"start_urgency": random.choice(["immediate", "flexible"]),
|
| 236 |
+
"requires_clean_record": random.random() < 0.3,
|
| 237 |
+
"requires_medical": True,
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _make_trap_job(driver, job_id):
|
| 242 |
+
trap = _make_good_job(driver, job_id)
|
| 243 |
+
breaker = random.choice(driver["deal_breakers"])
|
| 244 |
+
if breaker == "touch_freight":
|
| 245 |
+
trap["has_touch_freight"] = True
|
| 246 |
+
elif breaker == "forced_dispatch":
|
| 247 |
+
trap["forced_dispatch"] = True
|
| 248 |
+
elif breaker == "team_driving":
|
| 249 |
+
trap["team_driving"] = True
|
| 250 |
+
elif breaker == "northeast":
|
| 251 |
+
trap["northeast_routes"] = True
|
| 252 |
+
elif breaker == "hazmat_no_premium":
|
| 253 |
+
trap["required_endorsements"] = ["H"]
|
| 254 |
+
trap["hazmat_premium"] = False
|
| 255 |
+
elif breaker == "no_benefits":
|
| 256 |
+
trap["benefits"] = "none"
|
| 257 |
+
return trap
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _make_partial_job(driver, job_id):
|
| 261 |
+
job = _make_good_job(driver, job_id)
|
| 262 |
+
if random.random() < 0.5:
|
| 263 |
+
job["pay_cpm"] = round(driver["min_cpm"] - random.uniform(0.01, 0.06), 2)
|
| 264 |
+
else:
|
| 265 |
+
others = [h for h in HOME_TIMES if h != driver["home_time_pref"]]
|
| 266 |
+
job["home_time"] = random.choice(others)
|
| 267 |
+
return job
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def format_jobs(jobs):
|
| 271 |
+
lines = []
|
| 272 |
+
for j in jobs:
|
| 273 |
+
endorse = ", ".join(j["required_endorsements"]) if j["required_endorsements"] else "none"
|
| 274 |
+
flags = []
|
| 275 |
+
if j["has_touch_freight"]:
|
| 276 |
+
flags.append("touch freight")
|
| 277 |
+
if j["forced_dispatch"]:
|
| 278 |
+
flags.append("forced dispatch")
|
| 279 |
+
if j["team_driving"]:
|
| 280 |
+
flags.append("team driving")
|
| 281 |
+
if j["northeast_routes"]:
|
| 282 |
+
flags.append("northeast routes")
|
| 283 |
+
flag_str = f" [{', '.join(flags)}]" if flags else ""
|
| 284 |
+
urgency = j.get("start_urgency", "flexible")
|
| 285 |
+
clean = "clean record required" if j.get("requires_clean_record") else ""
|
| 286 |
+
medical = "DOT medical required" if j.get("requires_medical") else ""
|
| 287 |
+
reqs = ", ".join(filter(None, [clean, medical]))
|
| 288 |
+
req_str = f" ({reqs})" if reqs else ""
|
| 289 |
+
lines.append(
|
| 290 |
+
f"Job {j['job_id']}: {j['company']} — CDL-{j['required_cdl']}, "
|
| 291 |
+
f"{j['min_experience']}+ yrs, {j['route_type']}, "
|
| 292 |
+
f"${j['pay_cpm']}/mi, {j['home_time']} home, "
|
| 293 |
+
f"{j['equipment']}, endorsements: {endorse}, "
|
| 294 |
+
f"benefits: {j['benefits']}, start: {urgency}{req_str}{flag_str}"
|
| 295 |
+
)
|
| 296 |
+
return "\n".join(lines)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def trust_label(trust):
|
| 300 |
+
if trust >= 0.7:
|
| 301 |
+
return "high"
|
| 302 |
+
elif trust >= 0.4:
|
| 303 |
+
return "medium"
|
| 304 |
+
return "low"
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# --- Job fit scoring ---
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def score_job_fit(driver, job):
|
| 311 |
+
"""Returns (score 0-100, issues list, fatal bool)."""
|
| 312 |
+
score = 100
|
| 313 |
+
issues = []
|
| 314 |
+
|
| 315 |
+
if driver["cdl_class"] != job["required_cdl"]:
|
| 316 |
+
return 0, ["CDL class mismatch"], True
|
| 317 |
+
if driver["experience_years"] < job["min_experience"]:
|
| 318 |
+
return 0, [f"Needs {job['min_experience']} yrs, driver has {driver['experience_years']}"], True
|
| 319 |
+
for e in job["required_endorsements"]:
|
| 320 |
+
if e not in driver["endorsements"]:
|
| 321 |
+
return 0, [f"Missing {e} endorsement"], True
|
| 322 |
+
|
| 323 |
+
if job["has_touch_freight"] and "touch_freight" in driver["deal_breakers"]:
|
| 324 |
+
return 0, ["Touch freight is a deal breaker"], True
|
| 325 |
+
if job["forced_dispatch"] and "forced_dispatch" in driver["deal_breakers"]:
|
| 326 |
+
return 0, ["Forced dispatch is a deal breaker"], True
|
| 327 |
+
if job["team_driving"] and "team_driving" in driver["deal_breakers"]:
|
| 328 |
+
return 0, ["Team driving is a deal breaker"], True
|
| 329 |
+
if job["northeast_routes"] and "northeast" in driver["deal_breakers"]:
|
| 330 |
+
return 0, ["Northeast routes is a deal breaker"], True
|
| 331 |
+
if ("H" in job["required_endorsements"] and not job["hazmat_premium"]
|
| 332 |
+
and "hazmat_no_premium" in driver["deal_breakers"]):
|
| 333 |
+
return 0, ["Hazmat without premium pay"], True
|
| 334 |
+
if job["benefits"] == "none" and "no_benefits" in driver["deal_breakers"]:
|
| 335 |
+
return 0, ["No benefits is a deal breaker"], True
|
| 336 |
+
|
| 337 |
+
if job["pay_cpm"] < driver["min_cpm"]:
|
| 338 |
+
diff = driver["min_cpm"] - job["pay_cpm"]
|
| 339 |
+
if diff > 0.10:
|
| 340 |
+
return 0, [f"Pay ${job['pay_cpm']}/mi way below min ${driver['min_cpm']}/mi"], True
|
| 341 |
+
score -= int(diff * 400)
|
| 342 |
+
issues.append(f"Pay is ${diff:.2f}/mi below minimum")
|
| 343 |
+
|
| 344 |
+
if job["home_time"] != driver["home_time_pref"]:
|
| 345 |
+
score -= 25
|
| 346 |
+
issues.append(f"Home time: job={job['home_time']}, wants={driver['home_time_pref']}")
|
| 347 |
+
|
| 348 |
+
if job["route_type"] != driver["route_pref"]:
|
| 349 |
+
score -= 15
|
| 350 |
+
issues.append(f"Route: job={job['route_type']}, wants={driver['route_pref']}")
|
| 351 |
+
|
| 352 |
+
if job["equipment"] != driver["equipment_pref"]:
|
| 353 |
+
score -= 10
|
| 354 |
+
issues.append(f"Equipment: job={job['equipment']}, prefers={driver['equipment_pref']}")
|
| 355 |
+
|
| 356 |
+
if job.get("requires_clean_record") and driver.get("violations") == "major":
|
| 357 |
+
return 0, ["Major violations disqualify for this position"], True
|
| 358 |
+
if job.get("requires_medical") and driver.get("medical_card") == "expired":
|
| 359 |
+
return 0, ["Expired DOT medical card"], True
|
| 360 |
+
|
| 361 |
+
if job.get("requires_clean_record") and driver.get("violations") == "minor":
|
| 362 |
+
score -= 15
|
| 363 |
+
issues.append("Minor violations may be a concern for clean-record position")
|
| 364 |
+
if driver.get("medical_card") == "expiring_soon":
|
| 365 |
+
score -= 5
|
| 366 |
+
issues.append("DOT medical card expiring soon, needs renewal")
|
| 367 |
+
if job.get("start_urgency") == "immediate" and driver.get("availability") == "1_month":
|
| 368 |
+
score -= 20
|
| 369 |
+
issues.append("Driver can't start for a month, job needs immediate start")
|
| 370 |
+
if driver.get("references") == "none":
|
| 371 |
+
score -= 10
|
| 372 |
+
issues.append("No references available")
|
| 373 |
+
elif driver.get("references") == "mixed":
|
| 374 |
+
score -= 5
|
| 375 |
+
issues.append("Mixed references from previous employers")
|
| 376 |
+
|
| 377 |
+
return max(0, score), issues, False
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# --- Natural language response templates ---
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _respond_experience(driver):
|
| 384 |
+
p = driver["personality"]
|
| 385 |
+
cdl = driver["cdl_class"]
|
| 386 |
+
yrs = driver["experience_years"]
|
| 387 |
+
endorse = driver["endorsements"]
|
| 388 |
+
loc = driver["location"]
|
| 389 |
+
endorse_str = ", ".join(endorse) if endorse else "none"
|
| 390 |
+
|
| 391 |
+
if p == "chatty":
|
| 392 |
+
return (
|
| 393 |
+
f"Oh yeah, I've been driving for {yrs} years now! Got my CDL-{cdl} "
|
| 394 |
+
f"right out of school. "
|
| 395 |
+
f"{'I picked up my ' + endorse_str + ' endorsements along the way.' if endorse else 'No special endorsements yet but been thinking about it.'} "
|
| 396 |
+
f"Based out of {loc}, been here my whole life."
|
| 397 |
+
)
|
| 398 |
+
elif p == "impatient":
|
| 399 |
+
return f"CDL-{cdl}, {yrs} years. Endorsements: {endorse_str}. {loc}."
|
| 400 |
+
elif p == "suspicious":
|
| 401 |
+
if driver["trust"] < 0.5:
|
| 402 |
+
return f"I've got a CDL-{cdl}. Been driving a while, out of {loc}."
|
| 403 |
+
return f"CDL-{cdl}, {yrs} years experience. Endorsements: {endorse_str}. Based in {loc}."
|
| 404 |
+
else:
|
| 405 |
+
return (
|
| 406 |
+
f"I hold a CDL-{cdl} with {yrs} years of commercial driving experience. "
|
| 407 |
+
f"Endorsements: {endorse_str}. I'm located in {loc}."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _respond_home_time(driver):
|
| 412 |
+
p = driver["personality"]
|
| 413 |
+
pref = driver["home_time_pref"]
|
| 414 |
+
templates = {
|
| 415 |
+
"chatty": {
|
| 416 |
+
"daily": "Oh yeah, I gotta be home every night. My wife would kill me otherwise! We got three kids and I help with homework every evening.",
|
| 417 |
+
"weekends": "I need my weekends, you know? My kids have soccer on Saturdays and church on Sundays. Weekday runs are fine though.",
|
| 418 |
+
"weekly": "I like to be home at least once a week. I can do a few days out but need to get back regularly.",
|
| 419 |
+
"biweekly": "I can do longer runs, two weeks out is fine. My buddy and I go fishing every other weekend so that works out.",
|
| 420 |
+
},
|
| 421 |
+
"impatient": {
|
| 422 |
+
"daily": "Home daily. Non-negotiable.",
|
| 423 |
+
"weekends": "Home on weekends.",
|
| 424 |
+
"weekly": "Home weekly.",
|
| 425 |
+
"biweekly": "Two weeks out is fine.",
|
| 426 |
+
},
|
| 427 |
+
"suspicious": {
|
| 428 |
+
"daily": "I need to be home... regularly." if driver["trust"] < 0.5 else "I need to be home every night, that's firm.",
|
| 429 |
+
"weekends": "I need my time off." if driver["trust"] < 0.5 else "I need to be home on weekends for my family.",
|
| 430 |
+
"weekly": "Can't be gone too long." if driver["trust"] < 0.5 else "I need to get home at least once a week.",
|
| 431 |
+
"biweekly": "I'm flexible on time out." if driver["trust"] < 0.5 else "Two weeks out, two days home works for me.",
|
| 432 |
+
},
|
| 433 |
+
"professional": {
|
| 434 |
+
"daily": "I'm looking for local routes that get me home every evening.",
|
| 435 |
+
"weekends": "I'd like to be home on weekends. Weekday runs are fine.",
|
| 436 |
+
"weekly": "I prefer weekly home time. A few days out, then home for a reset.",
|
| 437 |
+
"biweekly": "I'm comfortable with biweekly home time. I've done OTR for years.",
|
| 438 |
+
},
|
| 439 |
+
}
|
| 440 |
+
return templates[p][pref]
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _respond_pay(driver):
|
| 444 |
+
p = driver["personality"]
|
| 445 |
+
cpm = driver["min_cpm"]
|
| 446 |
+
if p == "chatty":
|
| 447 |
+
return f"I'm making ${cpm}/mile right now and honestly I won't move for less. If you can beat that by a few cents and throw in a decent sign-on bonus, I'm listening."
|
| 448 |
+
elif p == "impatient":
|
| 449 |
+
return f"${cpm}/mile minimum. Don't lowball me."
|
| 450 |
+
elif p == "suspicious":
|
| 451 |
+
if driver["trust"] < 0.5:
|
| 452 |
+
return "I need to be paid fair, you know what I'm saying? What are you offering?"
|
| 453 |
+
return f"Look, I need at least ${cpm}/mile. I know what I'm worth."
|
| 454 |
+
else:
|
| 455 |
+
return f"My minimum is ${cpm} per mile. I'm open to discussing total compensation including benefits."
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _respond_equipment(driver):
|
| 459 |
+
p = driver["personality"]
|
| 460 |
+
pref = driver["equipment_pref"]
|
| 461 |
+
pretty = pref.replace("_", " ")
|
| 462 |
+
if p == "chatty":
|
| 463 |
+
extra = " Got my tanker endorsement too so I can do that." if "N" in driver["endorsements"] else ""
|
| 464 |
+
return f"I've been running {pretty} mostly. Love it, got the hang of it.{extra} Wouldn't mind sticking with what I know."
|
| 465 |
+
elif p == "impatient":
|
| 466 |
+
return f"{pretty.title()}. That's what I run."
|
| 467 |
+
elif p == "suspicious":
|
| 468 |
+
if driver["trust"] < 0.5:
|
| 469 |
+
return "I've got experience with different trailers."
|
| 470 |
+
return f"I prefer {pretty}. That's where most of my experience is."
|
| 471 |
+
else:
|
| 472 |
+
return f"My primary experience is with {pretty} equipment. I'd prefer to stay in that lane."
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def _respond_route(driver):
|
| 476 |
+
p = driver["personality"]
|
| 477 |
+
pref = driver["route_pref"]
|
| 478 |
+
routes = {
|
| 479 |
+
"chatty": {
|
| 480 |
+
"OTR": "I like the open road, OTR is my thing. See the country, you know?",
|
| 481 |
+
"regional": "Regional is my sweet spot. Good miles but still get home.",
|
| 482 |
+
"local": "Local runs for me. I know every road in this city!",
|
| 483 |
+
"dedicated": "Dedicated routes are great. Same customer, same lanes, no surprises.",
|
| 484 |
+
},
|
| 485 |
+
"impatient": {"OTR": "OTR.", "regional": "Regional.", "local": "Local.", "dedicated": "Dedicated."},
|
| 486 |
+
"suspicious": {
|
| 487 |
+
"OTR": ("Depends on the route." if driver["trust"] < 0.5 else "I'm looking for OTR work."),
|
| 488 |
+
"regional": ("Depends on the area." if driver["trust"] < 0.5 else "I'm looking for regional work."),
|
| 489 |
+
"local": ("I want to stay close to home." if driver["trust"] < 0.5 else "Local is what I want."),
|
| 490 |
+
"dedicated": ("Depends on the lanes." if driver["trust"] < 0.5 else "I prefer dedicated routes."),
|
| 491 |
+
},
|
| 492 |
+
"professional": {
|
| 493 |
+
"OTR": "I'm interested in OTR positions.",
|
| 494 |
+
"regional": "I'm looking for regional opportunities.",
|
| 495 |
+
"local": "I'd prefer local routes.",
|
| 496 |
+
"dedicated": "Dedicated lanes would be ideal.",
|
| 497 |
+
},
|
| 498 |
+
}
|
| 499 |
+
return routes[p][pref]
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _respond_deal_breakers(driver):
|
| 503 |
+
p = driver["personality"]
|
| 504 |
+
breakers = driver["deal_breakers"]
|
| 505 |
+
labels = {
|
| 506 |
+
"touch_freight": "touch freight",
|
| 507 |
+
"forced_dispatch": "forced dispatch",
|
| 508 |
+
"team_driving": "team driving",
|
| 509 |
+
"northeast": "northeast/NYC routes",
|
| 510 |
+
"hazmat_no_premium": "hazmat without extra pay",
|
| 511 |
+
"no_benefits": "no health benefits",
|
| 512 |
+
}
|
| 513 |
+
if p == "chatty":
|
| 514 |
+
items = [labels[b] for b in breakers]
|
| 515 |
+
return f"Oh man, don't even get me started. I will NOT do {', '.join(items)}. Had bad experiences with all of that."
|
| 516 |
+
elif p == "impatient":
|
| 517 |
+
return f"No {labels[breakers[0]]}. That's my line."
|
| 518 |
+
elif p == "suspicious":
|
| 519 |
+
if driver["trust"] < 0.5:
|
| 520 |
+
return "I've got my limits. What kind of freight are we talking about?"
|
| 521 |
+
items = [labels[b] for b in breakers]
|
| 522 |
+
return f"I won't do {', '.join(items)}. Those are hard stops for me."
|
| 523 |
+
else:
|
| 524 |
+
items = [labels[b] for b in breakers]
|
| 525 |
+
return f"My non-negotiables: no {', no '.join(items)}."
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def _respond_availability(driver):
|
| 529 |
+
p = driver["personality"]
|
| 530 |
+
avail = driver["availability"]
|
| 531 |
+
labels = {"immediately": "right away", "2_weeks": "in about two weeks", "1_month": "in about a month", "negotiable": "depends on the offer"}
|
| 532 |
+
if p == "chatty":
|
| 533 |
+
if avail == "immediately":
|
| 534 |
+
return "I'm ready to go! Just left my last company, sitting at home going crazy. Can start tomorrow if you need me."
|
| 535 |
+
elif avail == "2_weeks":
|
| 536 |
+
return "I need to give my current place two weeks notice. They've been good to me, wanna leave right."
|
| 537 |
+
elif avail == "1_month":
|
| 538 |
+
return "It'll be about a month. I'm finishing up a contract and need to wrap some things up at home too."
|
| 539 |
+
else:
|
| 540 |
+
return "Depends on what you've got. For the right job I could move quick, otherwise I'm okay where I am."
|
| 541 |
+
elif p == "impatient":
|
| 542 |
+
return f"Can start {labels[avail]}."
|
| 543 |
+
elif p == "suspicious":
|
| 544 |
+
if driver["trust"] < 0.5:
|
| 545 |
+
return "Why do you need to know that already? I'll be available when I'm available."
|
| 546 |
+
return f"I can start {labels[avail]}."
|
| 547 |
+
else:
|
| 548 |
+
return f"I'm available to start {labels[avail]}. I can be flexible depending on the opportunity."
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _respond_violations(driver):
|
| 552 |
+
p = driver["personality"]
|
| 553 |
+
violations = driver["violations"]
|
| 554 |
+
if p == "chatty":
|
| 555 |
+
if violations == "clean":
|
| 556 |
+
return "Clean record, twenty years no accidents! Well, one close call in '09 but that wasn't my fault. Nothing on the record though."
|
| 557 |
+
elif violations == "minor":
|
| 558 |
+
return "I had a minor thing a while back, nothing serious. A speeding ticket in a construction zone. Learned my lesson."
|
| 559 |
+
else:
|
| 560 |
+
return "Look, I had an incident a few years ago. It was a bad situation but I've cleaned up since then. I'm a different driver now."
|
| 561 |
+
elif p == "impatient":
|
| 562 |
+
if violations == "clean":
|
| 563 |
+
return "Clean record."
|
| 564 |
+
elif violations == "minor":
|
| 565 |
+
return "Minor stuff, nothing serious."
|
| 566 |
+
else:
|
| 567 |
+
return "I've had some issues. It's in the past."
|
| 568 |
+
elif p == "suspicious":
|
| 569 |
+
if driver["trust"] < 0.5:
|
| 570 |
+
return "Why are you asking about that? My record is my business."
|
| 571 |
+
if violations == "clean":
|
| 572 |
+
return "My record is clean. You can check."
|
| 573 |
+
elif violations == "minor":
|
| 574 |
+
return "There's a minor thing on there but nothing that should matter."
|
| 575 |
+
else:
|
| 576 |
+
return "I've had some trouble before. But I've been clean for two years now."
|
| 577 |
+
else:
|
| 578 |
+
if violations == "clean":
|
| 579 |
+
return "I have a clean driving record with no violations or incidents."
|
| 580 |
+
elif violations == "minor":
|
| 581 |
+
return "I have a minor violation on record. I'm happy to discuss the details."
|
| 582 |
+
else:
|
| 583 |
+
return "I do have a violation on my record. I've taken corrective steps since then."
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def _respond_medical_card(driver):
|
| 587 |
+
p = driver["personality"]
|
| 588 |
+
status = driver["medical_card"]
|
| 589 |
+
if p == "chatty":
|
| 590 |
+
if status == "valid":
|
| 591 |
+
return "Yep, DOT medical is all good! Just renewed it last month actually. Passed with flying colors."
|
| 592 |
+
elif status == "expiring_soon":
|
| 593 |
+
return "Oh yeah, I need to renew that soon actually. Thanks for reminding me. It's coming up in a few weeks."
|
| 594 |
+
else:
|
| 595 |
+
return "Ugh, yeah, it expired. I've been meaning to get that renewed. Can I still apply while I'm working on it?"
|
| 596 |
+
elif p == "impatient":
|
| 597 |
+
if status == "valid":
|
| 598 |
+
return "DOT medical is current."
|
| 599 |
+
elif status == "expiring_soon":
|
| 600 |
+
return "Expires soon. I'll renew it."
|
| 601 |
+
else:
|
| 602 |
+
return "It's expired. I'll get it done."
|
| 603 |
+
elif p == "suspicious":
|
| 604 |
+
if driver["trust"] < 0.5:
|
| 605 |
+
return "My medical stuff is between me and my doctor."
|
| 606 |
+
if status == "valid":
|
| 607 |
+
return "My DOT medical is current and valid."
|
| 608 |
+
elif status == "expiring_soon":
|
| 609 |
+
return "It's expiring soon but I've got an appointment scheduled."
|
| 610 |
+
else:
|
| 611 |
+
return "It lapsed. I can get it renewed if there's a real opportunity here."
|
| 612 |
+
else:
|
| 613 |
+
if status == "valid":
|
| 614 |
+
return "My DOT medical certificate is current and valid."
|
| 615 |
+
elif status == "expiring_soon":
|
| 616 |
+
return "My medical card is expiring soon. I plan to renew it promptly."
|
| 617 |
+
else:
|
| 618 |
+
return "My DOT medical has expired. I'm prepared to renew it for the right position."
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def _respond_references(driver):
|
| 622 |
+
p = driver["personality"]
|
| 623 |
+
refs = driver["references"]
|
| 624 |
+
if p == "chatty":
|
| 625 |
+
if refs == "strong":
|
| 626 |
+
return "Oh yeah, my last dispatcher loved me! You can call anyone I've worked for. They'll all say good things."
|
| 627 |
+
elif refs == "mixed":
|
| 628 |
+
return "Most of my old bosses would say good things... I had a rough patch at one place but we parted okay."
|
| 629 |
+
else:
|
| 630 |
+
return "I've mostly done owner-operator stuff, so I don't really have traditional references. But I can show you my load history!"
|
| 631 |
+
elif p == "impatient":
|
| 632 |
+
if refs == "strong":
|
| 633 |
+
return "References are solid. Call whoever you want."
|
| 634 |
+
elif refs == "mixed":
|
| 635 |
+
return "Some are better than others."
|
| 636 |
+
else:
|
| 637 |
+
return "Don't have references. I work for myself."
|
| 638 |
+
elif p == "suspicious":
|
| 639 |
+
if driver["trust"] < 0.5:
|
| 640 |
+
return "I'm not giving you names until I know this is serious."
|
| 641 |
+
if refs == "strong":
|
| 642 |
+
return "I've got good references. I'll provide them when we're further along."
|
| 643 |
+
elif refs == "mixed":
|
| 644 |
+
return "I have some references. It depends on who you talk to."
|
| 645 |
+
else:
|
| 646 |
+
return "I don't have traditional references."
|
| 647 |
+
else:
|
| 648 |
+
if refs == "strong":
|
| 649 |
+
return "I have strong references from my previous employers. Happy to provide contact information."
|
| 650 |
+
elif refs == "mixed":
|
| 651 |
+
return "I can provide references. My track record has been generally positive."
|
| 652 |
+
else:
|
| 653 |
+
return "I don't have employer references available, though I can provide other professional contacts."
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def _respond_pitch(driver, job):
|
| 657 |
+
score, issues, fatal = score_job_fit(driver, job)
|
| 658 |
+
if fatal:
|
| 659 |
+
reason = issues[0] if issues else "not a fit"
|
| 660 |
+
p = driver["personality"]
|
| 661 |
+
if p == "chatty":
|
| 662 |
+
return f"Nah, that's not gonna work for me. {reason}. Got anything else?"
|
| 663 |
+
elif p == "impatient":
|
| 664 |
+
return f"No. {reason}."
|
| 665 |
+
elif p == "suspicious":
|
| 666 |
+
return f"Why would you pitch me that? {reason}."
|
| 667 |
+
else:
|
| 668 |
+
return f"I'll have to pass. {reason}."
|
| 669 |
+
elif score >= 80:
|
| 670 |
+
p = driver["personality"]
|
| 671 |
+
if p == "chatty":
|
| 672 |
+
return "Now THAT sounds interesting! The pay is right, the home time works... I could see myself there."
|
| 673 |
+
elif p == "impatient":
|
| 674 |
+
return "That could work. What's next?"
|
| 675 |
+
elif p == "suspicious":
|
| 676 |
+
return "Hmm, that actually doesn't sound bad. What's the catch?"
|
| 677 |
+
else:
|
| 678 |
+
return "That aligns well with what I'm looking for. I'd like to move forward."
|
| 679 |
+
else:
|
| 680 |
+
concern = issues[0] if issues else "something's off"
|
| 681 |
+
p = driver["personality"]
|
| 682 |
+
if p == "chatty":
|
| 683 |
+
return f"It's close but I'm not sure... {concern}. Maybe if they could adjust something?"
|
| 684 |
+
elif p == "impatient":
|
| 685 |
+
return f"Ehh. {concern}."
|
| 686 |
+
elif p == "suspicious":
|
| 687 |
+
return f"I don't know... {concern}. What else you got?"
|
| 688 |
+
else:
|
| 689 |
+
return f"It's interesting but I have a concern: {concern}."
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
# --- Contact response templates ---
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def _respond_contact_good(driver, topic):
|
| 696 |
+
p = driver["personality"]
|
| 697 |
+
method = "text" if topic == "greeting" else "call"
|
| 698 |
+
if p == "chatty":
|
| 699 |
+
if method == "text":
|
| 700 |
+
return "Hey! Yeah I got your text. I've been looking for something new actually. What do you have for me?"
|
| 701 |
+
return "Hello? Oh hey, yeah I was hoping someone would reach out. I'm definitely interested in hearing about opportunities."
|
| 702 |
+
elif p == "impatient":
|
| 703 |
+
if method == "text":
|
| 704 |
+
return "Got your text. What do you have?"
|
| 705 |
+
return "Yeah, I'm listening. What's the job?"
|
| 706 |
+
elif p == "suspicious":
|
| 707 |
+
if method == "text":
|
| 708 |
+
return "Hey. How'd you get my number? ...Okay, I'm listening I guess."
|
| 709 |
+
return "Who is this? ...A recruiter? Alright, what are you offering?"
|
| 710 |
+
else:
|
| 711 |
+
if method == "text":
|
| 712 |
+
return "Thanks for reaching out. I'm open to new opportunities. What positions do you have available?"
|
| 713 |
+
return "Hello, thanks for the call. I'm currently exploring new opportunities. What do you have?"
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def _respond_contact_wrong(driver, topic):
|
| 717 |
+
p = driver["personality"]
|
| 718 |
+
if topic == "greeting": # texted a caller
|
| 719 |
+
if p == "chatty":
|
| 720 |
+
return "Oh hey, got your text. I usually prefer a phone call but no worries, what's up?"
|
| 721 |
+
elif p == "impatient":
|
| 722 |
+
return "Text is fine I guess. What do you want?"
|
| 723 |
+
elif p == "suspicious":
|
| 724 |
+
return "...Who is this? I don't usually respond to random texts."
|
| 725 |
+
else:
|
| 726 |
+
return "I received your message. I generally prefer a phone call, but I'm happy to chat."
|
| 727 |
+
else: # called a texter
|
| 728 |
+
if p == "chatty":
|
| 729 |
+
return "Oh, uh, hey. I wasn't expecting a call. I'm kinda busy, could you text me instead? ...Fine, what is it?"
|
| 730 |
+
elif p == "impatient":
|
| 731 |
+
return "I don't pick up unknown numbers usually. Should've texted. What do you want?"
|
| 732 |
+
elif p == "suspicious":
|
| 733 |
+
return "Who is this? I don't answer calls from numbers I don't know."
|
| 734 |
+
else:
|
| 735 |
+
return "Hello. I prefer to communicate via text if possible. But go ahead, what do you have?"
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def _respond_contact_repeat(driver):
|
| 739 |
+
p = driver["personality"]
|
| 740 |
+
if p == "chatty":
|
| 741 |
+
return "You already reached out to me! What else do you need?"
|
| 742 |
+
elif p == "impatient":
|
| 743 |
+
return "You already contacted me. What now?"
|
| 744 |
+
elif p == "suspicious":
|
| 745 |
+
return "Why are you contacting me again? We already talked."
|
| 746 |
+
else:
|
| 747 |
+
return "We've already been in touch. What's the next step?"
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def _respond_repeat_question(driver, topic):
|
| 751 |
+
p = driver["personality"]
|
| 752 |
+
if p == "chatty":
|
| 753 |
+
return f"Didn't I already tell you about my {topic}? I feel like we covered that!"
|
| 754 |
+
elif p == "impatient":
|
| 755 |
+
return f"I already answered that. Pay attention."
|
| 756 |
+
elif p == "suspicious":
|
| 757 |
+
return f"You already asked me about {topic}. Why are you asking again?"
|
| 758 |
+
else:
|
| 759 |
+
return f"I believe I already shared my {topic} preferences with you."
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# --- Offer/submit response templates ---
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _respond_offer_accept(driver, job):
|
| 766 |
+
p = driver["personality"]
|
| 767 |
+
company = job["company"]
|
| 768 |
+
if p == "chatty":
|
| 769 |
+
return f"Awesome! {company} sounds great, I'm excited to get started. Thanks for finding this for me!"
|
| 770 |
+
elif p == "impatient":
|
| 771 |
+
return f"Good. {company}. When do I start?"
|
| 772 |
+
elif p == "suspicious":
|
| 773 |
+
return f"Alright, {company} it is. I hope this works out. Thanks."
|
| 774 |
+
else:
|
| 775 |
+
return f"Thank you for the placement at {company}. I'm looking forward to getting started."
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def _respond_offer_concerns(driver, job, concern):
|
| 779 |
+
p = driver["personality"]
|
| 780 |
+
company = job["company"]
|
| 781 |
+
if p == "chatty":
|
| 782 |
+
return f"I mean, {company} is okay I guess. {concern} bugs me a little but maybe we can work something out?"
|
| 783 |
+
elif p == "impatient":
|
| 784 |
+
return f"Ehh. {concern}. Can you fix that?"
|
| 785 |
+
elif p == "suspicious":
|
| 786 |
+
return f"I'm not fully sold on {company}. {concern}. What are you going to do about it?"
|
| 787 |
+
else:
|
| 788 |
+
return f"I have a concern about the {company} position: {concern}. Can we discuss?"
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def _respond_offer_reject(driver, reason):
|
| 792 |
+
p = driver["personality"]
|
| 793 |
+
if p == "chatty":
|
| 794 |
+
return f"Yeah no, I can't do that. {reason}. I thought we talked about this?"
|
| 795 |
+
elif p == "impatient":
|
| 796 |
+
return f"No. {reason}. I'm done here."
|
| 797 |
+
elif p == "suspicious":
|
| 798 |
+
return f"Are you serious? {reason}. I knew this was a waste of my time."
|
| 799 |
+
else:
|
| 800 |
+
return f"I'm going to have to withdraw. {reason}. This isn't what we discussed."
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def _respond_ghosted(driver):
|
| 804 |
+
p = driver["personality"]
|
| 805 |
+
name = driver["name"].split()[0]
|
| 806 |
+
if p == "chatty":
|
| 807 |
+
return f"{name} stopped responding to your messages. Last seen: 'idk man this isn't working out...'"
|
| 808 |
+
elif p == "impatient":
|
| 809 |
+
return f"{name} blocked your number."
|
| 810 |
+
elif p == "suspicious":
|
| 811 |
+
return f"{name} stopped responding. They were never fully comfortable with the process."
|
| 812 |
+
else:
|
| 813 |
+
return f"{name} sent a polite message saying they've decided to go with another recruiter."
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
# --- Negotiation helpers ---
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def _get_negotiation_concerns(driver, job):
|
| 820 |
+
_, issues, _ = score_job_fit(driver, job)
|
| 821 |
+
return issues
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def _respond_negotiation(driver, action, job, concerns):
|
| 825 |
+
p = driver["personality"]
|
| 826 |
+
|
| 827 |
+
if action == "negotiate_pay":
|
| 828 |
+
if any("pay" in c.lower() for c in concerns):
|
| 829 |
+
if p == "chatty":
|
| 830 |
+
return "Well, if you can get them to bump it up a few cents, I'd feel a lot better about this."
|
| 831 |
+
elif p == "impatient":
|
| 832 |
+
return "More money would help. Get it done."
|
| 833 |
+
elif p == "suspicious":
|
| 834 |
+
return "I'll believe a pay bump when I see it in writing."
|
| 835 |
+
else:
|
| 836 |
+
return "I'd appreciate if you could negotiate a higher rate."
|
| 837 |
+
else:
|
| 838 |
+
return "Pay isn't really my concern here."
|
| 839 |
+
|
| 840 |
+
elif action == "negotiate_home_time":
|
| 841 |
+
if any("home time" in c.lower() for c in concerns):
|
| 842 |
+
if p == "chatty":
|
| 843 |
+
return "Yeah, if they could work with my schedule that would change everything. Talk to them?"
|
| 844 |
+
elif p == "impatient":
|
| 845 |
+
return "Fix the home time and we'll talk."
|
| 846 |
+
elif p == "suspicious":
|
| 847 |
+
return "They always say they'll adjust the schedule. Will they actually?"
|
| 848 |
+
else:
|
| 849 |
+
return "If the home time can be adjusted, I'd be much more interested."
|
| 850 |
+
else:
|
| 851 |
+
return "Home time isn't really my issue here."
|
| 852 |
+
|
| 853 |
+
elif action == "signing_bonus":
|
| 854 |
+
if p == "chatty":
|
| 855 |
+
return "A signing bonus? Hey, that's nice! Doesn't fix everything but it helps."
|
| 856 |
+
elif p == "impatient":
|
| 857 |
+
return "Bonus is fine. What about the real issues?"
|
| 858 |
+
elif p == "suspicious":
|
| 859 |
+
return "Bonuses are nice but they don't solve long-term problems."
|
| 860 |
+
else:
|
| 861 |
+
return "I appreciate the signing bonus offer. It's a positive gesture."
|
| 862 |
+
|
| 863 |
+
elif action == "address_concern":
|
| 864 |
+
if concerns:
|
| 865 |
+
if p == "chatty":
|
| 866 |
+
return f"Yeah, my big thing is: {concerns[0]}. If you can work that out, I'm in."
|
| 867 |
+
elif p == "impatient":
|
| 868 |
+
return f"{concerns[0]}. Fix it."
|
| 869 |
+
elif p == "suspicious":
|
| 870 |
+
if driver["trust"] < 0.4:
|
| 871 |
+
return "I've told you my concerns. Are you actually going to do something about them?"
|
| 872 |
+
return f"Fine, here's what bothers me: {concerns[0]}."
|
| 873 |
+
else:
|
| 874 |
+
return f"My primary concern is: {concerns[0]}. I'd need that resolved."
|
| 875 |
+
else:
|
| 876 |
+
return "I don't really have any major concerns. I think we're good."
|
| 877 |
+
|
| 878 |
+
return "I'm not sure what you mean."
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
# --- CRM formatting ---
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def _api(code, **kwargs):
|
| 885 |
+
"""Format a JSON API response with status code."""
|
| 886 |
+
return json.dumps({"code": code, **kwargs})
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def format_crm(crm):
|
| 890 |
+
"""Format CRM record into readable string."""
|
| 891 |
+
lines = [f"Name: {crm['name']}", f"Stage: {crm['stage']}"]
|
| 892 |
+
if crm["fields"]:
|
| 893 |
+
lines.append("Fields:")
|
| 894 |
+
for k, v in sorted(crm["fields"].items()):
|
| 895 |
+
lines.append(f" {k}: {v}")
|
| 896 |
+
else:
|
| 897 |
+
lines.append("Fields: (none recorded)")
|
| 898 |
+
if crm["notes"]:
|
| 899 |
+
lines.append("Notes:")
|
| 900 |
+
for n in crm["notes"]:
|
| 901 |
+
lines.append(f" - {n}")
|
| 902 |
+
return "\n".join(lines)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
# --- Environment ---
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
class RecruitopenenvEnvironment(Environment):
|
| 909 |
+
"""Driver recruiting environment with tool-based long-horizon interaction."""
|
| 910 |
+
|
| 911 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 912 |
+
|
| 913 |
+
def __init__(self):
|
| 914 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 915 |
+
self._driver = {}
|
| 916 |
+
self._jobs = []
|
| 917 |
+
# CRM state
|
| 918 |
+
self._crm = {"name": "", "stage": "lead", "fields": {}, "notes": []}
|
| 919 |
+
self._has_read_crm = False
|
| 920 |
+
self._crm_read_count = 0
|
| 921 |
+
# Messaging state
|
| 922 |
+
self._pending_reply = None # (response_text, topic)
|
| 923 |
+
self._contacted = False
|
| 924 |
+
self._asked = set()
|
| 925 |
+
self._discovered_info = []
|
| 926 |
+
# Approval state
|
| 927 |
+
self._approval_status = "none"
|
| 928 |
+
self._approval_job_id = -1
|
| 929 |
+
# Negotiation state
|
| 930 |
+
self._matched_job_id = -1
|
| 931 |
+
self._negotiation_round = 0
|
| 932 |
+
self._negotiation_score_bonus = 0
|
| 933 |
+
self._negotiation_concerns = []
|
| 934 |
+
# Interaction tracking
|
| 935 |
+
self._last_contact_step = 0
|
| 936 |
+
|
| 937 |
+
def _make_obs(self, reward=0.0, done=False, feedback=""):
|
| 938 |
+
return RecruitopenenvObservation(
|
| 939 |
+
driver_name=self._driver.get("name", ""),
|
| 940 |
+
crm_summary=format_crm(self._crm) if self._has_read_crm else "",
|
| 941 |
+
jobs_summary=format_jobs(self._jobs) if self._jobs else "",
|
| 942 |
+
discovered_info="\n".join(self._discovered_info),
|
| 943 |
+
stage=self._crm["stage"],
|
| 944 |
+
feedback=feedback,
|
| 945 |
+
pending_reply=self._pending_reply is not None,
|
| 946 |
+
done=done,
|
| 947 |
+
reward=reward,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
def reset(self, seed: int = None) -> RecruitopenenvObservation:
|
| 951 |
+
if seed is not None:
|
| 952 |
+
random.seed(seed)
|
| 953 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 954 |
+
self._driver = generate_driver()
|
| 955 |
+
self._jobs = generate_jobs(self._driver)
|
| 956 |
+
self._crm = {"name": self._driver["name"], "stage": "lead", "fields": {}, "notes": []}
|
| 957 |
+
self._has_read_crm = False
|
| 958 |
+
self._crm_read_count = 0
|
| 959 |
+
self._pending_reply = None
|
| 960 |
+
self._contacted = False
|
| 961 |
+
self._asked = set()
|
| 962 |
+
self._discovered_info = []
|
| 963 |
+
self._approval_status = "none"
|
| 964 |
+
self._approval_job_id = -1
|
| 965 |
+
self._matched_job_id = -1
|
| 966 |
+
self._negotiation_round = 0
|
| 967 |
+
self._negotiation_score_bonus = 0
|
| 968 |
+
self._negotiation_concerns = []
|
| 969 |
+
self._last_contact_step = 0
|
| 970 |
+
|
| 971 |
+
return self._make_obs(
|
| 972 |
+
feedback=_api(200, driver=self._driver["name"], jobs=len(self._jobs))
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
def step(self, action: RecruitopenenvAction) -> RecruitopenenvObservation:
|
| 976 |
+
if not self._driver:
|
| 977 |
+
return self._make_obs(reward=0.0, done=True, feedback=_api(400, error="no_episode"))
|
| 978 |
+
|
| 979 |
+
tool = action.tool
|
| 980 |
+
act = action.action
|
| 981 |
+
|
| 982 |
+
# Validate tool+action
|
| 983 |
+
if tool not in VALID_TOOL_ACTIONS:
|
| 984 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_tool", tool=tool))
|
| 985 |
+
if act not in VALID_TOOL_ACTIONS[tool]:
|
| 986 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", tool=tool, action=act))
|
| 987 |
+
|
| 988 |
+
# Check terminal
|
| 989 |
+
if self._crm["stage"] in ("hired", "lost", "ghosted"):
|
| 990 |
+
return self._make_obs(reward=0.0, done=True, feedback=_api(400, error="episode_ended"))
|
| 991 |
+
|
| 992 |
+
self._state.step_count += 1
|
| 993 |
+
|
| 994 |
+
if self._state.step_count >= MAX_STEPS:
|
| 995 |
+
self._crm["stage"] = "ghosted"
|
| 996 |
+
return self._make_obs(reward=-3.0, done=True, feedback=_api(200, result="ghosted", reason="timeout"))
|
| 997 |
+
|
| 998 |
+
# Passive trust decay — driver loses patience while recruiter isn't talking to them
|
| 999 |
+
idle_gap = self._state.step_count - self._last_contact_step
|
| 1000 |
+
if idle_gap > 2:
|
| 1001 |
+
# Accelerating decay: longer silence = faster trust loss
|
| 1002 |
+
idle_decay = 0.01 * (idle_gap - 2)
|
| 1003 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - idle_decay)
|
| 1004 |
+
if self._driver["trust"] <= 0.1:
|
| 1005 |
+
self._crm["stage"] = "ghosted"
|
| 1006 |
+
return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
|
| 1007 |
+
|
| 1008 |
+
# Route to handler
|
| 1009 |
+
if tool == "crm":
|
| 1010 |
+
return self._handle_crm(act, action)
|
| 1011 |
+
elif tool == "messaging":
|
| 1012 |
+
return self._handle_messaging(act, action)
|
| 1013 |
+
elif tool == "approval":
|
| 1014 |
+
return self._handle_approval(act, action)
|
| 1015 |
+
elif tool == "workflow":
|
| 1016 |
+
return self._handle_workflow(act, action)
|
| 1017 |
+
|
| 1018 |
+
return self._make_obs(reward=-1.0, feedback=_api(500, error="internal_error"))
|
| 1019 |
+
|
| 1020 |
+
# --- CRM tool ---
|
| 1021 |
+
|
| 1022 |
+
def _handle_crm(self, act, action):
|
| 1023 |
+
if act == "read_candidate":
|
| 1024 |
+
self._has_read_crm = True
|
| 1025 |
+
self._crm_read_count += 1
|
| 1026 |
+
reward = 0.0 if self._crm_read_count <= 1 else -0.1
|
| 1027 |
+
return self._make_obs(reward=reward, feedback=_api(200, data=self._crm))
|
| 1028 |
+
|
| 1029 |
+
elif act == "update_stage":
|
| 1030 |
+
new_stage = action.stage
|
| 1031 |
+
current = self._crm["stage"]
|
| 1032 |
+
|
| 1033 |
+
if new_stage not in ALL_STAGES:
|
| 1034 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_stage", stage=new_stage))
|
| 1035 |
+
|
| 1036 |
+
# Compute penalty for non-ideal transitions
|
| 1037 |
+
penalty = 0.0
|
| 1038 |
+
if new_stage not in ("lost", "ghosted"):
|
| 1039 |
+
cur_idx = STAGE_ORDER.index(current) if current in STAGE_ORDER else -1
|
| 1040 |
+
new_idx = STAGE_ORDER.index(new_stage) if new_stage in STAGE_ORDER else -1
|
| 1041 |
+
if new_idx >= 0 and cur_idx >= 0:
|
| 1042 |
+
diff = new_idx - cur_idx
|
| 1043 |
+
if diff == 0:
|
| 1044 |
+
# Same stage — wasted action
|
| 1045 |
+
penalty = -0.1
|
| 1046 |
+
elif diff == 1:
|
| 1047 |
+
# Correct next stage — no penalty
|
| 1048 |
+
penalty = 0.0
|
| 1049 |
+
elif diff > 1:
|
| 1050 |
+
# Skipping stages forward — penalize per skip
|
| 1051 |
+
penalty = -0.5 * (diff - 1)
|
| 1052 |
+
else:
|
| 1053 |
+
# Going backwards — heavier penalty
|
| 1054 |
+
penalty = -1.0 * abs(diff)
|
| 1055 |
+
|
| 1056 |
+
self._crm["stage"] = new_stage
|
| 1057 |
+
if new_stage == "hired":
|
| 1058 |
+
return self._finalize_hire(penalty)
|
| 1059 |
+
if new_stage == "lost":
|
| 1060 |
+
return self._finalize_lost(penalty)
|
| 1061 |
+
return self._make_obs(reward=0.0 + penalty, feedback=_api(200, stage=new_stage))
|
| 1062 |
+
|
| 1063 |
+
elif act == "update_field":
|
| 1064 |
+
field = action.field
|
| 1065 |
+
if field not in VALID_CRM_FIELDS:
|
| 1066 |
+
return self._make_obs(reward=-0.5, feedback=_api(400, error="unknown_field", field=field))
|
| 1067 |
+
self._crm["fields"][field] = action.value
|
| 1068 |
+
return self._make_obs(reward=0.0, feedback=_api(200, field=field, value=action.value))
|
| 1069 |
+
|
| 1070 |
+
elif act == "add_note":
|
| 1071 |
+
if not action.value:
|
| 1072 |
+
return self._make_obs(reward=-0.5, feedback=_api(400, error="empty_note"))
|
| 1073 |
+
self._crm["notes"].append(action.value)
|
| 1074 |
+
return self._make_obs(reward=0.0, feedback=_api(200, notes=len(self._crm["notes"])))
|
| 1075 |
+
|
| 1076 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
|
| 1077 |
+
|
| 1078 |
+
# --- Messaging tool ---
|
| 1079 |
+
|
| 1080 |
+
def _handle_messaging(self, act, action):
|
| 1081 |
+
if act == "send_message":
|
| 1082 |
+
topic = action.topic
|
| 1083 |
+
|
| 1084 |
+
# Invalid topic — message still reaches driver, they're confused
|
| 1085 |
+
if topic not in VALID_TOPICS:
|
| 1086 |
+
self._last_contact_step = self._state.step_count
|
| 1087 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - self._driver["decay"] * 2)
|
| 1088 |
+
if self._driver["trust"] <= 0.1:
|
| 1089 |
+
self._crm["stage"] = "ghosted"
|
| 1090 |
+
return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
|
| 1091 |
+
self._pending_reply = ("I'm not sure what you're asking about.", topic)
|
| 1092 |
+
return self._make_obs(reward=-1.0, feedback=_api(200, topic=topic, warning="driver_confused"))
|
| 1093 |
+
|
| 1094 |
+
# Penalty for skipping CRM read, but still send
|
| 1095 |
+
penalty = 0.0
|
| 1096 |
+
if not self._has_read_crm:
|
| 1097 |
+
penalty -= 1.0
|
| 1098 |
+
# Penalty for ignoring pending reply (overwrite it), but still send
|
| 1099 |
+
if self._pending_reply is not None:
|
| 1100 |
+
penalty -= 1.0
|
| 1101 |
+
self._pending_reply = None
|
| 1102 |
+
|
| 1103 |
+
self._last_contact_step = self._state.step_count
|
| 1104 |
+
|
| 1105 |
+
# Trust decay on each message
|
| 1106 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - self._driver["decay"])
|
| 1107 |
+
|
| 1108 |
+
# Trust dropout check
|
| 1109 |
+
if self._driver["trust"] <= 0.1:
|
| 1110 |
+
self._crm["stage"] = "ghosted"
|
| 1111 |
+
return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
|
| 1112 |
+
|
| 1113 |
+
# Generate response based on topic
|
| 1114 |
+
response, reward = self._generate_message_response(topic, action.job_id)
|
| 1115 |
+
if response is None:
|
| 1116 |
+
return self._make_obs(reward=reward + penalty, feedback=_api(404, error="no_valid_target", topic=topic))
|
| 1117 |
+
if response == "NEGOTIATION_EXHAUSTED":
|
| 1118 |
+
self._crm["stage"] = "lost"
|
| 1119 |
+
return self._make_obs(reward=reward + penalty, done=True, feedback=_api(200, result="lost", reason="negotiation_exhausted"))
|
| 1120 |
+
self._pending_reply = (response, topic)
|
| 1121 |
+
return self._make_obs(reward=reward + penalty, feedback=_api(200, topic=topic))
|
| 1122 |
+
|
| 1123 |
+
elif act == "read_reply":
|
| 1124 |
+
if self._pending_reply is None:
|
| 1125 |
+
return self._make_obs(reward=-0.5, feedback=_api(200, reply=None))
|
| 1126 |
+
self._last_contact_step = self._state.step_count
|
| 1127 |
+
|
| 1128 |
+
response, topic = self._pending_reply
|
| 1129 |
+
self._pending_reply = None
|
| 1130 |
+
|
| 1131 |
+
# Auto-add to discovered info for screening topics
|
| 1132 |
+
if topic in SCREENING_TOPICS:
|
| 1133 |
+
self._discovered_info.append(f"[{topic.upper().replace('_', ' ')}] {response}")
|
| 1134 |
+
self._asked.add(f"ask_{topic}")
|
| 1135 |
+
elif topic == "pitch":
|
| 1136 |
+
self._discovered_info.append(f"[PITCH] {response}")
|
| 1137 |
+
elif topic in ("negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern"):
|
| 1138 |
+
self._discovered_info.append(f"[NEGOTIATE: {topic.replace('_', ' ')}] {response}")
|
| 1139 |
+
elif topic == "offer":
|
| 1140 |
+
self._discovered_info.append(f"[OFFER] {response}")
|
| 1141 |
+
|
| 1142 |
+
return self._make_obs(reward=0.0, feedback=_api(200, topic=topic, reply=response))
|
| 1143 |
+
|
| 1144 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
|
| 1145 |
+
|
| 1146 |
+
def _generate_message_response(self, topic, job_id):
|
| 1147 |
+
"""Generate driver's response to a message. Returns (response, reward)."""
|
| 1148 |
+
reward = -0.1 # base step cost
|
| 1149 |
+
|
| 1150 |
+
# --- Contact topics ---
|
| 1151 |
+
if topic in ("greeting", "call"):
|
| 1152 |
+
if self._contacted:
|
| 1153 |
+
return _respond_contact_repeat(self._driver), -1.0
|
| 1154 |
+
self._contacted = True
|
| 1155 |
+
pref = self._driver["preferred_contact"]
|
| 1156 |
+
matches = (topic == "greeting" and pref == "text") or (topic == "call" and pref == "call")
|
| 1157 |
+
if matches:
|
| 1158 |
+
self._driver["trust"] = min(1.0, self._driver["trust"] + 0.15)
|
| 1159 |
+
return _respond_contact_good(self._driver, topic), 1.0
|
| 1160 |
+
else:
|
| 1161 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.10)
|
| 1162 |
+
return _respond_contact_wrong(self._driver, topic), -0.3
|
| 1163 |
+
|
| 1164 |
+
# --- Screening topics ---
|
| 1165 |
+
if topic in SCREENING_TOPICS:
|
| 1166 |
+
if not self._contacted:
|
| 1167 |
+
# Still works but driver is cold — penalty
|
| 1168 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.15)
|
| 1169 |
+
ask_key = f"ask_{topic}"
|
| 1170 |
+
if ask_key in self._asked:
|
| 1171 |
+
return _respond_repeat_question(self._driver, topic.replace("_", " ")), -0.5
|
| 1172 |
+
|
| 1173 |
+
respond_map = {
|
| 1174 |
+
"experience": _respond_experience,
|
| 1175 |
+
"home_time": _respond_home_time,
|
| 1176 |
+
"pay": _respond_pay,
|
| 1177 |
+
"equipment": _respond_equipment,
|
| 1178 |
+
"route": _respond_route,
|
| 1179 |
+
"deal_breakers": _respond_deal_breakers,
|
| 1180 |
+
"availability": _respond_availability,
|
| 1181 |
+
"violations": _respond_violations,
|
| 1182 |
+
"medical_card": _respond_medical_card,
|
| 1183 |
+
"references": _respond_references,
|
| 1184 |
+
}
|
| 1185 |
+
response = respond_map[topic](self._driver)
|
| 1186 |
+
penalty = -1.0 if not self._contacted else -0.1
|
| 1187 |
+
return response, penalty
|
| 1188 |
+
|
| 1189 |
+
# --- Pitch ---
|
| 1190 |
+
if topic == "pitch":
|
| 1191 |
+
if not self._contacted:
|
| 1192 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.15)
|
| 1193 |
+
matching = [j for j in self._jobs if j["job_id"] == job_id]
|
| 1194 |
+
if not matching:
|
| 1195 |
+
# No match — pick nothing, return None (will be caught by handler)
|
| 1196 |
+
return None, -1.0
|
| 1197 |
+
penalty = -1.0 if not self._contacted else -0.1
|
| 1198 |
+
return _respond_pitch(self._driver, matching[0]), penalty
|
| 1199 |
+
|
| 1200 |
+
# --- Offer ---
|
| 1201 |
+
if topic == "offer":
|
| 1202 |
+
penalty = 0.0
|
| 1203 |
+
if self._approval_status != "approved":
|
| 1204 |
+
# Allowed but heavy penalty — driver gets confused
|
| 1205 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.2)
|
| 1206 |
+
penalty = -2.0
|
| 1207 |
+
job_id_to_use = self._approval_job_id if job_id < 0 else job_id
|
| 1208 |
+
matching = [j for j in self._jobs if j["job_id"] == job_id_to_use]
|
| 1209 |
+
if not matching:
|
| 1210 |
+
return None, -1.0 + penalty
|
| 1211 |
+
job = matching[0]
|
| 1212 |
+
self._matched_job_id = job_id_to_use
|
| 1213 |
+
score, issues, fatal = score_job_fit(self._driver, job)
|
| 1214 |
+
if not fatal:
|
| 1215 |
+
score = min(100, score + self._negotiation_score_bonus)
|
| 1216 |
+
if fatal:
|
| 1217 |
+
return _respond_offer_reject(self._driver, issues[0]), -0.5 + penalty
|
| 1218 |
+
elif score >= 70:
|
| 1219 |
+
return _respond_offer_accept(self._driver, job), 0.0 + penalty
|
| 1220 |
+
elif score >= 50:
|
| 1221 |
+
concern = issues[0] if issues else "minor concerns"
|
| 1222 |
+
self._negotiation_concerns = issues
|
| 1223 |
+
return _respond_offer_concerns(self._driver, job, concern), 0.0 + penalty
|
| 1224 |
+
else:
|
| 1225 |
+
return _respond_offer_reject(self._driver, issues[0] if issues else "not a fit"), -0.5 + penalty
|
| 1226 |
+
|
| 1227 |
+
# --- Negotiation topics ---
|
| 1228 |
+
if topic in ("negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern"):
|
| 1229 |
+
if self._matched_job_id < 0 and self._approval_job_id >= 0:
|
| 1230 |
+
self._matched_job_id = self._approval_job_id
|
| 1231 |
+
if self._matched_job_id < 0:
|
| 1232 |
+
return None, -1.0
|
| 1233 |
+
if self._negotiation_round >= 5:
|
| 1234 |
+
return "NEGOTIATION_EXHAUSTED", -2.0
|
| 1235 |
+
|
| 1236 |
+
self._negotiation_round += 1
|
| 1237 |
+
matches = [j for j in self._jobs if j["job_id"] == self._matched_job_id]
|
| 1238 |
+
if not matches:
|
| 1239 |
+
return None, -1.0
|
| 1240 |
+
job = matches[0]
|
| 1241 |
+
if not self._negotiation_concerns:
|
| 1242 |
+
self._negotiation_concerns = _get_negotiation_concerns(self._driver, job)
|
| 1243 |
+
response = _respond_negotiation(self._driver, topic, job, self._negotiation_concerns)
|
| 1244 |
+
|
| 1245 |
+
# Score bonus
|
| 1246 |
+
if topic == "address_concern" and self._negotiation_concerns:
|
| 1247 |
+
self._negotiation_score_bonus += 15
|
| 1248 |
+
self._negotiation_concerns.pop(0)
|
| 1249 |
+
elif topic == "negotiate_pay" and any("pay" in c.lower() for c in self._negotiation_concerns):
|
| 1250 |
+
self._negotiation_score_bonus += 10
|
| 1251 |
+
self._negotiation_concerns = [c for c in self._negotiation_concerns if "pay" not in c.lower()]
|
| 1252 |
+
elif topic == "negotiate_home_time" and any("home time" in c.lower() for c in self._negotiation_concerns):
|
| 1253 |
+
self._negotiation_score_bonus += 10
|
| 1254 |
+
self._negotiation_concerns = [c for c in self._negotiation_concerns if "home time" not in c.lower()]
|
| 1255 |
+
elif topic == "signing_bonus":
|
| 1256 |
+
self._negotiation_score_bonus += 5
|
| 1257 |
+
else:
|
| 1258 |
+
self._negotiation_score_bonus += 2
|
| 1259 |
+
|
| 1260 |
+
# Extra trust decay during negotiation
|
| 1261 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.01)
|
| 1262 |
+
return response, -0.1
|
| 1263 |
+
|
| 1264 |
+
return None, -1.0
|
| 1265 |
+
|
| 1266 |
+
# --- Approval tool ---
|
| 1267 |
+
|
| 1268 |
+
def _handle_approval(self, act, action):
|
| 1269 |
+
if act == "request_approval":
|
| 1270 |
+
if action.job_id < 0:
|
| 1271 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="job_id_required"))
|
| 1272 |
+
matching = [j for j in self._jobs if j["job_id"] == action.job_id]
|
| 1273 |
+
if not matching:
|
| 1274 |
+
return self._make_obs(reward=-1.0, feedback=_api(404, error="job_not_found", job_id=action.job_id))
|
| 1275 |
+
# Allow re-request but penalize — resets approval
|
| 1276 |
+
penalty = -0.5 if self._approval_status in ("pending", "approved") else 0.0
|
| 1277 |
+
self._approval_status = "pending"
|
| 1278 |
+
self._approval_job_id = action.job_id
|
| 1279 |
+
return self._make_obs(reward=0.0 + penalty, feedback=_api(202, approval_status="pending", job_id=action.job_id))
|
| 1280 |
+
|
| 1281 |
+
elif act == "check_approval":
|
| 1282 |
+
if self._approval_status == "none":
|
| 1283 |
+
return self._make_obs(reward=-0.5, feedback=_api(200, approval_status="none"))
|
| 1284 |
+
if self._approval_status == "pending":
|
| 1285 |
+
return self._make_obs(reward=-0.1, feedback=_api(202, approval_status="pending"))
|
| 1286 |
+
return self._make_obs(
|
| 1287 |
+
reward=0.5 if self._approval_status == "approved" else -0.5,
|
| 1288 |
+
feedback=_api(200, approval_status=self._approval_status, job_id=self._approval_job_id)
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
|
| 1292 |
+
|
| 1293 |
+
# --- Workflow tool ---
|
| 1294 |
+
|
| 1295 |
+
def _handle_workflow(self, act, action):
|
| 1296 |
+
if act == "wait":
|
| 1297 |
+
if self._approval_status == "pending":
|
| 1298 |
+
# Process approval based on job quality
|
| 1299 |
+
job = [j for j in self._jobs if j["job_id"] == self._approval_job_id]
|
| 1300 |
+
if job:
|
| 1301 |
+
score, _, fatal = score_job_fit(self._driver, job[0])
|
| 1302 |
+
if fatal:
|
| 1303 |
+
self._approval_status = "denied"
|
| 1304 |
+
else:
|
| 1305 |
+
self._approval_status = "approved"
|
| 1306 |
+
else:
|
| 1307 |
+
self._approval_status = "denied"
|
| 1308 |
+
return self._make_obs(reward=0.0, feedback=_api(200, elapsed="1h"))
|
| 1309 |
+
|
| 1310 |
+
# Generic wait — trust decay + penalty for wasting time
|
| 1311 |
+
self._driver["trust"] = max(0.0, self._driver["trust"] - 0.02)
|
| 1312 |
+
return self._make_obs(reward=-0.5, feedback=_api(200, elapsed="1h"))
|
| 1313 |
+
|
| 1314 |
+
return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
|
| 1315 |
+
|
| 1316 |
+
# --- Terminal handlers ---
|
| 1317 |
+
|
| 1318 |
+
def _score_crm(self):
|
| 1319 |
+
"""Score CRM accuracy — compare recorded fields to ground truth."""
|
| 1320 |
+
ground_truth = {
|
| 1321 |
+
"cdl_class": self._driver["cdl_class"],
|
| 1322 |
+
"years_experience": str(self._driver["experience_years"]),
|
| 1323 |
+
"location": self._driver["location"],
|
| 1324 |
+
"home_time_pref": self._driver["home_time_pref"],
|
| 1325 |
+
"pay_expectation": str(self._driver["min_cpm"]),
|
| 1326 |
+
"equipment_pref": self._driver["equipment_pref"],
|
| 1327 |
+
"route_pref": self._driver["route_pref"],
|
| 1328 |
+
"availability": self._driver["availability"],
|
| 1329 |
+
"violations": self._driver["violations"],
|
| 1330 |
+
"medical_card": self._driver["medical_card"],
|
| 1331 |
+
"references": self._driver["references"],
|
| 1332 |
+
}
|
| 1333 |
+
# Endorsements and deal_breakers are lists — normalize
|
| 1334 |
+
ground_truth["endorsements"] = ", ".join(sorted(self._driver["endorsements"])) if self._driver["endorsements"] else "none"
|
| 1335 |
+
ground_truth["deal_breakers"] = ", ".join(sorted(self._driver["deal_breakers"]))
|
| 1336 |
+
|
| 1337 |
+
score = 0.0
|
| 1338 |
+
for field, truth in ground_truth.items():
|
| 1339 |
+
recorded = self._crm["fields"].get(field, "")
|
| 1340 |
+
if not recorded:
|
| 1341 |
+
continue
|
| 1342 |
+
# Exact match (case-insensitive)
|
| 1343 |
+
if recorded.strip().lower() == truth.lower():
|
| 1344 |
+
score += 0.4
|
| 1345 |
+
# Partial match — truth appears in recorded or vice versa
|
| 1346 |
+
elif truth.lower() in recorded.strip().lower() or recorded.strip().lower() in truth.lower():
|
| 1347 |
+
score += 0.2
|
| 1348 |
+
else:
|
| 1349 |
+
# Wrong value recorded — small penalty
|
| 1350 |
+
score -= 0.1
|
| 1351 |
+
|
| 1352 |
+
# Small bonus for notes (shows diligence)
|
| 1353 |
+
score += min(0.5, len(self._crm["notes"]) * 0.1)
|
| 1354 |
+
|
| 1355 |
+
# Cap: up to 5.0 bonus for perfect CRM (13 fields × 0.4 = 5.2)
|
| 1356 |
+
return max(0.0, min(5.0, score))
|
| 1357 |
+
|
| 1358 |
+
def _finalize_hire(self, stage_penalty=0.0):
|
| 1359 |
+
"""Handle stage transition to hired — compute final reward."""
|
| 1360 |
+
crm_bonus = self._score_crm()
|
| 1361 |
+
|
| 1362 |
+
if self._approval_status != "approved":
|
| 1363 |
+
self._crm["stage"] = "lost"
|
| 1364 |
+
return self._make_obs(
|
| 1365 |
+
reward=-5.0 + stage_penalty, done=True,
|
| 1366 |
+
feedback=_api(200, result="lost", reason="no_approval")
|
| 1367 |
+
)
|
| 1368 |
+
|
| 1369 |
+
job_id = self._approval_job_id
|
| 1370 |
+
matching = [j for j in self._jobs if j["job_id"] == job_id]
|
| 1371 |
+
if not matching:
|
| 1372 |
+
self._crm["stage"] = "lost"
|
| 1373 |
+
return self._make_obs(
|
| 1374 |
+
reward=-5.0 + stage_penalty, done=True,
|
| 1375 |
+
feedback=_api(200, result="lost", reason="no_job")
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
job = matching[0]
|
| 1379 |
+
score, issues, fatal = score_job_fit(self._driver, job)
|
| 1380 |
+
if not fatal:
|
| 1381 |
+
score = min(100, score + self._negotiation_score_bonus)
|
| 1382 |
+
|
| 1383 |
+
if fatal:
|
| 1384 |
+
self._crm["stage"] = "lost"
|
| 1385 |
+
return self._make_obs(
|
| 1386 |
+
reward=-5.0 + stage_penalty, done=True,
|
| 1387 |
+
feedback=_api(200, result="rejected", reason=issues[0], job_id=job_id)
|
| 1388 |
+
)
|
| 1389 |
+
elif score >= 70:
|
| 1390 |
+
return self._make_obs(
|
| 1391 |
+
reward=10.0 + crm_bonus + stage_penalty, done=True,
|
| 1392 |
+
feedback=_api(200, result="hired", job_id=job_id, score=score, crm_bonus=round(crm_bonus, 1))
|
| 1393 |
+
)
|
| 1394 |
+
elif score >= 50:
|
| 1395 |
+
return self._make_obs(
|
| 1396 |
+
reward=4.0 + crm_bonus + stage_penalty, done=True,
|
| 1397 |
+
feedback=_api(200, result="hired_with_reservations", job_id=job_id, score=score, concern=issues[0] if issues else "minor")
|
| 1398 |
+
)
|
| 1399 |
+
else:
|
| 1400 |
+
self._crm["stage"] = "lost"
|
| 1401 |
+
return self._make_obs(
|
| 1402 |
+
reward=-5.0 + stage_penalty, done=True,
|
| 1403 |
+
feedback=_api(200, result="rejected", reason=issues[0] if issues else "poor_fit", job_id=job_id)
|
| 1404 |
+
)
|
| 1405 |
+
|
| 1406 |
+
def _finalize_lost(self, stage_penalty=0.0):
|
| 1407 |
+
"""Handle stage transition to lost."""
|
| 1408 |
+
has_good = any(score_job_fit(self._driver, j)[0] >= 70 for j in self._jobs)
|
| 1409 |
+
if has_good:
|
| 1410 |
+
return self._make_obs(
|
| 1411 |
+
reward=-3.0 + stage_penalty, done=True,
|
| 1412 |
+
feedback=_api(200, result="lost", good_match_existed=True)
|
| 1413 |
+
)
|
| 1414 |
+
else:
|
| 1415 |
+
return self._make_obs(
|
| 1416 |
+
reward=1.0 + stage_penalty, done=True,
|
| 1417 |
+
feedback=_api(200, result="lost", good_match_existed=False)
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
@property
|
| 1421 |
+
def state(self) -> State:
|
| 1422 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
train_colab.ipynb
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🚛 Driver Recruit Environment — RL Training with TRL\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Train a 3B LLM to recruit truck drivers using REINFORCE with TRL.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"The model learns to choose the right screening topics to ask drivers,\n",
|
| 12 |
+
"then the auto-pilot handles CRM updates, approval, and hiring.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"**Environment**: [OpenEnv 0.2.1](https://github.com/meta-pytorch/OpenEnv) deployed on HF Spaces\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"**Model**: Qwen/Qwen2.5-3B-Instruct\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"**Algorithm**: REINFORCE with batch-level advantage normalization"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"## 1. Install Dependencies"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"!pip install -q openenv-core[core]==0.2.1 trl transformers torch accelerate"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"## 2. Connect to the Environment\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"The recruiting environment is deployed on HF Spaces. Replace the URL below with your Space URL."
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"import json\n",
|
| 53 |
+
"import random\n",
|
| 54 |
+
"import re\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"import torch\n",
|
| 57 |
+
"import torch.nn.functional as F\n",
|
| 58 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# --- Connect to environment ---\n",
|
| 61 |
+
"# Replace with your HF Space URL\n",
|
| 62 |
+
"ENV_URL = \"https://YOUR-USERNAME-recruitopenenv.hf.space\" # <-- CHANGE THIS\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"from openenv.client import EnvClient\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# Quick test: reset and check the env is alive\n",
|
| 67 |
+
"import requests\n",
|
| 68 |
+
"resp = requests.post(f\"{ENV_URL}/reset\", json={\"seed\": 42})\n",
|
| 69 |
+
"data = resp.json()\n",
|
| 70 |
+
"print(f\"Driver: {data['observation']['driver_name']}\")\n",
|
| 71 |
+
"print(f\"Stage: {data['observation']['stage']}\")\n",
|
| 72 |
+
"print(\"Environment connected!\")"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"source": [
|
| 79 |
+
"## 3. Environment Helper Functions"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"def env_reset(seed=None):\n",
|
| 89 |
+
" \"\"\"Reset environment via HTTP.\"\"\"\n",
|
| 90 |
+
" payload = {\"seed\": seed} if seed else {}\n",
|
| 91 |
+
" resp = requests.post(f\"{ENV_URL}/reset\", json=payload)\n",
|
| 92 |
+
" return resp.json()\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"def env_step(tool, action, **kwargs):\n",
|
| 95 |
+
" \"\"\"Step environment via HTTP.\"\"\"\n",
|
| 96 |
+
" payload = {\"tool\": tool, \"action\": action, **kwargs}\n",
|
| 97 |
+
" resp = requests.post(f\"{ENV_URL}/step\", json=payload)\n",
|
| 98 |
+
" return resp.json()\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# --- Topic-based auto-pilot ---\n",
|
| 101 |
+
"SCREENING_TOPICS = [\n",
|
| 102 |
+
" \"experience\", \"home_time\", \"pay\", \"equipment\", \"route\",\n",
|
| 103 |
+
" \"deal_breakers\", \"availability\", \"violations\", \"medical_card\", \"references\",\n",
|
| 104 |
+
"]\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"SYSTEM_PROMPT = \"\"\"You are a truck driver recruiter screening a candidate. Choose the next topic to discuss.\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"Topics for first contact: greeting (text), call (phone)\n",
|
| 109 |
+
"Screening topics: experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references\n",
|
| 110 |
+
"Say \"done\" when you have enough info to proceed with hiring.\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"Respond with ONLY the topic name, nothing else.\"\"\"\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"ALL_TOPICS = [\"greeting\", \"call\"] + SCREENING_TOPICS + [\"done\"]\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"def parse_topic(text):\n",
|
| 117 |
+
" \"\"\"Extract topic name from model output.\"\"\"\n",
|
| 118 |
+
" text = text.strip().lower().replace('\"', '').replace(\"'\", \"\")\n",
|
| 119 |
+
" text = text.split(\"\\n\")[0].strip().split(\".\")[0].strip()\n",
|
| 120 |
+
" for topic in ALL_TOPICS:\n",
|
| 121 |
+
" if topic in text or topic.replace(\"_\", \" \") in text:\n",
|
| 122 |
+
" return topic\n",
|
| 123 |
+
" if \"deal\" in text: return \"deal_breakers\"\n",
|
| 124 |
+
" if \"home\" in text: return \"home_time\"\n",
|
| 125 |
+
" if \"medical\" in text: return \"medical_card\"\n",
|
| 126 |
+
" return \"done\"\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"def build_prompt(obs, asked):\n",
|
| 129 |
+
" \"\"\"Build prompt showing state and available topics.\"\"\"\n",
|
| 130 |
+
" parts = [f\"Driver: {obs['driver_name']}\"]\n",
|
| 131 |
+
" if obs.get('jobs_summary'):\n",
|
| 132 |
+
" parts.append(f\"Jobs:\\n{obs['jobs_summary']}\")\n",
|
| 133 |
+
" if obs.get('discovered_info'):\n",
|
| 134 |
+
" parts.append(f\"Discovered:\\n{obs['discovered_info']}\")\n",
|
| 135 |
+
" parts.append(f\"Stage: {obs['stage']}\")\n",
|
| 136 |
+
" if asked:\n",
|
| 137 |
+
" parts.append(f\"Already asked: {', '.join(asked)}\")\n",
|
| 138 |
+
" available = [t for t in ALL_TOPICS if t not in asked]\n",
|
| 139 |
+
" parts.append(f\"Available: {', '.join(available)}\")\n",
|
| 140 |
+
" return \"\\n\".join(parts)\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"print(\"Helpers loaded!\")"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "markdown",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"source": [
|
| 149 |
+
"## 4. Run a Demo Episode\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"Watch the auto-pilot run a full recruiting episode."
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": null,
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"def run_demo_episode(seed=42):\n",
|
| 161 |
+
" \"\"\"Run one full episode with scripted topic choices.\"\"\"\n",
|
| 162 |
+
" state = env_reset(seed=seed)\n",
|
| 163 |
+
" obs = state[\"observation\"]\n",
|
| 164 |
+
" total_reward = 0.0\n",
|
| 165 |
+
" print(f\"=== Driver: {obs['driver_name']} ===\")\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" # Read CRM\n",
|
| 168 |
+
" state = env_step(\"crm\", \"read_candidate\")\n",
|
| 169 |
+
" total_reward += state[\"reward\"]\n",
|
| 170 |
+
" obs = state[\"observation\"]\n",
|
| 171 |
+
" print(f\"\\nJobs available:\\n{obs['jobs_summary'][:200]}...\")\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" # Greet\n",
|
| 174 |
+
" state = env_step(\"messaging\", \"send_message\", topic=\"greeting\")\n",
|
| 175 |
+
" total_reward += state[\"reward\"]\n",
|
| 176 |
+
" print(f\"\\nGreeting reward: {state['reward']}\")\n",
|
| 177 |
+
"\n",
|
| 178 |
+
" state = env_step(\"messaging\", \"read_reply\")\n",
|
| 179 |
+
" total_reward += state[\"reward\"]\n",
|
| 180 |
+
" obs = state[\"observation\"]\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"contacted\")\n",
|
| 183 |
+
" total_reward += state[\"reward\"]\n",
|
| 184 |
+
"\n",
|
| 185 |
+
" # Screen\n",
|
| 186 |
+
" for topic in [\"experience\", \"deal_breakers\", \"pay\", \"home_time\"]:\n",
|
| 187 |
+
" if state.get(\"done\"): break\n",
|
| 188 |
+
" state = env_step(\"messaging\", \"send_message\", topic=topic)\n",
|
| 189 |
+
" total_reward += state[\"reward\"]\n",
|
| 190 |
+
" state = env_step(\"messaging\", \"read_reply\")\n",
|
| 191 |
+
" total_reward += state[\"reward\"]\n",
|
| 192 |
+
" obs = state[\"observation\"]\n",
|
| 193 |
+
" print(f\" {topic}: reward={state['reward']:.1f}\")\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" print(f\"\\nDiscovered:\\n{obs.get('discovered_info', 'none')[:300]}\")\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" # Approval + hire\n",
|
| 198 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"interested\")\n",
|
| 199 |
+
" total_reward += state[\"reward\"]\n",
|
| 200 |
+
" state = env_step(\"approval\", \"request_approval\", job_id=0)\n",
|
| 201 |
+
" total_reward += state[\"reward\"]\n",
|
| 202 |
+
" state = env_step(\"workflow\", \"wait\")\n",
|
| 203 |
+
" total_reward += state[\"reward\"]\n",
|
| 204 |
+
" state = env_step(\"approval\", \"check_approval\")\n",
|
| 205 |
+
" total_reward += state[\"reward\"]\n",
|
| 206 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"approval_pending\")\n",
|
| 207 |
+
" total_reward += state[\"reward\"]\n",
|
| 208 |
+
" state = env_step(\"messaging\", \"send_message\", topic=\"offer\", job_id=0)\n",
|
| 209 |
+
" total_reward += state[\"reward\"]\n",
|
| 210 |
+
" state = env_step(\"messaging\", \"read_reply\")\n",
|
| 211 |
+
" total_reward += state[\"reward\"]\n",
|
| 212 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"offer_sent\")\n",
|
| 213 |
+
" total_reward += state[\"reward\"]\n",
|
| 214 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"hired\")\n",
|
| 215 |
+
" total_reward += state[\"reward\"]\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" obs = state[\"observation\"]\n",
|
| 218 |
+
" print(f\"\\nFinal stage: {obs['stage']}\")\n",
|
| 219 |
+
" print(f\"Total reward: {total_reward:.1f}\")\n",
|
| 220 |
+
" print(f\"Done: {state.get('done')}\")\n",
|
| 221 |
+
" return total_reward\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"run_demo_episode()"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "markdown",
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"source": [
|
| 230 |
+
"## 5. Load Model"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"execution_count": null,
|
| 236 |
+
"metadata": {},
|
| 237 |
+
"outputs": [],
|
| 238 |
+
"source": [
|
| 239 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
| 240 |
+
"TEMPERATURE = 1.5\n",
|
| 241 |
+
"MAX_NEW_TOKENS = 32\n",
|
| 242 |
+
"MAX_TOPICS = 8\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 245 |
+
"if tokenizer.pad_token_id is None:\n",
|
| 246 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 249 |
+
" MODEL_NAME,\n",
|
| 250 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 251 |
+
" device_map=\"auto\",\n",
|
| 252 |
+
")\n",
|
| 253 |
+
"model.gradient_checkpointing_enable()\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)\n",
|
| 256 |
+
"device = next(model.parameters()).device\n",
|
| 257 |
+
"print(f\"Model loaded on {device}\")"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "markdown",
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"source": [
|
| 264 |
+
"## 6. Training Loop — REINFORCE with Auto-Pilot\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"The model only picks screening topics (1-5 tokens per decision).\n",
|
| 267 |
+
"The auto-pilot handles CRM, stages, approval, and hiring.\n",
|
| 268 |
+
"Rewards come from the full episode outcome."
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": null,
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"outputs": [],
|
| 276 |
+
"source": [
|
| 277 |
+
"def rollout_episode(model, tokenizer, device, seed=None):\n",
|
| 278 |
+
" \"\"\"Run one auto-piloted episode. Model picks topics, wrapper does the rest.\"\"\"\n",
|
| 279 |
+
" if seed is None:\n",
|
| 280 |
+
" seed = random.randint(0, 2**31 - 1)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" state = env_reset(seed=seed)\n",
|
| 283 |
+
" obs = state[\"observation\"]\n",
|
| 284 |
+
" total_reward = 0.0\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" # Auto: read CRM\n",
|
| 287 |
+
" state = env_step(\"crm\", \"read_candidate\")\n",
|
| 288 |
+
" total_reward += state[\"reward\"]\n",
|
| 289 |
+
" obs = state[\"observation\"]\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" if state.get(\"done\"):\n",
|
| 292 |
+
" return None\n",
|
| 293 |
+
"\n",
|
| 294 |
+
" turn_data = []\n",
|
| 295 |
+
" asked = []\n",
|
| 296 |
+
" contacted = False\n",
|
| 297 |
+
"\n",
|
| 298 |
+
" for _ in range(MAX_TOPICS):\n",
|
| 299 |
+
" if state.get(\"done\"):\n",
|
| 300 |
+
" break\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" # Build prompt\n",
|
| 303 |
+
" obs_text = build_prompt(obs, asked)\n",
|
| 304 |
+
" messages = [\n",
|
| 305 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 306 |
+
" {\"role\": \"user\", \"content\": obs_text},\n",
|
| 307 |
+
" ]\n",
|
| 308 |
+
" prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
|
| 309 |
+
" input_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
| 310 |
+
"\n",
|
| 311 |
+
" # Model generates topic\n",
|
| 312 |
+
" with torch.no_grad():\n",
|
| 313 |
+
" output = model.generate(\n",
|
| 314 |
+
" input_ids, max_new_tokens=MAX_NEW_TOKENS,\n",
|
| 315 |
+
" do_sample=True, temperature=TEMPERATURE,\n",
|
| 316 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 317 |
+
" )\n",
|
| 318 |
+
" gen_ids = output[0, input_ids.shape[1]:].tolist()\n",
|
| 319 |
+
" response = tokenizer.decode(gen_ids, skip_special_tokens=True)\n",
|
| 320 |
+
" topic = parse_topic(response)\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" turn_data.append({\n",
|
| 323 |
+
" \"prompt_ids\": input_ids[0].tolist(),\n",
|
| 324 |
+
" \"gen_ids\": gen_ids,\n",
|
| 325 |
+
" \"topic\": topic,\n",
|
| 326 |
+
" \"turn_reward\": 0.0,\n",
|
| 327 |
+
" })\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" if topic == \"done\":\n",
|
| 330 |
+
" break\n",
|
| 331 |
+
" if topic in asked:\n",
|
| 332 |
+
" total_reward -= 0.5\n",
|
| 333 |
+
" turn_data[-1][\"turn_reward\"] = -0.5\n",
|
| 334 |
+
" asked.append(topic)\n",
|
| 335 |
+
" continue\n",
|
| 336 |
+
"\n",
|
| 337 |
+
" asked.append(topic)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" # Auto: send_message + read_reply\n",
|
| 340 |
+
" state = env_step(\"messaging\", \"send_message\", topic=topic)\n",
|
| 341 |
+
" turn_reward = state[\"reward\"]\n",
|
| 342 |
+
" total_reward += state[\"reward\"]\n",
|
| 343 |
+
" obs = state[\"observation\"]\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" if not state.get(\"done\"):\n",
|
| 346 |
+
" state = env_step(\"messaging\", \"read_reply\")\n",
|
| 347 |
+
" turn_reward += state[\"reward\"]\n",
|
| 348 |
+
" total_reward += state[\"reward\"]\n",
|
| 349 |
+
" obs = state[\"observation\"]\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" # Auto: update stage after contact\n",
|
| 352 |
+
" if topic in (\"greeting\", \"call\") and not contacted and not state.get(\"done\"):\n",
|
| 353 |
+
" contacted = True\n",
|
| 354 |
+
" state = env_step(\"crm\", \"update_stage\", stage=\"contacted\")\n",
|
| 355 |
+
" turn_reward += state[\"reward\"]\n",
|
| 356 |
+
" total_reward += state[\"reward\"]\n",
|
| 357 |
+
" obs = state[\"observation\"]\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" turn_data[-1][\"turn_reward\"] = turn_reward\n",
|
| 360 |
+
"\n",
|
| 361 |
+
" if not turn_data:\n",
|
| 362 |
+
" return None\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" # Auto: approval + offer + hire\n",
|
| 365 |
+
" if not state.get(\"done\") and contacted:\n",
|
| 366 |
+
" for action_spec in [\n",
|
| 367 |
+
" (\"crm\", \"update_stage\", {\"stage\": \"interested\"}),\n",
|
| 368 |
+
" (\"approval\", \"request_approval\", {\"job_id\": 0}),\n",
|
| 369 |
+
" (\"workflow\", \"wait\", {}),\n",
|
| 370 |
+
" (\"approval\", \"check_approval\", {}),\n",
|
| 371 |
+
" (\"crm\", \"update_stage\", {\"stage\": \"approval_pending\"}),\n",
|
| 372 |
+
" (\"messaging\", \"send_message\", {\"topic\": \"offer\", \"job_id\": 0}),\n",
|
| 373 |
+
" (\"messaging\", \"read_reply\", {}),\n",
|
| 374 |
+
" (\"crm\", \"update_stage\", {\"stage\": \"offer_sent\"}),\n",
|
| 375 |
+
" (\"crm\", \"update_stage\", {\"stage\": \"hired\"}),\n",
|
| 376 |
+
" ]:\n",
|
| 377 |
+
" if state.get(\"done\"): break\n",
|
| 378 |
+
" state = env_step(action_spec[0], action_spec[1], **action_spec[2])\n",
|
| 379 |
+
" total_reward += state[\"reward\"]\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" # Sample one turn for training\n",
|
| 382 |
+
" t = random.randrange(len(turn_data))\n",
|
| 383 |
+
" td = turn_data[t]\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" return {\n",
|
| 386 |
+
" \"prompt_ids\": td[\"prompt_ids\"],\n",
|
| 387 |
+
" \"gen_ids\": td[\"gen_ids\"][:MAX_NEW_TOKENS],\n",
|
| 388 |
+
" \"reward\": total_reward,\n",
|
| 389 |
+
" \"stage\": obs.get(\"stage\", \"unknown\"),\n",
|
| 390 |
+
" \"topic\": td[\"topic\"],\n",
|
| 391 |
+
" \"num_topics\": len(asked),\n",
|
| 392 |
+
" }\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"# Quick test\n",
|
| 395 |
+
"ep = rollout_episode(model, tokenizer, device)\n",
|
| 396 |
+
"if ep:\n",
|
| 397 |
+
" print(f\"Topic chosen: {ep['topic']}, Reward: {ep['reward']:.1f}, Stage: {ep['stage']}, Topics asked: {ep['num_topics']}\")"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"execution_count": null,
|
| 403 |
+
"metadata": {},
|
| 404 |
+
"outputs": [],
|
| 405 |
+
"source": [
|
| 406 |
+
"# --- REINFORCE Training Loop ---\n",
|
| 407 |
+
"BATCH_SIZE = 4\n",
|
| 408 |
+
"NUM_STEPS = 50\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"print(f\"Training for {NUM_STEPS} steps, batch size {BATCH_SIZE}\")\n",
|
| 411 |
+
"print(\"=\" * 60)\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"history = {\"loss\": [], \"reward_mean\": [], \"reward_std\": [], \"grad_norm\": []}\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"for step in range(1, NUM_STEPS + 1):\n",
|
| 416 |
+
" # --- Rollout (no gradients) ---\n",
|
| 417 |
+
" model.eval()\n",
|
| 418 |
+
" episodes = []\n",
|
| 419 |
+
" for i in range(BATCH_SIZE):\n",
|
| 420 |
+
" ep = rollout_episode(model, tokenizer, device)\n",
|
| 421 |
+
" if ep and ep[\"gen_ids\"]:\n",
|
| 422 |
+
" episodes.append(ep)\n",
|
| 423 |
+
"\n",
|
| 424 |
+
" if len(episodes) < 2:\n",
|
| 425 |
+
" print(f\"Step {step}: not enough episodes, skipping\")\n",
|
| 426 |
+
" continue\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" # --- Batch-level advantages ---\n",
|
| 429 |
+
" rewards = [ep[\"reward\"] for ep in episodes]\n",
|
| 430 |
+
" mean_r = sum(rewards) / len(rewards)\n",
|
| 431 |
+
" std_r = max(torch.tensor(rewards).std().item(), 1e-4)\n",
|
| 432 |
+
" advantages = [(r - mean_r) / std_r for r in rewards]\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" # --- REINFORCE update ---\n",
|
| 435 |
+
" model.train()\n",
|
| 436 |
+
" optimizer.zero_grad()\n",
|
| 437 |
+
" total_loss = 0.0\n",
|
| 438 |
+
"\n",
|
| 439 |
+
" for ep, adv in zip(episodes, advantages):\n",
|
| 440 |
+
" input_ids = torch.tensor(\n",
|
| 441 |
+
" [ep[\"prompt_ids\"] + ep[\"gen_ids\"]], device=device\n",
|
| 442 |
+
" )\n",
|
| 443 |
+
" prompt_len = len(ep[\"prompt_ids\"])\n",
|
| 444 |
+
" comp_len = len(ep[\"gen_ids\"])\n",
|
| 445 |
+
" if comp_len == 0:\n",
|
| 446 |
+
" continue\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" outputs = model(input_ids)\n",
|
| 449 |
+
" logits = outputs.logits[0, prompt_len - 1 : prompt_len + comp_len - 1]\n",
|
| 450 |
+
" targets = input_ids[0, prompt_len : prompt_len + comp_len]\n",
|
| 451 |
+
" log_probs = F.log_softmax(logits, dim=-1)\n",
|
| 452 |
+
" token_lps = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" loss = -(adv * token_lps.sum()) / len(episodes)\n",
|
| 455 |
+
" loss.backward()\n",
|
| 456 |
+
" total_loss += loss.item()\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()\n",
|
| 459 |
+
" optimizer.step()\n",
|
| 460 |
+
"\n",
|
| 461 |
+
" # --- Log ---\n",
|
| 462 |
+
" history[\"loss\"].append(total_loss)\n",
|
| 463 |
+
" history[\"reward_mean\"].append(mean_r)\n",
|
| 464 |
+
" history[\"reward_std\"].append(std_r)\n",
|
| 465 |
+
" history[\"grad_norm\"].append(grad_norm)\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" topics = [ep[\"topic\"] for ep in episodes]\n",
|
| 468 |
+
" print(f\"Step {step:3d} | loss={total_loss:+.3f} | reward={mean_r:+.1f}±{std_r:.1f} | \"\n",
|
| 469 |
+
" f\"grad={grad_norm:.3f} | topics={topics}\")\n",
|
| 470 |
+
"\n",
|
| 471 |
+
"print(\"\\nTraining complete!\")"
|
| 472 |
+
]
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"cell_type": "markdown",
|
| 476 |
+
"metadata": {},
|
| 477 |
+
"source": [
|
| 478 |
+
"## 7. Plot Training Curves"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"cell_type": "code",
|
| 483 |
+
"execution_count": null,
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [],
|
| 486 |
+
"source": [
|
| 487 |
+
"import matplotlib.pyplot as plt\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"fig, axes = plt.subplots(2, 2, figsize=(12, 8))\n",
|
| 490 |
+
"fig.suptitle(\"Driver Recruit RL Training\", fontsize=14)\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"axes[0, 0].plot(history[\"loss\"])\n",
|
| 493 |
+
"axes[0, 0].set_title(\"Loss\")\n",
|
| 494 |
+
"axes[0, 0].set_xlabel(\"Step\")\n",
|
| 495 |
+
"\n",
|
| 496 |
+
"axes[0, 1].plot(history[\"reward_mean\"])\n",
|
| 497 |
+
"axes[0, 1].set_title(\"Mean Reward\")\n",
|
| 498 |
+
"axes[0, 1].set_xlabel(\"Step\")\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"axes[1, 0].plot(history[\"reward_std\"])\n",
|
| 501 |
+
"axes[1, 0].set_title(\"Reward Std\")\n",
|
| 502 |
+
"axes[1, 0].set_xlabel(\"Step\")\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"axes[1, 1].plot(history[\"grad_norm\"])\n",
|
| 505 |
+
"axes[1, 1].set_title(\"Gradient Norm\")\n",
|
| 506 |
+
"axes[1, 1].set_xlabel(\"Step\")\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"plt.tight_layout()\n",
|
| 509 |
+
"plt.show()"
|
| 510 |
+
]
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"cell_type": "markdown",
|
| 514 |
+
"metadata": {},
|
| 515 |
+
"source": [
|
| 516 |
+
"## 8. Test the Trained Model"
|
| 517 |
+
]
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"cell_type": "code",
|
| 521 |
+
"execution_count": null,
|
| 522 |
+
"metadata": {},
|
| 523 |
+
"outputs": [],
|
| 524 |
+
"source": [
|
| 525 |
+
"print(\"=== Testing trained model ===\")\n",
|
| 526 |
+
"model.eval()\n",
|
| 527 |
+
"test_rewards = []\n",
|
| 528 |
+
"for i in range(5):\n",
|
| 529 |
+
" ep = rollout_episode(model, tokenizer, device)\n",
|
| 530 |
+
" if ep:\n",
|
| 531 |
+
" test_rewards.append(ep[\"reward\"])\n",
|
| 532 |
+
" print(f\" Episode {i+1}: reward={ep['reward']:.1f}, stage={ep['stage']}, \"\n",
|
| 533 |
+
" f\"topics={ep['num_topics']}, chose={ep['topic']}\")\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"if test_rewards:\n",
|
| 536 |
+
" print(f\"\\nMean test reward: {sum(test_rewards)/len(test_rewards):.1f}\")"
|
| 537 |
+
]
|
| 538 |
+
}
|
| 539 |
+
],
|
| 540 |
+
"metadata": {
|
| 541 |
+
"kernelspec": {
|
| 542 |
+
"display_name": "Python 3",
|
| 543 |
+
"language": "python",
|
| 544 |
+
"name": "python3"
|
| 545 |
+
},
|
| 546 |
+
"language_info": {
|
| 547 |
+
"name": "python",
|
| 548 |
+
"version": "3.10.0"
|
| 549 |
+
},
|
| 550 |
+
"accelerator": "GPU",
|
| 551 |
+
"colab": {
|
| 552 |
+
"gpuType": "T4",
|
| 553 |
+
"provenance": []
|
| 554 |
+
}
|
| 555 |
+
},
|
| 556 |
+
"nbformat": 4,
|
| 557 |
+
"nbformat_minor": 4
|
| 558 |
+
}
|
train_grpo.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GRPO training script for the Driver Recruit Environment.
|
| 3 |
+
|
| 4 |
+
Uses TRL's GRPOTrainer with rollout_func for multi-turn episodes.
|
| 5 |
+
The model controls EVERY action in the episode via tool calls.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python train_grpo.py --model Qwen/Qwen2.5-3B-Instruct --use-qlora
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
|
| 20 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 21 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 22 |
+
|
| 23 |
+
# --- Prompt templates ---
|
| 24 |
+
|
| 25 |
+
SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
|
| 26 |
+
|
| 27 |
+
You have 4 tools:
|
| 28 |
+
|
| 29 |
+
## crm
|
| 30 |
+
- read_candidate: Read the current CRM record
|
| 31 |
+
- update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
|
| 32 |
+
- update_field: Record info (field + value)
|
| 33 |
+
- add_note: Add a free-text note
|
| 34 |
+
|
| 35 |
+
## messaging
|
| 36 |
+
- send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
|
| 37 |
+
- read_reply: Read the driver's response
|
| 38 |
+
|
| 39 |
+
## approval
|
| 40 |
+
- request_approval: Request approval for a job (needs job_id)
|
| 41 |
+
- check_approval: Check approval status
|
| 42 |
+
|
| 43 |
+
## workflow
|
| 44 |
+
- wait: Advance time (needed for approval processing)
|
| 45 |
+
|
| 46 |
+
## Rules
|
| 47 |
+
- Must read CRM before messaging
|
| 48 |
+
- Must read_reply before sending another message
|
| 49 |
+
- Must request_approval and wait before sending offer
|
| 50 |
+
- Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
|
| 51 |
+
- Record important info in CRM with update_field
|
| 52 |
+
- Too many messages hurt trust
|
| 53 |
+
|
| 54 |
+
## Workflow
|
| 55 |
+
1. crm.read_candidate
|
| 56 |
+
2. messaging.send_message (greeting/call) → read_reply → update_stage(contacted)
|
| 57 |
+
3. messaging.send_message (screening topics) → read_reply → crm.update_field
|
| 58 |
+
4. crm.update_stage(interested)
|
| 59 |
+
5. approval.request_approval → workflow.wait → approval.check_approval
|
| 60 |
+
6. crm.update_stage(approval_pending)
|
| 61 |
+
7. messaging.send_message(offer) → read_reply
|
| 62 |
+
8. crm.update_stage(offer_sent) → crm.update_stage(hired)
|
| 63 |
+
|
| 64 |
+
Respond with ONLY JSON:
|
| 65 |
+
{"tool": "crm", "action": "read_candidate"}
|
| 66 |
+
{"tool": "messaging", "action": "send_message", "topic": "experience"}
|
| 67 |
+
{"tool": "messaging", "action": "read_reply"}
|
| 68 |
+
{"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
|
| 69 |
+
{"tool": "crm", "action": "update_stage", "stage": "contacted"}
|
| 70 |
+
{"tool": "approval", "action": "request_approval", "job_id": 2}
|
| 71 |
+
{"tool": "workflow", "action": "wait"}
|
| 72 |
+
{"tool": "approval", "action": "check_approval"}
|
| 73 |
+
{"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
|
| 74 |
+
{"tool": "crm", "action": "update_stage", "stage": "hired"}"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def format_observation(obs):
|
| 78 |
+
"""Format observation into a user prompt for the LLM."""
|
| 79 |
+
parts = [f"Driver: {obs.driver_name}"]
|
| 80 |
+
if obs.crm_summary:
|
| 81 |
+
parts.append(f"CRM:\n{obs.crm_summary}")
|
| 82 |
+
if obs.jobs_summary:
|
| 83 |
+
parts.append(f"Jobs:\n{obs.jobs_summary}")
|
| 84 |
+
if obs.discovered_info:
|
| 85 |
+
parts.append(f"Discovered:\n{obs.discovered_info}")
|
| 86 |
+
status = f"Stage: {obs.stage}"
|
| 87 |
+
if obs.pending_reply:
|
| 88 |
+
status += " | PENDING REPLY"
|
| 89 |
+
parts.append(status)
|
| 90 |
+
if obs.feedback:
|
| 91 |
+
parts.append(f"Result: {obs.feedback}")
|
| 92 |
+
return "\n".join(parts)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def format_observation_compact(obs):
|
| 96 |
+
"""Compact observation for embedding in completion_ids (~30-60 tokens)."""
|
| 97 |
+
parts = [f"Stage: {obs.stage}"]
|
| 98 |
+
if obs.pending_reply:
|
| 99 |
+
parts.append("PENDING REPLY")
|
| 100 |
+
if obs.feedback:
|
| 101 |
+
parts.append(obs.feedback[:200])
|
| 102 |
+
if obs.discovered_info:
|
| 103 |
+
parts.append(obs.discovered_info[:200])
|
| 104 |
+
return "\n".join(parts)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def parse_action(text):
|
| 108 |
+
"""Parse LLM output into a RecruitopenenvAction."""
|
| 109 |
+
text = text.strip()
|
| 110 |
+
|
| 111 |
+
# Remove markdown fences
|
| 112 |
+
if "```" in text:
|
| 113 |
+
for part in text.split("```"):
|
| 114 |
+
part = part.strip()
|
| 115 |
+
if part.startswith("json"):
|
| 116 |
+
part = part[4:].strip()
|
| 117 |
+
if part.startswith("{"):
|
| 118 |
+
text = part
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
# Try JSON
|
| 122 |
+
try:
|
| 123 |
+
data = json.loads(text)
|
| 124 |
+
if isinstance(data, list):
|
| 125 |
+
data = data[0] if data else {}
|
| 126 |
+
if isinstance(data, dict) and "tool" in data and "action" in data:
|
| 127 |
+
return RecruitopenenvAction(
|
| 128 |
+
tool=data["tool"],
|
| 129 |
+
action=data["action"],
|
| 130 |
+
topic=data.get("topic", ""),
|
| 131 |
+
job_id=int(data.get("job_id", -1)),
|
| 132 |
+
stage=str(data.get("stage", "")),
|
| 133 |
+
field=str(data.get("field", "")),
|
| 134 |
+
value=str(data.get("value", "")),
|
| 135 |
+
)
|
| 136 |
+
except (json.JSONDecodeError, KeyError, IndexError, ValueError, TypeError):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
# Fallback: try to detect intent
|
| 140 |
+
text_lower = text.lower()
|
| 141 |
+
if "read_candidate" in text_lower:
|
| 142 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 143 |
+
if "read_reply" in text_lower:
|
| 144 |
+
return RecruitopenenvAction(tool="messaging", action="read_reply")
|
| 145 |
+
if "check_approval" in text_lower:
|
| 146 |
+
return RecruitopenenvAction(tool="approval", action="check_approval")
|
| 147 |
+
if "wait" in text_lower:
|
| 148 |
+
return RecruitopenenvAction(tool="workflow", action="wait")
|
| 149 |
+
|
| 150 |
+
# Default to reading CRM
|
| 151 |
+
return RecruitopenenvAction(tool="crm", action="read_candidate")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# --- Multi-turn rollout ---
|
| 155 |
+
|
| 156 |
+
ENV_URL = "http://localhost:8001"
|
| 157 |
+
MAX_COMPLETION_TOKENS = 1536
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _build_chat_transition(tokenizer, obs_text):
|
| 161 |
+
"""Build chat-formatted transition tokens: end assistant turn, user obs, start assistant.
|
| 162 |
+
|
| 163 |
+
Result: <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
|
| 164 |
+
This ensures the model sees proper chat structure during the forward pass.
|
| 165 |
+
"""
|
| 166 |
+
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 167 |
+
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 168 |
+
|
| 169 |
+
# Encode role tags and newlines
|
| 170 |
+
nl = tokenizer.encode("\n", add_special_tokens=False)
|
| 171 |
+
user_tag = tokenizer.encode("user", add_special_tokens=False)
|
| 172 |
+
asst_tag = tokenizer.encode("assistant", add_special_tokens=False)
|
| 173 |
+
obs_ids = tokenizer.encode(obs_text, add_special_tokens=False)[:60]
|
| 174 |
+
|
| 175 |
+
# <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
|
| 176 |
+
return (
|
| 177 |
+
[im_end] + nl +
|
| 178 |
+
[im_start] + user_tag + nl +
|
| 179 |
+
obs_ids +
|
| 180 |
+
[im_end] + nl +
|
| 181 |
+
[im_start] + asst_tag + nl
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def rollout_once(trainer, env, tokenizer, prompt_text, system_prompt, max_turns=15):
|
| 186 |
+
"""Run one multi-turn episode with chat-formatted transitions.
|
| 187 |
+
|
| 188 |
+
completion_ids: [action1, <|im_end|>user obs<|im_start|>assistant, action2, ...]
|
| 189 |
+
The chat template structure lets the forward pass assign proper logprobs.
|
| 190 |
+
"""
|
| 191 |
+
seed = random.randint(0, 2**31 - 1)
|
| 192 |
+
result = env.reset(seed=seed)
|
| 193 |
+
obs = result.observation
|
| 194 |
+
|
| 195 |
+
prompt_ids = []
|
| 196 |
+
completion_ids = []
|
| 197 |
+
logprobs = []
|
| 198 |
+
env_mask = []
|
| 199 |
+
total_reward = 0.0
|
| 200 |
+
steps = 0
|
| 201 |
+
|
| 202 |
+
messages = [
|
| 203 |
+
{"role": "system", "content": system_prompt},
|
| 204 |
+
{"role": "user", "content": format_observation(obs)},
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
while not result.done and steps < max_turns:
|
| 208 |
+
# Check if we're near the token budget (need room for action + transition)
|
| 209 |
+
if len(completion_ids) > MAX_COMPLETION_TOKENS - 60:
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
current_prompt = tokenizer.apply_chat_template(
|
| 213 |
+
messages, add_generation_prompt=True, tokenize=False
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
rollout_outputs = generate_rollout_completions(trainer, [current_prompt])[0]
|
| 217 |
+
|
| 218 |
+
if steps == 0:
|
| 219 |
+
prompt_ids = list(rollout_outputs["prompt_ids"])
|
| 220 |
+
|
| 221 |
+
action_ids = list(rollout_outputs["completion_ids"])
|
| 222 |
+
action_logprobs = list(rollout_outputs["logprobs"])
|
| 223 |
+
|
| 224 |
+
# Add action tokens (these get gradients)
|
| 225 |
+
completion_ids.extend(action_ids)
|
| 226 |
+
logprobs.extend(action_logprobs)
|
| 227 |
+
env_mask.extend([1] * len(action_ids))
|
| 228 |
+
|
| 229 |
+
response = rollout_outputs.get("text") or tokenizer.decode(
|
| 230 |
+
action_ids, skip_special_tokens=True
|
| 231 |
+
)
|
| 232 |
+
messages.append({"role": "assistant", "content": response})
|
| 233 |
+
|
| 234 |
+
action = parse_action(response)
|
| 235 |
+
result = env.step(action)
|
| 236 |
+
obs = result.observation
|
| 237 |
+
total_reward += result.reward
|
| 238 |
+
steps += 1
|
| 239 |
+
|
| 240 |
+
if not result.done:
|
| 241 |
+
# Build chat-formatted transition so forward pass sees proper structure
|
| 242 |
+
obs_text = format_observation_compact(obs)
|
| 243 |
+
transition_ids = _build_chat_transition(tokenizer, obs_text)
|
| 244 |
+
|
| 245 |
+
completion_ids.extend(transition_ids)
|
| 246 |
+
logprobs.extend([0.0] * len(transition_ids))
|
| 247 |
+
env_mask.extend([0] * len(transition_ids))
|
| 248 |
+
|
| 249 |
+
messages.append({"role": "user", "content": format_observation(obs)})
|
| 250 |
+
|
| 251 |
+
# Truncate to fit max_completion_length
|
| 252 |
+
completion_ids = completion_ids[:MAX_COMPLETION_TOKENS]
|
| 253 |
+
logprobs = logprobs[:MAX_COMPLETION_TOKENS]
|
| 254 |
+
env_mask = env_mask[:MAX_COMPLETION_TOKENS]
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"prompt_ids": prompt_ids,
|
| 258 |
+
"completion_ids": completion_ids,
|
| 259 |
+
"logprobs": logprobs,
|
| 260 |
+
"env_mask": env_mask,
|
| 261 |
+
"env_reward": total_reward,
|
| 262 |
+
"steps": steps,
|
| 263 |
+
"final_stage": obs.stage,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def rollout_func(prompts, trainer):
|
| 268 |
+
"""Multi-turn rollout: model controls every action in the episode."""
|
| 269 |
+
tokenizer = trainer.processing_class
|
| 270 |
+
env = RecruitopenenvEnv(base_url=ENV_URL)
|
| 271 |
+
|
| 272 |
+
all_prompt_ids = []
|
| 273 |
+
all_completion_ids = []
|
| 274 |
+
all_logprobs = []
|
| 275 |
+
all_env_rewards = []
|
| 276 |
+
all_env_mask = []
|
| 277 |
+
|
| 278 |
+
for prompt_text in prompts:
|
| 279 |
+
episode = rollout_once(trainer, env, tokenizer, prompt_text, SYSTEM_PROMPT)
|
| 280 |
+
|
| 281 |
+
if episode["completion_ids"]:
|
| 282 |
+
all_prompt_ids.append(episode["prompt_ids"])
|
| 283 |
+
all_completion_ids.append(episode["completion_ids"])
|
| 284 |
+
all_logprobs.append(episode["logprobs"])
|
| 285 |
+
all_env_mask.append(episode["env_mask"])
|
| 286 |
+
else:
|
| 287 |
+
tok_ids = tokenizer.encode("wait", add_special_tokens=False)
|
| 288 |
+
all_prompt_ids.append(episode["prompt_ids"] or tok_ids)
|
| 289 |
+
all_completion_ids.append(tok_ids)
|
| 290 |
+
all_logprobs.append([0.0] * len(tok_ids))
|
| 291 |
+
all_env_mask.append([1] * len(tok_ids))
|
| 292 |
+
|
| 293 |
+
all_env_rewards.append(episode["env_reward"])
|
| 294 |
+
print(f" Episode {len(all_env_rewards)}: reward={episode['env_reward']:.1f}, "
|
| 295 |
+
f"steps={episode['steps']}, stage={episode['final_stage']}")
|
| 296 |
+
|
| 297 |
+
env.close()
|
| 298 |
+
|
| 299 |
+
mean_r = sum(all_env_rewards) / len(all_env_rewards)
|
| 300 |
+
std_r = torch.tensor(all_env_rewards).std().item()
|
| 301 |
+
print(f"Rollout done: {len(all_env_rewards)} episodes, mean_reward={mean_r:.2f}, std={std_r:.2f}")
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"prompt_ids": all_prompt_ids,
|
| 305 |
+
"completion_ids": all_completion_ids,
|
| 306 |
+
"logprobs": [[(lp,) for lp in seq] for seq in all_logprobs],
|
| 307 |
+
"env_reward": all_env_rewards,
|
| 308 |
+
"env_mask": all_env_mask,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# --- Reward function (fallback, rewards come from rollout) ---
|
| 313 |
+
|
| 314 |
+
def reward_total(completions, **kwargs):
|
| 315 |
+
"""Extract environment rewards passed via rollout_func kwargs."""
|
| 316 |
+
env_rewards = kwargs.get("env_reward", [])
|
| 317 |
+
if env_rewards:
|
| 318 |
+
return [float(r) for r in env_rewards]
|
| 319 |
+
return [0.0] * len(completions)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# --- Main ---
|
| 323 |
+
|
| 324 |
+
def main():
|
| 325 |
+
parser = argparse.ArgumentParser(description="GRPO training for Driver Recruit Environment")
|
| 326 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Model to train")
|
| 327 |
+
parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL")
|
| 328 |
+
parser.add_argument("--num-episodes", type=int, default=16, help="Number of training episodes (dataset size)")
|
| 329 |
+
parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt")
|
| 330 |
+
parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
|
| 331 |
+
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
|
| 332 |
+
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
|
| 333 |
+
parser.add_argument("--output-dir", default="./recruit-grpo-output", help="Output directory")
|
| 334 |
+
parser.add_argument("--vllm-mode", default="colocate", choices=["colocate", "server"],
|
| 335 |
+
help="vLLM mode: colocate (1 GPU) or server (2+ GPUs)")
|
| 336 |
+
parser.add_argument("--use-qlora", action="store_true", help="Use QLoRA (4-bit) for memory efficiency")
|
| 337 |
+
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
| 338 |
+
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
|
| 339 |
+
args = parser.parse_args()
|
| 340 |
+
|
| 341 |
+
global ENV_URL
|
| 342 |
+
ENV_URL = args.env_url
|
| 343 |
+
|
| 344 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 345 |
+
|
| 346 |
+
prompts = []
|
| 347 |
+
env = RecruitopenenvEnv(base_url=args.env_url)
|
| 348 |
+
for i in range(args.num_episodes):
|
| 349 |
+
result = env.reset()
|
| 350 |
+
obs = result.observation
|
| 351 |
+
user_prompt = format_observation(obs)
|
| 352 |
+
messages = [
|
| 353 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 354 |
+
{"role": "user", "content": user_prompt},
|
| 355 |
+
]
|
| 356 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 357 |
+
messages, add_generation_prompt=True, tokenize=False
|
| 358 |
+
)
|
| 359 |
+
prompts.append(prompt_text)
|
| 360 |
+
env.close()
|
| 361 |
+
|
| 362 |
+
dataset = Dataset.from_dict({"prompt": prompts})
|
| 363 |
+
|
| 364 |
+
peft_config = None
|
| 365 |
+
model_kwargs = {}
|
| 366 |
+
if args.use_qlora:
|
| 367 |
+
from peft import LoraConfig
|
| 368 |
+
peft_config = LoraConfig(
|
| 369 |
+
r=args.lora_r,
|
| 370 |
+
lora_alpha=args.lora_alpha,
|
| 371 |
+
lora_dropout=0.05,
|
| 372 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 373 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 374 |
+
task_type="CAUSAL_LM",
|
| 375 |
+
)
|
| 376 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 377 |
+
load_in_4bit=True,
|
| 378 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 379 |
+
bnb_4bit_quant_type="nf4",
|
| 380 |
+
)
|
| 381 |
+
print(f"Using QLoRA: r={args.lora_r}, alpha={args.lora_alpha}, 4-bit")
|
| 382 |
+
|
| 383 |
+
grpo_config = GRPOConfig(
|
| 384 |
+
output_dir=args.output_dir,
|
| 385 |
+
use_vllm=True,
|
| 386 |
+
vllm_mode=args.vllm_mode,
|
| 387 |
+
num_train_epochs=args.epochs,
|
| 388 |
+
num_generations=args.num_generations,
|
| 389 |
+
max_completion_length=1536,
|
| 390 |
+
per_device_train_batch_size=args.batch_size,
|
| 391 |
+
gradient_accumulation_steps=4,
|
| 392 |
+
gradient_checkpointing=True,
|
| 393 |
+
learning_rate=args.lr,
|
| 394 |
+
temperature=0.7,
|
| 395 |
+
logging_steps=1,
|
| 396 |
+
save_steps=50,
|
| 397 |
+
bf16=True,
|
| 398 |
+
report_to="wandb",
|
| 399 |
+
run_name="recruit-grpo-tools",
|
| 400 |
+
model_init_kwargs=model_kwargs if model_kwargs else None,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
trainer_kwargs = dict(
|
| 404 |
+
model=args.model,
|
| 405 |
+
processing_class=tokenizer,
|
| 406 |
+
reward_funcs=[reward_total],
|
| 407 |
+
train_dataset=dataset,
|
| 408 |
+
args=grpo_config,
|
| 409 |
+
rollout_func=rollout_func,
|
| 410 |
+
)
|
| 411 |
+
if peft_config is not None:
|
| 412 |
+
trainer_kwargs["peft_config"] = peft_config
|
| 413 |
+
|
| 414 |
+
trainer = GRPOTrainer(**trainer_kwargs)
|
| 415 |
+
|
| 416 |
+
print("=" * 50)
|
| 417 |
+
print(f"Training {args.model} (TOOL-BASED MULTI-TURN)")
|
| 418 |
+
print(f"Environment: {args.env_url}")
|
| 419 |
+
print(f"QLoRA: {args.use_qlora}")
|
| 420 |
+
print(f"Episodes: {args.num_episodes}")
|
| 421 |
+
print(f"Epochs: {args.epochs}")
|
| 422 |
+
print(f"Generations per prompt: {args.num_generations}")
|
| 423 |
+
print("=" * 50)
|
| 424 |
+
|
| 425 |
+
trainer.train()
|
| 426 |
+
trainer.save_model(args.output_dir)
|
| 427 |
+
print(f"\nModel saved to {args.output_dir}")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
if __name__ == "__main__":
|
| 431 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|