vivekvish2004's picture
Upload folder using huggingface_hub
dc97fe1 verified
from fastapi import FastAPI, HTTPException, Query, Response, Request
from fastapi.middleware.cors import CORSMiddleware
import os
import json
from openai import OpenAI
from .env import CustomerSupportEnv
from .models import Action, Observation, SYSTEM_PROMPT, DEFAULT_MODEL, DEFAULT_API_BASE
def load_tasks_from_json():
"""Load tasks from tasks.json strictly."""
json_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tasks.json")
if os.path.exists(json_path):
try:
with open(json_path, "r") as f:
return json.load(f)
except Exception:
return []
return []
TASKS = load_tasks_from_json()
app = FastAPI(
title="OpenEnv Customer Support API",
version="1.0.0",
description="Enterprise AI Customer Support OpenEnv simulation environment.",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
return Response(status_code=204)
# AI Configuration
# Mandatory Pre-Submission Configuration
API_KEY = os.getenv("HF_TOKEN")
API_BASE_URL = os.getenv("API_BASE_URL") or DEFAULT_API_BASE
MODEL_NAME = os.getenv("MODEL_NAME") or DEFAULT_MODEL
# Global session manager to support concurrent evaluations
SESSIONS = {}
ai_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
def get_env(session_id: str = "default") -> CustomerSupportEnv:
"""Retrieve or create an environment instance for a specific session."""
if session_id not in SESSIONS:
SESSIONS[session_id] = CustomerSupportEnv()
return SESSIONS[session_id]
# ───────────────────────────────────────────────────────────────────────────────
# OpenEnv Standard Endpoints
# ───────────────────────────────────────────────────────────────────────────────
@app.get("/health", tags=["Health"])
def health_check():
"""Standard health check endpoint required by OpenEnv runtime validator."""
return {"status": "healthy", "service": "customer-support-env"}
@app.get("/metadata", tags=["Environment Info"])
def get_metadata():
"""Return environment metadata β€” required by OpenEnv runtime validator."""
return {
"name": "customer-support-env",
"description": "Enterprise AI Customer Support simulation where an agent processes a queue of support tickets through classification, prioritization, response generation, and resolution.",
"version": "1.0.0",
"tags": ["customer-support", "enterprise-ai", "decision-making"],
"mode": "simulation",
}
@app.get("/schema", tags=["Schema"])
def get_schema():
"""Return JSON schemas for action, observation, and state β€” required by OpenEnv validator."""
return {
"action": {
"type": "object",
"properties": {
"action_type": {
"type": "string",
"enum": ["classify_ticket", "assign_priority", "generate_response", "resolve", "escalate", "search_kb", "ask_clarification"],
"description": "The type of action to perform on the current ticket."
},
"payload": {
"type": "object",
"description": "Action-specific parameters.",
"examples": [
{"classification": "refund"},
{"priority": "high"},
{"response": "I am sorry for the inconvenience..."},
{}
]
}
},
"required": ["action_type", "payload"]
},
"observation": {
"type": "object",
"properties": {
"state": {"type": "object", "description": "Current environment state dict"},
"info": {"type": "object", "description": "Additional metadata about the current state"}
},
"required": ["state"]
},
"state": {
"type": "object",
"properties": {
"ticket_text": {"type": "string"},
"sentiment": {"type": "string", "enum": ["angry", "neutral", "panicked", "curious", "happy", "concerned"]},
"priority": {"type": ["string", "null"], "enum": ["low", "medium", "high", None]},
"status": {"type": "string", "enum": ["open", "closed", "session_complete"]},
"classification": {"type": ["string", "null"]},
"response": {"type": ["string", "null"]},
"queue_size": {"type": "integer"},
"resolved": {"type": "integer"},
"total_reward": {"type": "number"},
"last_step_status": {"type": "string", "enum": ["success", "failed", "neutral"]}
}
}
}
@app.get("/reset", tags=["Environment Control"], operation_id="reset_env_get")
@app.post("/reset", tags=["Environment Control"], operation_id="reset_env_post")
def reset_env(session_id: str = "default"):
"""Reset the environment for a specific session."""
env = get_env(session_id)
obs = env.reset()
state = obs.state
return {
"observation": state,
"state": state,
"reward": 0.0,
"done": False,
"session_id": session_id
}
@app.post("/step", tags=["Environment Control"])
def step_env(action: Action, session_id: str = "default"):
"""Submit an action to a specific session."""
env = get_env(session_id)
if not env.queue:
env.reset()
obs, reward, done, info = env.step(action)
state = obs.state
return {
"observation": state,
"state": state,
"reward": float(reward.value),
"done": bool(done),
"info": info,
"session_id": session_id
}
@app.get("/state", tags=["State Management"])
def get_state(session_id: str = "default"):
"""Retrieve the current deterministic state of a session."""
env = get_env(session_id)
obs = env.state()
state = obs.state
if state.get("status") == "session_complete":
obs = env.reset()
state = obs.state
return {
"observation": state,
"state": state,
"session_id": session_id
}
@app.get("/tasks", tags=["Environment Info"])
def get_tasks(session_id: str = "default"):
"""Retrieve all available tasks for a session."""
env = get_env(session_id)
return env.get_tasks()
@app.get("/grader", tags=["Environment Info"])
def run_grader(
task_id: str = Query(..., description="Task ID to grade (e.g. 'task_easy_1')"),
session_id: str = "default"
):
"""Grade a specific task for a session."""
env = get_env(session_id)
tasks = env.get_tasks()
task_meta = next((t for t in tasks if t["id"] == task_id), None)
if task_meta is None:
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
if not task_meta.get("grader"):
raise HTTPException(status_code=400, detail=f"Task '{task_id}' does not have a grader.")
difficulty = task_meta.get("difficulty", "EASY")
mock_state = _build_mock_state(difficulty)
ground_truth = {
"expected_classification": "refund",
"expected_priority": "high",
"sentiment": "angry",
}
try:
score = env.grade(task_id, [{"state": mock_state}], ground_truth)
score = float(max(0.0, min(1.0, score)))
return {
"task_id": task_id,
"score": score,
"reward": score,
"success": score >= 0.5,
"message": f"Task '{task_id}' graded with score {score:.4f}",
"difficulty": difficulty,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Grader execution failed: {str(e)}")
def _build_mock_state(difficulty: str) -> dict:
"""Build a near-perfect mock state for deterministic grader testing."""
return {
"ticket_text": "I bought a premium subscription but it's not working. I want my money back right now!",
"sentiment": "angry",
"classification": "refund",
"priority": "high",
"response": "I am so sorry for the inconvenience. We completely understand your frustration.",
"status": "closed",
"queue_size": 0,
"resolved": 1,
"total_reward": 0.8,
}
@app.post("/mcp", tags=["Environment Info"])
async def mcp_endpoint(request: Request):
"""Minimal JSON-RPC 2.0 endpoint required by OpenEnv runtime validator."""
try:
body = await request.json()
except Exception:
body = {}
method = body.get("method", "")
req_id = body.get("id", 1)
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "customer-support-env", "version": "1.0.0"},
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"tools": [
{
"name": "step",
"description": "Take a step in the customer support environment. Available actions: classify_ticket, assign_priority, generate_response, search_kb, ask_clarification, resolve, escalate.",
"inputSchema": {
"type": "object",
"properties": {
"action_type": {"type": "string"},
"payload": {"type": "object"}
}
}
}
]
}
}
else:
return {"jsonrpc": "2.0", "id": req_id, "result": {}}
@app.get("/baseline", tags=["Environment Control"])
def run_baseline(session_id: str = "default"):
"""Execute a hardcoded 'perfect' baseline workflow in isolation."""
env = get_env(session_id)
if not env.queue:
env.reset()
gt = env_instance.ground_truth
baseline_sequence = [
{"action_type": "classify_ticket", "payload": {"classification": gt["expected_classification"]}},
{"action_type": "assign_priority", "payload": {"priority": gt["expected_priority"]}},
{"action_type": "generate_response", "payload": {"response": "I am so sorry for the inconvenience. That is completely fixed now."}},
{"action_type": "resolve", "payload": {}}
]
trace_results = []
for step_logic in baseline_sequence:
action = Action(**step_logic)
obs, reward, done, info = env_instance.step(action)
trace_results.append({
"action": step_logic,
"reward_earned": reward.value,
"done": done
})
if done:
break
return {
"message": "Baseline ideal sequence successfully executed against ground truth.",
"trace": trace_results,
"final_state": env.current_state,
"session_id": session_id
}
@app.get("/predict", tags=["Environment Control"])
async def predict_action(session_id: str = "default"):
"""Ask the AI Model to suggest the next logical action for the current ticket."""
env = get_env(session_id)
if env.current_state is None or not env.queue:
raise HTTPException(status_code=400, detail="No active session or queue is empty.")
if not ai_client:
raise HTTPException(status_code=500, detail="AI Client not configured. Ensure HF_TOKEN is set.")
try:
completion = ai_client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Current State: {json.dumps(env.current_state)}"}
],
temperature=0.0,
response_format={"type": "json_object"}
)
return json.loads(completion.choices[0].message.content)
except Exception as e:
raise HTTPException(status_code=500, detail=f"LLM Prediction failed: {str(e)}")
def main():
import uvicorn
print("πŸš€ Starting OpenEnv Customer Support Backend...")
uvicorn.run("backend.main:app", host="0.0.0.0", port=7860, reload=False, log_level="info")
if __name__ == "__main__":
main()