bathientran commited on
Commit
be37527
·
verified ·
1 Parent(s): e8d8505

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=recruitopenenv
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,123 @@
1
  ---
2
- title: Recruitopenenv
3
- emoji: 🏢
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Driver Recruit Environment
3
+ emoji: 🚛
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - reinforcement-learning
13
+ - recruiting
14
+ - multi-turn
15
  ---
16
 
17
+ # 🚛 Driver Recruit Environment
18
+
19
+ A **multi-turn, tool-based RL environment** for training LLMs to recruit truck drivers through a CRM system. Built on [OpenEnv 0.2.1](https://github.com/meta-pytorch/OpenEnv).
20
+
21
+ The agent must discover driver qualifications through conversation, record info in the CRM, get management approval, and hire — all using structured tool calls across 15-40+ step episodes.
22
+
23
+ ## Pipeline
24
+
25
+ ```
26
+ lead → contacted → interested → approval_pending → offer_sent → hired
27
+ ```
28
+
29
+ ## Tools
30
+
31
+ | Tool | Actions | Purpose |
32
+ |------|---------|---------|
33
+ | **crm** | `read_candidate`, `update_stage`, `update_field`, `add_note` | Manage pipeline & record info |
34
+ | **messaging** | `send_message`, `read_reply` | Screen driver (18 topics) |
35
+ | **approval** | `request_approval`, `check_approval` | Get management sign-off |
36
+ | **workflow** | `wait` | Advance time for approval processing |
37
+
38
+ ## Reward Signal
39
+
40
+ - **Successful hire** (good job fit): **+10** to **+15** (base + CRM bonus)
41
+ - **Bad hire** (poor match): **-5**
42
+ - **Ghosted** (trust runs out): **-4**
43
+ - **Per-step**: Small rewards/penalties for correct/incorrect actions
44
+
45
+ ## What Makes This Hard
46
+
47
+ - **Long horizon**: 15-40+ tool calls per episode
48
+ - **Information gathering**: Must ask the right screening questions to match driver to the right job
49
+ - **Trust dynamics**: Each message costs trust — ask too many questions and the driver ghosts
50
+ - **Job matching**: 6 jobs per episode (1-2 good, 1-2 traps with deal-breakers, 2-3 partial)
51
+ - **Procedural correctness**: Must follow stage order, read replies before messaging, get approval before offering
52
+
53
+ ## Quick Start
54
+
55
+ ```python
56
+ from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
57
+
58
+ env = RecruitopenenvEnv(base_url="YOUR_SPACE_URL")
59
+
60
+ result = env.reset(seed=42)
61
+ obs = result.observation
62
+ print(f"Driver: {obs.driver_name}, Stage: {obs.stage}")
63
+
64
+ # Read CRM
65
+ result = env.step(RecruitopenenvAction(tool="crm", action="read_candidate"))
66
+ print(result.observation.jobs_summary)
67
+
68
+ # Greet driver
69
+ result = env.step(RecruitopenenvAction(tool="messaging", action="send_message", topic="greeting"))
70
+ print(f"Reward: {result.reward}")
71
+
72
+ # Read reply
73
+ result = env.step(RecruitopenenvAction(tool="messaging", action="read_reply"))
74
+ print(result.observation.discovered_info)
75
+
76
+ env.close()
77
+ ```
78
+
79
+ ## Training
80
+
81
+ We train using GRPO/REINFORCE with the model choosing screening topics. See `train_grpo.py` for the full training script.
82
+
83
+ ```bash
84
+ python train_grpo.py --model Qwen/Qwen2.5-3B-Instruct
85
+ ```
86
+
87
+ ## Deploying
88
+
89
+ ```bash
90
+ # From the recruitopenenv/ directory
91
+ openenv push
92
+ ```
93
+
94
+ ## Action Format
95
+
96
+ ```json
97
+ {"tool": "crm", "action": "read_candidate"}
98
+ {"tool": "messaging", "action": "send_message", "topic": "experience"}
99
+ {"tool": "messaging", "action": "read_reply"}
100
+ {"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
101
+ {"tool": "crm", "action": "update_stage", "stage": "contacted"}
102
+ {"tool": "approval", "action": "request_approval", "job_id": 2}
103
+ {"tool": "workflow", "action": "wait"}
104
+ {"tool": "approval", "action": "check_approval"}
105
+ {"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
106
+ {"tool": "crm", "action": "update_stage", "stage": "hired"}
107
+ ```
108
+
109
+ ## Observation Fields
110
+
111
+ | Field | Description |
112
+ |-------|-------------|
113
+ | `driver_name` | Driver's name |
114
+ | `crm_summary` | Full CRM record (empty until `read_candidate`) |
115
+ | `jobs_summary` | 6 available job listings |
116
+ | `discovered_info` | Info from screening conversations |
117
+ | `stage` | Current pipeline stage |
118
+ | `feedback` | API response from last action |
119
+ | `pending_reply` | Whether driver has unread message |
120
+
121
+ ## Screening Topics
122
+
123
+ `greeting`, `call`, `experience`, `home_time`, `pay`, `equipment`, `route`, `deal_breakers`, `availability`, `violations`, `medical_card`, `references`, `pitch`, `offer`, `negotiate_pay`, `negotiate_home_time`, `signing_bonus`, `address_concern`
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Recruitopenenv Environment."""
8
+
9
+ from .client import RecruitopenenvEnv
10
+ from .models import RecruitopenenvAction, RecruitopenenvObservation
11
+
12
+ __all__ = [
13
+ "RecruitopenenvAction",
14
+ "RecruitopenenvObservation",
15
+ "RecruitopenenvEnv",
16
+ ]
baseline_llm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM agent baseline — test how well a base model performs without RL training."""
2
+
3
+ import argparse
4
+ import json
5
+ import requests
6
+ from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
7
+
8
+ SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
9
+
10
+ You have 4 tools:
11
+
12
+ ## crm
13
+ - read_candidate: Read the current CRM record
14
+ - update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
15
+ - update_field: Record info (field + value)
16
+ - add_note: Add a free-text note
17
+
18
+ ## messaging
19
+ - send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
20
+ - read_reply: Read the driver's response
21
+
22
+ ## approval
23
+ - request_approval: Request approval for a job (needs job_id)
24
+ - check_approval: Check approval status
25
+
26
+ ## workflow
27
+ - wait: Advance time (needed for approval processing)
28
+
29
+ ## Rules
30
+ - Must read CRM before messaging
31
+ - Must read_reply before sending another message
32
+ - Must request_approval and wait before sending offer
33
+ - Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
34
+ - Record important info in CRM with update_field
35
+ - Too many messages hurt trust
36
+
37
+ ## Strategy
38
+ 1. crm.read_candidate → see the lead
39
+ 2. messaging.send_message(greeting or call) → messaging.read_reply → crm.update_stage(contacted)
40
+ 3. Screen: send_message(experience) → read_reply → update_field(cdl_class, value) ... repeat for key questions
41
+ 4. crm.update_stage(interested)
42
+ 5. approval.request_approval(job_id) → workflow.wait → approval.check_approval
43
+ 6. crm.update_stage(approval_pending)
44
+ 7. messaging.send_message(offer) → messaging.read_reply
45
+ 8. crm.update_stage(offer_sent) → crm.update_stage(hired)
46
+
47
+ Tips:
48
+ - ask_experience is critical (CDL class filters jobs)
49
+ - ask_deal_breakers helps avoid trap jobs
50
+ - ask_violations and ask_medical_card reveal fatal blockers
51
+ - If driver has concerns about offer, use negotiate_pay/negotiate_home_time/address_concern
52
+ - If no good match exists, update_stage to lost
53
+
54
+ Respond with ONLY JSON:
55
+ {"tool": "crm", "action": "read_candidate"}
56
+ {"tool": "messaging", "action": "send_message", "topic": "experience"}
57
+ {"tool": "messaging", "action": "read_reply"}
58
+ {"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
59
+ {"tool": "approval", "action": "request_approval", "job_id": 2}
60
+ {"tool": "crm", "action": "update_stage", "stage": "hired"}"""
61
+
62
+
63
+ def format_observation(obs):
64
+ parts = [f"Driver: {obs.driver_name}"]
65
+ if obs.crm_summary:
66
+ parts.append(f"CRM:\n{obs.crm_summary}")
67
+ if obs.jobs_summary:
68
+ parts.append(f"Jobs:\n{obs.jobs_summary}")
69
+ if obs.discovered_info:
70
+ parts.append(f"Discovered:\n{obs.discovered_info}")
71
+ status = f"Stage: {obs.stage}"
72
+ if obs.pending_reply:
73
+ status += " | PENDING REPLY"
74
+ parts.append(status)
75
+ if obs.feedback:
76
+ parts.append(f"Result: {obs.feedback}")
77
+ return "\n".join(parts)
78
+
79
+
80
+ def ask_llm(messages, llm_url, model):
81
+ resp = requests.post(llm_url, json={
82
+ "model": model,
83
+ "messages": messages,
84
+ "temperature": 0.1,
85
+ "max_tokens": 150,
86
+ })
87
+ content = resp.json()["choices"][0]["message"]["content"]
88
+ return content
89
+
90
+
91
+ def parse_action(text):
92
+ """Try to extract action from LLM response."""
93
+ text = text.strip()
94
+
95
+ # Remove markdown code fences
96
+ if "```" in text:
97
+ parts = text.split("```")
98
+ for part in parts:
99
+ part = part.strip()
100
+ if part.startswith("json"):
101
+ part = part[4:].strip()
102
+ if part.startswith("{"):
103
+ text = part
104
+ break
105
+
106
+ # Try JSON parse
107
+ try:
108
+ data = json.loads(text)
109
+ if isinstance(data, dict) and "tool" in data and "action" in data:
110
+ return RecruitopenenvAction(
111
+ tool=data["tool"],
112
+ action=data["action"],
113
+ topic=data.get("topic", ""),
114
+ job_id=data.get("job_id", -1),
115
+ stage=data.get("stage", ""),
116
+ field=data.get("field", ""),
117
+ value=data.get("value", ""),
118
+ )
119
+ except (json.JSONDecodeError, KeyError):
120
+ pass
121
+
122
+ # Fallback
123
+ text_lower = text.lower()
124
+ if "read_candidate" in text_lower:
125
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
126
+ if "read_reply" in text_lower:
127
+ return RecruitopenenvAction(tool="messaging", action="read_reply")
128
+ if "check_approval" in text_lower:
129
+ return RecruitopenenvAction(tool="approval", action="check_approval")
130
+ if "wait" in text_lower:
131
+ return RecruitopenenvAction(tool="workflow", action="wait")
132
+
133
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
134
+
135
+
136
+ def run_baseline(env_url, llm_url, model, num_episodes):
137
+ rewards = []
138
+ successes = 0
139
+ total_steps = 0
140
+
141
+ env = RecruitopenenvEnv(base_url=env_url)
142
+
143
+ for ep in range(num_episodes):
144
+ result = env.reset()
145
+ obs = result.observation
146
+ ep_reward = 0.0
147
+ steps = 0
148
+
149
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
150
+
151
+ while not result.done and steps < 100:
152
+ obs_text = format_observation(obs)
153
+ messages.append({"role": "user", "content": obs_text})
154
+
155
+ llm_response = ask_llm(messages, llm_url, model)
156
+ messages.append({"role": "assistant", "content": llm_response})
157
+
158
+ action = parse_action(llm_response)
159
+ result = env.step(action)
160
+ obs = result.observation
161
+ ep_reward += result.reward
162
+ steps += 1
163
+
164
+ print(f" Step {steps}: {action.tool}.{action.action}"
165
+ f"{'(' + action.topic + ')' if action.topic else ''}"
166
+ f"{'[job=' + str(action.job_id) + ']' if action.job_id >= 0 else ''}"
167
+ f" -> reward={result.reward:.1f}")
168
+
169
+ rewards.append(ep_reward)
170
+ total_steps += steps
171
+ if obs.stage == "hired":
172
+ successes += 1
173
+
174
+ print(f"Episode {ep+1}: total_reward={ep_reward:.1f}, steps={steps}, "
175
+ f"{'HIRED' if obs.stage == 'hired' else 'FAIL (' + obs.stage + ')'}")
176
+ print()
177
+
178
+ env.close()
179
+
180
+ avg_reward = sum(rewards) / len(rewards)
181
+ avg_steps = total_steps / num_episodes
182
+
183
+ print("\n========== LLM BASELINE (no RL) ==========")
184
+ print(f"Model: {model}")
185
+ print(f"Episodes: {num_episodes}")
186
+ print(f"Avg reward: {avg_reward:.2f}")
187
+ print(f"Min reward: {min(rewards):.2f}")
188
+ print(f"Max reward: {max(rewards):.2f}")
189
+ print(f"Hire rate: {successes}/{num_episodes} ({100*successes/num_episodes:.1f}%)")
190
+ print(f"Avg steps/episode: {avg_steps:.1f}")
191
+ print("==========================================")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser(description="LLM baseline for Driver Recruit Environment")
196
+ parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL")
197
+ parser.add_argument("--llm-url", default="http://localhost:8033/v1/chat/completions", help="LLM API URL")
198
+ parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct", help="Model name")
199
+ parser.add_argument("--episodes", type=int, default=20, help="Number of episodes")
200
+ args = parser.parse_args()
201
+
202
+ run_baseline(args.env_url, args.llm_url, args.model, args.episodes)
baseline_random.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Random agent baseline — establishes the floor for reward."""
2
+
3
+ import random
4
+ from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
5
+
6
+ TOOLS_ACTIONS = {
7
+ "crm": ["read_candidate", "update_stage", "update_field", "add_note"],
8
+ "messaging": ["send_message", "read_reply"],
9
+ "approval": ["request_approval", "check_approval"],
10
+ "workflow": ["wait"],
11
+ }
12
+
13
+ TOPICS = [
14
+ "greeting", "call", "experience", "home_time", "pay", "equipment",
15
+ "route", "deal_breakers", "availability", "violations", "medical_card",
16
+ "references", "pitch", "offer", "negotiate_pay", "negotiate_home_time",
17
+ "signing_bonus", "address_concern",
18
+ ]
19
+
20
+ STAGES = ["contacted", "interested", "approval_pending", "offer_sent", "hired", "lost"]
21
+
22
+ NUM_EPISODES = 100
23
+
24
+
25
+ def random_action():
26
+ tool = random.choice(list(TOOLS_ACTIONS.keys()))
27
+ action = random.choice(TOOLS_ACTIONS[tool])
28
+
29
+ topic = ""
30
+ job_id = -1
31
+ stage = ""
32
+ field = ""
33
+ value = ""
34
+
35
+ if tool == "messaging" and action == "send_message":
36
+ topic = random.choice(TOPICS)
37
+ if topic in ("pitch", "offer"):
38
+ job_id = random.randint(0, 5)
39
+ elif tool == "crm" and action == "update_stage":
40
+ stage = random.choice(STAGES)
41
+ elif tool == "crm" and action == "update_field":
42
+ field = random.choice(["cdl_class", "years_exp", "home_time_pref"])
43
+ value = "A"
44
+ elif tool == "approval" and action == "request_approval":
45
+ job_id = random.randint(0, 5)
46
+
47
+ return RecruitopenenvAction(
48
+ tool=tool, action=action, topic=topic,
49
+ job_id=job_id, stage=stage, field=field, value=value,
50
+ )
51
+
52
+
53
+ def run_baseline():
54
+ rewards = []
55
+ successes = 0
56
+ total_steps = 0
57
+
58
+ with RecruitopenenvEnv(base_url="http://localhost:8000").sync() as env:
59
+ for ep in range(NUM_EPISODES):
60
+ result = env.reset()
61
+ ep_reward = 0.0
62
+ steps = 0
63
+
64
+ while not result.done and steps < 100:
65
+ action = random_action()
66
+ result = env.step(action)
67
+ ep_reward += result.reward
68
+ steps += 1
69
+
70
+ rewards.append(ep_reward)
71
+ total_steps += steps
72
+
73
+ if result.observation.stage == "hired":
74
+ successes += 1
75
+
76
+ if (ep + 1) % 10 == 0:
77
+ avg_so_far = sum(rewards) / len(rewards)
78
+ print(f" Episode {ep+1}: reward={ep_reward:.1f}, running avg={avg_so_far:.2f}")
79
+
80
+ avg_reward = sum(rewards) / len(rewards)
81
+ avg_steps = total_steps / NUM_EPISODES
82
+
83
+ print("\n========== RANDOM BASELINE ==========")
84
+ print(f"Episodes: {NUM_EPISODES}")
85
+ print(f"Avg reward: {avg_reward:.2f}")
86
+ print(f"Min reward: {min(rewards):.2f}")
87
+ print(f"Max reward: {max(rewards):.2f}")
88
+ print(f"Hire rate: {successes}/{NUM_EPISODES} ({100*successes/NUM_EPISODES:.1f}%)")
89
+ print(f"Avg steps/episode: {avg_steps:.1f}")
90
+ print("======================================")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ run_baseline()
client.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Recruitopenenv Environment Client."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core.client_types import StepResult
6
+ from openenv.core.env_server.types import State
7
+ from openenv.core import EnvClient
8
+
9
+ from .models import RecruitopenenvAction, RecruitopenenvObservation
10
+
11
+
12
+ class RecruitopenenvEnv(
13
+ EnvClient[RecruitopenenvAction, RecruitopenenvObservation, State]
14
+ ):
15
+ """Client for the Driver Recruit Environment."""
16
+
17
+ def _step_payload(self, action: RecruitopenenvAction) -> Dict:
18
+ payload = {
19
+ "tool": action.tool,
20
+ "action": action.action,
21
+ }
22
+ if action.topic:
23
+ payload["topic"] = action.topic
24
+ if action.job_id >= 0:
25
+ payload["job_id"] = action.job_id
26
+ if action.stage:
27
+ payload["stage"] = action.stage
28
+ if action.field:
29
+ payload["field"] = action.field
30
+ if action.value:
31
+ payload["value"] = action.value
32
+ return payload
33
+
34
+ def _parse_result(self, payload: Dict) -> StepResult[RecruitopenenvObservation]:
35
+ obs_data = payload.get("observation", {})
36
+ observation = RecruitopenenvObservation(
37
+ driver_name=obs_data.get("driver_name", ""),
38
+ crm_summary=obs_data.get("crm_summary", ""),
39
+ jobs_summary=obs_data.get("jobs_summary", ""),
40
+ discovered_info=obs_data.get("discovered_info", ""),
41
+ stage=obs_data.get("stage", "lead"),
42
+ feedback=obs_data.get("feedback", ""),
43
+ pending_reply=obs_data.get("pending_reply", False),
44
+ done=payload.get("done", False),
45
+ reward=payload.get("reward", 0.0),
46
+ )
47
+
48
+ return StepResult(
49
+ observation=observation,
50
+ reward=payload.get("reward", 0.0),
51
+ done=payload.get("done", False),
52
+ )
53
+
54
+ def _parse_state(self, payload: Dict) -> State:
55
+ return State(
56
+ episode_id=payload.get("episode_id"),
57
+ step_count=payload.get("step_count", 0),
58
+ )
demo/index.html ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Driver Recruit Environment</title>
7
+ <link rel="preconnect" href="https://fonts.googleapis.com">
8
+ <link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500;600&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
9
+ <style>
10
+ *{margin:0;padding:0;box-sizing:border-box}
11
+ :root{
12
+ --bg:#09090b;--s1:#111113;--s2:#18181b;--s3:#27272a;
13
+ --b1:#27272a;--b2:#3f3f46;
14
+ --t1:#fafafa;--t2:#a1a1aa;--t3:#71717a;
15
+ --green:#22c55e;--red:#ef4444;--amber:#f59e0b;--blue:#3b82f6;--violet:#8b5cf6;--rose:#f43f5e;--cyan:#06b6d4;--orange:#f97316;
16
+ }
17
+ body{font-family:'Inter',system-ui,sans-serif;background:var(--bg);color:var(--t1);-webkit-font-smoothing:antialiased}
18
+ .mono{font-family:'IBM Plex Mono',monospace}
19
+
20
+ /* ─── HERO ─── */
21
+ .hero{min-height:100vh;display:flex;flex-direction:column;align-items:center;justify-content:center;padding:40px 24px;position:relative;overflow:hidden}
22
+ .hero::before{content:'';position:absolute;top:-200px;left:50%;transform:translateX(-50%);width:800px;height:800px;background:radial-gradient(circle,rgba(139,92,246,0.06) 0%,transparent 70%);pointer-events:none}
23
+ .hero-eyebrow{font-size:13px;font-weight:500;color:var(--violet);letter-spacing:0.08em;text-transform:uppercase;margin-bottom:20px}
24
+ .hero h1{font-size:clamp(2rem,5vw,3.5rem);font-weight:700;letter-spacing:-0.03em;line-height:1.1;text-align:center;max-width:700px;margin-bottom:16px}
25
+ .hero-sub{color:var(--t2);font-size:17px;line-height:1.6;text-align:center;max-width:560px;margin-bottom:48px}
26
+
27
+ .cards{display:grid;grid-template-columns:repeat(4,1fr);gap:1px;background:var(--b1);border:1px solid var(--b1);border-radius:12px;overflow:hidden;max-width:900px;width:100%;margin-bottom:48px}
28
+ .card{background:var(--s1);padding:24px 20px}
29
+ .card-num{font-family:'IBM Plex Mono',monospace;font-size:12px;color:var(--t3);margin-bottom:10px}
30
+ .card h3{font-size:14px;font-weight:600;margin-bottom:6px}
31
+ .card p{font-size:13px;color:var(--t2);line-height:1.55}
32
+ @media(max-width:800px){.cards{grid-template-columns:1fr 1fr}}
33
+ @media(max-width:500px){.cards{grid-template-columns:1fr}}
34
+
35
+ .btn{display:inline-flex;align-items:center;gap:8px;padding:12px 28px;border-radius:8px;font-size:14px;font-weight:600;cursor:pointer;border:none;transition:all .15s}
36
+ .btn-white{background:var(--t1);color:var(--bg)}
37
+ .btn-white:hover{opacity:.9}
38
+
39
+ .env-input{background:var(--s2);border:1px solid var(--b1);color:var(--t2);padding:8px 12px;border-radius:6px;font-size:13px;font-family:'IBM Plex Mono',monospace;width:240px;margin-top:16px;text-align:center}
40
+ .env-input:focus{outline:none;border-color:var(--b2)}
41
+
42
+ /* ─── GAME ─── */
43
+ .game{display:none;max-width:1280px;margin:0 auto;padding:16px 20px 40px}
44
+ .game.on{display:block}
45
+
46
+ /* Top bar */
47
+ .topbar{display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:1px solid var(--b1);margin-bottom:16px;flex-wrap:wrap;gap:12px}
48
+ .topbar-left{display:flex;align-items:center;gap:16px}
49
+ .avatar{width:36px;height:36px;border-radius:50%;background:var(--s3);display:flex;align-items:center;justify-content:center;font-weight:600;font-size:14px;flex-shrink:0}
50
+ .driver-meta h2{font-size:15px;font-weight:600;line-height:1}
51
+ .pill{display:inline-block;font-size:11px;font-weight:500;padding:2px 8px;border-radius:4px;margin-top:4px}
52
+ .pill-chatty{background:rgba(34,197,94,.12);color:var(--green)}
53
+ .pill-professional{background:rgba(59,130,246,.12);color:var(--blue)}
54
+ .pill-impatient{background:rgba(239,68,68,.12);color:var(--red)}
55
+ .pill-suspicious{background:rgba(244,63,94,.12);color:var(--rose)}
56
+
57
+ .topbar-stats{display:flex;gap:24px;align-items:center}
58
+ .ts{text-align:right}
59
+ .ts-label{font-size:11px;color:var(--t3);text-transform:uppercase;letter-spacing:.05em}
60
+ .ts-val{font-family:'IBM Plex Mono',monospace;font-size:16px;font-weight:600}
61
+
62
+ .trust-wrap{width:140px}
63
+ .trust-track{width:100%;height:4px;background:var(--s3);border-radius:2px;margin-top:4px;overflow:hidden}
64
+ .trust-fill{height:100%;border-radius:2px;transition:width .4s ease}
65
+ .trust-num{display:flex;justify-content:space-between;align-items:center;margin-top:3px}
66
+ .trust-num span{font-family:'IBM Plex Mono',monospace;font-size:11px;color:var(--t3)}
67
+ .trust-num .delta{font-weight:600}
68
+ .delta-up{color:var(--green)}
69
+ .delta-down{color:var(--red)}
70
+
71
+ /* Layout */
72
+ .layout{display:grid;grid-template-columns:260px 1fr 260px;gap:16px;min-height:calc(100vh - 120px)}
73
+ @media(max-width:1000px){.layout{grid-template-columns:1fr}}
74
+ .right-col{display:flex;flex-direction:column;gap:12px}
75
+
76
+ /* Sidebar */
77
+ .sidebar{display:flex;flex-direction:column;gap:12px}
78
+ .pane{background:var(--s1);border:1px solid var(--b1);border-radius:10px;overflow:hidden}
79
+ .pane-head{font-size:11px;font-weight:600;color:var(--t3);text-transform:uppercase;letter-spacing:.06em;padding:12px 14px 8px;display:flex;align-items:center;justify-content:space-between}
80
+
81
+ .job{padding:10px 14px;border-bottom:1px solid var(--b1);font-size:13px;cursor:default;transition:background .1s}
82
+ .job:last-child{border-bottom:none}
83
+ .job:hover{background:var(--s2)}
84
+ .job-id{font-family:'IBM Plex Mono',monospace;color:var(--t3);font-size:11px;font-weight:500}
85
+ .job-co{font-weight:500;margin-left:6px}
86
+ .job-det{color:var(--t2);font-size:12px;margin-top:3px;line-height:1.5}
87
+ .job-warn{color:var(--amber);font-size:11px;margin-top:2px}
88
+
89
+ .info-item{padding:8px 14px;border-bottom:1px solid var(--b1);font-size:13px;line-height:1.5}
90
+ .info-item:last-child{border-bottom:none}
91
+ .info-cat{font-family:'IBM Plex Mono',monospace;font-size:10px;font-weight:600;color:var(--violet);text-transform:uppercase;letter-spacing:.04em}
92
+ .info-empty{padding:14px;color:var(--t3);font-size:13px;font-style:italic}
93
+
94
+ .crm-field{padding:5px 14px;border-bottom:1px solid var(--b1);font-size:12px;display:flex;justify-content:space-between}
95
+ .crm-field:last-child{border-bottom:none}
96
+ .crm-key{color:var(--t3);font-family:'IBM Plex Mono',monospace;font-size:11px}
97
+ .crm-val{color:var(--t1)}
98
+ .crm-empty{padding:14px;color:var(--t3);font-size:13px;font-style:italic}
99
+
100
+ /* Main area */
101
+ .main{display:flex;flex-direction:column;gap:12px}
102
+
103
+ /* Timeline */
104
+ .timeline{flex:1;background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:16px;overflow-y:auto;max-height:calc(100vh - 320px);min-height:320px}
105
+ .tl-entry{display:flex;gap:12px;margin-bottom:16px;animation:slideIn .25s ease}
106
+ @keyframes slideIn{from{opacity:0;transform:translateY(6px)}to{opacity:1;transform:none}}
107
+ .tl-dot-col{display:flex;flex-direction:column;align-items:center;padding-top:4px}
108
+ .tl-dot{width:8px;height:8px;border-radius:50%;flex-shrink:0}
109
+ .tl-line{width:1px;flex:1;background:var(--b1);margin-top:4px}
110
+ .tl-content{flex:1;min-width:0}
111
+ .tl-head{display:flex;align-items:center;gap:8px;margin-bottom:4px;flex-wrap:wrap}
112
+ .tl-action{font-size:13px;font-weight:600}
113
+ .tl-reward{font-family:'IBM Plex Mono',monospace;font-size:12px;font-weight:500}
114
+ .tl-reward.pos{color:var(--green)}.tl-reward.neg{color:var(--red)}.tl-reward.zero{color:var(--t3)}
115
+ .tl-tool-badge{font-size:10px;font-weight:600;padding:1px 6px;border-radius:3px;text-transform:uppercase;letter-spacing:.04em}
116
+ .badge-crm{background:rgba(6,182,212,.12);color:var(--cyan)}
117
+ .badge-messaging{background:rgba(139,92,246,.12);color:var(--violet)}
118
+ .badge-approval{background:rgba(249,115,22,.12);color:var(--orange)}
119
+ .badge-workflow{background:rgba(113,113,122,.15);color:var(--t3)}
120
+ .tl-body{font-size:13px;color:var(--t2);line-height:1.55;padding:8px 12px;background:var(--s2);border-radius:6px;border-left:2px solid var(--b1)}
121
+ .tl-body.good{border-left-color:var(--green)}
122
+ .tl-body.bad{border-left-color:var(--red)}
123
+ .tl-step{font-family:'IBM Plex Mono',monospace;font-size:10px;color:var(--t3)}
124
+
125
+ /* Tool sections */
126
+ .tool-section{background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:10px 14px}
127
+ .tool-section-head{font-size:11px;font-weight:600;text-transform:uppercase;letter-spacing:.05em;margin-bottom:8px;display:flex;align-items:center;gap:6px}
128
+ .tool-section-head .dot{width:6px;height:6px;border-radius:50%;display:inline-block}
129
+
130
+ .act-grid{display:flex;flex-wrap:wrap;gap:6px}
131
+ .act{padding:6px 12px;border-radius:6px;font-size:12px;font-weight:500;cursor:pointer;border:1px solid var(--b1);background:var(--s2);color:var(--t2);transition:all .12s;white-space:nowrap}
132
+ .act:hover{background:var(--s3);color:var(--t1);border-color:var(--b2)}
133
+ .act:disabled{opacity:.3;cursor:not-allowed}
134
+ .act-go{border-color:rgba(34,197,94,.3);color:var(--green)}
135
+ .act-go:hover{background:rgba(34,197,94,.1);border-color:var(--green)}
136
+ .act-no{border-color:rgba(239,68,68,.2);color:var(--red)}
137
+ .act-no:hover{background:rgba(239,68,68,.08);border-color:var(--red)}
138
+ .act-warn{border-color:rgba(245,158,11,.2);color:var(--amber)}
139
+ .act-warn:hover{background:rgba(245,158,11,.08);border-color:var(--amber)}
140
+
141
+ /* Pipeline */
142
+ .pipeline{background:var(--s1);border:1px solid var(--b1);border-radius:10px;padding:10px 14px}
143
+ .pipe-stages{display:flex;gap:2px;align-items:center}
144
+ .pipe-stage{font-size:10px;font-weight:600;padding:4px 8px;border-radius:4px;text-transform:uppercase;letter-spacing:.04em;background:var(--s2);color:var(--t3);border:1px solid var(--b1)}
145
+ .pipe-stage.active{background:rgba(139,92,246,.15);color:var(--violet);border-color:var(--violet)}
146
+ .pipe-stage.done{background:rgba(34,197,94,.1);color:var(--green);border-color:rgba(34,197,94,.3)}
147
+ .pipe-stage.fail{background:rgba(239,68,68,.1);color:var(--red);border-color:rgba(239,68,68,.3)}
148
+ .pipe-arrow{color:var(--t3);font-size:10px}
149
+
150
+ /* Pending reply indicator */
151
+ .pending-badge{font-size:11px;font-weight:500;padding:3px 10px;border-radius:4px;background:rgba(139,92,246,.12);color:var(--violet);animation:pulse 1.5s ease infinite}
152
+ @keyframes pulse{0%,100%{opacity:1}50%{opacity:.5}}
153
+
154
+ /* Modal */
155
+ .modal-bg{display:none;position:fixed;inset:0;background:rgba(0,0,0,.6);z-index:100;align-items:center;justify-content:center;backdrop-filter:blur(4px)}
156
+ .modal-bg.on{display:flex}
157
+ .modal{background:var(--s1);border:1px solid var(--b1);border-radius:12px;padding:20px;width:420px;max-width:90vw}
158
+ .modal h3{font-size:14px;font-weight:600;margin-bottom:14px}
159
+ .modal-job{display:block;width:100%;text-align:left;background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:10px 12px;border-radius:6px;margin-bottom:6px;cursor:pointer;font-size:13px;transition:border-color .12s}
160
+ .modal-job:hover{border-color:var(--violet)}
161
+ .modal-cancel{display:block;width:100%;text-align:center;background:transparent;border:1px solid var(--b1);color:var(--t3);padding:8px;border-radius:6px;cursor:pointer;font-size:13px;margin-top:8px}
162
+ .modal-cancel:hover{color:var(--t2)}
163
+
164
+ /* Input modal */
165
+ .modal input[type="text"]{width:100%;background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:8px 12px;border-radius:6px;font-size:13px;font-family:'IBM Plex Mono',monospace;margin-bottom:8px}
166
+ .modal input[type="text"]:focus{outline:none;border-color:var(--violet)}
167
+ .modal-row{display:flex;gap:8px;margin-top:8px}
168
+ .modal-row .btn{flex:1;justify-content:center;padding:8px;font-size:13px}
169
+ .modal-btn-go{background:var(--violet);color:white;border:none;padding:8px 16px;border-radius:6px;font-size:13px;font-weight:500;cursor:pointer}
170
+ .modal-btn-go:hover{opacity:.9}
171
+
172
+ /* End screen */
173
+ .endscreen{display:none;position:fixed;inset:0;background:rgba(0,0,0,.75);z-index:200;align-items:center;justify-content:center;backdrop-filter:blur(8px)}
174
+ .endscreen.on{display:flex}
175
+ .end-card{background:var(--s1);border:1px solid var(--b1);border-radius:14px;padding:40px 36px;text-align:center;width:440px;max-width:90vw}
176
+ .end-label{font-size:11px;font-weight:600;text-transform:uppercase;letter-spacing:.1em;margin-bottom:8px}
177
+ .end-title{font-size:28px;font-weight:700;letter-spacing:-0.02em;margin-bottom:6px}
178
+ .end-sub{font-size:13px;color:var(--t2);margin-bottom:28px;line-height:1.5}
179
+ .end-grid{display:grid;grid-template-columns:1fr 1fr 1fr;gap:1px;background:var(--b1);border:1px solid var(--b1);border-radius:8px;overflow:hidden;margin-bottom:24px}
180
+ .end-stat{background:var(--s1);padding:14px 8px}
181
+ .end-stat-val{font-family:'IBM Plex Mono',monospace;font-size:18px;font-weight:600}
182
+ .end-stat-lbl{font-size:11px;color:var(--t3);margin-top:2px}
183
+
184
+ /* Hidden info toggle */
185
+ .hidden-info.off{display:none}
186
+ .toggle-hidden{background:var(--s2);border:1px solid var(--b1);color:var(--t3);padding:4px 10px;border-radius:4px;font-size:11px;cursor:pointer;margin-left:12px}
187
+ .toggle-hidden:hover{color:var(--t2);border-color:var(--b2)}
188
+
189
+ /* Stage select */
190
+ .stage-select{background:var(--s2);border:1px solid var(--b1);color:var(--t1);padding:4px 8px;border-radius:4px;font-size:12px;font-family:'IBM Plex Mono',monospace}
191
+ .stage-select:focus{outline:none;border-color:var(--violet)}
192
+ </style>
193
+ </head>
194
+ <body>
195
+
196
+ <div class="hero" id="hero">
197
+ <div class="hero-eyebrow">OpenEnv Hackathon &mdash; Long-Horizon RL</div>
198
+ <h1>Train an AI to recruit truck drivers through tool calls</h1>
199
+ <p class="hero-sub">A multi-turn RL environment where agents use CRM, messaging, approval, and workflow tools across 40-70 step episodes to screen candidates, avoid trap jobs, and close hires.</p>
200
+
201
+ <div class="cards">
202
+ <div class="card">
203
+ <div class="card-num mono">01</div>
204
+ <h3>Tool calling</h3>
205
+ <p>4 tools &mdash; CRM, messaging, approval, workflow. The agent must call the right tool with the right action at each step.</p>
206
+ </div>
207
+ <div class="card">
208
+ <div class="card-num mono">02</div>
209
+ <h3>Long horizon</h3>
210
+ <p>Episodes span 40-70 steps through a full recruiting pipeline: lead &rarr; contacted &rarr; interested &rarr; approval &rarr; offer &rarr; hired.</p>
211
+ </div>
212
+ <div class="card">
213
+ <div class="card-num mono">03</div>
214
+ <h3>Hidden information</h3>
215
+ <p>Driver preferences, deal breakers, and personality are hidden. Must be discovered through screening messages.</p>
216
+ </div>
217
+ <div class="card">
218
+ <div class="card-num mono">04</div>
219
+ <h3>Trap jobs</h3>
220
+ <p>Jobs that look perfect but violate deal breakers. Skip screening and you'll hire for the wrong one &mdash; big negative reward.</p>
221
+ </div>
222
+ </div>
223
+
224
+ <button class="btn btn-white" onclick="startGame()">Play the environment</button>
225
+ <input class="env-input" id="envUrl" value="http://localhost:8000" spellcheck="false">
226
+ </div>
227
+
228
+ <div class="game" id="game">
229
+ <div class="topbar">
230
+ <div class="topbar-left">
231
+ <div class="avatar" id="av">?</div>
232
+ <div class="driver-meta">
233
+ <h2 id="dName">---</h2>
234
+ <span class="pill" id="dPers"></span>
235
+ </div>
236
+ </div>
237
+ <div class="topbar-stats">
238
+ <div class="ts">
239
+ <div class="ts-label">Stage</div>
240
+ <div class="ts-val" id="uiStage">lead</div>
241
+ </div>
242
+ <div class="ts">
243
+ <div class="ts-label">Step</div>
244
+ <div class="ts-val"><span id="uiStep">0</span><span style="color:var(--t3)"> / 100</span></div>
245
+ </div>
246
+ <div class="ts">
247
+ <div class="ts-label">Reward</div>
248
+ <div class="ts-val" id="uiRew">0.0</div>
249
+ </div>
250
+ <div id="pendingBadge" style="display:none" class="pending-badge">Unread reply</div>
251
+ <button class="toggle-hidden" onclick="toggleHidden()">Show hidden</button>
252
+ </div>
253
+ </div>
254
+
255
+ <!-- Pipeline -->
256
+ <div class="pipeline">
257
+ <div class="pipe-stages" id="pipeStages"></div>
258
+ </div>
259
+
260
+ <div class="layout" style="margin-top:12px">
261
+ <div class="sidebar">
262
+ <div class="pane" id="jobsPane">
263
+ <div class="pane-head">Jobs</div>
264
+ <div id="jobsList"></div>
265
+ </div>
266
+ <div class="pane">
267
+ <div class="pane-head">CRM Record</div>
268
+ <div id="crmList"><div class="crm-empty">Not loaded &mdash; use crm.read_candidate</div></div>
269
+ </div>
270
+ </div>
271
+ <div class="main">
272
+ <div class="timeline" id="tl"></div>
273
+
274
+ <!-- Tool: CRM -->
275
+ <div class="tool-section">
276
+ <div class="tool-section-head"><span class="dot" style="background:var(--cyan)"></span><span style="color:var(--cyan)">CRM</span></div>
277
+ <div class="act-grid">
278
+ <button class="act" onclick="doTool('crm','read_candidate')">read_candidate</button>
279
+ <button class="act" onclick="showStageModal()">update_stage</button>
280
+ <button class="act" onclick="showFieldModal()">update_field</button>
281
+ <button class="act" onclick="showNoteModal()">add_note</button>
282
+ </div>
283
+ </div>
284
+
285
+ <!-- Tool: Messaging -->
286
+ <div class="tool-section">
287
+ <div class="tool-section-head"><span class="dot" style="background:var(--violet)"></span><span style="color:var(--violet)">Messaging</span></div>
288
+ <div class="act-grid" id="msgGrid">
289
+ <button class="act" onclick="doMsg('greeting')">greeting</button>
290
+ <button class="act" onclick="doMsg('call')">call</button>
291
+ <span style="width:1px;height:24px;background:var(--b1)"></span>
292
+ <button class="act" onclick="doMsg('experience')">experience</button>
293
+ <button class="act" onclick="doMsg('home_time')">home time</button>
294
+ <button class="act" onclick="doMsg('pay')">pay</button>
295
+ <button class="act" onclick="doMsg('equipment')">equipment</button>
296
+ <button class="act" onclick="doMsg('route')">route</button>
297
+ <button class="act" onclick="doMsg('deal_breakers')">deal breakers</button>
298
+ <button class="act" onclick="doMsg('availability')">availability</button>
299
+ <button class="act" onclick="doMsg('violations')">violations</button>
300
+ <button class="act" onclick="doMsg('medical_card')">medical card</button>
301
+ <button class="act" onclick="doMsg('references')">references</button>
302
+ <span style="width:1px;height:24px;background:var(--b1)"></span>
303
+ <button class="act act-warn" onclick="showJobModal('pitch')">pitch job</button>
304
+ <button class="act act-warn" onclick="showJobModal('offer')">send offer</button>
305
+ <span style="width:1px;height:24px;background:var(--b1)"></span>
306
+ <button class="act" onclick="doMsg('negotiate_pay')">negotiate pay</button>
307
+ <button class="act" onclick="doMsg('negotiate_home_time')">negotiate home</button>
308
+ <button class="act" onclick="doMsg('signing_bonus')">signing bonus</button>
309
+ <button class="act" onclick="doMsg('address_concern')">address concern</button>
310
+ <span style="width:1px;height:24px;background:var(--b1)"></span>
311
+ <button class="act act-go" onclick="doTool('messaging','read_reply')">read_reply</button>
312
+ </div>
313
+ </div>
314
+
315
+ <!-- Tool: Approval + Workflow -->
316
+ <div style="display:flex;gap:12px">
317
+ <div class="tool-section" style="flex:1">
318
+ <div class="tool-section-head"><span class="dot" style="background:var(--orange)"></span><span style="color:var(--orange)">Approval</span></div>
319
+ <div class="act-grid">
320
+ <button class="act" onclick="showJobModal('request_approval')">request_approval</button>
321
+ <button class="act" onclick="doTool('approval','check_approval')">check_approval</button>
322
+ </div>
323
+ </div>
324
+ <div class="tool-section" style="flex:1">
325
+ <div class="tool-section-head"><span class="dot" style="background:var(--t3)"></span><span style="color:var(--t3)">Workflow</span></div>
326
+ <div class="act-grid">
327
+ <button class="act" onclick="doTool('workflow','wait')">wait</button>
328
+ <button class="act act-go" onclick="showStageModal('hired')">hire (finish)</button>
329
+ <button class="act act-no" onclick="doStage('lost')">reject (lost)</button>
330
+ </div>
331
+ </div>
332
+ </div>
333
+ </div>
334
+ <div class="right-col">
335
+ <div class="pane">
336
+ <div class="pane-head">Discovered Info</div>
337
+ <div id="infoList"><div class="info-empty">No info yet &mdash; send messages and read replies</div></div>
338
+ </div>
339
+ </div>
340
+ </div>
341
+ </div>
342
+
343
+ <!-- Job picker modal -->
344
+ <div class="modal-bg" id="modalBg">
345
+ <div class="modal">
346
+ <h3 id="modalTitle">Select job</h3>
347
+ <div id="modalJobs"></div>
348
+ <button class="modal-cancel" onclick="closeModal()">Cancel</button>
349
+ </div>
350
+ </div>
351
+
352
+ <!-- Stage modal -->
353
+ <div class="modal-bg" id="stageModalBg">
354
+ <div class="modal">
355
+ <h3>Update Pipeline Stage</h3>
356
+ <div id="stageModalBtns"></div>
357
+ <button class="modal-cancel" onclick="closeStageModal()">Cancel</button>
358
+ </div>
359
+ </div>
360
+
361
+ <!-- Field modal -->
362
+ <div class="modal-bg" id="fieldModalBg">
363
+ <div class="modal">
364
+ <h3>Update CRM Field</h3>
365
+ <select id="fieldSelect" class="stage-select" style="width:100%;margin-bottom:8px;padding:8px">
366
+ <option value="cdl_class">cdl_class</option>
367
+ <option value="years_experience">years_experience</option>
368
+ <option value="endorsements">endorsements</option>
369
+ <option value="location">location</option>
370
+ <option value="home_time_pref">home_time_pref</option>
371
+ <option value="pay_expectation">pay_expectation</option>
372
+ <option value="equipment_pref">equipment_pref</option>
373
+ <option value="route_pref">route_pref</option>
374
+ <option value="deal_breakers">deal_breakers</option>
375
+ <option value="availability">availability</option>
376
+ <option value="violations">violations</option>
377
+ <option value="medical_card">medical_card</option>
378
+ <option value="references">references</option>
379
+ <option value="matched_job">matched_job</option>
380
+ </select>
381
+ <input type="text" id="fieldValue" placeholder="Value..." />
382
+ <div class="modal-row">
383
+ <button class="modal-btn-go" onclick="submitField()">Save</button>
384
+ <button class="modal-cancel" onclick="closeFieldModal()" style="margin-top:0">Cancel</button>
385
+ </div>
386
+ </div>
387
+ </div>
388
+
389
+ <!-- Note modal -->
390
+ <div class="modal-bg" id="noteModalBg">
391
+ <div class="modal">
392
+ <h3>Add CRM Note</h3>
393
+ <input type="text" id="noteValue" placeholder="Note text..." />
394
+ <div class="modal-row">
395
+ <button class="modal-btn-go" onclick="submitNote()">Add</button>
396
+ <button class="modal-cancel" onclick="closeNoteModal()" style="margin-top:0">Cancel</button>
397
+ </div>
398
+ </div>
399
+ </div>
400
+
401
+ <!-- End screen -->
402
+ <div class="endscreen" id="endscreen">
403
+ <div class="end-card">
404
+ <div class="end-label" id="endLabel"></div>
405
+ <div class="end-title" id="endTitle"></div>
406
+ <div class="end-sub" id="endSub"></div>
407
+ <div class="end-grid">
408
+ <div class="end-stat"><div class="end-stat-val" id="erRew"></div><div class="end-stat-lbl">Reward</div></div>
409
+ <div class="end-stat"><div class="end-stat-val" id="erStep"></div><div class="end-stat-lbl">Steps</div></div>
410
+ <div class="end-stat"><div class="end-stat-val" id="erStage"></div><div class="end-stat-lbl">Final Stage</div></div>
411
+ </div>
412
+ <button class="btn btn-white" onclick="startGame()">Play again</button>
413
+ </div>
414
+ </div>
415
+
416
+ <script>
417
+ let ENV='',WS=null;
418
+ let S={obs:null,rew:0,done:false,jobs:[],stepCount:0};
419
+ let showHidden=false;
420
+
421
+ const STAGES=['lead','contacted','interested','approval_pending','offer_sent','hired'];
422
+ const FAIL_STAGES=['lost','ghosted'];
423
+
424
+ function toggleHidden(){
425
+ showHidden=!showHidden;
426
+ document.querySelectorAll('.hidden-info').forEach(el=>el.classList.toggle('off',!showHidden));
427
+ document.querySelector('.toggle-hidden').textContent=showHidden?'Hide hidden':'Show hidden';
428
+ }
429
+
430
+ function wsUrl(){
431
+ const base=document.getElementById('envUrl').value.replace(/\/$/,'');
432
+ return base.replace(/^http/,'ws')+'/ws';
433
+ }
434
+
435
+ function connectWS(){
436
+ return new Promise((resolve,reject)=>{
437
+ if(WS&&WS.readyState===WebSocket.OPEN){resolve();return}
438
+ if(WS)WS.close();
439
+ WS=new WebSocket(wsUrl());
440
+ WS.onopen=()=>resolve();
441
+ WS.onerror=()=>reject(new Error('WebSocket connection failed'));
442
+ WS.onmessage=(ev)=>{
443
+ const msg=JSON.parse(ev.data);
444
+ if(msg.type==='error'){
445
+ console.error('WS error:',msg.data);
446
+ if(pendingResolve){pendingResolve=null;}
447
+ return;
448
+ }
449
+ if(msg.type==='observation'&&pendingResolve){
450
+ const cb=pendingResolve;pendingResolve=null;
451
+ cb(msg.data);
452
+ }
453
+ };
454
+ WS.onclose=()=>{WS=null};
455
+ });
456
+ }
457
+
458
+ let pendingResolve=null;
459
+ function wsSend(msg){
460
+ return new Promise(resolve=>{
461
+ pendingResolve=resolve;
462
+ WS.send(JSON.stringify(msg));
463
+ });
464
+ }
465
+
466
+ async function startGame(){
467
+ document.getElementById('hero').style.display='none';
468
+ document.getElementById('game').classList.add('on');
469
+ document.getElementById('endscreen').classList.remove('on');
470
+ document.getElementById('tl').innerHTML='';
471
+ S={obs:null,rew:0,done:false,jobs:[],stepCount:0};
472
+ try{
473
+ await connectWS();
474
+ const d=await wsSend({type:'reset'});
475
+ handle(d,null);
476
+ }catch(e){
477
+ alert('Cannot reach server: '+e.message);
478
+ document.getElementById('hero').style.display='';
479
+ document.getElementById('game').classList.remove('on');
480
+ }
481
+ }
482
+
483
+ // --- Tool actions ---
484
+ async function doTool(tool,action,extra){
485
+ if(S.done||!WS)return;
486
+ const data={tool,action,...(extra||{})};
487
+ const d=await wsSend({type:'step',data});
488
+ handle(d,tool+'.'+action,data);
489
+ }
490
+
491
+ async function doMsg(topic,jobId){
492
+ const extra={topic};
493
+ if(jobId!==undefined)extra.job_id=jobId;
494
+ await doTool('messaging','send_message',extra);
495
+ }
496
+
497
+ async function doStage(stage){
498
+ await doTool('crm','update_stage',{stage});
499
+ }
500
+
501
+ // --- Handle response ---
502
+ function handle(d,label,actionData){
503
+ const o=d.observation,rw=d.reward||0;
504
+ S.obs=o; S.rew+=rw; S.done=d.done;
505
+ if(o.steps_taken!==undefined)S.stepCount=o.steps_taken;
506
+ else if(label)S.stepCount++;
507
+ render(o,rw,label,actionData);
508
+ if(d.done)setTimeout(()=>showEnd(o),500);
509
+ }
510
+
511
+ function render(o,rw,label,actionData){
512
+ // Driver info
513
+ document.getElementById('dName').textContent=o.driver_name;
514
+ document.getElementById('av').textContent=o.driver_name?o.driver_name[0]:'?';
515
+
516
+ // Stage
517
+ document.getElementById('uiStage').textContent=o.stage;
518
+ document.getElementById('uiStep').textContent=S.stepCount;
519
+
520
+ // Reward
521
+ const re=document.getElementById('uiRew');
522
+ re.textContent=(S.rew>=0?'+':'')+S.rew.toFixed(1);
523
+ re.style.color=S.rew>=0?'var(--green)':'var(--red)';
524
+
525
+ // Pending reply
526
+ document.getElementById('pendingBadge').style.display=o.pending_reply?'':'none';
527
+
528
+ // Pipeline
529
+ renderPipeline(o.stage);
530
+
531
+ // Jobs
532
+ if(o.jobs_summary){
533
+ const lines=o.jobs_summary.split('\n');
534
+ document.getElementById('jobsList').innerHTML=lines.map(l=>{
535
+ const fm=l.match(/\[(.+?)\]/);
536
+ const warn=fm?'<div class="job-warn">'+fm[1]+'</div>':'';
537
+ const parts=l.split(' \u2014 ');
538
+ const hd=parts[0]||'';
539
+ const det=parts[1]||'';
540
+ const im=hd.match(/^Job (\d+): (.+)/);
541
+ return '<div class="job"><span class="job-id">#'+((im&&im[1])||'?')+'</span><span class="job-co">'+((im&&im[2])||hd)+'</span><div class="job-det">'+det+'</div>'+warn+'</div>';
542
+ }).join('');
543
+ S.jobs=lines.map(l=>{const m=l.match(/^Job (\d+): (.+?) \u2014/);return m?{id:+m[1],label:'#'+m[1]+' '+m[2]}:null}).filter(Boolean);
544
+ }
545
+
546
+ // CRM
547
+ if(o.crm_summary){
548
+ const lines=o.crm_summary.split('\n');
549
+ let html='';
550
+ lines.forEach(l=>{
551
+ const fieldMatch=l.match(/^\s{2}(\w+):\s*(.+)/);
552
+ if(fieldMatch){
553
+ html+='<div class="crm-field"><span class="crm-key">'+fieldMatch[1]+'</span><span class="crm-val">'+fieldMatch[2]+'</span></div>';
554
+ } else if(l.startsWith('Name:')||l.startsWith('Stage:')){
555
+ html+='<div class="crm-field"><span class="crm-key">'+l.split(':')[0]+'</span><span class="crm-val">'+l.split(':').slice(1).join(':').trim()+'</span></div>';
556
+ } else if(l.trim()==='Fields: (none recorded)'){
557
+ html+='<div class="crm-field"><span class="crm-key" style="color:var(--t3)">no fields recorded</span></div>';
558
+ } else if(l.match(/^\s{2}-\s(.+)/)){
559
+ html+='<div class="crm-field"><span class="crm-key">note</span><span class="crm-val" style="font-style:italic">'+l.match(/^\s{2}-\s(.+)/)[1]+'</span></div>';
560
+ }
561
+ });
562
+ document.getElementById('crmList').innerHTML=html||'<div class="crm-empty">Empty CRM</div>';
563
+ }
564
+
565
+ // Discovered info
566
+ if(o.discovered_info){
567
+ const items=o.discovered_info.split('\n').filter(l=>l.trim());
568
+ document.getElementById('infoList').innerHTML=items.map(l=>{
569
+ const m=l.match(/^\[(.+?)\]\s*(.*)/);
570
+ if(m)return '<div class="info-item"><span class="info-cat">'+m[1]+'</span><br>'+m[2]+'</div>';
571
+ return '<div class="info-item">'+l+'</div>';
572
+ }).join('');
573
+ }
574
+
575
+ // Timeline
576
+ if(label){
577
+ const tl=document.getElementById('tl');
578
+ const rwClass=rw>0?'pos':rw<0?'neg':'zero';
579
+ const rwStr=rw>=0?'+'+rw.toFixed(1):rw.toFixed(1);
580
+ const dotColor=rw>0?'var(--green)':rw<0?'var(--red)':'var(--b2)';
581
+ const bodyClass=rw>0?'good':rw<0?'bad':'';
582
+
583
+ // Tool badge
584
+ let toolName='';
585
+ if(actionData&&actionData.tool)toolName=actionData.tool;
586
+ else if(label.includes('.'))toolName=label.split('.')[0];
587
+ const badgeClass={'crm':'badge-crm','messaging':'badge-messaging','approval':'badge-approval','workflow':'badge-workflow'}[toolName]||'badge-workflow';
588
+ const badge=toolName?'<span class="tl-tool-badge '+badgeClass+'">'+toolName+'</span>':'';
589
+
590
+ // Parse feedback for display
591
+ let feedbackText='';
592
+ if(o.feedback){
593
+ try{
594
+ const fb=JSON.parse(o.feedback);
595
+ if(fb.reply)feedbackText=fb.reply;
596
+ else if(fb.message)feedbackText=fb.message;
597
+ else if(fb.error)feedbackText='Error: '+fb.error;
598
+ else if(fb.result)feedbackText='Result: '+fb.result+(fb.reason?' ('+fb.reason+')':'');
599
+ else if(fb.approval_status)feedbackText='Approval: '+fb.approval_status;
600
+ else if(fb.stage)feedbackText='Stage updated: '+fb.stage;
601
+ else if(fb.field)feedbackText=fb.field+' = '+fb.value;
602
+ else if(fb.elapsed)feedbackText='Time elapsed: '+fb.elapsed;
603
+ else feedbackText=o.feedback;
604
+ }catch(e){feedbackText=o.feedback}
605
+ }
606
+
607
+ let html='<div class="tl-entry"><div class="tl-dot-col"><div class="tl-dot" style="background:'+dotColor+'"></div><div class="tl-line"></div></div><div class="tl-content"><div class="tl-head"><span class="tl-step mono">'+S.stepCount+'</span>'+badge+'<span class="tl-action">'+label+'</span><span class="tl-reward '+rwClass+'">'+rwStr+'</span></div>';
608
+ if(feedbackText)html+='<div class="tl-body '+bodyClass+'">'+feedbackText+'</div>';
609
+ html+='</div></div>';
610
+ tl.innerHTML+=html;
611
+ tl.scrollTop=tl.scrollHeight;
612
+ } else if(o.feedback){
613
+ const tl=document.getElementById('tl');
614
+ let feedbackText='';
615
+ try{
616
+ const fb=JSON.parse(o.feedback);
617
+ feedbackText='New episode: '+o.driver_name+' &mdash; '+fb.jobs+' jobs available';
618
+ }catch(e){feedbackText=o.feedback}
619
+ tl.innerHTML+='<div class="tl-entry"><div class="tl-dot-col"><div class="tl-dot" style="background:var(--violet)"></div><div class="tl-line"></div></div><div class="tl-content"><div class="tl-head"><span class="tl-step mono">0</span><span class="tl-action">Episode start</span></div><div class="tl-body">'+feedbackText+'</div></div></div>';
620
+ }
621
+ }
622
+
623
+ function renderPipeline(currentStage){
624
+ const el=document.getElementById('pipeStages');
625
+ const failStage=FAIL_STAGES.includes(currentStage)?currentStage:null;
626
+ const curIdx=STAGES.indexOf(currentStage);
627
+ let html='';
628
+ STAGES.forEach((s,i)=>{
629
+ let cls='pipe-stage';
630
+ if(failStage){cls+=' fail'}
631
+ else if(i<curIdx)cls+=' done';
632
+ else if(i===curIdx)cls+=' active';
633
+ html+='<span class="'+cls+'">'+s.replace('_',' ')+'</span>';
634
+ if(i<STAGES.length-1)html+='<span class="pipe-arrow">&rarr;</span>';
635
+ });
636
+ if(failStage){
637
+ html+='<span class="pipe-arrow">&rarr;</span><span class="pipe-stage fail">'+failStage+'</span>';
638
+ }
639
+ el.innerHTML=html;
640
+ }
641
+
642
+ // --- Modals ---
643
+ let pendingModalAction='';
644
+
645
+ function showJobModal(action){
646
+ pendingModalAction=action;
647
+ const titles={'pitch':'Pitch which job?','offer':'Send offer for which job?','request_approval':'Request approval for which job?'};
648
+ document.getElementById('modalTitle').textContent=titles[action]||'Select job';
649
+ document.getElementById('modalJobs').innerHTML=S.jobs.map(j=>'<button class="modal-job" onclick="selJob('+j.id+')">'+j.label+'</button>').join('');
650
+ document.getElementById('modalBg').classList.add('on');
651
+ }
652
+ function selJob(id){
653
+ closeModal();
654
+ if(pendingModalAction==='pitch'||pendingModalAction==='offer'){
655
+ doMsg(pendingModalAction,id);
656
+ } else if(pendingModalAction==='request_approval'){
657
+ doTool('approval','request_approval',{job_id:id});
658
+ }
659
+ }
660
+ function closeModal(){document.getElementById('modalBg').classList.remove('on')}
661
+
662
+ function showStageModal(preselect){
663
+ const stages=['contacted','interested','approval_pending','offer_sent','hired','lost'];
664
+ document.getElementById('stageModalBtns').innerHTML=stages.map(s=>{
665
+ const cls=s==='hired'?'modal-job" style="border-color:var(--green);color:var(--green)':s==='lost'?'modal-job" style="border-color:var(--red);color:var(--red)':'modal-job';
666
+ return '<button class="'+cls+'" onclick="doStage(\''+s+'\');closeStageModal()">'+s.replace('_',' ')+'</button>';
667
+ }).join('');
668
+ document.getElementById('stageModalBg').classList.add('on');
669
+ }
670
+ function closeStageModal(){document.getElementById('stageModalBg').classList.remove('on')}
671
+
672
+ function showFieldModal(){document.getElementById('fieldModalBg').classList.add('on');document.getElementById('fieldValue').value='';document.getElementById('fieldValue').focus()}
673
+ function closeFieldModal(){document.getElementById('fieldModalBg').classList.remove('on')}
674
+ function submitField(){
675
+ const f=document.getElementById('fieldSelect').value;
676
+ const v=document.getElementById('fieldValue').value;
677
+ if(!v)return;
678
+ closeFieldModal();
679
+ doTool('crm','update_field',{field:f,value:v});
680
+ }
681
+
682
+ function showNoteModal(){document.getElementById('noteModalBg').classList.add('on');document.getElementById('noteValue').value='';document.getElementById('noteValue').focus()}
683
+ function closeNoteModal(){document.getElementById('noteModalBg').classList.remove('on')}
684
+ function submitNote(){
685
+ const v=document.getElementById('noteValue').value;
686
+ if(!v)return;
687
+ closeNoteModal();
688
+ doTool('crm','add_note',{value:v});
689
+ }
690
+
691
+ // --- End screen ---
692
+ function showEnd(o){
693
+ const e=document.getElementById('endscreen');e.classList.add('on');
694
+ const win=o.stage==='hired';
695
+ document.getElementById('endLabel').textContent=win?'DRIVER HIRED':'EPISODE ENDED';
696
+ document.getElementById('endLabel').style.color=win?'var(--green)':'var(--red)';
697
+ document.getElementById('endTitle').textContent=win?'Placement complete':o.stage==='ghosted'?'Driver ghosted':'Failed';
698
+ document.getElementById('endTitle').style.color=win?'var(--green)':'var(--t1)';
699
+
700
+ let subText='';
701
+ if(o.feedback){
702
+ try{
703
+ const fb=JSON.parse(o.feedback);
704
+ if(fb.reason)subText=fb.reason.replace(/_/g,' ');
705
+ if(fb.result)subText=fb.result.replace(/_/g,' ')+(subText?' &mdash; '+subText:'');
706
+ if(fb.score)subText+=' (fit score: '+fb.score+')';
707
+ if(fb.crm_bonus)subText+=' CRM bonus: +'+fb.crm_bonus;
708
+ }catch(e){subText=o.feedback}
709
+ }
710
+ document.getElementById('endSub').innerHTML=subText;
711
+
712
+ const rv=document.getElementById('erRew');
713
+ rv.textContent=(S.rew>=0?'+':'')+S.rew.toFixed(1);
714
+ rv.style.color=S.rew>=0?'var(--green)':'var(--red)';
715
+ document.getElementById('erStep').textContent=S.stepCount;
716
+ document.getElementById('erStage').textContent=o.stage;
717
+ }
718
+
719
+ // Enter to submit in modals
720
+ document.getElementById('fieldValue').addEventListener('keydown',e=>{if(e.key==='Enter')submitField()});
721
+ document.getElementById('noteValue').addEventListener('keydown',e=>{if(e.key==='Enter')submitNote()});
722
+ </script>
723
+ </body>
724
+ </html>
eval_trained.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate a trained model against the recruiting environment."""
2
+
3
+ import argparse
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
+ from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
8
+
9
+ SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
10
+
11
+ You have 4 tools:
12
+
13
+ ## crm
14
+ - read_candidate: Read the current CRM record
15
+ - update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
16
+ - update_field: Record info (field + value)
17
+ - add_note: Add a free-text note
18
+
19
+ ## messaging
20
+ - send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
21
+ - read_reply: Read the driver's response
22
+
23
+ ## approval
24
+ - request_approval: Request approval for a job (needs job_id)
25
+ - check_approval: Check approval status
26
+
27
+ ## workflow
28
+ - wait: Advance time (needed for approval processing)
29
+
30
+ ## Rules
31
+ - Must read CRM before messaging
32
+ - Must read_reply before sending another message
33
+ - Must request_approval and wait before sending offer
34
+ - Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
35
+ - Record important info in CRM with update_field
36
+
37
+ Respond with ONLY JSON:
38
+ {"tool": "crm", "action": "read_candidate"}
39
+ {"tool": "messaging", "action": "send_message", "topic": "experience"}
40
+ {"tool": "messaging", "action": "read_reply"}
41
+ {"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
42
+ {"tool": "crm", "action": "update_stage", "stage": "contacted"}
43
+ {"tool": "approval", "action": "request_approval", "job_id": 2}
44
+ {"tool": "workflow", "action": "wait"}
45
+ {"tool": "approval", "action": "check_approval"}
46
+ {"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
47
+ {"tool": "crm", "action": "update_stage", "stage": "hired"}"""
48
+
49
+
50
+ def format_observation(obs):
51
+ parts = [f"Driver: {obs.driver_name}"]
52
+ if obs.crm_summary:
53
+ parts.append(f"CRM:\n{obs.crm_summary}")
54
+ if obs.jobs_summary:
55
+ parts.append(f"Jobs:\n{obs.jobs_summary}")
56
+ if obs.discovered_info:
57
+ parts.append(f"Discovered:\n{obs.discovered_info}")
58
+ status = f"Stage: {obs.stage}"
59
+ if obs.pending_reply:
60
+ status += " | PENDING REPLY"
61
+ parts.append(status)
62
+ if obs.feedback:
63
+ parts.append(f"Result: {obs.feedback}")
64
+ return "\n".join(parts)
65
+
66
+
67
+ def parse_action(text):
68
+ text = text.strip()
69
+ if "```" in text:
70
+ for part in text.split("```"):
71
+ part = part.strip()
72
+ if part.startswith("json"):
73
+ part = part[4:].strip()
74
+ if part.startswith("{"):
75
+ text = part
76
+ break
77
+ try:
78
+ data = json.loads(text)
79
+ if isinstance(data, list):
80
+ data = data[0] if data else {}
81
+ if isinstance(data, dict) and "tool" in data and "action" in data:
82
+ return RecruitopenenvAction(
83
+ tool=data["tool"],
84
+ action=data["action"],
85
+ topic=data.get("topic", ""),
86
+ job_id=data.get("job_id", -1),
87
+ stage=data.get("stage", ""),
88
+ field=data.get("field", ""),
89
+ value=data.get("value", ""),
90
+ )
91
+ except (json.JSONDecodeError, KeyError, IndexError):
92
+ pass
93
+
94
+ text_lower = text.lower()
95
+ if "read_candidate" in text_lower:
96
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
97
+ if "read_reply" in text_lower:
98
+ return RecruitopenenvAction(tool="messaging", action="read_reply")
99
+ if "check_approval" in text_lower:
100
+ return RecruitopenenvAction(tool="approval", action="check_approval")
101
+ if "wait" in text_lower:
102
+ return RecruitopenenvAction(tool="workflow", action="wait")
103
+
104
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
105
+
106
+
107
+ def generate(model, tokenizer, messages, device):
108
+ prompt = tokenizer.apply_chat_template(
109
+ messages, add_generation_prompt=True, tokenize=False
110
+ )
111
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
112
+ with torch.no_grad():
113
+ outputs = model.generate(
114
+ **inputs,
115
+ max_new_tokens=128,
116
+ temperature=0.1,
117
+ do_sample=True,
118
+ pad_token_id=tokenizer.eos_token_id,
119
+ )
120
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
121
+ return tokenizer.decode(new_tokens, skip_special_tokens=True)
122
+
123
+
124
+ def main():
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument("--model", default="./recruit-grpo-output", help="Path to trained model")
127
+ parser.add_argument("--base-model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Base model for comparison")
128
+ parser.add_argument("--env-url", default="http://localhost:8001")
129
+ parser.add_argument("--num-episodes", type=int, default=20)
130
+ parser.add_argument("--compare", action="store_true", help="Also run base model for comparison")
131
+ args = parser.parse_args()
132
+
133
+ device = "cuda" if torch.cuda.is_available() else "cpu"
134
+
135
+ models_to_eval = [("TRAINED", args.model)]
136
+ if args.compare:
137
+ models_to_eval.append(("BASE", args.base_model))
138
+
139
+ for label, model_path in models_to_eval:
140
+ print(f"\n{'='*50}")
141
+ print(f"Evaluating: {label} ({model_path})")
142
+ print(f"{'='*50}")
143
+
144
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
145
+ model = AutoModelForCausalLM.from_pretrained(
146
+ model_path, torch_dtype=torch.float16, device_map="auto"
147
+ )
148
+
149
+ rewards = []
150
+ successes = 0
151
+ total_steps = 0
152
+
153
+ with RecruitopenenvEnv(base_url=args.env_url) as env:
154
+ for ep in range(args.num_episodes):
155
+ result = env.reset()
156
+ obs = result.observation
157
+ ep_reward = 0.0
158
+ steps = 0
159
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
160
+
161
+ while not result.done and steps < 100:
162
+ obs_text = format_observation(obs)
163
+ messages.append({"role": "user", "content": obs_text})
164
+
165
+ response = generate(model, tokenizer, messages, device)
166
+ messages.append({"role": "assistant", "content": response})
167
+
168
+ action = parse_action(response)
169
+ result = env.step(action)
170
+ obs = result.observation
171
+ ep_reward += result.reward
172
+ steps += 1
173
+
174
+ print(f" Step {steps}: {action.tool}.{action.action}"
175
+ f"{'(' + action.topic + ')' if action.topic else ''}"
176
+ f"{'[job=' + str(action.job_id) + ']' if action.job_id >= 0 else ''}"
177
+ f" -> reward={result.reward:.1f}")
178
+
179
+ rewards.append(ep_reward)
180
+ total_steps += steps
181
+ hired = obs.stage == "hired"
182
+ if hired:
183
+ successes += 1
184
+
185
+ print(f"Episode {ep+1}: reward={ep_reward:.1f}, steps={steps}, "
186
+ f"{'HIRED' if hired else 'FAIL (' + obs.stage + ')'}")
187
+ print()
188
+
189
+ avg_reward = sum(rewards) / len(rewards)
190
+ avg_steps = total_steps / args.num_episodes
191
+
192
+ print(f"\n{'='*40}")
193
+ print(f" {label} RESULTS")
194
+ print(f"{'='*40}")
195
+ print(f"Model: {model_path}")
196
+ print(f"Episodes: {args.num_episodes}")
197
+ print(f"Avg reward: {avg_reward:.2f}")
198
+ print(f"Min reward: {min(rewards):.2f}")
199
+ print(f"Max reward: {max(rewards):.2f}")
200
+ print(f"Hire rate: {successes}/{args.num_episodes} ({100*successes/args.num_episodes:.1f}%)")
201
+ print(f"Avg steps/episode: {avg_steps:.1f}")
202
+ print(f"{'='*40}")
203
+
204
+ del model
205
+ torch.cuda.empty_cache()
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()
models.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the Driver Recruit Environment.
3
+
4
+ Tool-based action interface for long-horizon recruiting pipeline.
5
+ Agent uses CRM, messaging, approval, and workflow tools.
6
+ """
7
+
8
+ from pydantic import Field
9
+
10
+ from openenv.core.env_server.types import Action, Observation
11
+
12
+
13
+ class RecruitopenenvAction(Action):
14
+ """Tool-based action the agent takes."""
15
+
16
+ tool: str = Field(
17
+ ...,
18
+ description="Tool: crm, messaging, approval, workflow",
19
+ )
20
+ action: str = Field(
21
+ ...,
22
+ description=(
23
+ "Action within tool. "
24
+ "crm: read_candidate, update_stage, update_field, add_note. "
25
+ "messaging: send_message, read_reply. "
26
+ "approval: request_approval, check_approval. "
27
+ "workflow: wait."
28
+ ),
29
+ )
30
+ topic: str = Field(
31
+ default="",
32
+ description=(
33
+ "Message topic for messaging.send_message: "
34
+ "greeting, call, experience, home_time, pay, equipment, route, "
35
+ "deal_breakers, availability, violations, medical_card, references, "
36
+ "pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern"
37
+ ),
38
+ )
39
+ job_id: int = Field(
40
+ default=-1,
41
+ description="Job index (0-5). Used with pitch, offer, request_approval.",
42
+ )
43
+ stage: str = Field(
44
+ default="",
45
+ description="Target stage for crm.update_stage: contacted, interested, approval_pending, offer_sent, hired, lost",
46
+ )
47
+ field: str = Field(
48
+ default="",
49
+ description="CRM field for crm.update_field",
50
+ )
51
+ value: str = Field(
52
+ default="",
53
+ description="Value for crm.update_field or text for crm.add_note",
54
+ )
55
+
56
+
57
+ class RecruitopenenvObservation(Observation):
58
+ """What the agent sees after each action."""
59
+
60
+ driver_name: str = Field(default="", description="Driver's name")
61
+ crm_summary: str = Field(default="", description="CRM record (empty until read_candidate)")
62
+ jobs_summary: str = Field(default="", description="Available job listings")
63
+ discovered_info: str = Field(default="", description="Info discovered through conversation")
64
+
65
+ stage: str = Field(default="lead", description="Current pipeline stage")
66
+ feedback: str = Field(default="", description="API response from last action")
67
+ pending_reply: bool = Field(default=False, description="Whether an unread message is waiting")
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: recruitopenenv
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
openenv_recruitopenenv.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-recruitopenenv
3
+ Version: 0.1.0
4
+ Summary: Recruitopenenv environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.0
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_recruitopenenv.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ __init__.py
3
+ client.py
4
+ models.py
5
+ pyproject.toml
6
+ ./__init__.py
7
+ ./client.py
8
+ ./models.py
9
+ openenv_recruitopenenv.egg-info/PKG-INFO
10
+ openenv_recruitopenenv.egg-info/SOURCES.txt
11
+ openenv_recruitopenenv.egg-info/dependency_links.txt
12
+ openenv_recruitopenenv.egg-info/entry_points.txt
13
+ openenv_recruitopenenv.egg-info/requires.txt
14
+ openenv_recruitopenenv.egg-info/top_level.txt
15
+ server/__init__.py
16
+ server/app.py
17
+ server/recruitopenenv_environment.py
openenv_recruitopenenv.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_recruitopenenv.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = recruitopenenv.server.app:main
openenv_recruitopenenv.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
openenv_recruitopenenv.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ recruitopenenv
play.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interactive CLI to play the recruiting environment manually."""
2
+
3
+ import json
4
+ import requests
5
+
6
+ BASE_URL = "http://localhost:8000"
7
+
8
+ SHORTCUTS = {
9
+ "r": '{"tool":"crm","action":"read_candidate"}',
10
+ "rr": '{"tool":"messaging","action":"read_reply"}',
11
+ "w": '{"tool":"workflow","action":"wait"}',
12
+ "ca": '{"tool":"approval","action":"check_approval"}',
13
+ "hi": '{"tool":"crm","action":"update_stage","stage":"hired"}',
14
+ "lost": '{"tool":"crm","action":"update_stage","stage":"lost"}',
15
+ }
16
+
17
+ TOPIC_SHORTCUTS = {
18
+ "g": "greeting", "c": "call", "exp": "experience", "ht": "home_time",
19
+ "pay": "pay", "eq": "equipment", "rt": "route", "db": "deal_breakers",
20
+ "av": "availability", "vio": "violations", "med": "medical_card",
21
+ "ref": "references", "pitch": "pitch", "offer": "offer",
22
+ "np": "negotiate_pay", "nht": "negotiate_home_time",
23
+ "sb": "signing_bonus", "ac": "address_concern",
24
+ }
25
+
26
+ def print_obs(obs, reward):
27
+ print(f"\n{'='*60}")
28
+ print(f"Driver: {obs['driver_name']}")
29
+ if obs.get('crm_summary'):
30
+ print(f"\nCRM:\n{obs['crm_summary']}")
31
+ if obs.get('jobs_summary'):
32
+ print(f"\nJobs:\n{obs['jobs_summary']}")
33
+ if obs.get('discovered_info'):
34
+ print(f"\nDiscovered:\n{obs['discovered_info']}")
35
+ status = f"Stage: {obs['stage']}"
36
+ if obs.get('pending_reply'):
37
+ status += " | PENDING REPLY"
38
+ print(f"\n{status}")
39
+ print(f"Reward this step: {reward}")
40
+ if obs.get('feedback'):
41
+ try:
42
+ fb = json.loads(obs['feedback'])
43
+ print(f"Response: {json.dumps(fb, indent=2)}")
44
+ except (json.JSONDecodeError, TypeError):
45
+ print(f"Response: {obs['feedback']}")
46
+
47
+ def print_help():
48
+ print("\nShortcuts:")
49
+ print(" r = read CRM")
50
+ print(" rr = read reply")
51
+ print(" w = wait")
52
+ print(" ca = check approval")
53
+ print(" hi = update stage to hired")
54
+ print(" lost = update stage to lost")
55
+ print("\nSend message: s <topic> e.g. s g, s exp, s offer")
56
+ print(" Topics: g=greeting c=call exp=experience ht=home_time pay eq=equipment")
57
+ print(" rt=route db=deal_breakers av=availability vio=violations med=medical_card")
58
+ print(" ref=references pitch offer np=negotiate_pay nht=negotiate_home_time")
59
+ print(" sb=signing_bonus ac=address_concern")
60
+ print("\nWith job_id: s pitch 2 s offer 3")
61
+ print("\nUpdate stage: st <stage> e.g. st contacted")
62
+ print("Update field: f <field> <value> e.g. f cdl_class A")
63
+ print("Add note: n <text> e.g. n Driver prefers OTR")
64
+ print("Request approval: ra <job_id> e.g. ra 2")
65
+ print("\nOr paste raw JSON: {\"tool\":\"crm\",\"action\":\"read_candidate\"}")
66
+ print(" q = quit, h = help, reset = new episode")
67
+
68
+ def parse_input(user_input):
69
+ user_input = user_input.strip()
70
+ if not user_input:
71
+ return None
72
+
73
+ # Shortcuts
74
+ if user_input in SHORTCUTS:
75
+ return json.loads(SHORTCUTS[user_input])
76
+
77
+ # Raw JSON
78
+ if user_input.startswith("{"):
79
+ return json.loads(user_input)
80
+
81
+ parts = user_input.split(None, 2)
82
+ cmd = parts[0]
83
+
84
+ # Send message: s <topic> [job_id]
85
+ if cmd == "s" and len(parts) >= 2:
86
+ topic = TOPIC_SHORTCUTS.get(parts[1], parts[1])
87
+ action = {"tool": "messaging", "action": "send_message", "topic": topic}
88
+ if len(parts) >= 3:
89
+ action["job_id"] = int(parts[2])
90
+ return action
91
+
92
+ # Update stage: st <stage>
93
+ if cmd == "st" and len(parts) >= 2:
94
+ return {"tool": "crm", "action": "update_stage", "stage": parts[1]}
95
+
96
+ # Update field: f <field> <value>
97
+ if cmd == "f" and len(parts) >= 3:
98
+ return {"tool": "crm", "action": "update_field", "field": parts[1], "value": parts[2]}
99
+
100
+ # Add note: n <text>
101
+ if cmd == "n" and len(parts) >= 2:
102
+ return {"tool": "crm", "action": "add_note", "value": " ".join(parts[1:])}
103
+
104
+ # Request approval: ra <job_id>
105
+ if cmd == "ra" and len(parts) >= 2:
106
+ return {"tool": "approval", "action": "request_approval", "job_id": int(parts[1])}
107
+
108
+ print(f"Unknown command: {user_input}. Type 'h' for help.")
109
+ return None
110
+
111
+ def main():
112
+ session = requests.Session()
113
+ total_reward = 0.0
114
+
115
+ print("\n🚛 DRIVER RECRUITING ENVIRONMENT — INTERACTIVE MODE")
116
+ print_help()
117
+
118
+ # Reset
119
+ resp = session.post(f"{BASE_URL}/reset", json={})
120
+ data = resp.json()
121
+ obs = data["observation"]
122
+ print_obs(obs, 0)
123
+
124
+ while True:
125
+ try:
126
+ user_input = input("\n> ").strip()
127
+ except (EOFError, KeyboardInterrupt):
128
+ print("\nBye!")
129
+ break
130
+
131
+ if user_input == "q":
132
+ break
133
+ if user_input == "h":
134
+ print_help()
135
+ continue
136
+ if user_input == "reset":
137
+ resp = session.post(f"{BASE_URL}/reset", json={})
138
+ data = resp.json()
139
+ obs = data["observation"]
140
+ total_reward = 0.0
141
+ print_obs(obs, 0)
142
+ continue
143
+
144
+ action = parse_input(user_input)
145
+ if action is None:
146
+ continue
147
+
148
+ print(f"→ {action['tool']}.{action['action']}"
149
+ + (f"({action.get('topic', '')})" if action.get('topic') else "")
150
+ + (f"[job={action['job_id']}]" if action.get('job_id', -1) >= 0 else "")
151
+ + (f"({action.get('stage', '')})" if action.get('stage') else "")
152
+ + (f"({action.get('field', '')}={action.get('value', '')})" if action.get('field') else ""))
153
+
154
+ resp = session.post(f"{BASE_URL}/step", json=action)
155
+ data = resp.json()
156
+ obs = data["observation"]
157
+ reward = data["reward"]
158
+ done = data["done"]
159
+ total_reward += reward
160
+
161
+ print_obs(obs, reward)
162
+ print(f"Total reward: {total_reward:.1f}")
163
+
164
+ if done:
165
+ print(f"\n{'='*60}")
166
+ print(f"EPISODE OVER — Final stage: {obs['stage']} | Total reward: {total_reward:.1f}")
167
+ print(f"{'='*60}")
168
+ print("Type 'reset' for a new episode or 'q' to quit.")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-recruitopenenv"
13
+ version = "0.1.0"
14
+ description = "Recruitopenenv environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]==0.2.1",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m recruitopenenv.server.app
40
+ server = "recruitopenenv.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["recruitopenenv", "recruitopenenv.server"]
45
+ package-dir = { "recruitopenenv" = ".", "recruitopenenv.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Recruitopenenv environment server components."""
8
+
9
+ from .recruitopenenv_environment import RecruitopenenvEnvironment
10
+
11
+ __all__ = ["RecruitopenenvEnvironment"]
server/app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Recruitopenenv Environment.
9
+
10
+ This module creates an HTTP server that exposes the RecruitopenenvEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ import os
32
+
33
+ from fastapi.middleware.cors import CORSMiddleware
34
+ from fastapi.responses import FileResponse
35
+
36
+ try:
37
+ from openenv.core.env_server.http_server import create_app
38
+ except Exception as e: # pragma: no cover
39
+ raise ImportError(
40
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
41
+ ) from e
42
+
43
+ # Import from local models.py (PYTHONPATH includes /app/env in Docker)
44
+ from models import RecruitopenenvAction, RecruitopenenvObservation
45
+ from .recruitopenenv_environment import RecruitopenenvEnvironment
46
+
47
+
48
+ # Create the app with web interface and README integration
49
+ app = create_app(
50
+ RecruitopenenvEnvironment,
51
+ RecruitopenenvAction,
52
+ RecruitopenenvObservation,
53
+ env_name="recruitopenenv",
54
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
55
+ )
56
+
57
+ # CORS for demo page
58
+ app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=["*"],
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ # Serve the demo page
66
+ _DEMO_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "demo")
67
+
68
+
69
+ @app.get("/demo", include_in_schema=False)
70
+ async def demo_page():
71
+ return FileResponse(os.path.join(_DEMO_DIR, "index.html"))
72
+
73
+
74
+ def main(host: str = "0.0.0.0", port: int = 8000):
75
+ """
76
+ Entry point for direct execution via uv run or python -m.
77
+
78
+ This function enables running the server without Docker:
79
+ uv run --project . server
80
+ uv run --project . server --port 8001
81
+ python -m recruitopenenv.server.app
82
+
83
+ Args:
84
+ host: Host address to bind to (default: "0.0.0.0")
85
+ port: Port number to listen on (default: 8000)
86
+
87
+ For production deployments, consider using uvicorn directly with
88
+ multiple workers:
89
+ uvicorn recruitopenenv.server.app:app --workers 4
90
+ """
91
+ import uvicorn
92
+
93
+ uvicorn.run(app, host=host, port=port)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ import argparse
98
+
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument("--port", type=int, default=8000)
101
+ args = parser.parse_args()
102
+ main(port=args.port)
server/recruitopenenv_environment.py ADDED
@@ -0,0 +1,1422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Driver Recruit Environment — Tool-based Long-Horizon.
3
+
4
+ Agent interacts through 4 tools: CRM, messaging, approval, workflow.
5
+ Each recruiting interaction requires multiple tool calls, creating
6
+ naturally long episodes (40-70 steps).
7
+
8
+ Pipeline: lead → contacted → interested → approval_pending → offer_sent → hired
9
+ Terminal failures: lost, ghosted
10
+ """
11
+
12
+ import json
13
+ import random
14
+ from uuid import uuid4
15
+
16
+ from openenv.core.env_server.interfaces import Environment
17
+ from openenv.core.env_server.types import State
18
+
19
+ from models import RecruitopenenvAction, RecruitopenenvObservation
20
+
21
+ # --- Constants ---
22
+
23
+ FIRST_NAMES = [
24
+ "Mike", "James", "Robert", "John", "David", "Carlos", "Marcus",
25
+ "Sarah", "Maria", "Linda", "Patricia", "Jessica", "Angela", "Rosa",
26
+ "Travis", "Derek", "Kevin", "Brandon", "Tyler", "Dustin", "Ray",
27
+ ]
28
+ LAST_NAMES = [
29
+ "Johnson", "Smith", "Williams", "Garcia", "Martinez", "Brown",
30
+ "Davis", "Rodriguez", "Wilson", "Taylor", "Thomas", "Moore",
31
+ "Jackson", "White", "Harris", "Clark", "Lewis", "Young",
32
+ ]
33
+ LOCATIONS = [
34
+ "Dallas TX", "Atlanta GA", "Chicago IL", "Denver CO", "Phoenix AZ",
35
+ "Memphis TN", "Louisville KY", "Nashville TN", "Indianapolis IN",
36
+ "Columbus OH", "Jacksonville FL", "Charlotte NC", "Kansas City MO",
37
+ ]
38
+ COMPANIES = [
39
+ "Werner Enterprises", "Swift Transport", "Schneider National",
40
+ "J.B. Hunt", "KLLM Transport", "Heartland Express",
41
+ "Covenant Logistics", "USA Truck", "Marten Transport",
42
+ "Prime Inc", "CR England", "Western Express",
43
+ ]
44
+
45
+ CDL_CLASSES = ["A", "B"]
46
+ ENDORSEMENTS_ALL = ["H", "N", "T", "TWIC"]
47
+ HOME_TIMES = ["daily", "weekends", "weekly", "biweekly"]
48
+ ROUTE_TYPES = ["OTR", "regional", "local", "dedicated"]
49
+ EQUIPMENT_TYPES = ["dry_van", "flatbed", "reefer", "tanker"]
50
+ CONTACT_METHODS = ["text", "call"]
51
+ DEAL_BREAKERS_ALL = [
52
+ "touch_freight", "forced_dispatch", "team_driving",
53
+ "northeast", "hazmat_no_premium", "no_benefits",
54
+ ]
55
+
56
+ PERSONALITY_PARAMS = {
57
+ "chatty": {"initial_trust": 0.80, "decay": 0.02, "reveal_breakers": "all"},
58
+ "professional": {"initial_trust": 0.70, "decay": 0.025, "reveal_breakers": "all"},
59
+ "impatient": {"initial_trust": 0.60, "decay": 0.04, "reveal_breakers": "partial"},
60
+ "suspicious": {"initial_trust": 0.55, "decay": 0.03, "reveal_breakers": "all_if_trusted"},
61
+ }
62
+
63
+ AVAILABILITIES = ["immediately", "2_weeks", "1_month", "negotiable"]
64
+ VIOLATION_LEVELS = ["clean", "minor", "major"]
65
+ MEDICAL_CARD_STATUS = ["valid", "expiring_soon", "expired"]
66
+ REFERENCE_QUALITY = ["strong", "mixed", "none"]
67
+
68
+ MAX_STEPS = 100
69
+
70
+ VALID_TOOL_ACTIONS = {
71
+ "crm": {"read_candidate", "update_stage", "update_field", "add_note"},
72
+ "messaging": {"send_message", "read_reply"},
73
+ "approval": {"request_approval", "check_approval"},
74
+ "workflow": {"wait"},
75
+ }
76
+
77
+ VALID_TOPICS = {
78
+ "greeting", "call",
79
+ "experience", "home_time", "pay", "equipment", "route", "deal_breakers",
80
+ "availability", "violations", "medical_card", "references",
81
+ "pitch", "offer",
82
+ "negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern",
83
+ }
84
+
85
+ STAGE_ORDER = ["lead", "contacted", "interested", "approval_pending", "offer_sent", "hired"]
86
+ ALL_STAGES = set(STAGE_ORDER) | {"lost", "ghosted"}
87
+
88
+ SCREENING_TOPICS = {
89
+ "experience", "home_time", "pay", "equipment", "route", "deal_breakers",
90
+ "availability", "violations", "medical_card", "references",
91
+ }
92
+
93
+ VALID_CRM_FIELDS = {
94
+ "cdl_class", "years_experience", "endorsements", "location",
95
+ "home_time_pref", "pay_expectation", "equipment_pref", "route_pref",
96
+ "deal_breakers", "availability", "violations", "medical_card", "references",
97
+ "matched_job",
98
+ }
99
+
100
+
101
+ # --- Data generation ---
102
+
103
+
104
+ def generate_driver():
105
+ personality = random.choices(
106
+ ["chatty", "professional", "impatient", "suspicious"],
107
+ weights=[25, 35, 20, 20],
108
+ )[0]
109
+ params = PERSONALITY_PARAMS[personality]
110
+
111
+ cdl = random.choices(CDL_CLASSES, weights=[75, 25])[0]
112
+ exp = random.randint(1, 20)
113
+
114
+ endorsements = [e for e in ENDORSEMENTS_ALL if random.random() < 0.10 + exp * 0.02]
115
+
116
+ equip_opts = ["dry_van", "flatbed", "reefer"]
117
+ if "N" in endorsements:
118
+ equip_opts.append("tanker")
119
+ equipment_pref = random.choice(equip_opts)
120
+
121
+ n_breakers = random.choices([1, 2, 3], weights=[30, 50, 20])[0]
122
+ deal_breakers = random.sample(DEAL_BREAKERS_ALL, n_breakers)
123
+
124
+ return {
125
+ "name": f"{random.choice(FIRST_NAMES)} {random.choice(LAST_NAMES)}",
126
+ "cdl_class": cdl,
127
+ "endorsements": endorsements,
128
+ "experience_years": exp,
129
+ "location": random.choice(LOCATIONS),
130
+ "preferred_contact": random.choice(CONTACT_METHODS),
131
+ "personality": personality,
132
+ "trust": params["initial_trust"],
133
+ "decay": params["decay"],
134
+ "home_time_pref": random.choices(HOME_TIMES, weights=[15, 30, 30, 25])[0],
135
+ "min_cpm": round(random.uniform(0.48, 0.78), 2),
136
+ "equipment_pref": equipment_pref,
137
+ "route_pref": random.choices(ROUTE_TYPES, weights=[20, 30, 30, 20])[0],
138
+ "deal_breakers": deal_breakers,
139
+ "availability": random.choices(AVAILABILITIES, weights=[30, 35, 25, 10])[0],
140
+ "violations": random.choices(VIOLATION_LEVELS, weights=[60, 30, 10])[0],
141
+ "medical_card": random.choices(MEDICAL_CARD_STATUS, weights=[70, 20, 10])[0],
142
+ "references": random.choices(REFERENCE_QUALITY, weights=[40, 40, 20])[0],
143
+ }
144
+
145
+
146
+ def generate_jobs(driver):
147
+ """Generate 6 jobs: 1-2 good, 1-2 traps, 2-3 bad."""
148
+ jobs = []
149
+ if random.random() > 0.2:
150
+ jobs.append(_make_good_job(driver, 0))
151
+ else:
152
+ jobs.append(_make_trap_job(driver, 0))
153
+ jobs.append(_make_trap_job(driver, 1))
154
+ jobs.append(_make_partial_job(driver, 2))
155
+
156
+ bad_cdl = "B" if driver["cdl_class"] == "A" else "A"
157
+ jobs.append({
158
+ "job_id": 3, "company": random.choice(COMPANIES),
159
+ "required_cdl": bad_cdl, "required_endorsements": [],
160
+ "min_experience": random.randint(1, 5),
161
+ "route_type": random.choice(ROUTE_TYPES),
162
+ "home_time": random.choice(HOME_TIMES),
163
+ "pay_cpm": round(random.uniform(0.50, 0.85), 2),
164
+ "equipment": random.choice(EQUIPMENT_TYPES),
165
+ "has_touch_freight": random.random() < 0.3,
166
+ "forced_dispatch": random.random() < 0.3,
167
+ "team_driving": False, "northeast_routes": False,
168
+ "hazmat_premium": False,
169
+ "benefits": random.choice(["none", "basic", "good"]),
170
+ "location": random.choice(LOCATIONS),
171
+ "start_urgency": random.choice(["immediate", "flexible"]),
172
+ "requires_clean_record": random.random() < 0.3,
173
+ "requires_medical": True,
174
+ })
175
+
176
+ jobs.append({
177
+ "job_id": 4, "company": random.choice(COMPANIES),
178
+ "required_cdl": driver["cdl_class"],
179
+ "required_endorsements": ["H", "T"],
180
+ "min_experience": driver["experience_years"] + random.randint(5, 10),
181
+ "route_type": random.choice(ROUTE_TYPES),
182
+ "home_time": random.choice(HOME_TIMES),
183
+ "pay_cpm": round(random.uniform(0.70, 0.90), 2),
184
+ "equipment": random.choice(EQUIPMENT_TYPES),
185
+ "has_touch_freight": False, "forced_dispatch": False,
186
+ "team_driving": False, "northeast_routes": False,
187
+ "hazmat_premium": True, "benefits": "excellent",
188
+ "location": random.choice(LOCATIONS),
189
+ "start_urgency": "flexible",
190
+ "requires_clean_record": True,
191
+ "requires_medical": True,
192
+ })
193
+
194
+ if random.random() < 0.5:
195
+ jobs.append(_make_trap_job(driver, 5))
196
+ else:
197
+ jobs.append({
198
+ "job_id": 5, "company": random.choice(COMPANIES),
199
+ "required_cdl": bad_cdl, "required_endorsements": [],
200
+ "min_experience": random.randint(1, 8),
201
+ "route_type": random.choice(ROUTE_TYPES),
202
+ "home_time": driver["home_time_pref"],
203
+ "pay_cpm": round(driver["min_cpm"] + random.uniform(0.05, 0.15), 2),
204
+ "equipment": driver["equipment_pref"],
205
+ "has_touch_freight": False, "forced_dispatch": False,
206
+ "team_driving": False, "northeast_routes": False,
207
+ "hazmat_premium": False, "benefits": "good",
208
+ "location": random.choice(LOCATIONS),
209
+ "start_urgency": random.choice(["immediate", "flexible"]),
210
+ "requires_clean_record": random.random() < 0.3,
211
+ "requires_medical": True,
212
+ })
213
+
214
+ random.shuffle(jobs)
215
+ for i, j in enumerate(jobs):
216
+ j["job_id"] = i
217
+ return jobs
218
+
219
+
220
+ def _make_good_job(driver, job_id):
221
+ return {
222
+ "job_id": job_id, "company": random.choice(COMPANIES),
223
+ "required_cdl": driver["cdl_class"],
224
+ "required_endorsements": [e for e in driver["endorsements"] if random.random() < 0.3],
225
+ "min_experience": max(1, driver["experience_years"] - random.randint(1, 3)),
226
+ "route_type": driver["route_pref"],
227
+ "home_time": driver["home_time_pref"],
228
+ "pay_cpm": round(driver["min_cpm"] + random.uniform(0.03, 0.12), 2),
229
+ "equipment": driver["equipment_pref"],
230
+ "has_touch_freight": False, "forced_dispatch": False,
231
+ "team_driving": False, "northeast_routes": False,
232
+ "hazmat_premium": "H" in driver.get("endorsements", []),
233
+ "benefits": random.choice(["good", "excellent"]),
234
+ "location": random.choice(LOCATIONS),
235
+ "start_urgency": random.choice(["immediate", "flexible"]),
236
+ "requires_clean_record": random.random() < 0.3,
237
+ "requires_medical": True,
238
+ }
239
+
240
+
241
+ def _make_trap_job(driver, job_id):
242
+ trap = _make_good_job(driver, job_id)
243
+ breaker = random.choice(driver["deal_breakers"])
244
+ if breaker == "touch_freight":
245
+ trap["has_touch_freight"] = True
246
+ elif breaker == "forced_dispatch":
247
+ trap["forced_dispatch"] = True
248
+ elif breaker == "team_driving":
249
+ trap["team_driving"] = True
250
+ elif breaker == "northeast":
251
+ trap["northeast_routes"] = True
252
+ elif breaker == "hazmat_no_premium":
253
+ trap["required_endorsements"] = ["H"]
254
+ trap["hazmat_premium"] = False
255
+ elif breaker == "no_benefits":
256
+ trap["benefits"] = "none"
257
+ return trap
258
+
259
+
260
+ def _make_partial_job(driver, job_id):
261
+ job = _make_good_job(driver, job_id)
262
+ if random.random() < 0.5:
263
+ job["pay_cpm"] = round(driver["min_cpm"] - random.uniform(0.01, 0.06), 2)
264
+ else:
265
+ others = [h for h in HOME_TIMES if h != driver["home_time_pref"]]
266
+ job["home_time"] = random.choice(others)
267
+ return job
268
+
269
+
270
+ def format_jobs(jobs):
271
+ lines = []
272
+ for j in jobs:
273
+ endorse = ", ".join(j["required_endorsements"]) if j["required_endorsements"] else "none"
274
+ flags = []
275
+ if j["has_touch_freight"]:
276
+ flags.append("touch freight")
277
+ if j["forced_dispatch"]:
278
+ flags.append("forced dispatch")
279
+ if j["team_driving"]:
280
+ flags.append("team driving")
281
+ if j["northeast_routes"]:
282
+ flags.append("northeast routes")
283
+ flag_str = f" [{', '.join(flags)}]" if flags else ""
284
+ urgency = j.get("start_urgency", "flexible")
285
+ clean = "clean record required" if j.get("requires_clean_record") else ""
286
+ medical = "DOT medical required" if j.get("requires_medical") else ""
287
+ reqs = ", ".join(filter(None, [clean, medical]))
288
+ req_str = f" ({reqs})" if reqs else ""
289
+ lines.append(
290
+ f"Job {j['job_id']}: {j['company']} — CDL-{j['required_cdl']}, "
291
+ f"{j['min_experience']}+ yrs, {j['route_type']}, "
292
+ f"${j['pay_cpm']}/mi, {j['home_time']} home, "
293
+ f"{j['equipment']}, endorsements: {endorse}, "
294
+ f"benefits: {j['benefits']}, start: {urgency}{req_str}{flag_str}"
295
+ )
296
+ return "\n".join(lines)
297
+
298
+
299
+ def trust_label(trust):
300
+ if trust >= 0.7:
301
+ return "high"
302
+ elif trust >= 0.4:
303
+ return "medium"
304
+ return "low"
305
+
306
+
307
+ # --- Job fit scoring ---
308
+
309
+
310
+ def score_job_fit(driver, job):
311
+ """Returns (score 0-100, issues list, fatal bool)."""
312
+ score = 100
313
+ issues = []
314
+
315
+ if driver["cdl_class"] != job["required_cdl"]:
316
+ return 0, ["CDL class mismatch"], True
317
+ if driver["experience_years"] < job["min_experience"]:
318
+ return 0, [f"Needs {job['min_experience']} yrs, driver has {driver['experience_years']}"], True
319
+ for e in job["required_endorsements"]:
320
+ if e not in driver["endorsements"]:
321
+ return 0, [f"Missing {e} endorsement"], True
322
+
323
+ if job["has_touch_freight"] and "touch_freight" in driver["deal_breakers"]:
324
+ return 0, ["Touch freight is a deal breaker"], True
325
+ if job["forced_dispatch"] and "forced_dispatch" in driver["deal_breakers"]:
326
+ return 0, ["Forced dispatch is a deal breaker"], True
327
+ if job["team_driving"] and "team_driving" in driver["deal_breakers"]:
328
+ return 0, ["Team driving is a deal breaker"], True
329
+ if job["northeast_routes"] and "northeast" in driver["deal_breakers"]:
330
+ return 0, ["Northeast routes is a deal breaker"], True
331
+ if ("H" in job["required_endorsements"] and not job["hazmat_premium"]
332
+ and "hazmat_no_premium" in driver["deal_breakers"]):
333
+ return 0, ["Hazmat without premium pay"], True
334
+ if job["benefits"] == "none" and "no_benefits" in driver["deal_breakers"]:
335
+ return 0, ["No benefits is a deal breaker"], True
336
+
337
+ if job["pay_cpm"] < driver["min_cpm"]:
338
+ diff = driver["min_cpm"] - job["pay_cpm"]
339
+ if diff > 0.10:
340
+ return 0, [f"Pay ${job['pay_cpm']}/mi way below min ${driver['min_cpm']}/mi"], True
341
+ score -= int(diff * 400)
342
+ issues.append(f"Pay is ${diff:.2f}/mi below minimum")
343
+
344
+ if job["home_time"] != driver["home_time_pref"]:
345
+ score -= 25
346
+ issues.append(f"Home time: job={job['home_time']}, wants={driver['home_time_pref']}")
347
+
348
+ if job["route_type"] != driver["route_pref"]:
349
+ score -= 15
350
+ issues.append(f"Route: job={job['route_type']}, wants={driver['route_pref']}")
351
+
352
+ if job["equipment"] != driver["equipment_pref"]:
353
+ score -= 10
354
+ issues.append(f"Equipment: job={job['equipment']}, prefers={driver['equipment_pref']}")
355
+
356
+ if job.get("requires_clean_record") and driver.get("violations") == "major":
357
+ return 0, ["Major violations disqualify for this position"], True
358
+ if job.get("requires_medical") and driver.get("medical_card") == "expired":
359
+ return 0, ["Expired DOT medical card"], True
360
+
361
+ if job.get("requires_clean_record") and driver.get("violations") == "minor":
362
+ score -= 15
363
+ issues.append("Minor violations may be a concern for clean-record position")
364
+ if driver.get("medical_card") == "expiring_soon":
365
+ score -= 5
366
+ issues.append("DOT medical card expiring soon, needs renewal")
367
+ if job.get("start_urgency") == "immediate" and driver.get("availability") == "1_month":
368
+ score -= 20
369
+ issues.append("Driver can't start for a month, job needs immediate start")
370
+ if driver.get("references") == "none":
371
+ score -= 10
372
+ issues.append("No references available")
373
+ elif driver.get("references") == "mixed":
374
+ score -= 5
375
+ issues.append("Mixed references from previous employers")
376
+
377
+ return max(0, score), issues, False
378
+
379
+
380
+ # --- Natural language response templates ---
381
+
382
+
383
+ def _respond_experience(driver):
384
+ p = driver["personality"]
385
+ cdl = driver["cdl_class"]
386
+ yrs = driver["experience_years"]
387
+ endorse = driver["endorsements"]
388
+ loc = driver["location"]
389
+ endorse_str = ", ".join(endorse) if endorse else "none"
390
+
391
+ if p == "chatty":
392
+ return (
393
+ f"Oh yeah, I've been driving for {yrs} years now! Got my CDL-{cdl} "
394
+ f"right out of school. "
395
+ f"{'I picked up my ' + endorse_str + ' endorsements along the way.' if endorse else 'No special endorsements yet but been thinking about it.'} "
396
+ f"Based out of {loc}, been here my whole life."
397
+ )
398
+ elif p == "impatient":
399
+ return f"CDL-{cdl}, {yrs} years. Endorsements: {endorse_str}. {loc}."
400
+ elif p == "suspicious":
401
+ if driver["trust"] < 0.5:
402
+ return f"I've got a CDL-{cdl}. Been driving a while, out of {loc}."
403
+ return f"CDL-{cdl}, {yrs} years experience. Endorsements: {endorse_str}. Based in {loc}."
404
+ else:
405
+ return (
406
+ f"I hold a CDL-{cdl} with {yrs} years of commercial driving experience. "
407
+ f"Endorsements: {endorse_str}. I'm located in {loc}."
408
+ )
409
+
410
+
411
+ def _respond_home_time(driver):
412
+ p = driver["personality"]
413
+ pref = driver["home_time_pref"]
414
+ templates = {
415
+ "chatty": {
416
+ "daily": "Oh yeah, I gotta be home every night. My wife would kill me otherwise! We got three kids and I help with homework every evening.",
417
+ "weekends": "I need my weekends, you know? My kids have soccer on Saturdays and church on Sundays. Weekday runs are fine though.",
418
+ "weekly": "I like to be home at least once a week. I can do a few days out but need to get back regularly.",
419
+ "biweekly": "I can do longer runs, two weeks out is fine. My buddy and I go fishing every other weekend so that works out.",
420
+ },
421
+ "impatient": {
422
+ "daily": "Home daily. Non-negotiable.",
423
+ "weekends": "Home on weekends.",
424
+ "weekly": "Home weekly.",
425
+ "biweekly": "Two weeks out is fine.",
426
+ },
427
+ "suspicious": {
428
+ "daily": "I need to be home... regularly." if driver["trust"] < 0.5 else "I need to be home every night, that's firm.",
429
+ "weekends": "I need my time off." if driver["trust"] < 0.5 else "I need to be home on weekends for my family.",
430
+ "weekly": "Can't be gone too long." if driver["trust"] < 0.5 else "I need to get home at least once a week.",
431
+ "biweekly": "I'm flexible on time out." if driver["trust"] < 0.5 else "Two weeks out, two days home works for me.",
432
+ },
433
+ "professional": {
434
+ "daily": "I'm looking for local routes that get me home every evening.",
435
+ "weekends": "I'd like to be home on weekends. Weekday runs are fine.",
436
+ "weekly": "I prefer weekly home time. A few days out, then home for a reset.",
437
+ "biweekly": "I'm comfortable with biweekly home time. I've done OTR for years.",
438
+ },
439
+ }
440
+ return templates[p][pref]
441
+
442
+
443
+ def _respond_pay(driver):
444
+ p = driver["personality"]
445
+ cpm = driver["min_cpm"]
446
+ if p == "chatty":
447
+ return f"I'm making ${cpm}/mile right now and honestly I won't move for less. If you can beat that by a few cents and throw in a decent sign-on bonus, I'm listening."
448
+ elif p == "impatient":
449
+ return f"${cpm}/mile minimum. Don't lowball me."
450
+ elif p == "suspicious":
451
+ if driver["trust"] < 0.5:
452
+ return "I need to be paid fair, you know what I'm saying? What are you offering?"
453
+ return f"Look, I need at least ${cpm}/mile. I know what I'm worth."
454
+ else:
455
+ return f"My minimum is ${cpm} per mile. I'm open to discussing total compensation including benefits."
456
+
457
+
458
+ def _respond_equipment(driver):
459
+ p = driver["personality"]
460
+ pref = driver["equipment_pref"]
461
+ pretty = pref.replace("_", " ")
462
+ if p == "chatty":
463
+ extra = " Got my tanker endorsement too so I can do that." if "N" in driver["endorsements"] else ""
464
+ return f"I've been running {pretty} mostly. Love it, got the hang of it.{extra} Wouldn't mind sticking with what I know."
465
+ elif p == "impatient":
466
+ return f"{pretty.title()}. That's what I run."
467
+ elif p == "suspicious":
468
+ if driver["trust"] < 0.5:
469
+ return "I've got experience with different trailers."
470
+ return f"I prefer {pretty}. That's where most of my experience is."
471
+ else:
472
+ return f"My primary experience is with {pretty} equipment. I'd prefer to stay in that lane."
473
+
474
+
475
+ def _respond_route(driver):
476
+ p = driver["personality"]
477
+ pref = driver["route_pref"]
478
+ routes = {
479
+ "chatty": {
480
+ "OTR": "I like the open road, OTR is my thing. See the country, you know?",
481
+ "regional": "Regional is my sweet spot. Good miles but still get home.",
482
+ "local": "Local runs for me. I know every road in this city!",
483
+ "dedicated": "Dedicated routes are great. Same customer, same lanes, no surprises.",
484
+ },
485
+ "impatient": {"OTR": "OTR.", "regional": "Regional.", "local": "Local.", "dedicated": "Dedicated."},
486
+ "suspicious": {
487
+ "OTR": ("Depends on the route." if driver["trust"] < 0.5 else "I'm looking for OTR work."),
488
+ "regional": ("Depends on the area." if driver["trust"] < 0.5 else "I'm looking for regional work."),
489
+ "local": ("I want to stay close to home." if driver["trust"] < 0.5 else "Local is what I want."),
490
+ "dedicated": ("Depends on the lanes." if driver["trust"] < 0.5 else "I prefer dedicated routes."),
491
+ },
492
+ "professional": {
493
+ "OTR": "I'm interested in OTR positions.",
494
+ "regional": "I'm looking for regional opportunities.",
495
+ "local": "I'd prefer local routes.",
496
+ "dedicated": "Dedicated lanes would be ideal.",
497
+ },
498
+ }
499
+ return routes[p][pref]
500
+
501
+
502
+ def _respond_deal_breakers(driver):
503
+ p = driver["personality"]
504
+ breakers = driver["deal_breakers"]
505
+ labels = {
506
+ "touch_freight": "touch freight",
507
+ "forced_dispatch": "forced dispatch",
508
+ "team_driving": "team driving",
509
+ "northeast": "northeast/NYC routes",
510
+ "hazmat_no_premium": "hazmat without extra pay",
511
+ "no_benefits": "no health benefits",
512
+ }
513
+ if p == "chatty":
514
+ items = [labels[b] for b in breakers]
515
+ return f"Oh man, don't even get me started. I will NOT do {', '.join(items)}. Had bad experiences with all of that."
516
+ elif p == "impatient":
517
+ return f"No {labels[breakers[0]]}. That's my line."
518
+ elif p == "suspicious":
519
+ if driver["trust"] < 0.5:
520
+ return "I've got my limits. What kind of freight are we talking about?"
521
+ items = [labels[b] for b in breakers]
522
+ return f"I won't do {', '.join(items)}. Those are hard stops for me."
523
+ else:
524
+ items = [labels[b] for b in breakers]
525
+ return f"My non-negotiables: no {', no '.join(items)}."
526
+
527
+
528
+ def _respond_availability(driver):
529
+ p = driver["personality"]
530
+ avail = driver["availability"]
531
+ labels = {"immediately": "right away", "2_weeks": "in about two weeks", "1_month": "in about a month", "negotiable": "depends on the offer"}
532
+ if p == "chatty":
533
+ if avail == "immediately":
534
+ return "I'm ready to go! Just left my last company, sitting at home going crazy. Can start tomorrow if you need me."
535
+ elif avail == "2_weeks":
536
+ return "I need to give my current place two weeks notice. They've been good to me, wanna leave right."
537
+ elif avail == "1_month":
538
+ return "It'll be about a month. I'm finishing up a contract and need to wrap some things up at home too."
539
+ else:
540
+ return "Depends on what you've got. For the right job I could move quick, otherwise I'm okay where I am."
541
+ elif p == "impatient":
542
+ return f"Can start {labels[avail]}."
543
+ elif p == "suspicious":
544
+ if driver["trust"] < 0.5:
545
+ return "Why do you need to know that already? I'll be available when I'm available."
546
+ return f"I can start {labels[avail]}."
547
+ else:
548
+ return f"I'm available to start {labels[avail]}. I can be flexible depending on the opportunity."
549
+
550
+
551
+ def _respond_violations(driver):
552
+ p = driver["personality"]
553
+ violations = driver["violations"]
554
+ if p == "chatty":
555
+ if violations == "clean":
556
+ return "Clean record, twenty years no accidents! Well, one close call in '09 but that wasn't my fault. Nothing on the record though."
557
+ elif violations == "minor":
558
+ return "I had a minor thing a while back, nothing serious. A speeding ticket in a construction zone. Learned my lesson."
559
+ else:
560
+ return "Look, I had an incident a few years ago. It was a bad situation but I've cleaned up since then. I'm a different driver now."
561
+ elif p == "impatient":
562
+ if violations == "clean":
563
+ return "Clean record."
564
+ elif violations == "minor":
565
+ return "Minor stuff, nothing serious."
566
+ else:
567
+ return "I've had some issues. It's in the past."
568
+ elif p == "suspicious":
569
+ if driver["trust"] < 0.5:
570
+ return "Why are you asking about that? My record is my business."
571
+ if violations == "clean":
572
+ return "My record is clean. You can check."
573
+ elif violations == "minor":
574
+ return "There's a minor thing on there but nothing that should matter."
575
+ else:
576
+ return "I've had some trouble before. But I've been clean for two years now."
577
+ else:
578
+ if violations == "clean":
579
+ return "I have a clean driving record with no violations or incidents."
580
+ elif violations == "minor":
581
+ return "I have a minor violation on record. I'm happy to discuss the details."
582
+ else:
583
+ return "I do have a violation on my record. I've taken corrective steps since then."
584
+
585
+
586
+ def _respond_medical_card(driver):
587
+ p = driver["personality"]
588
+ status = driver["medical_card"]
589
+ if p == "chatty":
590
+ if status == "valid":
591
+ return "Yep, DOT medical is all good! Just renewed it last month actually. Passed with flying colors."
592
+ elif status == "expiring_soon":
593
+ return "Oh yeah, I need to renew that soon actually. Thanks for reminding me. It's coming up in a few weeks."
594
+ else:
595
+ return "Ugh, yeah, it expired. I've been meaning to get that renewed. Can I still apply while I'm working on it?"
596
+ elif p == "impatient":
597
+ if status == "valid":
598
+ return "DOT medical is current."
599
+ elif status == "expiring_soon":
600
+ return "Expires soon. I'll renew it."
601
+ else:
602
+ return "It's expired. I'll get it done."
603
+ elif p == "suspicious":
604
+ if driver["trust"] < 0.5:
605
+ return "My medical stuff is between me and my doctor."
606
+ if status == "valid":
607
+ return "My DOT medical is current and valid."
608
+ elif status == "expiring_soon":
609
+ return "It's expiring soon but I've got an appointment scheduled."
610
+ else:
611
+ return "It lapsed. I can get it renewed if there's a real opportunity here."
612
+ else:
613
+ if status == "valid":
614
+ return "My DOT medical certificate is current and valid."
615
+ elif status == "expiring_soon":
616
+ return "My medical card is expiring soon. I plan to renew it promptly."
617
+ else:
618
+ return "My DOT medical has expired. I'm prepared to renew it for the right position."
619
+
620
+
621
+ def _respond_references(driver):
622
+ p = driver["personality"]
623
+ refs = driver["references"]
624
+ if p == "chatty":
625
+ if refs == "strong":
626
+ return "Oh yeah, my last dispatcher loved me! You can call anyone I've worked for. They'll all say good things."
627
+ elif refs == "mixed":
628
+ return "Most of my old bosses would say good things... I had a rough patch at one place but we parted okay."
629
+ else:
630
+ return "I've mostly done owner-operator stuff, so I don't really have traditional references. But I can show you my load history!"
631
+ elif p == "impatient":
632
+ if refs == "strong":
633
+ return "References are solid. Call whoever you want."
634
+ elif refs == "mixed":
635
+ return "Some are better than others."
636
+ else:
637
+ return "Don't have references. I work for myself."
638
+ elif p == "suspicious":
639
+ if driver["trust"] < 0.5:
640
+ return "I'm not giving you names until I know this is serious."
641
+ if refs == "strong":
642
+ return "I've got good references. I'll provide them when we're further along."
643
+ elif refs == "mixed":
644
+ return "I have some references. It depends on who you talk to."
645
+ else:
646
+ return "I don't have traditional references."
647
+ else:
648
+ if refs == "strong":
649
+ return "I have strong references from my previous employers. Happy to provide contact information."
650
+ elif refs == "mixed":
651
+ return "I can provide references. My track record has been generally positive."
652
+ else:
653
+ return "I don't have employer references available, though I can provide other professional contacts."
654
+
655
+
656
+ def _respond_pitch(driver, job):
657
+ score, issues, fatal = score_job_fit(driver, job)
658
+ if fatal:
659
+ reason = issues[0] if issues else "not a fit"
660
+ p = driver["personality"]
661
+ if p == "chatty":
662
+ return f"Nah, that's not gonna work for me. {reason}. Got anything else?"
663
+ elif p == "impatient":
664
+ return f"No. {reason}."
665
+ elif p == "suspicious":
666
+ return f"Why would you pitch me that? {reason}."
667
+ else:
668
+ return f"I'll have to pass. {reason}."
669
+ elif score >= 80:
670
+ p = driver["personality"]
671
+ if p == "chatty":
672
+ return "Now THAT sounds interesting! The pay is right, the home time works... I could see myself there."
673
+ elif p == "impatient":
674
+ return "That could work. What's next?"
675
+ elif p == "suspicious":
676
+ return "Hmm, that actually doesn't sound bad. What's the catch?"
677
+ else:
678
+ return "That aligns well with what I'm looking for. I'd like to move forward."
679
+ else:
680
+ concern = issues[0] if issues else "something's off"
681
+ p = driver["personality"]
682
+ if p == "chatty":
683
+ return f"It's close but I'm not sure... {concern}. Maybe if they could adjust something?"
684
+ elif p == "impatient":
685
+ return f"Ehh. {concern}."
686
+ elif p == "suspicious":
687
+ return f"I don't know... {concern}. What else you got?"
688
+ else:
689
+ return f"It's interesting but I have a concern: {concern}."
690
+
691
+
692
+ # --- Contact response templates ---
693
+
694
+
695
+ def _respond_contact_good(driver, topic):
696
+ p = driver["personality"]
697
+ method = "text" if topic == "greeting" else "call"
698
+ if p == "chatty":
699
+ if method == "text":
700
+ return "Hey! Yeah I got your text. I've been looking for something new actually. What do you have for me?"
701
+ return "Hello? Oh hey, yeah I was hoping someone would reach out. I'm definitely interested in hearing about opportunities."
702
+ elif p == "impatient":
703
+ if method == "text":
704
+ return "Got your text. What do you have?"
705
+ return "Yeah, I'm listening. What's the job?"
706
+ elif p == "suspicious":
707
+ if method == "text":
708
+ return "Hey. How'd you get my number? ...Okay, I'm listening I guess."
709
+ return "Who is this? ...A recruiter? Alright, what are you offering?"
710
+ else:
711
+ if method == "text":
712
+ return "Thanks for reaching out. I'm open to new opportunities. What positions do you have available?"
713
+ return "Hello, thanks for the call. I'm currently exploring new opportunities. What do you have?"
714
+
715
+
716
+ def _respond_contact_wrong(driver, topic):
717
+ p = driver["personality"]
718
+ if topic == "greeting": # texted a caller
719
+ if p == "chatty":
720
+ return "Oh hey, got your text. I usually prefer a phone call but no worries, what's up?"
721
+ elif p == "impatient":
722
+ return "Text is fine I guess. What do you want?"
723
+ elif p == "suspicious":
724
+ return "...Who is this? I don't usually respond to random texts."
725
+ else:
726
+ return "I received your message. I generally prefer a phone call, but I'm happy to chat."
727
+ else: # called a texter
728
+ if p == "chatty":
729
+ return "Oh, uh, hey. I wasn't expecting a call. I'm kinda busy, could you text me instead? ...Fine, what is it?"
730
+ elif p == "impatient":
731
+ return "I don't pick up unknown numbers usually. Should've texted. What do you want?"
732
+ elif p == "suspicious":
733
+ return "Who is this? I don't answer calls from numbers I don't know."
734
+ else:
735
+ return "Hello. I prefer to communicate via text if possible. But go ahead, what do you have?"
736
+
737
+
738
+ def _respond_contact_repeat(driver):
739
+ p = driver["personality"]
740
+ if p == "chatty":
741
+ return "You already reached out to me! What else do you need?"
742
+ elif p == "impatient":
743
+ return "You already contacted me. What now?"
744
+ elif p == "suspicious":
745
+ return "Why are you contacting me again? We already talked."
746
+ else:
747
+ return "We've already been in touch. What's the next step?"
748
+
749
+
750
+ def _respond_repeat_question(driver, topic):
751
+ p = driver["personality"]
752
+ if p == "chatty":
753
+ return f"Didn't I already tell you about my {topic}? I feel like we covered that!"
754
+ elif p == "impatient":
755
+ return f"I already answered that. Pay attention."
756
+ elif p == "suspicious":
757
+ return f"You already asked me about {topic}. Why are you asking again?"
758
+ else:
759
+ return f"I believe I already shared my {topic} preferences with you."
760
+
761
+
762
+ # --- Offer/submit response templates ---
763
+
764
+
765
+ def _respond_offer_accept(driver, job):
766
+ p = driver["personality"]
767
+ company = job["company"]
768
+ if p == "chatty":
769
+ return f"Awesome! {company} sounds great, I'm excited to get started. Thanks for finding this for me!"
770
+ elif p == "impatient":
771
+ return f"Good. {company}. When do I start?"
772
+ elif p == "suspicious":
773
+ return f"Alright, {company} it is. I hope this works out. Thanks."
774
+ else:
775
+ return f"Thank you for the placement at {company}. I'm looking forward to getting started."
776
+
777
+
778
+ def _respond_offer_concerns(driver, job, concern):
779
+ p = driver["personality"]
780
+ company = job["company"]
781
+ if p == "chatty":
782
+ return f"I mean, {company} is okay I guess. {concern} bugs me a little but maybe we can work something out?"
783
+ elif p == "impatient":
784
+ return f"Ehh. {concern}. Can you fix that?"
785
+ elif p == "suspicious":
786
+ return f"I'm not fully sold on {company}. {concern}. What are you going to do about it?"
787
+ else:
788
+ return f"I have a concern about the {company} position: {concern}. Can we discuss?"
789
+
790
+
791
+ def _respond_offer_reject(driver, reason):
792
+ p = driver["personality"]
793
+ if p == "chatty":
794
+ return f"Yeah no, I can't do that. {reason}. I thought we talked about this?"
795
+ elif p == "impatient":
796
+ return f"No. {reason}. I'm done here."
797
+ elif p == "suspicious":
798
+ return f"Are you serious? {reason}. I knew this was a waste of my time."
799
+ else:
800
+ return f"I'm going to have to withdraw. {reason}. This isn't what we discussed."
801
+
802
+
803
+ def _respond_ghosted(driver):
804
+ p = driver["personality"]
805
+ name = driver["name"].split()[0]
806
+ if p == "chatty":
807
+ return f"{name} stopped responding to your messages. Last seen: 'idk man this isn't working out...'"
808
+ elif p == "impatient":
809
+ return f"{name} blocked your number."
810
+ elif p == "suspicious":
811
+ return f"{name} stopped responding. They were never fully comfortable with the process."
812
+ else:
813
+ return f"{name} sent a polite message saying they've decided to go with another recruiter."
814
+
815
+
816
+ # --- Negotiation helpers ---
817
+
818
+
819
+ def _get_negotiation_concerns(driver, job):
820
+ _, issues, _ = score_job_fit(driver, job)
821
+ return issues
822
+
823
+
824
+ def _respond_negotiation(driver, action, job, concerns):
825
+ p = driver["personality"]
826
+
827
+ if action == "negotiate_pay":
828
+ if any("pay" in c.lower() for c in concerns):
829
+ if p == "chatty":
830
+ return "Well, if you can get them to bump it up a few cents, I'd feel a lot better about this."
831
+ elif p == "impatient":
832
+ return "More money would help. Get it done."
833
+ elif p == "suspicious":
834
+ return "I'll believe a pay bump when I see it in writing."
835
+ else:
836
+ return "I'd appreciate if you could negotiate a higher rate."
837
+ else:
838
+ return "Pay isn't really my concern here."
839
+
840
+ elif action == "negotiate_home_time":
841
+ if any("home time" in c.lower() for c in concerns):
842
+ if p == "chatty":
843
+ return "Yeah, if they could work with my schedule that would change everything. Talk to them?"
844
+ elif p == "impatient":
845
+ return "Fix the home time and we'll talk."
846
+ elif p == "suspicious":
847
+ return "They always say they'll adjust the schedule. Will they actually?"
848
+ else:
849
+ return "If the home time can be adjusted, I'd be much more interested."
850
+ else:
851
+ return "Home time isn't really my issue here."
852
+
853
+ elif action == "signing_bonus":
854
+ if p == "chatty":
855
+ return "A signing bonus? Hey, that's nice! Doesn't fix everything but it helps."
856
+ elif p == "impatient":
857
+ return "Bonus is fine. What about the real issues?"
858
+ elif p == "suspicious":
859
+ return "Bonuses are nice but they don't solve long-term problems."
860
+ else:
861
+ return "I appreciate the signing bonus offer. It's a positive gesture."
862
+
863
+ elif action == "address_concern":
864
+ if concerns:
865
+ if p == "chatty":
866
+ return f"Yeah, my big thing is: {concerns[0]}. If you can work that out, I'm in."
867
+ elif p == "impatient":
868
+ return f"{concerns[0]}. Fix it."
869
+ elif p == "suspicious":
870
+ if driver["trust"] < 0.4:
871
+ return "I've told you my concerns. Are you actually going to do something about them?"
872
+ return f"Fine, here's what bothers me: {concerns[0]}."
873
+ else:
874
+ return f"My primary concern is: {concerns[0]}. I'd need that resolved."
875
+ else:
876
+ return "I don't really have any major concerns. I think we're good."
877
+
878
+ return "I'm not sure what you mean."
879
+
880
+
881
+ # --- CRM formatting ---
882
+
883
+
884
+ def _api(code, **kwargs):
885
+ """Format a JSON API response with status code."""
886
+ return json.dumps({"code": code, **kwargs})
887
+
888
+
889
+ def format_crm(crm):
890
+ """Format CRM record into readable string."""
891
+ lines = [f"Name: {crm['name']}", f"Stage: {crm['stage']}"]
892
+ if crm["fields"]:
893
+ lines.append("Fields:")
894
+ for k, v in sorted(crm["fields"].items()):
895
+ lines.append(f" {k}: {v}")
896
+ else:
897
+ lines.append("Fields: (none recorded)")
898
+ if crm["notes"]:
899
+ lines.append("Notes:")
900
+ for n in crm["notes"]:
901
+ lines.append(f" - {n}")
902
+ return "\n".join(lines)
903
+
904
+
905
+ # --- Environment ---
906
+
907
+
908
+ class RecruitopenenvEnvironment(Environment):
909
+ """Driver recruiting environment with tool-based long-horizon interaction."""
910
+
911
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
912
+
913
+ def __init__(self):
914
+ self._state = State(episode_id=str(uuid4()), step_count=0)
915
+ self._driver = {}
916
+ self._jobs = []
917
+ # CRM state
918
+ self._crm = {"name": "", "stage": "lead", "fields": {}, "notes": []}
919
+ self._has_read_crm = False
920
+ self._crm_read_count = 0
921
+ # Messaging state
922
+ self._pending_reply = None # (response_text, topic)
923
+ self._contacted = False
924
+ self._asked = set()
925
+ self._discovered_info = []
926
+ # Approval state
927
+ self._approval_status = "none"
928
+ self._approval_job_id = -1
929
+ # Negotiation state
930
+ self._matched_job_id = -1
931
+ self._negotiation_round = 0
932
+ self._negotiation_score_bonus = 0
933
+ self._negotiation_concerns = []
934
+ # Interaction tracking
935
+ self._last_contact_step = 0
936
+
937
+ def _make_obs(self, reward=0.0, done=False, feedback=""):
938
+ return RecruitopenenvObservation(
939
+ driver_name=self._driver.get("name", ""),
940
+ crm_summary=format_crm(self._crm) if self._has_read_crm else "",
941
+ jobs_summary=format_jobs(self._jobs) if self._jobs else "",
942
+ discovered_info="\n".join(self._discovered_info),
943
+ stage=self._crm["stage"],
944
+ feedback=feedback,
945
+ pending_reply=self._pending_reply is not None,
946
+ done=done,
947
+ reward=reward,
948
+ )
949
+
950
+ def reset(self, seed: int = None) -> RecruitopenenvObservation:
951
+ if seed is not None:
952
+ random.seed(seed)
953
+ self._state = State(episode_id=str(uuid4()), step_count=0)
954
+ self._driver = generate_driver()
955
+ self._jobs = generate_jobs(self._driver)
956
+ self._crm = {"name": self._driver["name"], "stage": "lead", "fields": {}, "notes": []}
957
+ self._has_read_crm = False
958
+ self._crm_read_count = 0
959
+ self._pending_reply = None
960
+ self._contacted = False
961
+ self._asked = set()
962
+ self._discovered_info = []
963
+ self._approval_status = "none"
964
+ self._approval_job_id = -1
965
+ self._matched_job_id = -1
966
+ self._negotiation_round = 0
967
+ self._negotiation_score_bonus = 0
968
+ self._negotiation_concerns = []
969
+ self._last_contact_step = 0
970
+
971
+ return self._make_obs(
972
+ feedback=_api(200, driver=self._driver["name"], jobs=len(self._jobs))
973
+ )
974
+
975
+ def step(self, action: RecruitopenenvAction) -> RecruitopenenvObservation:
976
+ if not self._driver:
977
+ return self._make_obs(reward=0.0, done=True, feedback=_api(400, error="no_episode"))
978
+
979
+ tool = action.tool
980
+ act = action.action
981
+
982
+ # Validate tool+action
983
+ if tool not in VALID_TOOL_ACTIONS:
984
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_tool", tool=tool))
985
+ if act not in VALID_TOOL_ACTIONS[tool]:
986
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", tool=tool, action=act))
987
+
988
+ # Check terminal
989
+ if self._crm["stage"] in ("hired", "lost", "ghosted"):
990
+ return self._make_obs(reward=0.0, done=True, feedback=_api(400, error="episode_ended"))
991
+
992
+ self._state.step_count += 1
993
+
994
+ if self._state.step_count >= MAX_STEPS:
995
+ self._crm["stage"] = "ghosted"
996
+ return self._make_obs(reward=-3.0, done=True, feedback=_api(200, result="ghosted", reason="timeout"))
997
+
998
+ # Passive trust decay — driver loses patience while recruiter isn't talking to them
999
+ idle_gap = self._state.step_count - self._last_contact_step
1000
+ if idle_gap > 2:
1001
+ # Accelerating decay: longer silence = faster trust loss
1002
+ idle_decay = 0.01 * (idle_gap - 2)
1003
+ self._driver["trust"] = max(0.0, self._driver["trust"] - idle_decay)
1004
+ if self._driver["trust"] <= 0.1:
1005
+ self._crm["stage"] = "ghosted"
1006
+ return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
1007
+
1008
+ # Route to handler
1009
+ if tool == "crm":
1010
+ return self._handle_crm(act, action)
1011
+ elif tool == "messaging":
1012
+ return self._handle_messaging(act, action)
1013
+ elif tool == "approval":
1014
+ return self._handle_approval(act, action)
1015
+ elif tool == "workflow":
1016
+ return self._handle_workflow(act, action)
1017
+
1018
+ return self._make_obs(reward=-1.0, feedback=_api(500, error="internal_error"))
1019
+
1020
+ # --- CRM tool ---
1021
+
1022
+ def _handle_crm(self, act, action):
1023
+ if act == "read_candidate":
1024
+ self._has_read_crm = True
1025
+ self._crm_read_count += 1
1026
+ reward = 0.0 if self._crm_read_count <= 1 else -0.1
1027
+ return self._make_obs(reward=reward, feedback=_api(200, data=self._crm))
1028
+
1029
+ elif act == "update_stage":
1030
+ new_stage = action.stage
1031
+ current = self._crm["stage"]
1032
+
1033
+ if new_stage not in ALL_STAGES:
1034
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_stage", stage=new_stage))
1035
+
1036
+ # Compute penalty for non-ideal transitions
1037
+ penalty = 0.0
1038
+ if new_stage not in ("lost", "ghosted"):
1039
+ cur_idx = STAGE_ORDER.index(current) if current in STAGE_ORDER else -1
1040
+ new_idx = STAGE_ORDER.index(new_stage) if new_stage in STAGE_ORDER else -1
1041
+ if new_idx >= 0 and cur_idx >= 0:
1042
+ diff = new_idx - cur_idx
1043
+ if diff == 0:
1044
+ # Same stage — wasted action
1045
+ penalty = -0.1
1046
+ elif diff == 1:
1047
+ # Correct next stage — no penalty
1048
+ penalty = 0.0
1049
+ elif diff > 1:
1050
+ # Skipping stages forward — penalize per skip
1051
+ penalty = -0.5 * (diff - 1)
1052
+ else:
1053
+ # Going backwards — heavier penalty
1054
+ penalty = -1.0 * abs(diff)
1055
+
1056
+ self._crm["stage"] = new_stage
1057
+ if new_stage == "hired":
1058
+ return self._finalize_hire(penalty)
1059
+ if new_stage == "lost":
1060
+ return self._finalize_lost(penalty)
1061
+ return self._make_obs(reward=0.0 + penalty, feedback=_api(200, stage=new_stage))
1062
+
1063
+ elif act == "update_field":
1064
+ field = action.field
1065
+ if field not in VALID_CRM_FIELDS:
1066
+ return self._make_obs(reward=-0.5, feedback=_api(400, error="unknown_field", field=field))
1067
+ self._crm["fields"][field] = action.value
1068
+ return self._make_obs(reward=0.0, feedback=_api(200, field=field, value=action.value))
1069
+
1070
+ elif act == "add_note":
1071
+ if not action.value:
1072
+ return self._make_obs(reward=-0.5, feedback=_api(400, error="empty_note"))
1073
+ self._crm["notes"].append(action.value)
1074
+ return self._make_obs(reward=0.0, feedback=_api(200, notes=len(self._crm["notes"])))
1075
+
1076
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
1077
+
1078
+ # --- Messaging tool ---
1079
+
1080
+ def _handle_messaging(self, act, action):
1081
+ if act == "send_message":
1082
+ topic = action.topic
1083
+
1084
+ # Invalid topic — message still reaches driver, they're confused
1085
+ if topic not in VALID_TOPICS:
1086
+ self._last_contact_step = self._state.step_count
1087
+ self._driver["trust"] = max(0.0, self._driver["trust"] - self._driver["decay"] * 2)
1088
+ if self._driver["trust"] <= 0.1:
1089
+ self._crm["stage"] = "ghosted"
1090
+ return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
1091
+ self._pending_reply = ("I'm not sure what you're asking about.", topic)
1092
+ return self._make_obs(reward=-1.0, feedback=_api(200, topic=topic, warning="driver_confused"))
1093
+
1094
+ # Penalty for skipping CRM read, but still send
1095
+ penalty = 0.0
1096
+ if not self._has_read_crm:
1097
+ penalty -= 1.0
1098
+ # Penalty for ignoring pending reply (overwrite it), but still send
1099
+ if self._pending_reply is not None:
1100
+ penalty -= 1.0
1101
+ self._pending_reply = None
1102
+
1103
+ self._last_contact_step = self._state.step_count
1104
+
1105
+ # Trust decay on each message
1106
+ self._driver["trust"] = max(0.0, self._driver["trust"] - self._driver["decay"])
1107
+
1108
+ # Trust dropout check
1109
+ if self._driver["trust"] <= 0.1:
1110
+ self._crm["stage"] = "ghosted"
1111
+ return self._make_obs(reward=-4.0, done=True, feedback=_api(200, result="ghosted", message=_respond_ghosted(self._driver)))
1112
+
1113
+ # Generate response based on topic
1114
+ response, reward = self._generate_message_response(topic, action.job_id)
1115
+ if response is None:
1116
+ return self._make_obs(reward=reward + penalty, feedback=_api(404, error="no_valid_target", topic=topic))
1117
+ if response == "NEGOTIATION_EXHAUSTED":
1118
+ self._crm["stage"] = "lost"
1119
+ return self._make_obs(reward=reward + penalty, done=True, feedback=_api(200, result="lost", reason="negotiation_exhausted"))
1120
+ self._pending_reply = (response, topic)
1121
+ return self._make_obs(reward=reward + penalty, feedback=_api(200, topic=topic))
1122
+
1123
+ elif act == "read_reply":
1124
+ if self._pending_reply is None:
1125
+ return self._make_obs(reward=-0.5, feedback=_api(200, reply=None))
1126
+ self._last_contact_step = self._state.step_count
1127
+
1128
+ response, topic = self._pending_reply
1129
+ self._pending_reply = None
1130
+
1131
+ # Auto-add to discovered info for screening topics
1132
+ if topic in SCREENING_TOPICS:
1133
+ self._discovered_info.append(f"[{topic.upper().replace('_', ' ')}] {response}")
1134
+ self._asked.add(f"ask_{topic}")
1135
+ elif topic == "pitch":
1136
+ self._discovered_info.append(f"[PITCH] {response}")
1137
+ elif topic in ("negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern"):
1138
+ self._discovered_info.append(f"[NEGOTIATE: {topic.replace('_', ' ')}] {response}")
1139
+ elif topic == "offer":
1140
+ self._discovered_info.append(f"[OFFER] {response}")
1141
+
1142
+ return self._make_obs(reward=0.0, feedback=_api(200, topic=topic, reply=response))
1143
+
1144
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
1145
+
1146
+ def _generate_message_response(self, topic, job_id):
1147
+ """Generate driver's response to a message. Returns (response, reward)."""
1148
+ reward = -0.1 # base step cost
1149
+
1150
+ # --- Contact topics ---
1151
+ if topic in ("greeting", "call"):
1152
+ if self._contacted:
1153
+ return _respond_contact_repeat(self._driver), -1.0
1154
+ self._contacted = True
1155
+ pref = self._driver["preferred_contact"]
1156
+ matches = (topic == "greeting" and pref == "text") or (topic == "call" and pref == "call")
1157
+ if matches:
1158
+ self._driver["trust"] = min(1.0, self._driver["trust"] + 0.15)
1159
+ return _respond_contact_good(self._driver, topic), 1.0
1160
+ else:
1161
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.10)
1162
+ return _respond_contact_wrong(self._driver, topic), -0.3
1163
+
1164
+ # --- Screening topics ---
1165
+ if topic in SCREENING_TOPICS:
1166
+ if not self._contacted:
1167
+ # Still works but driver is cold — penalty
1168
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.15)
1169
+ ask_key = f"ask_{topic}"
1170
+ if ask_key in self._asked:
1171
+ return _respond_repeat_question(self._driver, topic.replace("_", " ")), -0.5
1172
+
1173
+ respond_map = {
1174
+ "experience": _respond_experience,
1175
+ "home_time": _respond_home_time,
1176
+ "pay": _respond_pay,
1177
+ "equipment": _respond_equipment,
1178
+ "route": _respond_route,
1179
+ "deal_breakers": _respond_deal_breakers,
1180
+ "availability": _respond_availability,
1181
+ "violations": _respond_violations,
1182
+ "medical_card": _respond_medical_card,
1183
+ "references": _respond_references,
1184
+ }
1185
+ response = respond_map[topic](self._driver)
1186
+ penalty = -1.0 if not self._contacted else -0.1
1187
+ return response, penalty
1188
+
1189
+ # --- Pitch ---
1190
+ if topic == "pitch":
1191
+ if not self._contacted:
1192
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.15)
1193
+ matching = [j for j in self._jobs if j["job_id"] == job_id]
1194
+ if not matching:
1195
+ # No match — pick nothing, return None (will be caught by handler)
1196
+ return None, -1.0
1197
+ penalty = -1.0 if not self._contacted else -0.1
1198
+ return _respond_pitch(self._driver, matching[0]), penalty
1199
+
1200
+ # --- Offer ---
1201
+ if topic == "offer":
1202
+ penalty = 0.0
1203
+ if self._approval_status != "approved":
1204
+ # Allowed but heavy penalty — driver gets confused
1205
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.2)
1206
+ penalty = -2.0
1207
+ job_id_to_use = self._approval_job_id if job_id < 0 else job_id
1208
+ matching = [j for j in self._jobs if j["job_id"] == job_id_to_use]
1209
+ if not matching:
1210
+ return None, -1.0 + penalty
1211
+ job = matching[0]
1212
+ self._matched_job_id = job_id_to_use
1213
+ score, issues, fatal = score_job_fit(self._driver, job)
1214
+ if not fatal:
1215
+ score = min(100, score + self._negotiation_score_bonus)
1216
+ if fatal:
1217
+ return _respond_offer_reject(self._driver, issues[0]), -0.5 + penalty
1218
+ elif score >= 70:
1219
+ return _respond_offer_accept(self._driver, job), 0.0 + penalty
1220
+ elif score >= 50:
1221
+ concern = issues[0] if issues else "minor concerns"
1222
+ self._negotiation_concerns = issues
1223
+ return _respond_offer_concerns(self._driver, job, concern), 0.0 + penalty
1224
+ else:
1225
+ return _respond_offer_reject(self._driver, issues[0] if issues else "not a fit"), -0.5 + penalty
1226
+
1227
+ # --- Negotiation topics ---
1228
+ if topic in ("negotiate_pay", "negotiate_home_time", "signing_bonus", "address_concern"):
1229
+ if self._matched_job_id < 0 and self._approval_job_id >= 0:
1230
+ self._matched_job_id = self._approval_job_id
1231
+ if self._matched_job_id < 0:
1232
+ return None, -1.0
1233
+ if self._negotiation_round >= 5:
1234
+ return "NEGOTIATION_EXHAUSTED", -2.0
1235
+
1236
+ self._negotiation_round += 1
1237
+ matches = [j for j in self._jobs if j["job_id"] == self._matched_job_id]
1238
+ if not matches:
1239
+ return None, -1.0
1240
+ job = matches[0]
1241
+ if not self._negotiation_concerns:
1242
+ self._negotiation_concerns = _get_negotiation_concerns(self._driver, job)
1243
+ response = _respond_negotiation(self._driver, topic, job, self._negotiation_concerns)
1244
+
1245
+ # Score bonus
1246
+ if topic == "address_concern" and self._negotiation_concerns:
1247
+ self._negotiation_score_bonus += 15
1248
+ self._negotiation_concerns.pop(0)
1249
+ elif topic == "negotiate_pay" and any("pay" in c.lower() for c in self._negotiation_concerns):
1250
+ self._negotiation_score_bonus += 10
1251
+ self._negotiation_concerns = [c for c in self._negotiation_concerns if "pay" not in c.lower()]
1252
+ elif topic == "negotiate_home_time" and any("home time" in c.lower() for c in self._negotiation_concerns):
1253
+ self._negotiation_score_bonus += 10
1254
+ self._negotiation_concerns = [c for c in self._negotiation_concerns if "home time" not in c.lower()]
1255
+ elif topic == "signing_bonus":
1256
+ self._negotiation_score_bonus += 5
1257
+ else:
1258
+ self._negotiation_score_bonus += 2
1259
+
1260
+ # Extra trust decay during negotiation
1261
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.01)
1262
+ return response, -0.1
1263
+
1264
+ return None, -1.0
1265
+
1266
+ # --- Approval tool ---
1267
+
1268
+ def _handle_approval(self, act, action):
1269
+ if act == "request_approval":
1270
+ if action.job_id < 0:
1271
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="job_id_required"))
1272
+ matching = [j for j in self._jobs if j["job_id"] == action.job_id]
1273
+ if not matching:
1274
+ return self._make_obs(reward=-1.0, feedback=_api(404, error="job_not_found", job_id=action.job_id))
1275
+ # Allow re-request but penalize — resets approval
1276
+ penalty = -0.5 if self._approval_status in ("pending", "approved") else 0.0
1277
+ self._approval_status = "pending"
1278
+ self._approval_job_id = action.job_id
1279
+ return self._make_obs(reward=0.0 + penalty, feedback=_api(202, approval_status="pending", job_id=action.job_id))
1280
+
1281
+ elif act == "check_approval":
1282
+ if self._approval_status == "none":
1283
+ return self._make_obs(reward=-0.5, feedback=_api(200, approval_status="none"))
1284
+ if self._approval_status == "pending":
1285
+ return self._make_obs(reward=-0.1, feedback=_api(202, approval_status="pending"))
1286
+ return self._make_obs(
1287
+ reward=0.5 if self._approval_status == "approved" else -0.5,
1288
+ feedback=_api(200, approval_status=self._approval_status, job_id=self._approval_job_id)
1289
+ )
1290
+
1291
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
1292
+
1293
+ # --- Workflow tool ---
1294
+
1295
+ def _handle_workflow(self, act, action):
1296
+ if act == "wait":
1297
+ if self._approval_status == "pending":
1298
+ # Process approval based on job quality
1299
+ job = [j for j in self._jobs if j["job_id"] == self._approval_job_id]
1300
+ if job:
1301
+ score, _, fatal = score_job_fit(self._driver, job[0])
1302
+ if fatal:
1303
+ self._approval_status = "denied"
1304
+ else:
1305
+ self._approval_status = "approved"
1306
+ else:
1307
+ self._approval_status = "denied"
1308
+ return self._make_obs(reward=0.0, feedback=_api(200, elapsed="1h"))
1309
+
1310
+ # Generic wait — trust decay + penalty for wasting time
1311
+ self._driver["trust"] = max(0.0, self._driver["trust"] - 0.02)
1312
+ return self._make_obs(reward=-0.5, feedback=_api(200, elapsed="1h"))
1313
+
1314
+ return self._make_obs(reward=-1.0, feedback=_api(400, error="unknown_action", action=act))
1315
+
1316
+ # --- Terminal handlers ---
1317
+
1318
+ def _score_crm(self):
1319
+ """Score CRM accuracy — compare recorded fields to ground truth."""
1320
+ ground_truth = {
1321
+ "cdl_class": self._driver["cdl_class"],
1322
+ "years_experience": str(self._driver["experience_years"]),
1323
+ "location": self._driver["location"],
1324
+ "home_time_pref": self._driver["home_time_pref"],
1325
+ "pay_expectation": str(self._driver["min_cpm"]),
1326
+ "equipment_pref": self._driver["equipment_pref"],
1327
+ "route_pref": self._driver["route_pref"],
1328
+ "availability": self._driver["availability"],
1329
+ "violations": self._driver["violations"],
1330
+ "medical_card": self._driver["medical_card"],
1331
+ "references": self._driver["references"],
1332
+ }
1333
+ # Endorsements and deal_breakers are lists — normalize
1334
+ ground_truth["endorsements"] = ", ".join(sorted(self._driver["endorsements"])) if self._driver["endorsements"] else "none"
1335
+ ground_truth["deal_breakers"] = ", ".join(sorted(self._driver["deal_breakers"]))
1336
+
1337
+ score = 0.0
1338
+ for field, truth in ground_truth.items():
1339
+ recorded = self._crm["fields"].get(field, "")
1340
+ if not recorded:
1341
+ continue
1342
+ # Exact match (case-insensitive)
1343
+ if recorded.strip().lower() == truth.lower():
1344
+ score += 0.4
1345
+ # Partial match — truth appears in recorded or vice versa
1346
+ elif truth.lower() in recorded.strip().lower() or recorded.strip().lower() in truth.lower():
1347
+ score += 0.2
1348
+ else:
1349
+ # Wrong value recorded — small penalty
1350
+ score -= 0.1
1351
+
1352
+ # Small bonus for notes (shows diligence)
1353
+ score += min(0.5, len(self._crm["notes"]) * 0.1)
1354
+
1355
+ # Cap: up to 5.0 bonus for perfect CRM (13 fields × 0.4 = 5.2)
1356
+ return max(0.0, min(5.0, score))
1357
+
1358
+ def _finalize_hire(self, stage_penalty=0.0):
1359
+ """Handle stage transition to hired — compute final reward."""
1360
+ crm_bonus = self._score_crm()
1361
+
1362
+ if self._approval_status != "approved":
1363
+ self._crm["stage"] = "lost"
1364
+ return self._make_obs(
1365
+ reward=-5.0 + stage_penalty, done=True,
1366
+ feedback=_api(200, result="lost", reason="no_approval")
1367
+ )
1368
+
1369
+ job_id = self._approval_job_id
1370
+ matching = [j for j in self._jobs if j["job_id"] == job_id]
1371
+ if not matching:
1372
+ self._crm["stage"] = "lost"
1373
+ return self._make_obs(
1374
+ reward=-5.0 + stage_penalty, done=True,
1375
+ feedback=_api(200, result="lost", reason="no_job")
1376
+ )
1377
+
1378
+ job = matching[0]
1379
+ score, issues, fatal = score_job_fit(self._driver, job)
1380
+ if not fatal:
1381
+ score = min(100, score + self._negotiation_score_bonus)
1382
+
1383
+ if fatal:
1384
+ self._crm["stage"] = "lost"
1385
+ return self._make_obs(
1386
+ reward=-5.0 + stage_penalty, done=True,
1387
+ feedback=_api(200, result="rejected", reason=issues[0], job_id=job_id)
1388
+ )
1389
+ elif score >= 70:
1390
+ return self._make_obs(
1391
+ reward=10.0 + crm_bonus + stage_penalty, done=True,
1392
+ feedback=_api(200, result="hired", job_id=job_id, score=score, crm_bonus=round(crm_bonus, 1))
1393
+ )
1394
+ elif score >= 50:
1395
+ return self._make_obs(
1396
+ reward=4.0 + crm_bonus + stage_penalty, done=True,
1397
+ feedback=_api(200, result="hired_with_reservations", job_id=job_id, score=score, concern=issues[0] if issues else "minor")
1398
+ )
1399
+ else:
1400
+ self._crm["stage"] = "lost"
1401
+ return self._make_obs(
1402
+ reward=-5.0 + stage_penalty, done=True,
1403
+ feedback=_api(200, result="rejected", reason=issues[0] if issues else "poor_fit", job_id=job_id)
1404
+ )
1405
+
1406
+ def _finalize_lost(self, stage_penalty=0.0):
1407
+ """Handle stage transition to lost."""
1408
+ has_good = any(score_job_fit(self._driver, j)[0] >= 70 for j in self._jobs)
1409
+ if has_good:
1410
+ return self._make_obs(
1411
+ reward=-3.0 + stage_penalty, done=True,
1412
+ feedback=_api(200, result="lost", good_match_existed=True)
1413
+ )
1414
+ else:
1415
+ return self._make_obs(
1416
+ reward=1.0 + stage_penalty, done=True,
1417
+ feedback=_api(200, result="lost", good_match_existed=False)
1418
+ )
1419
+
1420
+ @property
1421
+ def state(self) -> State:
1422
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
train_colab.ipynb ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🚛 Driver Recruit Environment — RL Training with TRL\n",
8
+ "\n",
9
+ "Train a 3B LLM to recruit truck drivers using REINFORCE with TRL.\n",
10
+ "\n",
11
+ "The model learns to choose the right screening topics to ask drivers,\n",
12
+ "then the auto-pilot handles CRM updates, approval, and hiring.\n",
13
+ "\n",
14
+ "**Environment**: [OpenEnv 0.2.1](https://github.com/meta-pytorch/OpenEnv) deployed on HF Spaces\n",
15
+ "\n",
16
+ "**Model**: Qwen/Qwen2.5-3B-Instruct\n",
17
+ "\n",
18
+ "**Algorithm**: REINFORCE with batch-level advantage normalization"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## 1. Install Dependencies"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install -q openenv-core[core]==0.2.1 trl transformers torch accelerate"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "## 2. Connect to the Environment\n",
42
+ "\n",
43
+ "The recruiting environment is deployed on HF Spaces. Replace the URL below with your Space URL."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import json\n",
53
+ "import random\n",
54
+ "import re\n",
55
+ "\n",
56
+ "import torch\n",
57
+ "import torch.nn.functional as F\n",
58
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
59
+ "\n",
60
+ "# --- Connect to environment ---\n",
61
+ "# Replace with your HF Space URL\n",
62
+ "ENV_URL = \"https://YOUR-USERNAME-recruitopenenv.hf.space\" # <-- CHANGE THIS\n",
63
+ "\n",
64
+ "from openenv.client import EnvClient\n",
65
+ "\n",
66
+ "# Quick test: reset and check the env is alive\n",
67
+ "import requests\n",
68
+ "resp = requests.post(f\"{ENV_URL}/reset\", json={\"seed\": 42})\n",
69
+ "data = resp.json()\n",
70
+ "print(f\"Driver: {data['observation']['driver_name']}\")\n",
71
+ "print(f\"Stage: {data['observation']['stage']}\")\n",
72
+ "print(\"Environment connected!\")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "## 3. Environment Helper Functions"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "def env_reset(seed=None):\n",
89
+ " \"\"\"Reset environment via HTTP.\"\"\"\n",
90
+ " payload = {\"seed\": seed} if seed else {}\n",
91
+ " resp = requests.post(f\"{ENV_URL}/reset\", json=payload)\n",
92
+ " return resp.json()\n",
93
+ "\n",
94
+ "def env_step(tool, action, **kwargs):\n",
95
+ " \"\"\"Step environment via HTTP.\"\"\"\n",
96
+ " payload = {\"tool\": tool, \"action\": action, **kwargs}\n",
97
+ " resp = requests.post(f\"{ENV_URL}/step\", json=payload)\n",
98
+ " return resp.json()\n",
99
+ "\n",
100
+ "# --- Topic-based auto-pilot ---\n",
101
+ "SCREENING_TOPICS = [\n",
102
+ " \"experience\", \"home_time\", \"pay\", \"equipment\", \"route\",\n",
103
+ " \"deal_breakers\", \"availability\", \"violations\", \"medical_card\", \"references\",\n",
104
+ "]\n",
105
+ "\n",
106
+ "SYSTEM_PROMPT = \"\"\"You are a truck driver recruiter screening a candidate. Choose the next topic to discuss.\n",
107
+ "\n",
108
+ "Topics for first contact: greeting (text), call (phone)\n",
109
+ "Screening topics: experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references\n",
110
+ "Say \"done\" when you have enough info to proceed with hiring.\n",
111
+ "\n",
112
+ "Respond with ONLY the topic name, nothing else.\"\"\"\n",
113
+ "\n",
114
+ "ALL_TOPICS = [\"greeting\", \"call\"] + SCREENING_TOPICS + [\"done\"]\n",
115
+ "\n",
116
+ "def parse_topic(text):\n",
117
+ " \"\"\"Extract topic name from model output.\"\"\"\n",
118
+ " text = text.strip().lower().replace('\"', '').replace(\"'\", \"\")\n",
119
+ " text = text.split(\"\\n\")[0].strip().split(\".\")[0].strip()\n",
120
+ " for topic in ALL_TOPICS:\n",
121
+ " if topic in text or topic.replace(\"_\", \" \") in text:\n",
122
+ " return topic\n",
123
+ " if \"deal\" in text: return \"deal_breakers\"\n",
124
+ " if \"home\" in text: return \"home_time\"\n",
125
+ " if \"medical\" in text: return \"medical_card\"\n",
126
+ " return \"done\"\n",
127
+ "\n",
128
+ "def build_prompt(obs, asked):\n",
129
+ " \"\"\"Build prompt showing state and available topics.\"\"\"\n",
130
+ " parts = [f\"Driver: {obs['driver_name']}\"]\n",
131
+ " if obs.get('jobs_summary'):\n",
132
+ " parts.append(f\"Jobs:\\n{obs['jobs_summary']}\")\n",
133
+ " if obs.get('discovered_info'):\n",
134
+ " parts.append(f\"Discovered:\\n{obs['discovered_info']}\")\n",
135
+ " parts.append(f\"Stage: {obs['stage']}\")\n",
136
+ " if asked:\n",
137
+ " parts.append(f\"Already asked: {', '.join(asked)}\")\n",
138
+ " available = [t for t in ALL_TOPICS if t not in asked]\n",
139
+ " parts.append(f\"Available: {', '.join(available)}\")\n",
140
+ " return \"\\n\".join(parts)\n",
141
+ "\n",
142
+ "print(\"Helpers loaded!\")"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {},
148
+ "source": [
149
+ "## 4. Run a Demo Episode\n",
150
+ "\n",
151
+ "Watch the auto-pilot run a full recruiting episode."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "def run_demo_episode(seed=42):\n",
161
+ " \"\"\"Run one full episode with scripted topic choices.\"\"\"\n",
162
+ " state = env_reset(seed=seed)\n",
163
+ " obs = state[\"observation\"]\n",
164
+ " total_reward = 0.0\n",
165
+ " print(f\"=== Driver: {obs['driver_name']} ===\")\n",
166
+ "\n",
167
+ " # Read CRM\n",
168
+ " state = env_step(\"crm\", \"read_candidate\")\n",
169
+ " total_reward += state[\"reward\"]\n",
170
+ " obs = state[\"observation\"]\n",
171
+ " print(f\"\\nJobs available:\\n{obs['jobs_summary'][:200]}...\")\n",
172
+ "\n",
173
+ " # Greet\n",
174
+ " state = env_step(\"messaging\", \"send_message\", topic=\"greeting\")\n",
175
+ " total_reward += state[\"reward\"]\n",
176
+ " print(f\"\\nGreeting reward: {state['reward']}\")\n",
177
+ "\n",
178
+ " state = env_step(\"messaging\", \"read_reply\")\n",
179
+ " total_reward += state[\"reward\"]\n",
180
+ " obs = state[\"observation\"]\n",
181
+ "\n",
182
+ " state = env_step(\"crm\", \"update_stage\", stage=\"contacted\")\n",
183
+ " total_reward += state[\"reward\"]\n",
184
+ "\n",
185
+ " # Screen\n",
186
+ " for topic in [\"experience\", \"deal_breakers\", \"pay\", \"home_time\"]:\n",
187
+ " if state.get(\"done\"): break\n",
188
+ " state = env_step(\"messaging\", \"send_message\", topic=topic)\n",
189
+ " total_reward += state[\"reward\"]\n",
190
+ " state = env_step(\"messaging\", \"read_reply\")\n",
191
+ " total_reward += state[\"reward\"]\n",
192
+ " obs = state[\"observation\"]\n",
193
+ " print(f\" {topic}: reward={state['reward']:.1f}\")\n",
194
+ "\n",
195
+ " print(f\"\\nDiscovered:\\n{obs.get('discovered_info', 'none')[:300]}\")\n",
196
+ "\n",
197
+ " # Approval + hire\n",
198
+ " state = env_step(\"crm\", \"update_stage\", stage=\"interested\")\n",
199
+ " total_reward += state[\"reward\"]\n",
200
+ " state = env_step(\"approval\", \"request_approval\", job_id=0)\n",
201
+ " total_reward += state[\"reward\"]\n",
202
+ " state = env_step(\"workflow\", \"wait\")\n",
203
+ " total_reward += state[\"reward\"]\n",
204
+ " state = env_step(\"approval\", \"check_approval\")\n",
205
+ " total_reward += state[\"reward\"]\n",
206
+ " state = env_step(\"crm\", \"update_stage\", stage=\"approval_pending\")\n",
207
+ " total_reward += state[\"reward\"]\n",
208
+ " state = env_step(\"messaging\", \"send_message\", topic=\"offer\", job_id=0)\n",
209
+ " total_reward += state[\"reward\"]\n",
210
+ " state = env_step(\"messaging\", \"read_reply\")\n",
211
+ " total_reward += state[\"reward\"]\n",
212
+ " state = env_step(\"crm\", \"update_stage\", stage=\"offer_sent\")\n",
213
+ " total_reward += state[\"reward\"]\n",
214
+ " state = env_step(\"crm\", \"update_stage\", stage=\"hired\")\n",
215
+ " total_reward += state[\"reward\"]\n",
216
+ "\n",
217
+ " obs = state[\"observation\"]\n",
218
+ " print(f\"\\nFinal stage: {obs['stage']}\")\n",
219
+ " print(f\"Total reward: {total_reward:.1f}\")\n",
220
+ " print(f\"Done: {state.get('done')}\")\n",
221
+ " return total_reward\n",
222
+ "\n",
223
+ "run_demo_episode()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "metadata": {},
229
+ "source": [
230
+ "## 5. Load Model"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
240
+ "TEMPERATURE = 1.5\n",
241
+ "MAX_NEW_TOKENS = 32\n",
242
+ "MAX_TOPICS = 8\n",
243
+ "\n",
244
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
245
+ "if tokenizer.pad_token_id is None:\n",
246
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
247
+ "\n",
248
+ "model = AutoModelForCausalLM.from_pretrained(\n",
249
+ " MODEL_NAME,\n",
250
+ " torch_dtype=torch.bfloat16,\n",
251
+ " device_map=\"auto\",\n",
252
+ ")\n",
253
+ "model.gradient_checkpointing_enable()\n",
254
+ "\n",
255
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)\n",
256
+ "device = next(model.parameters()).device\n",
257
+ "print(f\"Model loaded on {device}\")"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "metadata": {},
263
+ "source": [
264
+ "## 6. Training Loop — REINFORCE with Auto-Pilot\n",
265
+ "\n",
266
+ "The model only picks screening topics (1-5 tokens per decision).\n",
267
+ "The auto-pilot handles CRM, stages, approval, and hiring.\n",
268
+ "Rewards come from the full episode outcome."
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "def rollout_episode(model, tokenizer, device, seed=None):\n",
278
+ " \"\"\"Run one auto-piloted episode. Model picks topics, wrapper does the rest.\"\"\"\n",
279
+ " if seed is None:\n",
280
+ " seed = random.randint(0, 2**31 - 1)\n",
281
+ "\n",
282
+ " state = env_reset(seed=seed)\n",
283
+ " obs = state[\"observation\"]\n",
284
+ " total_reward = 0.0\n",
285
+ "\n",
286
+ " # Auto: read CRM\n",
287
+ " state = env_step(\"crm\", \"read_candidate\")\n",
288
+ " total_reward += state[\"reward\"]\n",
289
+ " obs = state[\"observation\"]\n",
290
+ "\n",
291
+ " if state.get(\"done\"):\n",
292
+ " return None\n",
293
+ "\n",
294
+ " turn_data = []\n",
295
+ " asked = []\n",
296
+ " contacted = False\n",
297
+ "\n",
298
+ " for _ in range(MAX_TOPICS):\n",
299
+ " if state.get(\"done\"):\n",
300
+ " break\n",
301
+ "\n",
302
+ " # Build prompt\n",
303
+ " obs_text = build_prompt(obs, asked)\n",
304
+ " messages = [\n",
305
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
306
+ " {\"role\": \"user\", \"content\": obs_text},\n",
307
+ " ]\n",
308
+ " prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
309
+ " input_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
310
+ "\n",
311
+ " # Model generates topic\n",
312
+ " with torch.no_grad():\n",
313
+ " output = model.generate(\n",
314
+ " input_ids, max_new_tokens=MAX_NEW_TOKENS,\n",
315
+ " do_sample=True, temperature=TEMPERATURE,\n",
316
+ " pad_token_id=tokenizer.pad_token_id,\n",
317
+ " )\n",
318
+ " gen_ids = output[0, input_ids.shape[1]:].tolist()\n",
319
+ " response = tokenizer.decode(gen_ids, skip_special_tokens=True)\n",
320
+ " topic = parse_topic(response)\n",
321
+ "\n",
322
+ " turn_data.append({\n",
323
+ " \"prompt_ids\": input_ids[0].tolist(),\n",
324
+ " \"gen_ids\": gen_ids,\n",
325
+ " \"topic\": topic,\n",
326
+ " \"turn_reward\": 0.0,\n",
327
+ " })\n",
328
+ "\n",
329
+ " if topic == \"done\":\n",
330
+ " break\n",
331
+ " if topic in asked:\n",
332
+ " total_reward -= 0.5\n",
333
+ " turn_data[-1][\"turn_reward\"] = -0.5\n",
334
+ " asked.append(topic)\n",
335
+ " continue\n",
336
+ "\n",
337
+ " asked.append(topic)\n",
338
+ "\n",
339
+ " # Auto: send_message + read_reply\n",
340
+ " state = env_step(\"messaging\", \"send_message\", topic=topic)\n",
341
+ " turn_reward = state[\"reward\"]\n",
342
+ " total_reward += state[\"reward\"]\n",
343
+ " obs = state[\"observation\"]\n",
344
+ "\n",
345
+ " if not state.get(\"done\"):\n",
346
+ " state = env_step(\"messaging\", \"read_reply\")\n",
347
+ " turn_reward += state[\"reward\"]\n",
348
+ " total_reward += state[\"reward\"]\n",
349
+ " obs = state[\"observation\"]\n",
350
+ "\n",
351
+ " # Auto: update stage after contact\n",
352
+ " if topic in (\"greeting\", \"call\") and not contacted and not state.get(\"done\"):\n",
353
+ " contacted = True\n",
354
+ " state = env_step(\"crm\", \"update_stage\", stage=\"contacted\")\n",
355
+ " turn_reward += state[\"reward\"]\n",
356
+ " total_reward += state[\"reward\"]\n",
357
+ " obs = state[\"observation\"]\n",
358
+ "\n",
359
+ " turn_data[-1][\"turn_reward\"] = turn_reward\n",
360
+ "\n",
361
+ " if not turn_data:\n",
362
+ " return None\n",
363
+ "\n",
364
+ " # Auto: approval + offer + hire\n",
365
+ " if not state.get(\"done\") and contacted:\n",
366
+ " for action_spec in [\n",
367
+ " (\"crm\", \"update_stage\", {\"stage\": \"interested\"}),\n",
368
+ " (\"approval\", \"request_approval\", {\"job_id\": 0}),\n",
369
+ " (\"workflow\", \"wait\", {}),\n",
370
+ " (\"approval\", \"check_approval\", {}),\n",
371
+ " (\"crm\", \"update_stage\", {\"stage\": \"approval_pending\"}),\n",
372
+ " (\"messaging\", \"send_message\", {\"topic\": \"offer\", \"job_id\": 0}),\n",
373
+ " (\"messaging\", \"read_reply\", {}),\n",
374
+ " (\"crm\", \"update_stage\", {\"stage\": \"offer_sent\"}),\n",
375
+ " (\"crm\", \"update_stage\", {\"stage\": \"hired\"}),\n",
376
+ " ]:\n",
377
+ " if state.get(\"done\"): break\n",
378
+ " state = env_step(action_spec[0], action_spec[1], **action_spec[2])\n",
379
+ " total_reward += state[\"reward\"]\n",
380
+ "\n",
381
+ " # Sample one turn for training\n",
382
+ " t = random.randrange(len(turn_data))\n",
383
+ " td = turn_data[t]\n",
384
+ "\n",
385
+ " return {\n",
386
+ " \"prompt_ids\": td[\"prompt_ids\"],\n",
387
+ " \"gen_ids\": td[\"gen_ids\"][:MAX_NEW_TOKENS],\n",
388
+ " \"reward\": total_reward,\n",
389
+ " \"stage\": obs.get(\"stage\", \"unknown\"),\n",
390
+ " \"topic\": td[\"topic\"],\n",
391
+ " \"num_topics\": len(asked),\n",
392
+ " }\n",
393
+ "\n",
394
+ "# Quick test\n",
395
+ "ep = rollout_episode(model, tokenizer, device)\n",
396
+ "if ep:\n",
397
+ " print(f\"Topic chosen: {ep['topic']}, Reward: {ep['reward']:.1f}, Stage: {ep['stage']}, Topics asked: {ep['num_topics']}\")"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "# --- REINFORCE Training Loop ---\n",
407
+ "BATCH_SIZE = 4\n",
408
+ "NUM_STEPS = 50\n",
409
+ "\n",
410
+ "print(f\"Training for {NUM_STEPS} steps, batch size {BATCH_SIZE}\")\n",
411
+ "print(\"=\" * 60)\n",
412
+ "\n",
413
+ "history = {\"loss\": [], \"reward_mean\": [], \"reward_std\": [], \"grad_norm\": []}\n",
414
+ "\n",
415
+ "for step in range(1, NUM_STEPS + 1):\n",
416
+ " # --- Rollout (no gradients) ---\n",
417
+ " model.eval()\n",
418
+ " episodes = []\n",
419
+ " for i in range(BATCH_SIZE):\n",
420
+ " ep = rollout_episode(model, tokenizer, device)\n",
421
+ " if ep and ep[\"gen_ids\"]:\n",
422
+ " episodes.append(ep)\n",
423
+ "\n",
424
+ " if len(episodes) < 2:\n",
425
+ " print(f\"Step {step}: not enough episodes, skipping\")\n",
426
+ " continue\n",
427
+ "\n",
428
+ " # --- Batch-level advantages ---\n",
429
+ " rewards = [ep[\"reward\"] for ep in episodes]\n",
430
+ " mean_r = sum(rewards) / len(rewards)\n",
431
+ " std_r = max(torch.tensor(rewards).std().item(), 1e-4)\n",
432
+ " advantages = [(r - mean_r) / std_r for r in rewards]\n",
433
+ "\n",
434
+ " # --- REINFORCE update ---\n",
435
+ " model.train()\n",
436
+ " optimizer.zero_grad()\n",
437
+ " total_loss = 0.0\n",
438
+ "\n",
439
+ " for ep, adv in zip(episodes, advantages):\n",
440
+ " input_ids = torch.tensor(\n",
441
+ " [ep[\"prompt_ids\"] + ep[\"gen_ids\"]], device=device\n",
442
+ " )\n",
443
+ " prompt_len = len(ep[\"prompt_ids\"])\n",
444
+ " comp_len = len(ep[\"gen_ids\"])\n",
445
+ " if comp_len == 0:\n",
446
+ " continue\n",
447
+ "\n",
448
+ " outputs = model(input_ids)\n",
449
+ " logits = outputs.logits[0, prompt_len - 1 : prompt_len + comp_len - 1]\n",
450
+ " targets = input_ids[0, prompt_len : prompt_len + comp_len]\n",
451
+ " log_probs = F.log_softmax(logits, dim=-1)\n",
452
+ " token_lps = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n",
453
+ "\n",
454
+ " loss = -(adv * token_lps.sum()) / len(episodes)\n",
455
+ " loss.backward()\n",
456
+ " total_loss += loss.item()\n",
457
+ "\n",
458
+ " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()\n",
459
+ " optimizer.step()\n",
460
+ "\n",
461
+ " # --- Log ---\n",
462
+ " history[\"loss\"].append(total_loss)\n",
463
+ " history[\"reward_mean\"].append(mean_r)\n",
464
+ " history[\"reward_std\"].append(std_r)\n",
465
+ " history[\"grad_norm\"].append(grad_norm)\n",
466
+ "\n",
467
+ " topics = [ep[\"topic\"] for ep in episodes]\n",
468
+ " print(f\"Step {step:3d} | loss={total_loss:+.3f} | reward={mean_r:+.1f}±{std_r:.1f} | \"\n",
469
+ " f\"grad={grad_norm:.3f} | topics={topics}\")\n",
470
+ "\n",
471
+ "print(\"\\nTraining complete!\")"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "metadata": {},
477
+ "source": [
478
+ "## 7. Plot Training Curves"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "import matplotlib.pyplot as plt\n",
488
+ "\n",
489
+ "fig, axes = plt.subplots(2, 2, figsize=(12, 8))\n",
490
+ "fig.suptitle(\"Driver Recruit RL Training\", fontsize=14)\n",
491
+ "\n",
492
+ "axes[0, 0].plot(history[\"loss\"])\n",
493
+ "axes[0, 0].set_title(\"Loss\")\n",
494
+ "axes[0, 0].set_xlabel(\"Step\")\n",
495
+ "\n",
496
+ "axes[0, 1].plot(history[\"reward_mean\"])\n",
497
+ "axes[0, 1].set_title(\"Mean Reward\")\n",
498
+ "axes[0, 1].set_xlabel(\"Step\")\n",
499
+ "\n",
500
+ "axes[1, 0].plot(history[\"reward_std\"])\n",
501
+ "axes[1, 0].set_title(\"Reward Std\")\n",
502
+ "axes[1, 0].set_xlabel(\"Step\")\n",
503
+ "\n",
504
+ "axes[1, 1].plot(history[\"grad_norm\"])\n",
505
+ "axes[1, 1].set_title(\"Gradient Norm\")\n",
506
+ "axes[1, 1].set_xlabel(\"Step\")\n",
507
+ "\n",
508
+ "plt.tight_layout()\n",
509
+ "plt.show()"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": [
516
+ "## 8. Test the Trained Model"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "code",
521
+ "execution_count": null,
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "print(\"=== Testing trained model ===\")\n",
526
+ "model.eval()\n",
527
+ "test_rewards = []\n",
528
+ "for i in range(5):\n",
529
+ " ep = rollout_episode(model, tokenizer, device)\n",
530
+ " if ep:\n",
531
+ " test_rewards.append(ep[\"reward\"])\n",
532
+ " print(f\" Episode {i+1}: reward={ep['reward']:.1f}, stage={ep['stage']}, \"\n",
533
+ " f\"topics={ep['num_topics']}, chose={ep['topic']}\")\n",
534
+ "\n",
535
+ "if test_rewards:\n",
536
+ " print(f\"\\nMean test reward: {sum(test_rewards)/len(test_rewards):.1f}\")"
537
+ ]
538
+ }
539
+ ],
540
+ "metadata": {
541
+ "kernelspec": {
542
+ "display_name": "Python 3",
543
+ "language": "python",
544
+ "name": "python3"
545
+ },
546
+ "language_info": {
547
+ "name": "python",
548
+ "version": "3.10.0"
549
+ },
550
+ "accelerator": "GPU",
551
+ "colab": {
552
+ "gpuType": "T4",
553
+ "provenance": []
554
+ }
555
+ },
556
+ "nbformat": 4,
557
+ "nbformat_minor": 4
558
+ }
train_grpo.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO training script for the Driver Recruit Environment.
3
+
4
+ Uses TRL's GRPOTrainer with rollout_func for multi-turn episodes.
5
+ The model controls EVERY action in the episode via tool calls.
6
+
7
+ Usage:
8
+ python train_grpo.py --model Qwen/Qwen2.5-3B-Instruct --use-qlora
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import random
14
+
15
+ from datasets import Dataset
16
+ from transformers import AutoTokenizer, BitsAndBytesConfig
17
+ import torch
18
+
19
+ from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
20
+ from trl import GRPOConfig, GRPOTrainer
21
+ from trl.experimental.openenv import generate_rollout_completions
22
+
23
+ # --- Prompt templates ---
24
+
25
+ SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
26
+
27
+ You have 4 tools:
28
+
29
+ ## crm
30
+ - read_candidate: Read the current CRM record
31
+ - update_stage: Advance pipeline (contacted → interested → approval_pending → offer_sent → hired)
32
+ - update_field: Record info (field + value)
33
+ - add_note: Add a free-text note
34
+
35
+ ## messaging
36
+ - send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
37
+ - read_reply: Read the driver's response
38
+
39
+ ## approval
40
+ - request_approval: Request approval for a job (needs job_id)
41
+ - check_approval: Check approval status
42
+
43
+ ## workflow
44
+ - wait: Advance time (needed for approval processing)
45
+
46
+ ## Rules
47
+ - Must read CRM before messaging
48
+ - Must read_reply before sending another message
49
+ - Must request_approval and wait before sending offer
50
+ - Must follow stage order: lead → contacted → interested → approval_pending → offer_sent → hired
51
+ - Record important info in CRM with update_field
52
+ - Too many messages hurt trust
53
+
54
+ ## Workflow
55
+ 1. crm.read_candidate
56
+ 2. messaging.send_message (greeting/call) → read_reply → update_stage(contacted)
57
+ 3. messaging.send_message (screening topics) → read_reply → crm.update_field
58
+ 4. crm.update_stage(interested)
59
+ 5. approval.request_approval → workflow.wait → approval.check_approval
60
+ 6. crm.update_stage(approval_pending)
61
+ 7. messaging.send_message(offer) → read_reply
62
+ 8. crm.update_stage(offer_sent) → crm.update_stage(hired)
63
+
64
+ Respond with ONLY JSON:
65
+ {"tool": "crm", "action": "read_candidate"}
66
+ {"tool": "messaging", "action": "send_message", "topic": "experience"}
67
+ {"tool": "messaging", "action": "read_reply"}
68
+ {"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
69
+ {"tool": "crm", "action": "update_stage", "stage": "contacted"}
70
+ {"tool": "approval", "action": "request_approval", "job_id": 2}
71
+ {"tool": "workflow", "action": "wait"}
72
+ {"tool": "approval", "action": "check_approval"}
73
+ {"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
74
+ {"tool": "crm", "action": "update_stage", "stage": "hired"}"""
75
+
76
+
77
+ def format_observation(obs):
78
+ """Format observation into a user prompt for the LLM."""
79
+ parts = [f"Driver: {obs.driver_name}"]
80
+ if obs.crm_summary:
81
+ parts.append(f"CRM:\n{obs.crm_summary}")
82
+ if obs.jobs_summary:
83
+ parts.append(f"Jobs:\n{obs.jobs_summary}")
84
+ if obs.discovered_info:
85
+ parts.append(f"Discovered:\n{obs.discovered_info}")
86
+ status = f"Stage: {obs.stage}"
87
+ if obs.pending_reply:
88
+ status += " | PENDING REPLY"
89
+ parts.append(status)
90
+ if obs.feedback:
91
+ parts.append(f"Result: {obs.feedback}")
92
+ return "\n".join(parts)
93
+
94
+
95
+ def format_observation_compact(obs):
96
+ """Compact observation for embedding in completion_ids (~30-60 tokens)."""
97
+ parts = [f"Stage: {obs.stage}"]
98
+ if obs.pending_reply:
99
+ parts.append("PENDING REPLY")
100
+ if obs.feedback:
101
+ parts.append(obs.feedback[:200])
102
+ if obs.discovered_info:
103
+ parts.append(obs.discovered_info[:200])
104
+ return "\n".join(parts)
105
+
106
+
107
+ def parse_action(text):
108
+ """Parse LLM output into a RecruitopenenvAction."""
109
+ text = text.strip()
110
+
111
+ # Remove markdown fences
112
+ if "```" in text:
113
+ for part in text.split("```"):
114
+ part = part.strip()
115
+ if part.startswith("json"):
116
+ part = part[4:].strip()
117
+ if part.startswith("{"):
118
+ text = part
119
+ break
120
+
121
+ # Try JSON
122
+ try:
123
+ data = json.loads(text)
124
+ if isinstance(data, list):
125
+ data = data[0] if data else {}
126
+ if isinstance(data, dict) and "tool" in data and "action" in data:
127
+ return RecruitopenenvAction(
128
+ tool=data["tool"],
129
+ action=data["action"],
130
+ topic=data.get("topic", ""),
131
+ job_id=int(data.get("job_id", -1)),
132
+ stage=str(data.get("stage", "")),
133
+ field=str(data.get("field", "")),
134
+ value=str(data.get("value", "")),
135
+ )
136
+ except (json.JSONDecodeError, KeyError, IndexError, ValueError, TypeError):
137
+ pass
138
+
139
+ # Fallback: try to detect intent
140
+ text_lower = text.lower()
141
+ if "read_candidate" in text_lower:
142
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
143
+ if "read_reply" in text_lower:
144
+ return RecruitopenenvAction(tool="messaging", action="read_reply")
145
+ if "check_approval" in text_lower:
146
+ return RecruitopenenvAction(tool="approval", action="check_approval")
147
+ if "wait" in text_lower:
148
+ return RecruitopenenvAction(tool="workflow", action="wait")
149
+
150
+ # Default to reading CRM
151
+ return RecruitopenenvAction(tool="crm", action="read_candidate")
152
+
153
+
154
+ # --- Multi-turn rollout ---
155
+
156
+ ENV_URL = "http://localhost:8001"
157
+ MAX_COMPLETION_TOKENS = 1536
158
+
159
+
160
+ def _build_chat_transition(tokenizer, obs_text):
161
+ """Build chat-formatted transition tokens: end assistant turn, user obs, start assistant.
162
+
163
+ Result: <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
164
+ This ensures the model sees proper chat structure during the forward pass.
165
+ """
166
+ im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
167
+ im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
168
+
169
+ # Encode role tags and newlines
170
+ nl = tokenizer.encode("\n", add_special_tokens=False)
171
+ user_tag = tokenizer.encode("user", add_special_tokens=False)
172
+ asst_tag = tokenizer.encode("assistant", add_special_tokens=False)
173
+ obs_ids = tokenizer.encode(obs_text, add_special_tokens=False)[:60]
174
+
175
+ # <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
176
+ return (
177
+ [im_end] + nl +
178
+ [im_start] + user_tag + nl +
179
+ obs_ids +
180
+ [im_end] + nl +
181
+ [im_start] + asst_tag + nl
182
+ )
183
+
184
+
185
+ def rollout_once(trainer, env, tokenizer, prompt_text, system_prompt, max_turns=15):
186
+ """Run one multi-turn episode with chat-formatted transitions.
187
+
188
+ completion_ids: [action1, <|im_end|>user obs<|im_start|>assistant, action2, ...]
189
+ The chat template structure lets the forward pass assign proper logprobs.
190
+ """
191
+ seed = random.randint(0, 2**31 - 1)
192
+ result = env.reset(seed=seed)
193
+ obs = result.observation
194
+
195
+ prompt_ids = []
196
+ completion_ids = []
197
+ logprobs = []
198
+ env_mask = []
199
+ total_reward = 0.0
200
+ steps = 0
201
+
202
+ messages = [
203
+ {"role": "system", "content": system_prompt},
204
+ {"role": "user", "content": format_observation(obs)},
205
+ ]
206
+
207
+ while not result.done and steps < max_turns:
208
+ # Check if we're near the token budget (need room for action + transition)
209
+ if len(completion_ids) > MAX_COMPLETION_TOKENS - 60:
210
+ break
211
+
212
+ current_prompt = tokenizer.apply_chat_template(
213
+ messages, add_generation_prompt=True, tokenize=False
214
+ )
215
+
216
+ rollout_outputs = generate_rollout_completions(trainer, [current_prompt])[0]
217
+
218
+ if steps == 0:
219
+ prompt_ids = list(rollout_outputs["prompt_ids"])
220
+
221
+ action_ids = list(rollout_outputs["completion_ids"])
222
+ action_logprobs = list(rollout_outputs["logprobs"])
223
+
224
+ # Add action tokens (these get gradients)
225
+ completion_ids.extend(action_ids)
226
+ logprobs.extend(action_logprobs)
227
+ env_mask.extend([1] * len(action_ids))
228
+
229
+ response = rollout_outputs.get("text") or tokenizer.decode(
230
+ action_ids, skip_special_tokens=True
231
+ )
232
+ messages.append({"role": "assistant", "content": response})
233
+
234
+ action = parse_action(response)
235
+ result = env.step(action)
236
+ obs = result.observation
237
+ total_reward += result.reward
238
+ steps += 1
239
+
240
+ if not result.done:
241
+ # Build chat-formatted transition so forward pass sees proper structure
242
+ obs_text = format_observation_compact(obs)
243
+ transition_ids = _build_chat_transition(tokenizer, obs_text)
244
+
245
+ completion_ids.extend(transition_ids)
246
+ logprobs.extend([0.0] * len(transition_ids))
247
+ env_mask.extend([0] * len(transition_ids))
248
+
249
+ messages.append({"role": "user", "content": format_observation(obs)})
250
+
251
+ # Truncate to fit max_completion_length
252
+ completion_ids = completion_ids[:MAX_COMPLETION_TOKENS]
253
+ logprobs = logprobs[:MAX_COMPLETION_TOKENS]
254
+ env_mask = env_mask[:MAX_COMPLETION_TOKENS]
255
+
256
+ return {
257
+ "prompt_ids": prompt_ids,
258
+ "completion_ids": completion_ids,
259
+ "logprobs": logprobs,
260
+ "env_mask": env_mask,
261
+ "env_reward": total_reward,
262
+ "steps": steps,
263
+ "final_stage": obs.stage,
264
+ }
265
+
266
+
267
+ def rollout_func(prompts, trainer):
268
+ """Multi-turn rollout: model controls every action in the episode."""
269
+ tokenizer = trainer.processing_class
270
+ env = RecruitopenenvEnv(base_url=ENV_URL)
271
+
272
+ all_prompt_ids = []
273
+ all_completion_ids = []
274
+ all_logprobs = []
275
+ all_env_rewards = []
276
+ all_env_mask = []
277
+
278
+ for prompt_text in prompts:
279
+ episode = rollout_once(trainer, env, tokenizer, prompt_text, SYSTEM_PROMPT)
280
+
281
+ if episode["completion_ids"]:
282
+ all_prompt_ids.append(episode["prompt_ids"])
283
+ all_completion_ids.append(episode["completion_ids"])
284
+ all_logprobs.append(episode["logprobs"])
285
+ all_env_mask.append(episode["env_mask"])
286
+ else:
287
+ tok_ids = tokenizer.encode("wait", add_special_tokens=False)
288
+ all_prompt_ids.append(episode["prompt_ids"] or tok_ids)
289
+ all_completion_ids.append(tok_ids)
290
+ all_logprobs.append([0.0] * len(tok_ids))
291
+ all_env_mask.append([1] * len(tok_ids))
292
+
293
+ all_env_rewards.append(episode["env_reward"])
294
+ print(f" Episode {len(all_env_rewards)}: reward={episode['env_reward']:.1f}, "
295
+ f"steps={episode['steps']}, stage={episode['final_stage']}")
296
+
297
+ env.close()
298
+
299
+ mean_r = sum(all_env_rewards) / len(all_env_rewards)
300
+ std_r = torch.tensor(all_env_rewards).std().item()
301
+ print(f"Rollout done: {len(all_env_rewards)} episodes, mean_reward={mean_r:.2f}, std={std_r:.2f}")
302
+
303
+ return {
304
+ "prompt_ids": all_prompt_ids,
305
+ "completion_ids": all_completion_ids,
306
+ "logprobs": [[(lp,) for lp in seq] for seq in all_logprobs],
307
+ "env_reward": all_env_rewards,
308
+ "env_mask": all_env_mask,
309
+ }
310
+
311
+
312
+ # --- Reward function (fallback, rewards come from rollout) ---
313
+
314
+ def reward_total(completions, **kwargs):
315
+ """Extract environment rewards passed via rollout_func kwargs."""
316
+ env_rewards = kwargs.get("env_reward", [])
317
+ if env_rewards:
318
+ return [float(r) for r in env_rewards]
319
+ return [0.0] * len(completions)
320
+
321
+
322
+ # --- Main ---
323
+
324
+ def main():
325
+ parser = argparse.ArgumentParser(description="GRPO training for Driver Recruit Environment")
326
+ parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Model to train")
327
+ parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL")
328
+ parser.add_argument("--num-episodes", type=int, default=16, help="Number of training episodes (dataset size)")
329
+ parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt")
330
+ parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
331
+ parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
332
+ parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
333
+ parser.add_argument("--output-dir", default="./recruit-grpo-output", help="Output directory")
334
+ parser.add_argument("--vllm-mode", default="colocate", choices=["colocate", "server"],
335
+ help="vLLM mode: colocate (1 GPU) or server (2+ GPUs)")
336
+ parser.add_argument("--use-qlora", action="store_true", help="Use QLoRA (4-bit) for memory efficiency")
337
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
338
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
339
+ args = parser.parse_args()
340
+
341
+ global ENV_URL
342
+ ENV_URL = args.env_url
343
+
344
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
345
+
346
+ prompts = []
347
+ env = RecruitopenenvEnv(base_url=args.env_url)
348
+ for i in range(args.num_episodes):
349
+ result = env.reset()
350
+ obs = result.observation
351
+ user_prompt = format_observation(obs)
352
+ messages = [
353
+ {"role": "system", "content": SYSTEM_PROMPT},
354
+ {"role": "user", "content": user_prompt},
355
+ ]
356
+ prompt_text = tokenizer.apply_chat_template(
357
+ messages, add_generation_prompt=True, tokenize=False
358
+ )
359
+ prompts.append(prompt_text)
360
+ env.close()
361
+
362
+ dataset = Dataset.from_dict({"prompt": prompts})
363
+
364
+ peft_config = None
365
+ model_kwargs = {}
366
+ if args.use_qlora:
367
+ from peft import LoraConfig
368
+ peft_config = LoraConfig(
369
+ r=args.lora_r,
370
+ lora_alpha=args.lora_alpha,
371
+ lora_dropout=0.05,
372
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
373
+ "gate_proj", "up_proj", "down_proj"],
374
+ task_type="CAUSAL_LM",
375
+ )
376
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
377
+ load_in_4bit=True,
378
+ bnb_4bit_compute_dtype=torch.bfloat16,
379
+ bnb_4bit_quant_type="nf4",
380
+ )
381
+ print(f"Using QLoRA: r={args.lora_r}, alpha={args.lora_alpha}, 4-bit")
382
+
383
+ grpo_config = GRPOConfig(
384
+ output_dir=args.output_dir,
385
+ use_vllm=True,
386
+ vllm_mode=args.vllm_mode,
387
+ num_train_epochs=args.epochs,
388
+ num_generations=args.num_generations,
389
+ max_completion_length=1536,
390
+ per_device_train_batch_size=args.batch_size,
391
+ gradient_accumulation_steps=4,
392
+ gradient_checkpointing=True,
393
+ learning_rate=args.lr,
394
+ temperature=0.7,
395
+ logging_steps=1,
396
+ save_steps=50,
397
+ bf16=True,
398
+ report_to="wandb",
399
+ run_name="recruit-grpo-tools",
400
+ model_init_kwargs=model_kwargs if model_kwargs else None,
401
+ )
402
+
403
+ trainer_kwargs = dict(
404
+ model=args.model,
405
+ processing_class=tokenizer,
406
+ reward_funcs=[reward_total],
407
+ train_dataset=dataset,
408
+ args=grpo_config,
409
+ rollout_func=rollout_func,
410
+ )
411
+ if peft_config is not None:
412
+ trainer_kwargs["peft_config"] = peft_config
413
+
414
+ trainer = GRPOTrainer(**trainer_kwargs)
415
+
416
+ print("=" * 50)
417
+ print(f"Training {args.model} (TOOL-BASED MULTI-TURN)")
418
+ print(f"Environment: {args.env_url}")
419
+ print(f"QLoRA: {args.use_qlora}")
420
+ print(f"Episodes: {args.num_episodes}")
421
+ print(f"Epochs: {args.epochs}")
422
+ print(f"Generations per prompt: {args.num_generations}")
423
+ print("=" * 50)
424
+
425
+ trainer.train()
426
+ trainer.save_model(args.output_dir)
427
+ print(f"\nModel saved to {args.output_dir}")
428
+
429
+
430
+ if __name__ == "__main__":
431
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff