Spaces:
Sleeping
Sleeping
| import json | |
| import importlib | |
| import os | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type, cast | |
| from openai import OpenAI | |
| ROOT = Path(__file__).resolve().parent | |
| def _load_dotenv() -> None: | |
| env_path = ROOT / ".env" | |
| if not env_path.exists(): | |
| return | |
| for raw_line in env_path.read_text(encoding="utf-8").splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("#") or "=" not in line: | |
| continue | |
| key, value = line.split("=", 1) | |
| os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) | |
| _load_dotenv() | |
| if TYPE_CHECKING: | |
| from .models import Action | |
| from .server.helpdesk_environment import HelpdeskEnv | |
| def _import_local_modules() -> Tuple[Type["HelpdeskEnv"], Type["Action"], Any]: | |
| if __package__ not in (None, ""): | |
| from .models import Action, normalize_action | |
| from .server.helpdesk_environment import HelpdeskEnv | |
| return HelpdeskEnv, Action, normalize_action | |
| package_parent = ROOT.parent | |
| package_name = ROOT.name | |
| if str(package_parent) not in sys.path: | |
| sys.path.insert(0, str(package_parent)) | |
| helpdesk_environment = importlib.import_module( | |
| f"{package_name}.server.helpdesk_environment" | |
| ) | |
| models = importlib.import_module(f"{package_name}.models") | |
| return helpdesk_environment.HelpdeskEnv, models.Action, models.normalize_action | |
| HelpdeskEnv, Action, normalize_action = cast( | |
| Tuple[Type["HelpdeskEnv"], Type["Action"], Any], | |
| _import_local_modules(), | |
| ) | |
| if __package__ not in (None, ""): | |
| from .graders.score_utils import ensure_open_unit_interval | |
| else: | |
| from graders.score_utils import ensure_open_unit_interval | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "helpdesk-openenv") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL") or os.getenv("MODEL_NAME") or "gpt-5" | |
| API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") | |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://freakdivi-helpdesk-env.hf.space") | |
| HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN", "") | |
| TASK_NAME = os.getenv("TASK_NAME", "all") | |
| BENCHMARK = os.getenv("BENCHMARK", "helpdesk_env") | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0")) | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "120")) | |
| SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.50")) | |
| DISCOUNT_GAMMA = 0.9 | |
| KB_CANDIDATE_LIMIT = 6 | |
| REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "8")) | |
| MAX_STEPS_BY_TASK = { | |
| "easy": 1, | |
| "medium": 3, | |
| "hard": 8, | |
| } | |
| SUPPORTED_TASKS = ("easy", "medium", "hard") | |
| CURRENT_SCORE_METHOD = "single" | |
| SYSTEM_PROMPT_BASE = ( | |
| "You are a banking customer support agent for a UPI payments app. " | |
| "Never ask for PIN, OTP, CVV, or full card details. " | |
| "You must return exactly one JSON object with keys from: " | |
| "action_type, category, faq_id, message. " | |
| "Valid action_type values are exactly: classify, lookup_faq, ask_clarification, " | |
| "reply, escalate, resolve_ticket." | |
| ) | |
| def system_prompt_for_task(task_id: str) -> str: | |
| if task_id == "easy": | |
| return ( | |
| SYSTEM_PROMPT_BASE | |
| + " For easy tasks, classify the issue into exactly one category from " | |
| "observation.available_categories." | |
| ) | |
| if task_id == "medium": | |
| return ( | |
| SYSTEM_PROMPT_BASE | |
| + " For medium tasks, choose lookup_faq with the best faq_id from " | |
| "observation.knowledge_base, or use escalate when fraud or overdue review requires manual handling." | |
| ) | |
| return ( | |
| SYSTEM_PROMPT_BASE | |
| + " For hard tasks, ask for clarification first, then retrieve the right FAQ, " | |
| "then reply with safe guidance, and only resolve after the customer confirms the issue is fixed." | |
| ) | |
| def build_user_prompt(task_id: str, observation_json: str, history: List[str]) -> str: | |
| history_block = "\n".join(history[-4:]) if history else "None" | |
| return textwrap.dedent( | |
| f""" | |
| Task: {task_id} | |
| Observation JSON: | |
| {observation_json} | |
| Recent action history: | |
| {history_block} | |
| Return the next action as one JSON object only. | |
| """ | |
| ).strip() | |
| def _tokenize_text(text: str) -> List[str]: | |
| cleaned = [] | |
| for raw in text.lower().replace("/", " ").replace("_", " ").split(): | |
| token = "".join(ch for ch in raw if ch.isalnum()) | |
| if len(token) >= 3: | |
| cleaned.append(token) | |
| return cleaned | |
| def _compact_text(text: str, limit: int) -> str: | |
| normalized = " ".join(text.split()) | |
| if len(normalized) <= limit: | |
| return normalized | |
| return normalized[: limit - 3].rstrip() + "..." | |
| def _score_faq_candidate(entry: Dict[str, Any], query_terms: List[str]) -> int: | |
| searchable_parts = [ | |
| str(entry.get("faq_id") or entry.get("id") or ""), | |
| str(entry.get("category") or ""), | |
| str(entry.get("title") or entry.get("question") or ""), | |
| str(entry.get("content") or entry.get("answer") or ""), | |
| " ".join(str(tag) for tag in entry.get("tags", []) if isinstance(tag, str)), | |
| ] | |
| searchable_text = " ".join(searchable_parts).lower() | |
| return sum(3 if term in searchable_text else 0 for term in query_terms) | |
| def _candidate_faqs(observation: Any, history: List[str], limit: int = KB_CANDIDATE_LIMIT) -> List[Dict[str, Any]]: | |
| query_terms = _tokenize_text( | |
| " ".join( | |
| [ | |
| observation.customer_message, | |
| *[turn.get("content", "") for turn in observation.conversation_history[-4:]], | |
| *history[-4:], | |
| ] | |
| ) | |
| ) | |
| scored_entries: List[Tuple[int, int, Dict[str, Any]]] = [] | |
| for index, entry in enumerate(observation.knowledge_base): | |
| score = _score_faq_candidate(entry, query_terms) | |
| scored_entries.append((score, -index, entry)) | |
| ranked_entries = [ | |
| entry | |
| for score, _neg_index, entry in sorted(scored_entries, reverse=True) | |
| if score > 0 | |
| ] | |
| fallback_entries = [entry for _score, _neg_index, entry in sorted(scored_entries, reverse=True)] | |
| selected = (ranked_entries or fallback_entries)[:limit] | |
| compact_entries: List[Dict[str, Any]] = [] | |
| for entry in selected: | |
| compact_entries.append( | |
| { | |
| "faq_id": entry.get("faq_id") or entry.get("id"), | |
| "category": entry.get("category"), | |
| "title": entry.get("title") or entry.get("question"), | |
| "content": _compact_text( | |
| str(entry.get("content") or entry.get("answer") or ""), | |
| 220, | |
| ), | |
| "tags": entry.get("tags", [])[:5], | |
| } | |
| ) | |
| return compact_entries | |
| def _serialize_observation(task_id: str, observation: Any, history: List[str]) -> str: | |
| payload: Dict[str, Any] = { | |
| "case_id": observation.case_id, | |
| "task_id": task_id, | |
| "turn_number": observation.turn_number, | |
| "customer_message": observation.customer_message, | |
| "conversation_history": observation.conversation_history[-4:], | |
| "required_slots": observation.required_slots, | |
| } | |
| if task_id == "easy": | |
| payload["available_categories"] = observation.available_categories | |
| else: | |
| payload["knowledge_base"] = _candidate_faqs(observation, history) | |
| if task_id == "hard": | |
| payload["clarification_received"] = observation.known_facts.get( | |
| "clarification_received", False | |
| ) | |
| payload["faq_retrieved"] = observation.known_facts.get("faq_retrieved", False) | |
| payload["issue_resolved"] = observation.known_facts.get("issue_resolved", False) | |
| payload["collected_slots"] = observation.known_facts.get("collected_slots", {}) | |
| return json.dumps(payload, separators=(",", ":"), ensure_ascii=True) | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} " | |
| f"done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def _extract_json_object(text: str) -> str: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| if len(lines) >= 2 and lines[0].startswith("```"): | |
| lines = lines[1:] | |
| if lines and lines[-1].strip() == "```": | |
| lines = lines[:-1] | |
| text = "\n".join(lines).strip() | |
| return text | |
| _VALID_ACTIONS = frozenset( | |
| { | |
| "classify", | |
| "lookup_faq", | |
| "ask_clarification", | |
| "reply", | |
| "escalate", | |
| "resolve_ticket", | |
| } | |
| ) | |
| ActionType = Literal[ | |
| "classify", | |
| "lookup_faq", | |
| "ask_clarification", | |
| "reply", | |
| "escalate", | |
| "resolve_ticket", | |
| ] | |
| def _normalize_action_type(raw: object) -> Optional[ActionType]: | |
| if raw is None: | |
| return None | |
| value = str(raw).strip().lower().replace("-", "_") | |
| return cast(ActionType, value) if value in _VALID_ACTIONS else None | |
| def _fallback_action(task_id: str, turn_number: int) -> Dict[str, Any]: | |
| # Fallback actions are only for genuine parse failures or runtime exceptions. | |
| # They must stay conservative and must not act as a scoring strategy. | |
| if task_id == "easy": | |
| return {"action_type": "classify", "category": "payment_failure"} | |
| if task_id == "medium": | |
| return {"action_type": "escalate", "message": "Escalating for manual review."} | |
| if turn_number == 0: | |
| return { | |
| "action_type": "ask_clarification", | |
| "message": "Please share the UTR, amount, and exact issue.", | |
| } | |
| if turn_number == 1: | |
| return { | |
| "action_type": "escalate", | |
| "message": "Unable to process request. Escalating for manual review.", | |
| } | |
| if turn_number in (2, 3): | |
| return { | |
| "action_type": "reply", | |
| "message": "Please follow the safe steps in the app and confirm the result.", | |
| } | |
| return {"action_type": "resolve_ticket"} | |
| def parse_action(response_text: str, task_id: str, turn_number: int) -> Dict[str, Any]: | |
| text = _extract_json_object(response_text) | |
| try: | |
| payload = json.loads(text) | |
| except json.JSONDecodeError: | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| payload = json.loads(text[start : end + 1]) | |
| except json.JSONDecodeError: | |
| payload = {} | |
| else: | |
| payload = {} | |
| action_type = _normalize_action_type(payload.get("action_type")) | |
| if not action_type: | |
| return _fallback_action(task_id, turn_number) | |
| try: | |
| return { | |
| "action_type": action_type, | |
| "category": payload.get("category"), | |
| "faq_id": payload.get("faq_id"), | |
| "message": payload.get("message"), | |
| } | |
| except Exception: | |
| return _fallback_action(task_id, turn_number) | |
| def get_model_action( | |
| client: OpenAI, | |
| task_id: str, | |
| observation_json: str, | |
| history: List[str], | |
| turn_number: int, | |
| ) -> Dict[str, Any]: | |
| user_prompt = build_user_prompt(task_id, observation_json, history) | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt_for_task(task_id)}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| response_format={"type": "json_object"}, | |
| ) | |
| text = completion.choices[0].message.content or "" | |
| return parse_action(text, task_id, turn_number) | |
| def _resolve_requested_tasks(task_name: str) -> List[str]: | |
| normalized = task_name.strip().lower() | |
| if not normalized or normalized == "all": | |
| return list(SUPPORTED_TASKS) | |
| requested = [task.strip().lower() for task in task_name.split(",") if task.strip()] | |
| invalid = [task for task in requested if task not in SUPPORTED_TASKS] | |
| if invalid: | |
| raise ValueError( | |
| f"Unsupported TASK_NAME value(s): {', '.join(invalid)}. " | |
| f"Expected one of: {', '.join(SUPPORTED_TASKS)} or 'all'." | |
| ) | |
| return requested | |
| def _run_task(client: OpenAI, task_id: str) -> None: | |
| global CURRENT_SCORE_METHOD | |
| env = HelpdeskEnv() | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = ensure_open_unit_interval(0.0) | |
| success = False | |
| CURRENT_SCORE_METHOD = "single" | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| observation = env.reset(task_id) | |
| done = False | |
| for step in range(1, MAX_STEPS_BY_TASK.get(task_id, 3) + 1): | |
| if done: | |
| break | |
| error: Optional[str] = None | |
| try: | |
| raw_action = get_model_action( | |
| client=client, | |
| task_id=task_id, | |
| observation_json=_serialize_observation(task_id, observation, history), | |
| history=history, | |
| turn_number=observation.turn_number, | |
| ) | |
| action = normalize_action(raw_action) | |
| observation, reward, done, _info = env.step(action) | |
| reward_value = ensure_open_unit_interval(reward.value) | |
| except Exception as exc: | |
| raw_action = _fallback_action(task_id, observation.turn_number) | |
| action = normalize_action(raw_action) | |
| reward_value = ensure_open_unit_interval(0.0) | |
| done = True | |
| error = str(exc) | |
| action_str = json.dumps(action.model_dump(exclude_none=True), separators=(",", ":")) | |
| log_step( | |
| step=step, | |
| action=action_str, | |
| reward=reward_value, | |
| done=done, | |
| error=error, | |
| ) | |
| rewards.append(reward_value) | |
| steps_taken = step | |
| history.append(f"step={step} action={action_str} reward={reward_value:.2f}") | |
| if task_id == "easy": | |
| CURRENT_SCORE_METHOD = "single" | |
| raw_score = rewards[-1] if rewards else 0.0 | |
| elif task_id == "medium": | |
| CURRENT_SCORE_METHOD = "terminal" | |
| raw_score = rewards[-1] if rewards else 0.0 | |
| else: | |
| CURRENT_SCORE_METHOD = "discounted" | |
| discount_weights = [DISCOUNT_GAMMA**t for t in range(len(rewards))] | |
| discounted_sum = sum( | |
| reward * weight for reward, weight in zip(rewards, discount_weights) | |
| ) | |
| normalizer = sum(discount_weights) | |
| raw_score = (discounted_sum / normalizer) if normalizer else 0.0 | |
| score = ensure_open_unit_interval(raw_score) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| def main() -> None: | |
| if not API_KEY: | |
| raise RuntimeError( | |
| "Set API_KEY, OPENAI_API_KEY, or GROQ_API_KEY before running inference.py" | |
| ) | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=API_KEY, | |
| timeout=REQUEST_TIMEOUT_SECONDS, | |
| max_retries=0, | |
| ) | |
| for task_id in _resolve_requested_tasks(TASK_NAME): | |
| _run_task(client, task_id) | |
| if __name__ == "__main__": | |
| main() | |