RL-IVR / server /customer_env.py
hrajgarhia943's picture
Upload folder using huggingface_hub
7d80981 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import random, csv, json
from uuid import uuid4
# Use explicit relative or local imports
from models import CustomerAction, CustomerObservation
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from openai import OpenAI
local_llm = OpenAI(base_url="http://localhost:11434/v1", api_key="local-dev")
MODEL_NAME = "llama3"
class CustomerEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(self):
"""Initialize the Customer POMDP environment."""
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count = 0
self.hidden_intent = ""
self.persona = ""
self.scenarios = []
# Fallback just in case the file is missing
default_scenario = {"intent": "unknown", "persona": "neutral", "starting_utterance": "I need help."}
self.conversation_history = ""
try:
with open("basic_scenarios.csv", mode="r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
self.scenarios.append(row)
except Exception as e:
print(f"Warning: Could not load scenarios.csv. {e}")
self.scenarios.append(default_scenario)
def reset(self) -> CustomerObservation:
"""Reset the environment, pick a new hidden intent and persona."""
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count += 1
scenario = random.choice(self.scenarios)
self.hidden_intent = scenario["intent"]
self.persona = scenario["persona"]
start_msg = scenario["starting_utterance"]
self.conversation_history = f"System: Call connected.\nCustomer: {start_msg}"
return CustomerObservation(
customer_reply=start_msg,
tool_response=None,
conversation_history=self.conversation_history,
done=False,
reward=0.0,
metadata={"step": self._state.step_count}
)
def step(self, action: CustomerAction) -> CustomerObservation:
self._state.step_count += 1
step_reward = 0.0
done = False
tool_response = None
customer_reply = None
if action.action_type == "tool_call":
tool_name = action.content
# Mocking the database lookup for now
if tool_name == "lookup_account":
tool_response = "{'status': 'verified', 'balance': '$500'}"
step_reward += 0.5
else:
tool_response = f"Error: Tool '{tool_name}' not found."
step_reward -= 0.5
self.conversation_history += f"\nAgent [Action]: Used {tool_name}"
self.conversation_history += f"\nSystem: {tool_response}"
elif action.action_type == "speak":
self.conversation_history += f"\nAgent: {action.content}"
# Call made to the LLM
customer_reply = self._get_customer_reply(action.content)
self.conversation_history += f"\nCustomer: {customer_reply}"
step_reward -= 0.1 # Small penalty per turn to encourage efficiency
elif action.action_type == "end_call":
done = True
if self._state.step_count >= 15:
done = True
# THE JUDGE LLM EVALUATION
if done:
final_score, reasoning = self._evaluate_with_judge()
step_reward += final_score
metadata = {
"step": self._state.step_count,
"hidden_intent": self.hidden_intent,
"judge_reasoning": reasoning
}
else:
metadata = {"step": self._state.step_count}
return CustomerObservation(
customer_reply=customer_reply,
tool_response=tool_response,
conversation_history=self.conversation_history,
done=done,
reward=step_reward,
metadata=metadata
)
def _evaluate_with_judge(self) -> tuple[float, str]:
"""
Uses local LLM as a Judge to score the final transcript.
Returns a tuple of (score, reasoning).
"""
judge_prompt = f"""You are an expert QA Judge for a banking call center.
Review the transcript and score the Agent's performance from -5.0 to +10.0.
TRUE CUSTOMER INTENT: {self.hidden_intent}
SCORING RUBRIC:
- +10.0: Perfect. Intent captured, correct tools used, issue resolved efficiently.
- +5.0: Okay. Found the intent but took too many turns or was awkward.
- 0.0: Neutral. Didn't solve the issue but didn't hallucinate.
- -5.0: Failure. Missed the intent, hallucinated tools, or was rude.
TRANSCRIPT:
{self.conversation_history}
Respond ONLY with a valid JSON object in this exact format:
{{"score": 8.5, "reasoning": "A brief explanation of why."}}
"""
try:
response = local_llm.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": judge_prompt}],
response_format={ "type": "json_object" },
temperature=0.0
)
result = json.loads(response.choices[0].message.content)
score = float(result.get("score", 0.0))
reasoning = result.get("reasoning", "No reasoning provided.")
# Clamp the score just in case the LLM goes rogue
score = max(-5.0, min(10.0, score))
return score, reasoning
except Exception as e:
# Fallback if the local LLM fails to generate valid JSON
print(f"Judge Error: {e}")
return -2.0, "Judge LLM failed to parse transcript."
def _get_customer_reply(self, agent_text: str) -> str:
"""Uses local LLM to simulate the customer."""
system_prompt = f"""You are a banking customer calling support.
Your secret intent is: {self.hidden_intent}.
Your mood is: {self.persona}.
RULES:
1. Keep it under 2 sentences.
2. Do NOT reveal your full intent immediately. Wait for the agent to probe.
3. Respond naturally to what the agent just said.
Conversation history:
{self.conversation_history}"""
response = local_llm.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": agent_text}
],
temperature=0.7,
max_tokens=60
)
return response.choices[0].message.content.strip()
@property
def state(self) -> State:
return self._state