helpdesk_env / inference.py
Freakdivi's picture
Upload folder using huggingface_hub
9a013d0 verified
import json
import importlib
import os
import sys
import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type, cast
from openai import OpenAI
ROOT = Path(__file__).resolve().parent
def _load_dotenv() -> None:
env_path = ROOT / ".env"
if not env_path.exists():
return
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
_load_dotenv()
if TYPE_CHECKING:
from .models import Action
from .server.helpdesk_environment import HelpdeskEnv
def _import_local_modules() -> Tuple[Type["HelpdeskEnv"], Type["Action"], Any]:
if __package__ not in (None, ""):
from .models import Action, normalize_action
from .server.helpdesk_environment import HelpdeskEnv
return HelpdeskEnv, Action, normalize_action
package_parent = ROOT.parent
package_name = ROOT.name
if str(package_parent) not in sys.path:
sys.path.insert(0, str(package_parent))
helpdesk_environment = importlib.import_module(
f"{package_name}.server.helpdesk_environment"
)
models = importlib.import_module(f"{package_name}.models")
return helpdesk_environment.HelpdeskEnv, models.Action, models.normalize_action
HelpdeskEnv, Action, normalize_action = cast(
Tuple[Type["HelpdeskEnv"], Type["Action"], Any],
_import_local_modules(),
)
if __package__ not in (None, ""):
from .graders.score_utils import ensure_open_unit_interval
else:
from graders.score_utils import ensure_open_unit_interval
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "helpdesk-openenv")
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL") or os.getenv("MODEL_NAME") or "gpt-5"
API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY")
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://freakdivi-helpdesk-env.hf.space")
HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN", "")
TASK_NAME = os.getenv("TASK_NAME", "all")
BENCHMARK = os.getenv("BENCHMARK", "helpdesk_env")
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "120"))
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.50"))
DISCOUNT_GAMMA = 0.9
KB_CANDIDATE_LIMIT = 6
REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "8"))
MAX_STEPS_BY_TASK = {
"easy": 1,
"medium": 3,
"hard": 8,
}
SUPPORTED_TASKS = ("easy", "medium", "hard")
CURRENT_SCORE_METHOD = "single"
SYSTEM_PROMPT_BASE = (
"You are a banking customer support agent for a UPI payments app. "
"Never ask for PIN, OTP, CVV, or full card details. "
"You must return exactly one JSON object with keys from: "
"action_type, category, faq_id, message. "
"Valid action_type values are exactly: classify, lookup_faq, ask_clarification, "
"reply, escalate, resolve_ticket."
)
def system_prompt_for_task(task_id: str) -> str:
if task_id == "easy":
return (
SYSTEM_PROMPT_BASE
+ " For easy tasks, classify the issue into exactly one category from "
"observation.available_categories."
)
if task_id == "medium":
return (
SYSTEM_PROMPT_BASE
+ " For medium tasks, choose lookup_faq with the best faq_id from "
"observation.knowledge_base, or use escalate when fraud or overdue review requires manual handling."
)
return (
SYSTEM_PROMPT_BASE
+ " For hard tasks, ask for clarification first, then retrieve the right FAQ, "
"then reply with safe guidance, and only resolve after the customer confirms the issue is fixed."
)
def build_user_prompt(task_id: str, observation_json: str, history: List[str]) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
return textwrap.dedent(
f"""
Task: {task_id}
Observation JSON:
{observation_json}
Recent action history:
{history_block}
Return the next action as one JSON object only.
"""
).strip()
def _tokenize_text(text: str) -> List[str]:
cleaned = []
for raw in text.lower().replace("/", " ").replace("_", " ").split():
token = "".join(ch for ch in raw if ch.isalnum())
if len(token) >= 3:
cleaned.append(token)
return cleaned
def _compact_text(text: str, limit: int) -> str:
normalized = " ".join(text.split())
if len(normalized) <= limit:
return normalized
return normalized[: limit - 3].rstrip() + "..."
def _score_faq_candidate(entry: Dict[str, Any], query_terms: List[str]) -> int:
searchable_parts = [
str(entry.get("faq_id") or entry.get("id") or ""),
str(entry.get("category") or ""),
str(entry.get("title") or entry.get("question") or ""),
str(entry.get("content") or entry.get("answer") or ""),
" ".join(str(tag) for tag in entry.get("tags", []) if isinstance(tag, str)),
]
searchable_text = " ".join(searchable_parts).lower()
return sum(3 if term in searchable_text else 0 for term in query_terms)
def _candidate_faqs(observation: Any, history: List[str], limit: int = KB_CANDIDATE_LIMIT) -> List[Dict[str, Any]]:
query_terms = _tokenize_text(
" ".join(
[
observation.customer_message,
*[turn.get("content", "") for turn in observation.conversation_history[-4:]],
*history[-4:],
]
)
)
scored_entries: List[Tuple[int, int, Dict[str, Any]]] = []
for index, entry in enumerate(observation.knowledge_base):
score = _score_faq_candidate(entry, query_terms)
scored_entries.append((score, -index, entry))
ranked_entries = [
entry
for score, _neg_index, entry in sorted(scored_entries, reverse=True)
if score > 0
]
fallback_entries = [entry for _score, _neg_index, entry in sorted(scored_entries, reverse=True)]
selected = (ranked_entries or fallback_entries)[:limit]
compact_entries: List[Dict[str, Any]] = []
for entry in selected:
compact_entries.append(
{
"faq_id": entry.get("faq_id") or entry.get("id"),
"category": entry.get("category"),
"title": entry.get("title") or entry.get("question"),
"content": _compact_text(
str(entry.get("content") or entry.get("answer") or ""),
220,
),
"tags": entry.get("tags", [])[:5],
}
)
return compact_entries
def _serialize_observation(task_id: str, observation: Any, history: List[str]) -> str:
payload: Dict[str, Any] = {
"case_id": observation.case_id,
"task_id": task_id,
"turn_number": observation.turn_number,
"customer_message": observation.customer_message,
"conversation_history": observation.conversation_history[-4:],
"required_slots": observation.required_slots,
}
if task_id == "easy":
payload["available_categories"] = observation.available_categories
else:
payload["knowledge_base"] = _candidate_faqs(observation, history)
if task_id == "hard":
payload["clarification_received"] = observation.known_facts.get(
"clarification_received", False
)
payload["faq_retrieved"] = observation.known_facts.get("faq_retrieved", False)
payload["issue_resolved"] = observation.known_facts.get("issue_resolved", False)
payload["collected_slots"] = observation.known_facts.get("collected_slots", {})
return json.dumps(payload, separators=(",", ":"), ensure_ascii=True)
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
def _extract_json_object(text: str) -> str:
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
if len(lines) >= 2 and lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
text = "\n".join(lines).strip()
return text
_VALID_ACTIONS = frozenset(
{
"classify",
"lookup_faq",
"ask_clarification",
"reply",
"escalate",
"resolve_ticket",
}
)
ActionType = Literal[
"classify",
"lookup_faq",
"ask_clarification",
"reply",
"escalate",
"resolve_ticket",
]
def _normalize_action_type(raw: object) -> Optional[ActionType]:
if raw is None:
return None
value = str(raw).strip().lower().replace("-", "_")
return cast(ActionType, value) if value in _VALID_ACTIONS else None
def _fallback_action(task_id: str, turn_number: int) -> Dict[str, Any]:
# Fallback actions are only for genuine parse failures or runtime exceptions.
# They must stay conservative and must not act as a scoring strategy.
if task_id == "easy":
return {"action_type": "classify", "category": "payment_failure"}
if task_id == "medium":
return {"action_type": "escalate", "message": "Escalating for manual review."}
if turn_number == 0:
return {
"action_type": "ask_clarification",
"message": "Please share the UTR, amount, and exact issue.",
}
if turn_number == 1:
return {
"action_type": "escalate",
"message": "Unable to process request. Escalating for manual review.",
}
if turn_number in (2, 3):
return {
"action_type": "reply",
"message": "Please follow the safe steps in the app and confirm the result.",
}
return {"action_type": "resolve_ticket"}
def parse_action(response_text: str, task_id: str, turn_number: int) -> Dict[str, Any]:
text = _extract_json_object(response_text)
try:
payload = json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
payload = {}
else:
payload = {}
action_type = _normalize_action_type(payload.get("action_type"))
if not action_type:
return _fallback_action(task_id, turn_number)
try:
return {
"action_type": action_type,
"category": payload.get("category"),
"faq_id": payload.get("faq_id"),
"message": payload.get("message"),
}
except Exception:
return _fallback_action(task_id, turn_number)
def get_model_action(
client: OpenAI,
task_id: str,
observation_json: str,
history: List[str],
turn_number: int,
) -> Dict[str, Any]:
user_prompt = build_user_prompt(task_id, observation_json, history)
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt_for_task(task_id)},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
response_format={"type": "json_object"},
)
text = completion.choices[0].message.content or ""
return parse_action(text, task_id, turn_number)
def _resolve_requested_tasks(task_name: str) -> List[str]:
normalized = task_name.strip().lower()
if not normalized or normalized == "all":
return list(SUPPORTED_TASKS)
requested = [task.strip().lower() for task in task_name.split(",") if task.strip()]
invalid = [task for task in requested if task not in SUPPORTED_TASKS]
if invalid:
raise ValueError(
f"Unsupported TASK_NAME value(s): {', '.join(invalid)}. "
f"Expected one of: {', '.join(SUPPORTED_TASKS)} or 'all'."
)
return requested
def _run_task(client: OpenAI, task_id: str) -> None:
global CURRENT_SCORE_METHOD
env = HelpdeskEnv()
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = ensure_open_unit_interval(0.0)
success = False
CURRENT_SCORE_METHOD = "single"
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
observation = env.reset(task_id)
done = False
for step in range(1, MAX_STEPS_BY_TASK.get(task_id, 3) + 1):
if done:
break
error: Optional[str] = None
try:
raw_action = get_model_action(
client=client,
task_id=task_id,
observation_json=_serialize_observation(task_id, observation, history),
history=history,
turn_number=observation.turn_number,
)
action = normalize_action(raw_action)
observation, reward, done, _info = env.step(action)
reward_value = ensure_open_unit_interval(reward.value)
except Exception as exc:
raw_action = _fallback_action(task_id, observation.turn_number)
action = normalize_action(raw_action)
reward_value = ensure_open_unit_interval(0.0)
done = True
error = str(exc)
action_str = json.dumps(action.model_dump(exclude_none=True), separators=(",", ":"))
log_step(
step=step,
action=action_str,
reward=reward_value,
done=done,
error=error,
)
rewards.append(reward_value)
steps_taken = step
history.append(f"step={step} action={action_str} reward={reward_value:.2f}")
if task_id == "easy":
CURRENT_SCORE_METHOD = "single"
raw_score = rewards[-1] if rewards else 0.0
elif task_id == "medium":
CURRENT_SCORE_METHOD = "terminal"
raw_score = rewards[-1] if rewards else 0.0
else:
CURRENT_SCORE_METHOD = "discounted"
discount_weights = [DISCOUNT_GAMMA**t for t in range(len(rewards))]
discounted_sum = sum(
reward * weight for reward, weight in zip(rewards, discount_weights)
)
normalizer = sum(discount_weights)
raw_score = (discounted_sum / normalizer) if normalizer else 0.0
score = ensure_open_unit_interval(raw_score)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
def main() -> None:
if not API_KEY:
raise RuntimeError(
"Set API_KEY, OPENAI_API_KEY, or GROQ_API_KEY before running inference.py"
)
client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY,
timeout=REQUEST_TIMEOUT_SECONDS,
max_retries=0,
)
for task_id in _resolve_requested_tasks(TASK_NAME):
_run_task(client, task_id)
if __name__ == "__main__":
main()