TheAarvee05 commited on
Commit
7a6856d
·
verified ·
1 Parent(s): 0900b91

Upload server/app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +13 -6
server/app.py CHANGED
@@ -8,7 +8,7 @@ from uuid import uuid4
8
  import gradio as gr
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import RedirectResponse
11
- from pydantic import BaseModel, Field
12
  import uvicorn
13
 
14
  from meta_ads_env import MetaAdsAttributionEnv
@@ -134,10 +134,15 @@ with gr.Blocks(title="Meta Ads RL Playground") as demo:
134
  app = gr.mount_gradio_app(app, demo, path="/web")
135
 
136
 
137
-
138
  class ResetRequest(BaseModel):
139
- task_id: str = "easy_attribution_window"
140
- session_id: str | None = None
 
 
 
 
 
 
141
 
142
 
143
  class StepRequest(BaseModel):
@@ -187,7 +192,9 @@ def tasks() -> dict:
187
 
188
 
189
  @app.post("/reset")
190
- def reset_episode(req: ResetRequest) -> dict:
 
 
191
  if req.task_id not in TASK_REGISTRY:
192
  raise HTTPException(
193
  status_code=400,
@@ -275,4 +282,4 @@ def main() -> None:
275
 
276
 
277
  if __name__ == "__main__":
278
- main()
 
8
  import gradio as gr
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import RedirectResponse
11
+ from pydantic import AliasChoices, BaseModel, Field
12
  import uvicorn
13
 
14
  from meta_ads_env import MetaAdsAttributionEnv
 
134
  app = gr.mount_gradio_app(app, demo, path="/web")
135
 
136
 
 
137
  class ResetRequest(BaseModel):
138
+ task_id: str = Field(
139
+ default="easy_attribution_window",
140
+ validation_alias=AliasChoices("task_id", "task"),
141
+ )
142
+ session_id: str | None = Field(
143
+ default=None,
144
+ validation_alias=AliasChoices("session_id", "session"),
145
+ )
146
 
147
 
148
  class StepRequest(BaseModel):
 
192
 
193
 
194
  @app.post("/reset")
195
+ def reset_episode(req: ResetRequest | None = None) -> dict:
196
+ req = req or ResetRequest()
197
+
198
  if req.task_id not in TASK_REGISTRY:
199
  raise HTTPException(
200
  status_code=400,
 
282
 
283
 
284
  if __name__ == "__main__":
285
+ main()