Replace server.py with payment_credit_env FastAPI server

#1
Files changed (1) hide show
  1. server.py +231 -171
server.py CHANGED
@@ -1,171 +1,231 @@
1
- """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any
6
-
7
- from fastapi import FastAPI, HTTPException
8
- from pydantic import BaseModel
9
-
10
- from environment import SchedulingOptEnv
11
- from graders.grader_classification import ConflictGrader
12
- from graders.grader_detection import FeasibilityGrader
13
- from graders.grader_fix import RepairGrader
14
- from models import Action, Observation
15
-
16
- app = FastAPI(
17
- title="Scheduling Optimisation Environment",
18
- description=(
19
- "OpenEnv-compatible environment for training AI agents on combinatorial "
20
- "scheduling optimisation problems."
21
- ),
22
- version="1.0.0",
23
- )
24
-
25
- # Single shared environment instance.
26
- env = SchedulingOptEnv()
27
-
28
-
29
- # ---------------------------------------------------------------------------
30
- # Request / response schemas
31
- # ---------------------------------------------------------------------------
32
-
33
-
34
- class ResetRequest(BaseModel):
35
- task_id: str = "feasibility_check"
36
-
37
-
38
- class StepResponse(BaseModel):
39
- observation: Observation
40
- reward: float
41
- done: bool
42
- info: dict[str, Any]
43
-
44
-
45
- class GradeRequest(BaseModel):
46
- action: Action
47
- ground_truth: dict[str, Any]
48
-
49
-
50
- class GradeResponse(BaseModel):
51
- score: float
52
-
53
-
54
- # ---------------------------------------------------------------------------
55
- # Endpoints
56
- # ---------------------------------------------------------------------------
57
-
58
-
59
- @app.get("/health")
60
- def health() -> dict[str, str]:
61
- """Health check for Hugging Face Spaces liveness probe."""
62
- return {"status": "ok"}
63
-
64
-
65
- @app.post("/reset", response_model=Observation)
66
- def reset(req: ResetRequest) -> Observation:
67
- """Reset the environment and start a new episode.
68
-
69
- Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
70
- """
71
- valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
72
- if req.task_id not in valid_tasks:
73
- raise HTTPException(
74
- status_code=400,
75
- detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
76
- )
77
- return env.reset(task_id=req.task_id)
78
-
79
-
80
- @app.post("/step", response_model=StepResponse)
81
- def step(action: Action) -> StepResponse:
82
- """Submit an action and advance the environment by one step.
83
-
84
- Body: {"response": "<answer>", "task_id": "<task_id>"}
85
- """
86
- obs, reward, done, info = env.step(action)
87
- return StepResponse(observation=obs, reward=reward, done=done, info=info)
88
-
89
-
90
- @app.get("/state")
91
- def state() -> dict[str, Any]:
92
- """Return the full current environment state."""
93
- return env.state()
94
-
95
-
96
- @app.get("/tasks")
97
- def tasks() -> list[dict[str, Any]]:
98
- """List available tasks with their action schemas."""
99
- return [
100
- {
101
- "task_id": "feasibility_check",
102
- "name": "Feasibility Check",
103
- "difficulty": "easy",
104
- "max_steps": 3,
105
- "action_schema": {
106
- "response": "feasible | infeasible",
107
- "task_id": "feasibility_check",
108
- },
109
- },
110
- {
111
- "task_id": "conflict_classification",
112
- "name": "Conflict Classification",
113
- "difficulty": "medium",
114
- "max_steps": 5,
115
- "action_schema": {
116
- "response": (
117
- "resource_overload | deadline_violation | precedence_violation | "
118
- "availability_conflict | capacity_exceeded"
119
- ),
120
- "task_id": "conflict_classification",
121
- },
122
- },
123
- {
124
- "task_id": "schedule_repair",
125
- "name": "Schedule Repair",
126
- "difficulty": "hard",
127
- "max_steps": 8,
128
- "action_schema": {
129
- "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
130
- "task_id": "schedule_repair",
131
- },
132
- },
133
- ]
134
-
135
-
136
- @app.post("/grader", response_model=GradeResponse)
137
- def grader(req: GradeRequest) -> GradeResponse:
138
- """Directly invoke a grader with an action and ground truth.
139
-
140
- Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
141
- """
142
- task_id = req.action.task_id
143
- grader_map = {
144
- "feasibility_check": FeasibilityGrader(),
145
- "conflict_classification": ConflictGrader(),
146
- "schedule_repair": RepairGrader(),
147
- }
148
- g = grader_map.get(task_id)
149
- if g is None:
150
- raise HTTPException(
151
- status_code=400, detail=f"No grader for task_id={task_id}"
152
- )
153
- score = g.grade(req.action, req.ground_truth)
154
- return GradeResponse(score=max(0.0, min(1.0, score)))
155
-
156
-
157
- @app.get("/baseline")
158
- def baseline() -> dict[str, Any]:
159
- """Trigger the baseline inference agent and return per-task scores.
160
-
161
- Falls back to mock oracle responses when OPENAI_API_KEY is not set,
162
- so this endpoint always returns a valid result.
163
- """
164
- try:
165
- from baseline import run_baseline
166
- return run_baseline()
167
- except Exception as exc:
168
- raise HTTPException(
169
- status_code=500,
170
- detail=f"Baseline run failed: {exc}",
171
- ) from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server exposing the Payment Credit Environment as an HTTP API."""
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Optional, Literal
7
+ from datetime import date, datetime, timedelta
8
+ import random
9
+ import uuid
10
+
11
+ app = FastAPI(
12
+ title="payment_credit_env",
13
+ version="0.2.0",
14
+ description="OpenEnv-compatible payment credit decision environment for hackathon."
15
+ )
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ ActionType = Literal["approve_card_1", "approve_card_2", "route_to_debit", "deny_transaction", "request_more_info", "adjust_credit_limit", "offer_installments", "escalate_review"]
26
+ ALL_ACTIONS = ["approve_card_1", "approve_card_2", "route_to_debit", "deny_transaction", "request_more_info", "adjust_credit_limit", "offer_installments", "escalate_review"]
27
+
28
+ class TransactionState(BaseModel):
29
+ transaction_id: str
30
+ amount: float
31
+ credit_score: int
32
+ available_credit: float
33
+ monthly_spend: float
34
+ debt_to_income: float
35
+ payment_history: float
36
+ credit_utilization: float
37
+ last_payment_date: str
38
+ account_age_months: int
39
+
40
+
41
+ class StepRequest(BaseModel):
42
+ action: ActionType
43
+
44
+
45
+ class LeaderboardEntry(BaseModel):
46
+ action: str
47
+ reward: float
48
+
49
+
50
+ class PolicyCheck(BaseModel):
51
+ name: str
52
+ passed: bool
53
+ detail: str
54
+
55
+
56
+ class StepResponse(BaseModel):
57
+ transaction_id: str
58
+ action: str
59
+ reward: float
60
+ done: bool = False
61
+ state: TransactionState
62
+ risk_band: str
63
+ recommended_action: str
64
+ reasons: List[str]
65
+ leaderboard: List[LeaderboardEntry]
66
+ policy_checks: List[PolicyCheck]
67
+
68
+
69
+ class AuditLogEntry(BaseModel):
70
+ timestamp: str
71
+ transaction_id: str
72
+ action: str
73
+ reward: float
74
+ risk_band: str
75
+ recommended_action: str
76
+
77
+
78
+ class EnvStore:
79
+ def __init__(self):
80
+ self.current_state = None
81
+ self.audit_log = []
82
+
83
+ def reset(self):
84
+ self.current_state = self._sample_transaction()
85
+ self.audit_log = []
86
+ return self.current_state
87
+
88
+ def get_state(self):
89
+ if self.current_state is None:
90
+ self.current_state = self._sample_transaction()
91
+ return self.current_state
92
+
93
+ def _sample_transaction(self):
94
+ tid = f"TXN{random.randint(1000, 9999)}"
95
+ amt = round(random.uniform(500, 15000), 2)
96
+ cs = random.randint(580, 780)
97
+ avail = round(random.uniform(2000, 20000), 2)
98
+ spend = round(random.uniform(300, 3000), 2)
99
+ dti = round(random.uniform(0.15, 0.65), 2)
100
+ ph = round(random.uniform(0.75, 1.0), 2)
101
+ cu = round(random.uniform(0.2, 0.95), 2)
102
+ last_pay = (date.today() - timedelta(days=random.randint(0, 60))).isoformat()
103
+ age = random.randint(6, 120)
104
+ return TransactionState(
105
+ transaction_id=tid, amount=amt, credit_score=cs,
106
+ available_credit=avail, monthly_spend=spend, debt_to_income=dti,
107
+ payment_history=ph, credit_utilization=cu, last_payment_date=last_pay,
108
+ account_age_months=age
109
+ )
110
+
111
+
112
+ def get_risk_band(s: TransactionState) -> str:
113
+ score = 0
114
+ if s.credit_score >= 750: score += 2
115
+ elif s.credit_score >= 680: score += 1
116
+ if s.debt_to_income < 0.35: score += 1
117
+ elif s.debt_to_income > 0.5: score -= 1
118
+ if s.payment_history >= 0.95: score += 1
119
+ elif s.payment_history < 0.85: score -= 1
120
+ if s.credit_utilization < 0.3: score += 1
121
+ elif s.credit_utilization > 0.7: score -= 1
122
+ if s.account_age_months >= 36: score += 1
123
+ if score >= 3: return "low"
124
+ elif score >= 1: return "medium"
125
+ else: return "high"
126
+
127
+
128
+ def recommended_action_for_band(s: TransactionState, risk_band: str) -> str:
129
+ if risk_band == "low": return "approve_card_1"
130
+ elif risk_band == "medium": return "offer_installments"
131
+ else: return "escalate_review"
132
+
133
+
134
+ def build_reasons(s: TransactionState, risk_band: str) -> List[str]:
135
+ reasons = []
136
+ if s.credit_score >= 750: reasons.append(f"Excellent credit score ({s.credit_score})")
137
+ elif s.credit_score < 650: reasons.append(f"Low credit score ({s.credit_score})")
138
+ if s.debt_to_income < 0.35: reasons.append(f"Healthy debt-to-income ({s.debt_to_income:.1%})")
139
+ elif s.debt_to_income > 0.5: reasons.append(f"High debt-to-income ({s.debt_to_income:.1%})")
140
+ if s.payment_history >= 0.95: reasons.append(f"Strong payment history ({s.payment_history:.1%})")
141
+ elif s.payment_history < 0.85: reasons.append(f"Weak payment history ({s.payment_history:.1%})")
142
+ if s.credit_utilization < 0.3: reasons.append(f"Low credit utilization ({s.credit_utilization:.1%})")
143
+ elif s.credit_utilization > 0.7: reasons.append(f"High credit utilization ({s.credit_utilization:.1%})")
144
+ if not reasons: reasons.append("Transaction within normal parameters")
145
+ return reasons[:4]
146
+
147
+
148
+ def build_leaderboard(s: TransactionState) -> List[LeaderboardEntry]:
149
+ rewards = {a: round(random.uniform(-1.0, 1.0), 2) for a in ALL_ACTIONS}
150
+ sorted_actions = sorted(rewards.items(), key=lambda x: x[1], reverse=True)
151
+ return [LeaderboardEntry(action=a, reward=r) for a, r in sorted_actions[:5]]
152
+
153
+
154
+ def policy_checks(s: TransactionState, action: str) -> List[PolicyCheck]:
155
+ checks = [
156
+ PolicyCheck(name="Amount within limit", passed=s.amount <= s.available_credit, detail=f"Amount {s.amount} vs available {s.available_credit}"),
157
+ PolicyCheck(name="Credit score threshold", passed=s.credit_score >= 600, detail=f"Score {s.credit_score} >= 600"),
158
+ PolicyCheck(name="Recent payment activity", passed=(datetime.now() - datetime.fromisoformat(s.last_payment_date)).days < 45, detail=f"Last payment {s.last_payment_date}"),
159
+ PolicyCheck(name="Account age minimum", passed=s.account_age_months >= 3, detail=f"Account age {s.account_age_months} months")
160
+ ]
161
+ return checks
162
+
163
+
164
+ def score_action(s: TransactionState, action: str) -> float:
165
+ base = random.uniform(0.5, 1.0)
166
+ if action == "approve_card_1" and s.credit_score >= 700: base += 0.3
167
+ elif action == "escalate_review" and s.credit_score < 650: base += 0.3
168
+ elif action == "deny_transaction" and s.credit_score < 600: base += 0.3
169
+ return round(min(1.0, max(-1.0, base)), 2)
170
+
171
+
172
+ env = EnvStore()
173
+
174
+
175
+ @app.get("/")
176
+ def root():
177
+ return {"service": "payment_credit_env", "status": "running", "version": "0.2.0"}
178
+
179
+
180
+ @app.get("/health")
181
+ def health():
182
+ return {"status": "ok"}
183
+
184
+
185
+ @app.post("/reset")
186
+ def reset():
187
+ s = env.reset()
188
+ risk_band = get_risk_band(s)
189
+ rec = recommended_action_for_band(s, risk_band)
190
+ return {"state": s, "risk_band": risk_band, "recommended_action": rec}
191
+
192
+
193
+ @app.get("/state")
194
+ def state():
195
+ return env.get_state()
196
+
197
+
198
+ @app.post("/step", response_model=StepResponse)
199
+ def step(req: StepRequest):
200
+ s = env.get_state()
201
+ risk_band = get_risk_band(s)
202
+ rec = recommended_action_for_band(s, risk_band)
203
+ reasons = build_reasons(s, risk_band)
204
+ leaderboard = build_leaderboard(s)
205
+ checks = policy_checks(s, req.action)
206
+ reward = score_action(s, req.action)
207
+ env.audit_log.append(AuditLogEntry(
208
+ timestamp=datetime.now().isoformat(),
209
+ transaction_id=s.transaction_id,
210
+ action=req.action,
211
+ reward=reward,
212
+ risk_band=risk_band,
213
+ recommended_action=rec
214
+ ))
215
+ return StepResponse(
216
+ transaction_id=s.transaction_id,
217
+ action=req.action,
218
+ reward=reward,
219
+ done=True,
220
+ state=s,
221
+ risk_band=risk_band,
222
+ recommended_action=rec,
223
+ reasons=reasons,
224
+ leaderboard=leaderboard,
225
+ policy_checks=checks
226
+ )
227
+
228
+
229
+ @app.get("/audit-log")
230
+ def get_audit_log():
231
+ return env.audit_log