Upload server/app.py with huggingface_hub
Browse files- server/app.py +13 -6
server/app.py
CHANGED
|
@@ -8,7 +8,7 @@ from uuid import uuid4
|
|
| 8 |
import gradio as gr
|
| 9 |
from fastapi import FastAPI, HTTPException
|
| 10 |
from fastapi.responses import RedirectResponse
|
| 11 |
-
from pydantic import BaseModel, Field
|
| 12 |
import uvicorn
|
| 13 |
|
| 14 |
from meta_ads_env import MetaAdsAttributionEnv
|
|
@@ -134,10 +134,15 @@ with gr.Blocks(title="Meta Ads RL Playground") as demo:
|
|
| 134 |
app = gr.mount_gradio_app(app, demo, path="/web")
|
| 135 |
|
| 136 |
|
| 137 |
-
|
| 138 |
class ResetRequest(BaseModel):
|
| 139 |
-
task_id: str =
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
class StepRequest(BaseModel):
|
|
@@ -187,7 +192,9 @@ def tasks() -> dict:
|
|
| 187 |
|
| 188 |
|
| 189 |
@app.post("/reset")
|
| 190 |
-
def reset_episode(req: ResetRequest) -> dict:
|
|
|
|
|
|
|
| 191 |
if req.task_id not in TASK_REGISTRY:
|
| 192 |
raise HTTPException(
|
| 193 |
status_code=400,
|
|
@@ -275,4 +282,4 @@ def main() -> None:
|
|
| 275 |
|
| 276 |
|
| 277 |
if __name__ == "__main__":
|
| 278 |
-
main()
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
from fastapi import FastAPI, HTTPException
|
| 10 |
from fastapi.responses import RedirectResponse
|
| 11 |
+
from pydantic import AliasChoices, BaseModel, Field
|
| 12 |
import uvicorn
|
| 13 |
|
| 14 |
from meta_ads_env import MetaAdsAttributionEnv
|
|
|
|
| 134 |
app = gr.mount_gradio_app(app, demo, path="/web")
|
| 135 |
|
| 136 |
|
|
|
|
| 137 |
class ResetRequest(BaseModel):
|
| 138 |
+
task_id: str = Field(
|
| 139 |
+
default="easy_attribution_window",
|
| 140 |
+
validation_alias=AliasChoices("task_id", "task"),
|
| 141 |
+
)
|
| 142 |
+
session_id: str | None = Field(
|
| 143 |
+
default=None,
|
| 144 |
+
validation_alias=AliasChoices("session_id", "session"),
|
| 145 |
+
)
|
| 146 |
|
| 147 |
|
| 148 |
class StepRequest(BaseModel):
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
@app.post("/reset")
|
| 195 |
+
def reset_episode(req: ResetRequest | None = None) -> dict:
|
| 196 |
+
req = req or ResetRequest()
|
| 197 |
+
|
| 198 |
if req.task_id not in TASK_REGISTRY:
|
| 199 |
raise HTTPException(
|
| 200 |
status_code=400,
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
if __name__ == "__main__":
|
| 285 |
+
main()
|