AgentApiBench / mock_api /server.py
pandey019
fix
ff07c35
import json
import uuid
from datetime import datetime
from pathlib import Path
from fastapi import FastAPI, Header, Query, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, Any
from mock_api.call_logger import CallLogger
app = FastAPI(title="AgentAPIBench Mock API", version="1.0.0")
logger = CallLogger()
DATA_DIR = Path(__file__).parent / "data"
# Load fake data
CUSTOMERS = json.loads((DATA_DIR / "customers.json").read_text())
INVOICES = json.loads((DATA_DIR / "invoices.json").read_text())
PAYMENTS = json.loads((DATA_DIR / "payments.json").read_text())
VALID_TOKENS = {
"sk-bench-4921x": "test_account_1",
"sk-bench-7823y": "test_account_2",
"sk-bench-1155z": "test_account_3",
}
def verify_auth(authorization: Optional[str]) -> str:
"""Returns account_id or raises 401."""
if not authorization:
raise HTTPException(
status_code=401,
detail="Authorization header required. Use: Authorization: Bearer <token>",
)
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail=f"Authorization header must use Bearer scheme. Got: {authorization}",
)
token = authorization.replace("Bearer ", "").strip()
if token not in VALID_TOKENS:
raise HTTPException(
status_code=401, detail=f"Invalid token. Check your API credentials."
)
return VALID_TOKENS[token]
def log_call(method: str, path: str, params: dict, status: int, response: dict):
logger.log(
{
"method": method,
"path": path,
"params": params,
"status": status,
"response": response,
"timestamp": datetime.utcnow().isoformat(),
}
)
# Catch-all middleware to log 422s that fail FastAPI validation before reaching routes
@app.middleware("http")
async def log_failed_requests(request: Request, call_next):
# Safe clone of query params
params = dict(request.query_params)
response = await call_next(request)
if response.status_code >= 400 and not request.url.path.startswith("/v1/_internal"):
# The route handler didn't get to log it because of an exception, log it now
# We can't easily read the body here without consuming the stream, but we can log the 4xx
log_call(
request.method,
request.url.path,
params,
response.status_code,
{"error": "Request failed HTTP validation"},
)
return response
# ─── CUSTOMERS ───────────────────────────────────────────────────────────────
@app.get("/v1/customers/{customer_id}")
def get_customer(
customer_id: str,
include: Optional[str] = Query(None),
authorization: Optional[str] = Header(None),
):
verify_auth(authorization)
customer = next((c for c in CUSTOMERS if c["id"] == customer_id), None)
if not customer:
log_call("GET", f"/v1/customers/{customer_id}", {}, 404, {})
raise HTTPException(
status_code=404, detail=f"Customer '{customer_id}' not found."
)
result = {
"id": customer["id"],
"name": customer["name"],
"email": customer["email"],
"created_at": customer["created_at"],
}
if include:
fields = [f.strip() for f in include.split(",")]
if "subscription" in fields:
result["subscription"] = customer.get("subscription", {})
if "billing" in fields:
result["billing"] = customer.get("billing", {})
log_call("GET", f"/v1/customers/{customer_id}", {"include": include}, 200, result)
return result
@app.get("/v1/customers")
def list_customers(
status: Optional[str] = None,
limit: int = 20,
authorization: Optional[str] = Header(None),
):
verify_auth(authorization)
results = CUSTOMERS
if status:
results = [c for c in results if c.get("status") == status]
result = {"customers": results[:limit], "total": len(results)}
log_call("GET", "/v1/customers", {"status": status, "limit": limit}, 200, result)
return result
# ─── INVOICES ────────────────────────────────────────────────────────────────
@app.get("/v1/invoices")
def list_invoices(
customer_id: Optional[str] = None,
status: Optional[str] = None,
authorization: Optional[str] = Header(None),
):
verify_auth(authorization)
results = INVOICES
if customer_id:
results = [i for i in results if i["customer_id"] == customer_id]
if status:
results = [i for i in results if i["status"] == status]
result = {"invoices": results, "total": len(results)}
log_call(
"GET",
"/v1/invoices",
{"customer_id": customer_id, "status": status},
200,
result,
)
return result
class RemindBody(BaseModel):
channel: str # "email" | "sms" | "push"
@app.post("/v1/invoices/{invoice_id}/remind")
def send_reminder(
invoice_id: str, body: RemindBody, authorization: Optional[str] = Header(None)
):
verify_auth(authorization)
invoice = next((i for i in INVOICES if i["id"] == invoice_id), None)
if not invoice:
log_call(
"POST",
f"/v1/invoices/{invoice_id}/remind",
{"channel": body.channel},
404,
{},
)
raise HTTPException(
status_code=404, detail=f"Invoice '{invoice_id}' not found."
)
if body.channel not in ["email", "sms", "push"]:
log_call(
"POST",
f"/v1/invoices/{invoice_id}/remind",
{"channel": body.channel},
422,
{},
)
raise HTTPException(
status_code=422,
detail=f"Invalid channel '{body.channel}'. Must be: email, sms, or push.",
)
result = {
"reminder_id": f"rem_{uuid.uuid4().hex[:8]}",
"invoice_id": invoice_id,
"channel": body.channel,
"sent_at": datetime.utcnow().isoformat(),
"status": "sent",
}
log_call(
"POST",
f"/v1/invoices/{invoice_id}/remind",
{"channel": body.channel},
200,
result,
)
return result
# ─── PAYMENTS ────────────────────────────────────────────────────────────────
class PaymentBody(BaseModel):
customer_id: str
amount: float
currency: str # Required β€” ISO 4217 (USD, EUR, GBP...)
@app.post("/v1/payments")
def create_payment(body: PaymentBody, authorization: Optional[str] = Header(None)):
verify_auth(authorization)
if body.amount <= 0:
log_call("POST", "/v1/payments", body.dict(), 422, {})
raise HTTPException(status_code=422, detail="Amount must be greater than 0.")
if len(body.currency) != 3:
log_call("POST", "/v1/payments", body.dict(), 422, {})
raise HTTPException(
status_code=422,
detail=f"Currency must be a 3-letter ISO 4217 code (USD, EUR, GBP). Got: {body.currency}",
)
result = {
"payment_id": f"pay_{uuid.uuid4().hex[:8]}",
"customer_id": body.customer_id,
"amount": body.amount,
"currency": body.currency,
"status": "succeeded",
"created_at": datetime.utcnow().isoformat(),
}
log_call("POST", "/v1/payments", body.dict(), 200, result)
return result
# ─── CALL LOG ACCESS (for graders) ───────────────────────────────────────────
@app.get("/v1/_internal/call_log")
def get_call_log(authorization: Optional[str] = Header(None)):
verify_auth(authorization)
return {"calls": logger.get_calls()}
@app.delete("/v1/_internal/call_log")
def clear_call_log(authorization: Optional[str] = Header(None)):
verify_auth(authorization)
logger.clear()
return {"message": "Call log cleared"}