WIZARDIAN's picture
update
6e43149
"""
Banking Support Agent — FastAPI Server
"""
try:
from models import BankingSupportAction, BankingSupportObservation
from server.banking_support_environment import BankingSupportEnvironment
except ImportError:
try:
from ..models import BankingSupportAction, BankingSupportObservation
from .banking_support_environment import BankingSupportEnvironment
except ImportError:
from banking_support_env.models import BankingSupportAction, BankingSupportObservation
from banking_support_env.server.banking_support_environment import BankingSupportEnvironment
import traceback
from fastapi import FastAPI
from fastapi.responses import RedirectResponse, HTMLResponse
app = FastAPI(title="Banking Support Agent Environment")
_envs = {}
_gradio_error = None
# Mount Gradio UI at /ui (must happen at import time, before uvicorn starts)
try:
import sys as _sys
import os as _os
# Ensure project root (/app) is on sys.path
_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))
if _root not in _sys.path:
_sys.path.insert(0, _root)
import gradio as gr
from gradio_ui import create_gradio_app
_demo = create_gradio_app()
gr.mount_gradio_app(app, _demo, path="/ui")
print("[INFO] Gradio UI mounted at /ui", flush=True)
except Exception as _e:
_gradio_error = traceback.format_exc()
print(f"[WARNING] Gradio UI not loaded:\n{_gradio_error}", flush=True)
@app.get("/", include_in_schema=False)
def root():
if _gradio_error is None:
return RedirectResponse(url="/ui")
return RedirectResponse(url="/docs")
@app.get("/ui-debug")
def ui_debug():
"""Shows Gradio load error if UI failed to mount."""
if _gradio_error:
return HTMLResponse(f"""
<h2>Gradio UI failed to load</h2>
<p>Use <a href="/docs">/docs</a> for the API.</p>
<pre style="background:#fee;padding:12px">{_gradio_error}</pre>
""")
return {"gradio_loaded": True, "ui_path": "/ui"}
@app.get("/health")
def health():
return {"status": "ok", "environment": "banking_support"}
@app.post("/reset")
def reset(task_name: str = "balance_inquiry"):
env = BankingSupportEnvironment()
obs = env.reset(task_name=task_name)
_envs["default"] = env
return {
"observation": {
"done": obs.done,
"reward": obs.reward,
"customer_message": obs.customer_message,
"tool_result": obs.tool_result,
"available_tools": obs.available_tools,
"task_description": obs.task_description,
"conversation_history": obs.conversation_history,
"error": obs.error,
},
"reward": obs.reward,
"done": obs.done,
}
@app.post("/step")
def step(action: dict):
env = _envs.get("default")
if env is None:
return {"error": "Call /reset first"}
act = BankingSupportAction(
tool_name=action.get("tool_name"),
tool_args=action.get("tool_args", {}),
message=action.get("message"),
resolve=action.get("resolve", False),
)
obs = env.step(act)
return {
"observation": {
"done": obs.done,
"reward": obs.reward,
"customer_message": obs.customer_message,
"tool_result": obs.tool_result,
"available_tools": obs.available_tools,
"task_description": obs.task_description,
"conversation_history": obs.conversation_history,
"error": obs.error,
},
"reward": obs.reward,
"done": obs.done,
}
@app.get("/state")
def get_state():
env = _envs.get("default")
if env is None:
return {"error": "Call /reset first"}
s = env.state
return {
"episode_id": s.episode_id,
"step_count": s.step_count,
"task_name": s.task_name,
"max_steps": s.max_steps,
"tools_used": s.tools_used,
"resolution_status": s.resolution_status,
"score": s.score,
}
@app.get("/grade")
def grade():
env = _envs.get("default")
if env is None:
return {"error": "Call /reset first"}
return env.get_rubric_report()
@app.get("/tasks")
def list_tasks():
return {
"tasks": [
{"name": "balance_inquiry", "difficulty": "easy", "optimal_steps": 4},
{"name": "fraud_dispute", "difficulty": "medium", "optimal_steps": 6},
{"name": "loan_emi_dispute", "difficulty": "medium", "optimal_steps": 5},
{"name": "complex_escalation", "difficulty": "hard", "optimal_steps": 8},
{"name": "account_takeover", "difficulty": "hard", "optimal_steps": 7},
]
}
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()