ATISHAY005 commited on
Commit
6bad4aa
·
1 Parent(s): cfa64d1

Fix server entrypoint for OpenEnv validation

Browse files
Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. server/app.py +15 -1
pyproject.toml CHANGED
@@ -13,7 +13,7 @@ dependencies = [
13
  ]
14
 
15
  [project.scripts]
16
- server = "server.app:app"
17
 
18
  [build-system]
19
  requires = ["setuptools"]
 
13
  ]
14
 
15
  [project.scripts]
16
+ server = "server.app:main"
17
 
18
  [build-system]
19
  requires = ["setuptools"]
server/app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  import json
 
3
 
4
  from env.feed_env import FeedRankingEnv
5
  from agents.random_agent import RandomAgent
@@ -12,15 +13,18 @@ with open("data/posts.json", "r") as f:
12
  env = FeedRankingEnv(posts, task="hard")
13
  agent = RandomAgent()
14
 
 
15
  @app.get("/")
16
  def root():
17
  return {"message": "OpenEnv server running"}
18
 
 
19
  @app.post("/reset")
20
  def reset():
21
  state = env.reset()
22
  return {"state": state.__dict__}
23
 
 
24
  @app.post("/step")
25
  def step():
26
  action = agent.act(env.state(), env.posts)
@@ -30,4 +34,14 @@ def step():
30
  "state": state.__dict__,
31
  "reward": reward,
32
  "done": done
33
- }
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  import json
3
+ import uvicorn
4
 
5
  from env.feed_env import FeedRankingEnv
6
  from agents.random_agent import RandomAgent
 
13
  env = FeedRankingEnv(posts, task="hard")
14
  agent = RandomAgent()
15
 
16
+
17
  @app.get("/")
18
  def root():
19
  return {"message": "OpenEnv server running"}
20
 
21
+
22
  @app.post("/reset")
23
  def reset():
24
  state = env.reset()
25
  return {"state": state.__dict__}
26
 
27
+
28
  @app.post("/step")
29
  def step():
30
  action = agent.act(env.state(), env.posts)
 
34
  "state": state.__dict__,
35
  "reward": reward,
36
  "done": done
37
+ }
38
+
39
+
40
+ # 🔥 REQUIRED MAIN FUNCTION
41
+ def main():
42
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
43
+
44
+
45
+ # 🔥 REQUIRED ENTRY POINT
46
+ if __name__ == "__main__":
47
+ main()