AdithyaSK HF Staff commited on
Commit
1a3a8ee
Β·
verified Β·
1 Parent(s): d81f3f0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. server/app.py +20 -11
  2. server/opencode_environment.py +75 -83
server/app.py CHANGED
@@ -75,20 +75,29 @@ app = create_app(
75
 
76
 
77
  def _find_active_environment(request):
78
- """Locate the currently-active OpenCodeEnvironment instance for a request.
79
 
80
- ``create_app`` keeps per-session envs behind the scenes; for the SSE
81
- endpoint we just grab the most recent one (single-worker Space), so
82
- we poke at ``app.state.env_cache`` and fall back to ``web_manager``.
 
 
83
  """
84
- cache = getattr(app.state, "env_cache", None)
85
- if cache:
86
- try:
87
- return next(iter(cache.values()))
88
- except StopIteration:
89
- pass
 
 
 
 
 
 
 
90
  try:
91
- return _web_manager.get_environment() # type: ignore[name-defined]
92
  except Exception:
93
  return None
94
 
 
75
 
76
 
77
  def _find_active_environment(request):
78
+ """Locate a currently-active OpenCodeEnvironment instance.
79
 
80
+ ``create_app`` stores per-session envs internally; we don't have a
81
+ public accessor, so we poke at ``app.state`` attributes that match
82
+ OpenEnv's conventions. As a last resort we create a fresh env β€”
83
+ fine for single-worker Spaces because registries live in-process
84
+ and the default env is idle until a tool is called.
85
  """
86
+ # Most recent "env" attribute on app.state that looks like ours.
87
+ for attr_name in ("env_cache", "envs", "environments", "_envs"):
88
+ cache = getattr(app.state, attr_name, None)
89
+ if cache:
90
+ try:
91
+ if isinstance(cache, dict):
92
+ return next(iter(cache.values()))
93
+ if isinstance(cache, (list, tuple)):
94
+ return cache[-1]
95
+ except Exception:
96
+ pass
97
+ # Fallback β€” make a new env. Safe because the SSE endpoint only
98
+ # needs the _registry dict, which we then look up rollout_id in.
99
  try:
100
+ return OpenCodeEnvironment()
101
  except Exception:
102
  return None
103
 
server/opencode_environment.py CHANGED
@@ -538,74 +538,7 @@ class OpenCodeEnvironment(MCPEnvironment):
538
 
539
  return result.model_dump_json()
540
 
541
-
542
- # ── Helpers ─────────────────────────────────────────────────────────────────
543
-
544
-
545
- def _qualify_model(provider: str, model: str) -> str:
546
- """Return a ``<provider>/<model>`` string the primitive can split cleanly.
547
-
548
- The primitive splits ``config.model`` on the first ``/`` to recover the
549
- upstream model id. If the caller passes a model that already contains a
550
- slash (e.g. ``Qwen/Qwen3.5-4B``), we still prepend the provider so the
551
- split separates provider from model and the model part round-trips
552
- intact (``openai_compatible/Qwen/Qwen3.5-4B`` β†’ upstream ``Qwen/Qwen3.5-4B``).
553
- """
554
- # Strip an existing <provider>/ prefix only if it matches the configured
555
- # provider verbatim β€” otherwise treat the whole string as the model id.
556
- if model.startswith(provider + "/"):
557
- return model
558
- return f"{provider}/{model}"
559
-
560
-
561
- def _read_reward(sandbox: Any, reward_path: str) -> Optional[float]:
562
- try:
563
- raw = sandbox.read_text(reward_path).strip()
564
- except Exception:
565
- return None
566
- if not raw:
567
- return None
568
- try:
569
- return float(raw)
570
- except ValueError:
571
- return None
572
-
573
-
574
- def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
575
- """Clamp per-turn payload sizes to keep responses under a reasonable cap."""
576
- out = dict(turn)
577
- raw_response = out.get("response") or {}
578
- choices = raw_response.get("choices") or []
579
- first_choice = choices[0] if choices else {}
580
- compact: dict[str, Any] = {
581
- "finish_reason": first_choice.get("finish_reason"),
582
- "usage": raw_response.get("usage"),
583
- }
584
- # Surface upstream errors captured by the proxy so they reach the client.
585
- if raw_response.get("upstream_error") is not None:
586
- compact["upstream_error"] = raw_response["upstream_error"]
587
- if raw_response.get("upstream_status") is not None:
588
- compact["upstream_status"] = raw_response["upstream_status"]
589
- out["response"] = compact
590
- req = out.get("request") or {}
591
- messages = req.get("messages") or []
592
- # Keep request messages (trainer needs them) but drop very long tool schemas.
593
- req = {
594
- "model": req.get("model"),
595
- "messages": messages,
596
- "temperature": req.get("temperature"),
597
- "top_p": req.get("top_p"),
598
- "max_tokens": req.get("max_tokens"),
599
- "max_completion_tokens": req.get("max_completion_tokens"),
600
- "logprobs": req.get("logprobs"),
601
- "top_logprobs": req.get("top_logprobs"),
602
- "stream": req.get("stream"),
603
- }
604
- out["request"] = req
605
- return out
606
-
607
-
608
- # ── Async rollout plumbing ─────────────────────────────────────────────
609
 
610
  def _spawn_async_rollout(
611
  self,
@@ -627,7 +560,6 @@ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
627
  ) -> _RolloutHandle:
628
  from opencode_env import OpenCodeTask
629
 
630
- # Build the task payload up-front; staging happens on the worker.
631
  merged_uploads = dict(upload_files)
632
  if test_script:
633
  merged_uploads[REMOTE_TEST_PATH] = test_script
@@ -652,11 +584,8 @@ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
652
  handle = _RolloutHandle(
653
  rollout_id=rollout_id,
654
  task_id=task_id,
655
- session_factory_kwargs={
656
- "config": config,
657
- "mode": mode,
658
- "agent_timeout_s": agent_timeout_s,
659
- },
660
  task=task,
661
  )
662
  handle._test_script = test_script
@@ -669,12 +598,9 @@ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
669
  sandbox_backend=self._E2BSandboxBackend(),
670
  mode=mode,
671
  verifier=None,
672
- driver="serve", # Phase 2b path
673
  )
674
  handle.session = factory.create(task=task)
675
- # Block until the agent idles. The caller can abort via
676
- # ``abort_rollout`` any time; that triggers the serve
677
- # ``/abort`` endpoint and ``wait_for_completion`` returns.
678
  try:
679
  handle.session.wait_for_completion(timeout_s=agent_timeout_s)
680
  except Exception as exc: # noqa: BLE001
@@ -691,17 +617,17 @@ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
691
  return handle
692
 
693
  def _finalize_handle(self, handle: _RolloutHandle) -> str:
694
- """Run the verifier (if test_script present), collect the trace, and
695
- return a JSON-serialized :class:`RolloutResult` matching the shape
696
- returned by ``run_rollout``. Closes the session + sandbox."""
697
- result = self._result_cls(task_id=handle.task_id, mode=handle._kwargs.get("mode", ""))
698
  session = handle.session
699
  if session is None:
700
  result.error = handle.error or "session never created"
701
  return result.model_dump_json()
702
 
703
  result.sandbox_id = session.sandbox.sandbox_id
704
- result.exit_code = 0 # serve-driver has no exit code; use 0 unless aborted
705
  wall_s = (handle.finished_at or time.time()) - handle.started_at
706
  result.wall_s = round(wall_s, 3)
707
 
@@ -749,6 +675,72 @@ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
749
  return result.model_dump_json()
750
 
751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  def _tail(events: list[dict[str, Any]], n: int) -> str:
753
  """Return the last ``n`` opencode event lines as a newline-joined string."""
754
  if not events:
 
538
 
539
  return result.model_dump_json()
540
 
541
+ # ── Async rollout plumbing (Phase 2b) ────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
  def _spawn_async_rollout(
544
  self,
 
560
  ) -> _RolloutHandle:
561
  from opencode_env import OpenCodeTask
562
 
 
563
  merged_uploads = dict(upload_files)
564
  if test_script:
565
  merged_uploads[REMOTE_TEST_PATH] = test_script
 
584
  handle = _RolloutHandle(
585
  rollout_id=rollout_id,
586
  task_id=task_id,
587
+ session_factory_kwargs={"config": config, "mode": mode,
588
+ "agent_timeout_s": agent_timeout_s},
 
 
 
589
  task=task,
590
  )
591
  handle._test_script = test_script
 
598
  sandbox_backend=self._E2BSandboxBackend(),
599
  mode=mode,
600
  verifier=None,
601
+ driver="serve",
602
  )
603
  handle.session = factory.create(task=task)
 
 
 
604
  try:
605
  handle.session.wait_for_completion(timeout_s=agent_timeout_s)
606
  except Exception as exc: # noqa: BLE001
 
617
  return handle
618
 
619
  def _finalize_handle(self, handle: _RolloutHandle) -> str:
620
+ """Run the verifier (if present), collect the trace + workdir, and
621
+ return a JSON-serialized :class:`RolloutResult`. Closes the session."""
622
+ result = self._result_cls(task_id=handle.task_id,
623
+ mode=handle._kwargs.get("mode", ""))
624
  session = handle.session
625
  if session is None:
626
  result.error = handle.error or "session never created"
627
  return result.model_dump_json()
628
 
629
  result.sandbox_id = session.sandbox.sandbox_id
630
+ result.exit_code = 0
631
  wall_s = (handle.finished_at or time.time()) - handle.started_at
632
  result.wall_s = round(wall_s, 3)
633
 
 
675
  return result.model_dump_json()
676
 
677
 
678
+ # ── Helpers ─────────────────────────────────────────────────────────────────
679
+
680
+
681
+ def _qualify_model(provider: str, model: str) -> str:
682
+ """Return a ``<provider>/<model>`` string the primitive can split cleanly.
683
+
684
+ The primitive splits ``config.model`` on the first ``/`` to recover the
685
+ upstream model id. If the caller passes a model that already contains a
686
+ slash (e.g. ``Qwen/Qwen3.5-4B``), we still prepend the provider so the
687
+ split separates provider from model and the model part round-trips
688
+ intact (``openai_compatible/Qwen/Qwen3.5-4B`` β†’ upstream ``Qwen/Qwen3.5-4B``).
689
+ """
690
+ # Strip an existing <provider>/ prefix only if it matches the configured
691
+ # provider verbatim β€” otherwise treat the whole string as the model id.
692
+ if model.startswith(provider + "/"):
693
+ return model
694
+ return f"{provider}/{model}"
695
+
696
+
697
+ def _read_reward(sandbox: Any, reward_path: str) -> Optional[float]:
698
+ try:
699
+ raw = sandbox.read_text(reward_path).strip()
700
+ except Exception:
701
+ return None
702
+ if not raw:
703
+ return None
704
+ try:
705
+ return float(raw)
706
+ except ValueError:
707
+ return None
708
+
709
+
710
+ def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
711
+ """Clamp per-turn payload sizes to keep responses under a reasonable cap."""
712
+ out = dict(turn)
713
+ raw_response = out.get("response") or {}
714
+ choices = raw_response.get("choices") or []
715
+ first_choice = choices[0] if choices else {}
716
+ compact: dict[str, Any] = {
717
+ "finish_reason": first_choice.get("finish_reason"),
718
+ "usage": raw_response.get("usage"),
719
+ }
720
+ # Surface upstream errors captured by the proxy so they reach the client.
721
+ if raw_response.get("upstream_error") is not None:
722
+ compact["upstream_error"] = raw_response["upstream_error"]
723
+ if raw_response.get("upstream_status") is not None:
724
+ compact["upstream_status"] = raw_response["upstream_status"]
725
+ out["response"] = compact
726
+ req = out.get("request") or {}
727
+ messages = req.get("messages") or []
728
+ # Keep request messages (trainer needs them) but drop very long tool schemas.
729
+ req = {
730
+ "model": req.get("model"),
731
+ "messages": messages,
732
+ "temperature": req.get("temperature"),
733
+ "top_p": req.get("top_p"),
734
+ "max_tokens": req.get("max_tokens"),
735
+ "max_completion_tokens": req.get("max_completion_tokens"),
736
+ "logprobs": req.get("logprobs"),
737
+ "top_logprobs": req.get("top_logprobs"),
738
+ "stream": req.get("stream"),
739
+ }
740
+ out["request"] = req
741
+ return out
742
+
743
+
744
  def _tail(events: list[dict[str, Any]], n: int) -> str:
745
  """Return the last ``n`` opencode event lines as a newline-joined string."""
746
  if not events: