Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| import random, csv, json | |
| from uuid import uuid4 | |
| # Use explicit relative or local imports | |
| from models import CustomerAction, CustomerObservation | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from openai import OpenAI | |
| local_llm = OpenAI(base_url="http://localhost:11434/v1", api_key="local-dev") | |
| MODEL_NAME = "llama3" | |
| class CustomerEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = False | |
| def __init__(self): | |
| """Initialize the Customer POMDP environment.""" | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._reset_count = 0 | |
| self.hidden_intent = "" | |
| self.persona = "" | |
| self.scenarios = [] | |
| # Fallback just in case the file is missing | |
| default_scenario = {"intent": "unknown", "persona": "neutral", "starting_utterance": "I need help."} | |
| self.conversation_history = "" | |
| try: | |
| with open("basic_scenarios.csv", mode="r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| self.scenarios.append(row) | |
| except Exception as e: | |
| print(f"Warning: Could not load scenarios.csv. {e}") | |
| self.scenarios.append(default_scenario) | |
| def reset(self) -> CustomerObservation: | |
| """Reset the environment, pick a new hidden intent and persona.""" | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._reset_count += 1 | |
| scenario = random.choice(self.scenarios) | |
| self.hidden_intent = scenario["intent"] | |
| self.persona = scenario["persona"] | |
| start_msg = scenario["starting_utterance"] | |
| self.conversation_history = f"System: Call connected.\nCustomer: {start_msg}" | |
| return CustomerObservation( | |
| customer_reply=start_msg, | |
| tool_response=None, | |
| conversation_history=self.conversation_history, | |
| done=False, | |
| reward=0.0, | |
| metadata={"step": self._state.step_count} | |
| ) | |
| def step(self, action: CustomerAction) -> CustomerObservation: | |
| self._state.step_count += 1 | |
| step_reward = 0.0 | |
| done = False | |
| tool_response = None | |
| customer_reply = None | |
| if action.action_type == "tool_call": | |
| tool_name = action.content | |
| # Mocking the database lookup for now | |
| if tool_name == "lookup_account": | |
| tool_response = "{'status': 'verified', 'balance': '$500'}" | |
| step_reward += 0.5 | |
| else: | |
| tool_response = f"Error: Tool '{tool_name}' not found." | |
| step_reward -= 0.5 | |
| self.conversation_history += f"\nAgent [Action]: Used {tool_name}" | |
| self.conversation_history += f"\nSystem: {tool_response}" | |
| elif action.action_type == "speak": | |
| self.conversation_history += f"\nAgent: {action.content}" | |
| # Call made to the LLM | |
| customer_reply = self._get_customer_reply(action.content) | |
| self.conversation_history += f"\nCustomer: {customer_reply}" | |
| step_reward -= 0.1 # Small penalty per turn to encourage efficiency | |
| elif action.action_type == "end_call": | |
| done = True | |
| if self._state.step_count >= 15: | |
| done = True | |
| # THE JUDGE LLM EVALUATION | |
| if done: | |
| final_score, reasoning = self._evaluate_with_judge() | |
| step_reward += final_score | |
| metadata = { | |
| "step": self._state.step_count, | |
| "hidden_intent": self.hidden_intent, | |
| "judge_reasoning": reasoning | |
| } | |
| else: | |
| metadata = {"step": self._state.step_count} | |
| return CustomerObservation( | |
| customer_reply=customer_reply, | |
| tool_response=tool_response, | |
| conversation_history=self.conversation_history, | |
| done=done, | |
| reward=step_reward, | |
| metadata=metadata | |
| ) | |
| def _evaluate_with_judge(self) -> tuple[float, str]: | |
| """ | |
| Uses local LLM as a Judge to score the final transcript. | |
| Returns a tuple of (score, reasoning). | |
| """ | |
| judge_prompt = f"""You are an expert QA Judge for a banking call center. | |
| Review the transcript and score the Agent's performance from -5.0 to +10.0. | |
| TRUE CUSTOMER INTENT: {self.hidden_intent} | |
| SCORING RUBRIC: | |
| - +10.0: Perfect. Intent captured, correct tools used, issue resolved efficiently. | |
| - +5.0: Okay. Found the intent but took too many turns or was awkward. | |
| - 0.0: Neutral. Didn't solve the issue but didn't hallucinate. | |
| - -5.0: Failure. Missed the intent, hallucinated tools, or was rude. | |
| TRANSCRIPT: | |
| {self.conversation_history} | |
| Respond ONLY with a valid JSON object in this exact format: | |
| {{"score": 8.5, "reasoning": "A brief explanation of why."}} | |
| """ | |
| try: | |
| response = local_llm.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": judge_prompt}], | |
| response_format={ "type": "json_object" }, | |
| temperature=0.0 | |
| ) | |
| result = json.loads(response.choices[0].message.content) | |
| score = float(result.get("score", 0.0)) | |
| reasoning = result.get("reasoning", "No reasoning provided.") | |
| # Clamp the score just in case the LLM goes rogue | |
| score = max(-5.0, min(10.0, score)) | |
| return score, reasoning | |
| except Exception as e: | |
| # Fallback if the local LLM fails to generate valid JSON | |
| print(f"Judge Error: {e}") | |
| return -2.0, "Judge LLM failed to parse transcript." | |
| def _get_customer_reply(self, agent_text: str) -> str: | |
| """Uses local LLM to simulate the customer.""" | |
| system_prompt = f"""You are a banking customer calling support. | |
| Your secret intent is: {self.hidden_intent}. | |
| Your mood is: {self.persona}. | |
| RULES: | |
| 1. Keep it under 2 sentences. | |
| 2. Do NOT reveal your full intent immediately. Wait for the agent to probe. | |
| 3. Respond naturally to what the agent just said. | |
| Conversation history: | |
| {self.conversation_history}""" | |
| response = local_llm.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": agent_text} | |
| ], | |
| temperature=0.7, | |
| max_tokens=60 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def state(self) -> State: | |
| return self._state |