srinjoyd commited on
Commit
e205960
Β·
1 Parent(s): 7f2deb2
Files changed (1) hide show
  1. training/grpo_train.py +50 -12
training/grpo_train.py CHANGED
@@ -26,11 +26,17 @@ Usage (local, from repo root):
26
  HF Jobs (`hf jobs uv run`) often uploads **only the script file** to `/`,
27
  so `server/` is missing β†’ `ModuleNotFoundError: server`. Fix by either:
28
 
29
- A) Clone the full repo inside the job, then run from that directory, e.g.:
30
- hf jobs uv run --flavor h200 ... -- bash -lc '
 
 
 
 
 
 
31
  git clone https://github.com/<you>/scaler-hackathon.git /tmp/repo &&
32
- cd /tmp/repo &&
33
- python training/grpo_train.py --model ... ...
34
  '
35
 
36
  B) Set an explicit root (if your job packs the tree elsewhere):
@@ -103,14 +109,46 @@ def _find_repo_root() -> Path:
103
  )
104
 
105
  _REPO = _find_repo_root()
106
- if str(_REPO) not in sys.path:
107
- sys.path.insert(0, str(_REPO))
108
-
109
- from server.incident_environment import IncidentEnvironment # noqa: E402
110
- from tasks import compute_r_cross # noqa: E402
111
- from pools import POOLS, sample_task # noqa: E402
112
- from training.curriculum import CurriculumConfig, CurriculumRunner # noqa: E402
113
- from training.segment_grpo import GRPOGroup, Segment, grpo_advantages # noqa: E402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  # ──────────────────────────────────────────────────────────────────────
 
26
  HF Jobs (`hf jobs uv run`) often uploads **only the script file** to `/`,
27
  so `server/` is missing β†’ `ModuleNotFoundError: server`. Fix by either:
28
 
29
+ A) Clone the full repo inside the job, then run from that directory.
30
+
31
+ **Trap:** ``bash -lc`` often resets ``PATH`` so ``python`` is *not* the
32
+ same interpreter that ``hf jobs uv run`` installed ``torch`` into β†’
33
+ ``ModuleNotFoundError: torch``. Prefer ``bash -ec`` (no login) **or**
34
+ nest ``uv run`` after ``cd`` so deps apply to the training process:
35
+
36
+ hf jobs uv run --flavor h200 ... -- bash -ec '
37
  git clone https://github.com/<you>/scaler-hackathon.git /tmp/repo &&
38
+ cd /tmp/repo && git checkout <branch> &&
39
+ uv run --no-project --with torch --with transformers --with accelerate --with peft --with bitsandbytes --with tqdm --with fastapi --with uvicorn --with pydantic python training/grpo_train.py --model ... ...
40
  '
41
 
42
  B) Set an explicit root (if your job packs the tree elsewhere):
 
109
  )
110
 
111
  _REPO = _find_repo_root()
112
+
113
+
114
+ def _register_incident_env_pkg(repo: Path) -> None:
115
+ """
116
+ The repo is laid out as a *single* installable tree (``models.py``,
117
+ ``server/``, ``scenarios/``, …) but **without** a physical ``incident_env/``
118
+ directory. Subpackages use relative imports (e.g. ``from ..models`` in
119
+ ``server/``), so they must be loaded as ``incident_env.server``, not as a
120
+ bare top-level ``server`` (which breaks ``..``).
121
+
122
+ Register a synthetic parent package ``incident_env`` whose ``__path__`` is
123
+ the repository root. Then import only ``incident_env.*`` below.
124
+ """
125
+ import importlib.machinery
126
+ import types
127
+
128
+ root = repo.resolve()
129
+ root_s = str(root)
130
+ name = "incident_env"
131
+
132
+ existing = sys.modules.get(name)
133
+ if existing is not None and getattr(existing, "__path__", None):
134
+ return
135
+
136
+ pkg = types.ModuleType(name)
137
+ pkg.__path__ = [root_s]
138
+ spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
139
+ spec.submodule_search_locations = [root_s]
140
+ pkg.__spec__ = spec
141
+ pkg.__package__ = name
142
+ sys.modules[name] = pkg
143
+
144
+
145
+ _register_incident_env_pkg(_REPO)
146
+
147
+ from incident_env.server.incident_environment import IncidentEnvironment # noqa: E402
148
+ from incident_env.tasks import compute_r_cross # noqa: E402
149
+ from incident_env.pools import POOLS, sample_task # noqa: E402
150
+ from incident_env.training.curriculum import CurriculumConfig, CurriculumRunner # noqa: E402
151
+ from incident_env.training.segment_grpo import GRPOGroup, Segment, grpo_advantages # noqa: E402
152
 
153
 
154
  # ──────────────────────────────────────────────────────────────────────