Sagar Chapara commited on
Commit
f8a321a
Β·
1 Parent(s): d1221ff

Polish Space startup and root route

Browse files
Files changed (3) hide show
  1. Dockerfile +8 -8
  2. server/app.py +17 -0
  3. server/environment.py +13 -4
Dockerfile CHANGED
@@ -26,14 +26,14 @@ COPY inference.py .
26
 
27
  # Pre-download datasets at build time to avoid cold-start delays
28
  # (Falls back to hardcoded samples if download fails)
29
- RUN python -c "\
30
- try:\
31
- from datasets import load_dataset;\
32
- load_dataset('rajpurkar/squad', split='validation[:200]');\
33
- print('SQuAD cached.');\
34
- except Exception as e:\
35
- print(f'SQuAD cache skipped: {e}');\
36
- " || true
37
 
38
  # Hugging Face Spaces commonly uses 7860, while local OpenEnv docker providers
39
  # often inject PORT=8000. Support both.
 
26
 
27
  # Pre-download datasets at build time to avoid cold-start delays
28
  # (Falls back to hardcoded samples if download fails)
29
+ RUN python - <<'PY' || true
30
+ try:
31
+ from datasets import load_dataset
32
+ load_dataset("rajpurkar/squad", split="validation[:200]")
33
+ print("SQuAD cached.")
34
+ except Exception as e:
35
+ print(f"SQuAD cache skipped: {e}")
36
+ PY
37
 
38
  # Hugging Face Spaces commonly uses 7860, while local OpenEnv docker providers
39
  # often inject PORT=8000. Support both.
server/app.py CHANGED
@@ -26,6 +26,23 @@ app = create_fastapi_app(
26
  )
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def main() -> None:
30
  """Run the environment server for local validation and script entrypoints."""
31
  port = int(os.environ.get("PORT", "7860"))
 
26
  )
27
 
28
 
29
+ @app.get("/")
30
+ def root() -> dict:
31
+ """Friendly landing page for Spaces and browser visits."""
32
+ return {
33
+ "name": "Long-Context Summarization",
34
+ "status": "healthy",
35
+ "docs": {
36
+ "health": "/health",
37
+ "schema": "/schema",
38
+ "metadata": "/metadata",
39
+ "reset": "POST /reset",
40
+ "step": "POST /step",
41
+ "state": "GET /state",
42
+ },
43
+ }
44
+
45
+
46
  def main() -> None:
47
  """Run the environment server for local validation and script entrypoints."""
48
  port = int(os.environ.get("PORT", "7860"))
server/environment.py CHANGED
@@ -38,8 +38,8 @@ class SummarizationEnvironment(Environment):
38
  SUPPORTS_CONCURRENT_SESSIONS = False
39
 
40
  def __init__(self):
41
- logger.info("Initialising SummarizationEnvironment β€” loading datasets...")
42
- self._tasks = {name: get_task(name) for name in ("easy", "medium", "hard")}
43
  self._reset_episode_state()
44
  logger.info("Environment ready.")
45
 
@@ -61,6 +61,15 @@ class SummarizationEnvironment(Environment):
61
  # Hard task only: second chunk shown after first summary
62
  self._hard_chunk2: Optional[str] = None
63
 
 
 
 
 
 
 
 
 
 
64
  # ------------------------------------------------------------------
65
  # OpenEnv API
66
  # ------------------------------------------------------------------
@@ -93,7 +102,7 @@ class SummarizationEnvironment(Environment):
93
  self._episode_id = episode_id or f"ep_{random.randint(10000, 99999)}"
94
 
95
  rng_seed = seed
96
- task = self._tasks[task_name]
97
  sample = task.get_sample(seed=rng_seed)
98
 
99
  # Store episode data
@@ -129,7 +138,7 @@ class SummarizationEnvironment(Environment):
129
  # Append model response to conversation history
130
  self._messages.append({"role": "assistant", "content": response})
131
 
132
- task = self._tasks[self._task_name]
133
 
134
  # ── Summarize step ─────────────────────────────────────────────
135
  if self._step_type == "summarize":
 
38
  SUPPORTS_CONCURRENT_SESSIONS = False
39
 
40
  def __init__(self):
41
+ logger.info("Initialising SummarizationEnvironment...")
42
+ self._tasks: Dict[str, Any] = {}
43
  self._reset_episode_state()
44
  logger.info("Environment ready.")
45
 
 
61
  # Hard task only: second chunk shown after first summary
62
  self._hard_chunk2: Optional[str] = None
63
 
64
+ def _get_task(self, task_name: str):
65
+ """Lazily initialize tasks so app startup stays fast on Spaces."""
66
+ task = self._tasks.get(task_name)
67
+ if task is None:
68
+ logger.info("Loading task '%s'...", task_name)
69
+ task = get_task(task_name)
70
+ self._tasks[task_name] = task
71
+ return task
72
+
73
  # ------------------------------------------------------------------
74
  # OpenEnv API
75
  # ------------------------------------------------------------------
 
102
  self._episode_id = episode_id or f"ep_{random.randint(10000, 99999)}"
103
 
104
  rng_seed = seed
105
+ task = self._get_task(task_name)
106
  sample = task.get_sample(seed=rng_seed)
107
 
108
  # Store episode data
 
138
  # Append model response to conversation history
139
  self._messages.append({"role": "assistant", "content": response})
140
 
141
+ task = self._get_task(self._task_name)
142
 
143
  # ── Summarize step ─────────────────────────────────────────────
144
  if self._step_type == "summarize":