Spaces:
Runtime error
Runtime error
| """LLM agent baseline β test how well a base model performs without RL training.""" | |
| import argparse | |
| import json | |
| import requests | |
| from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction | |
| SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them. | |
| You have 4 tools: | |
| ## crm | |
| - read_candidate: Read the current CRM record | |
| - update_stage: Advance pipeline (contacted β interested β approval_pending β offer_sent β hired) | |
| - update_field: Record info (field + value) | |
| - add_note: Add a free-text note | |
| ## messaging | |
| - send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern) | |
| - read_reply: Read the driver's response | |
| ## approval | |
| - request_approval: Request approval for a job (needs job_id) | |
| - check_approval: Check approval status | |
| ## workflow | |
| - wait: Advance time (needed for approval processing) | |
| ## Rules | |
| - Must read CRM before messaging | |
| - Must read_reply before sending another message | |
| - Must request_approval and wait before sending offer | |
| - Must follow stage order: lead β contacted β interested β approval_pending β offer_sent β hired | |
| - Record important info in CRM with update_field | |
| - Too many messages hurt trust | |
| ## Strategy | |
| 1. crm.read_candidate β see the lead | |
| 2. messaging.send_message(greeting or call) β messaging.read_reply β crm.update_stage(contacted) | |
| 3. Screen: send_message(experience) β read_reply β update_field(cdl_class, value) ... repeat for key questions | |
| 4. crm.update_stage(interested) | |
| 5. approval.request_approval(job_id) β workflow.wait β approval.check_approval | |
| 6. crm.update_stage(approval_pending) | |
| 7. messaging.send_message(offer) β messaging.read_reply | |
| 8. crm.update_stage(offer_sent) β crm.update_stage(hired) | |
| Tips: | |
| - ask_experience is critical (CDL class filters jobs) | |
| - ask_deal_breakers helps avoid trap jobs | |
| - ask_violations and ask_medical_card reveal fatal blockers | |
| - If driver has concerns about offer, use negotiate_pay/negotiate_home_time/address_concern | |
| - If no good match exists, update_stage to lost | |
| Respond with ONLY JSON: | |
| {"tool": "crm", "action": "read_candidate"} | |
| {"tool": "messaging", "action": "send_message", "topic": "experience"} | |
| {"tool": "messaging", "action": "read_reply"} | |
| {"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"} | |
| {"tool": "approval", "action": "request_approval", "job_id": 2} | |
| {"tool": "crm", "action": "update_stage", "stage": "hired"}""" | |
| def format_observation(obs): | |
| parts = [f"Driver: {obs.driver_name}"] | |
| if obs.crm_summary: | |
| parts.append(f"CRM:\n{obs.crm_summary}") | |
| if obs.jobs_summary: | |
| parts.append(f"Jobs:\n{obs.jobs_summary}") | |
| if obs.discovered_info: | |
| parts.append(f"Discovered:\n{obs.discovered_info}") | |
| status = f"Stage: {obs.stage}" | |
| if obs.pending_reply: | |
| status += " | PENDING REPLY" | |
| parts.append(status) | |
| if obs.feedback: | |
| parts.append(f"Result: {obs.feedback}") | |
| return "\n".join(parts) | |
| def ask_llm(messages, llm_url, model): | |
| resp = requests.post(llm_url, json={ | |
| "model": model, | |
| "messages": messages, | |
| "temperature": 0.1, | |
| "max_tokens": 150, | |
| }) | |
| content = resp.json()["choices"][0]["message"]["content"] | |
| return content | |
| def parse_action(text): | |
| """Try to extract action from LLM response.""" | |
| text = text.strip() | |
| # Remove markdown code fences | |
| if "```" in text: | |
| parts = text.split("```") | |
| for part in parts: | |
| part = part.strip() | |
| if part.startswith("json"): | |
| part = part[4:].strip() | |
| if part.startswith("{"): | |
| text = part | |
| break | |
| # Try JSON parse | |
| try: | |
| data = json.loads(text) | |
| if isinstance(data, dict) and "tool" in data and "action" in data: | |
| return RecruitopenenvAction( | |
| tool=data["tool"], | |
| action=data["action"], | |
| topic=data.get("topic", ""), | |
| job_id=data.get("job_id", -1), | |
| stage=data.get("stage", ""), | |
| field=data.get("field", ""), | |
| value=data.get("value", ""), | |
| ) | |
| except (json.JSONDecodeError, KeyError): | |
| pass | |
| # Fallback | |
| text_lower = text.lower() | |
| if "read_candidate" in text_lower: | |
| return RecruitopenenvAction(tool="crm", action="read_candidate") | |
| if "read_reply" in text_lower: | |
| return RecruitopenenvAction(tool="messaging", action="read_reply") | |
| if "check_approval" in text_lower: | |
| return RecruitopenenvAction(tool="approval", action="check_approval") | |
| if "wait" in text_lower: | |
| return RecruitopenenvAction(tool="workflow", action="wait") | |
| return RecruitopenenvAction(tool="crm", action="read_candidate") | |
| def run_baseline(env_url, llm_url, model, num_episodes): | |
| rewards = [] | |
| successes = 0 | |
| total_steps = 0 | |
| env = RecruitopenenvEnv(base_url=env_url) | |
| for ep in range(num_episodes): | |
| result = env.reset() | |
| obs = result.observation | |
| ep_reward = 0.0 | |
| steps = 0 | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| while not result.done and steps < 100: | |
| obs_text = format_observation(obs) | |
| messages.append({"role": "user", "content": obs_text}) | |
| llm_response = ask_llm(messages, llm_url, model) | |
| messages.append({"role": "assistant", "content": llm_response}) | |
| action = parse_action(llm_response) | |
| result = env.step(action) | |
| obs = result.observation | |
| ep_reward += result.reward | |
| steps += 1 | |
| print(f" Step {steps}: {action.tool}.{action.action}" | |
| f"{'(' + action.topic + ')' if action.topic else ''}" | |
| f"{'[job=' + str(action.job_id) + ']' if action.job_id >= 0 else ''}" | |
| f" -> reward={result.reward:.1f}") | |
| rewards.append(ep_reward) | |
| total_steps += steps | |
| if obs.stage == "hired": | |
| successes += 1 | |
| print(f"Episode {ep+1}: total_reward={ep_reward:.1f}, steps={steps}, " | |
| f"{'HIRED' if obs.stage == 'hired' else 'FAIL (' + obs.stage + ')'}") | |
| print() | |
| env.close() | |
| avg_reward = sum(rewards) / len(rewards) | |
| avg_steps = total_steps / num_episodes | |
| print("\n========== LLM BASELINE (no RL) ==========") | |
| print(f"Model: {model}") | |
| print(f"Episodes: {num_episodes}") | |
| print(f"Avg reward: {avg_reward:.2f}") | |
| print(f"Min reward: {min(rewards):.2f}") | |
| print(f"Max reward: {max(rewards):.2f}") | |
| print(f"Hire rate: {successes}/{num_episodes} ({100*successes/num_episodes:.1f}%)") | |
| print(f"Avg steps/episode: {avg_steps:.1f}") | |
| print("==========================================") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="LLM baseline for Driver Recruit Environment") | |
| parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL") | |
| parser.add_argument("--llm-url", default="http://localhost:8033/v1/chat/completions", help="LLM API URL") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct", help="Model name") | |
| parser.add_argument("--episodes", type=int, default=20, help="Number of episodes") | |
| args = parser.parse_args() | |
| run_baseline(args.env_url, args.llm_url, args.model, args.episodes) | |