AlazarM commited on
Commit
af30aee
·
verified ·
1 Parent(s): 77e4330

Deploy training space template for oversight

Browse files
Files changed (4) hide show
  1. Dockerfile +19 -0
  2. README.md +8 -5
  3. __pycache__/run_space.cpython-313.pyc +0 -0
  4. run_space.py +223 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ghcr.io/astral-sh/uv:python3.12-bookworm
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ TRL_EXPERIMENTAL_SILENCE=1 \
6
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
7
+ PORT=7860
8
+
9
+ RUN apt-get update \
10
+ && apt-get install -y --no-install-recommends git curl ca-certificates \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ WORKDIR /app
14
+
15
+ COPY run_space.py /app/run_space.py
16
+
17
+ EXPOSE 7860
18
+
19
+ CMD ["python", "/app/run_space.py"]
README.md CHANGED
@@ -1,10 +1,13 @@
1
  ---
2
- title: Trenches Train Oversight
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: Trenches Entity Training
3
+ emoji: ⚙️
4
+ colorFrom: red
5
+ colorTo: gray
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # Trenches Entity Training
12
+
13
+ This Space runs one entity-specific Trenches training job and serves a simple status page on port `7860`.
__pycache__/run_space.cpython-313.pyc ADDED
Binary file (12 kB). View file
 
run_space.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import html
5
+ import os
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
13
+ from pathlib import Path
14
+
15
+
16
+ PORT = int(os.environ.get("PORT", "7860"))
17
+ LOG_PATH = Path("/tmp/trenches-space.log")
18
+ STATUS = {
19
+ "state": "starting",
20
+ "summary": "Initializing training space",
21
+ }
22
+ LOCK = threading.Lock()
23
+
24
+
25
+ def set_status(state: str, summary: str) -> None:
26
+ with LOCK:
27
+ STATUS["state"] = state
28
+ STATUS["summary"] = summary
29
+
30
+
31
+ def append_log(line: str) -> None:
32
+ with LOG_PATH.open("a", encoding="utf-8") as fh:
33
+ fh.write(line)
34
+ if not line.endswith("\n"):
35
+ fh.write("\n")
36
+
37
+
38
+ def run_and_stream(command: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None) -> None:
39
+ append_log(f"$ {' '.join(command)}")
40
+ process = subprocess.Popen(
41
+ command,
42
+ cwd=str(cwd) if cwd is not None else None,
43
+ env=env,
44
+ stdout=subprocess.PIPE,
45
+ stderr=subprocess.STDOUT,
46
+ text=True,
47
+ bufsize=1,
48
+ )
49
+ assert process.stdout is not None
50
+ for line in process.stdout:
51
+ sys.stdout.write(line)
52
+ sys.stdout.flush()
53
+ append_log(line.rstrip("\n"))
54
+ return_code = process.wait()
55
+ if return_code != 0:
56
+ raise subprocess.CalledProcessError(return_code, command)
57
+
58
+
59
+ def upload_output(output_dir: Path) -> None:
60
+ from huggingface_hub import HfApi
61
+
62
+ token = os.environ["HF_TOKEN"]
63
+ model_repo = os.environ["MODEL_REPO"]
64
+ api = HfApi(token=token)
65
+ api.upload_folder(
66
+ repo_id=model_repo,
67
+ repo_type="model",
68
+ folder_path=str(output_dir),
69
+ commit_message=os.environ.get("UPLOAD_MESSAGE", "Upload Trenches checkpoint"),
70
+ )
71
+
72
+
73
+ def train() -> None:
74
+ entity = os.environ["ENTITY"]
75
+ replay_id = os.environ["REPLAY_ID"]
76
+ model_id = os.environ.get("MODEL_ID", "Qwen/Qwen3-8B")
77
+ git_repo_url = os.environ.get("GIT_REPO_URL", "https://github.com/shlawgathon/trenches.git")
78
+ git_ref = os.environ.get("GIT_REF", "main")
79
+ generation_backend = os.environ.get("GENERATION_BACKEND", "vllm")
80
+
81
+ set_status("running", f"Preparing repo for {entity}")
82
+ workroot = Path(tempfile.mkdtemp(prefix="trenches-space-"))
83
+ repo_dir = workroot / "trenches"
84
+ output_dir = workroot / "output"
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ try:
88
+ run_and_stream(["git", "clone", "--depth", "1", git_repo_url, str(repo_dir)])
89
+ if git_ref != "main":
90
+ run_and_stream(["git", "fetch", "--depth", "1", "origin", git_ref], cwd=repo_dir)
91
+ run_and_stream(["git", "checkout", "-q", "FETCH_HEAD"], cwd=repo_dir)
92
+
93
+ python_bin = workroot / ".venv" / "bin" / "python"
94
+ set_status("running", f"Installing training stack for {entity}")
95
+ run_and_stream(["uv", "venv", str(workroot / ".venv"), "--python", "3.12"])
96
+ run_and_stream(
97
+ ["uv", "pip", "install", "--python", str(python_bin), "-e", "backend[train]", "huggingface_hub"],
98
+ cwd=repo_dir,
99
+ )
100
+ run_and_stream(
101
+ [
102
+ "uv",
103
+ "pip",
104
+ "install",
105
+ "--python",
106
+ str(python_bin),
107
+ "trl==0.29.0",
108
+ "vllm",
109
+ ],
110
+ cwd=repo_dir,
111
+ )
112
+
113
+ env = dict(os.environ)
114
+ env["TRL_EXPERIMENTAL_SILENCE"] = "1"
115
+ train_cmd = [
116
+ str(python_bin),
117
+ "-m",
118
+ "trenches_env.training_cli",
119
+ "--model-id",
120
+ model_id,
121
+ "--generation-backend",
122
+ generation_backend,
123
+ "--training-agent",
124
+ entity,
125
+ "--training-stage",
126
+ os.environ.get("TRAINING_STAGE", "stage_1_dense"),
127
+ "--replay-id",
128
+ replay_id,
129
+ "--train-size",
130
+ os.environ.get("TRAIN_SIZE", "4"),
131
+ "--max-steps",
132
+ os.environ.get("MAX_STEPS", "1"),
133
+ "--num-generations",
134
+ os.environ.get("NUM_GENERATIONS", "4"),
135
+ "--per-device-train-batch-size",
136
+ os.environ.get("PER_DEVICE_TRAIN_BATCH_SIZE", "1"),
137
+ "--gradient-accumulation-steps",
138
+ os.environ.get("GRADIENT_ACCUMULATION_STEPS", "1"),
139
+ "--learning-rate",
140
+ os.environ.get("LEARNING_RATE", "5e-6"),
141
+ "--beta",
142
+ os.environ.get("BETA", "0.001"),
143
+ "--warmup-steps",
144
+ os.environ.get("WARMUP_STEPS", "0"),
145
+ "--temperature",
146
+ os.environ.get("TEMPERATURE", "0.8"),
147
+ "--top-k",
148
+ os.environ.get("TOP_K", "10"),
149
+ "--top-p",
150
+ os.environ.get("TOP_P", "0.95"),
151
+ "--max-prompt-length",
152
+ os.environ.get("MAX_PROMPT_LENGTH", "1024"),
153
+ "--max-completion-length",
154
+ os.environ.get("MAX_COMPLETION_LENGTH", "128"),
155
+ "--save-strategy",
156
+ os.environ.get("SAVE_STRATEGY", "no"),
157
+ "--output-dir",
158
+ str(output_dir),
159
+ "--no-preview",
160
+ ]
161
+
162
+ if os.environ.get("QUANTIZE_4BIT", "").lower() in {"1", "true", "yes"}:
163
+ train_cmd.append("--quantize-4bit")
164
+
165
+ set_status("running", f"Training {entity}")
166
+ run_and_stream(train_cmd, cwd=repo_dir, env=env)
167
+
168
+ set_status("running", f"Uploading checkpoint for {entity}")
169
+ upload_output(output_dir)
170
+ set_status("completed", f"Completed training and upload for {entity}")
171
+ except Exception as exc:
172
+ set_status("failed", f"{type(exc).__name__}: {exc}")
173
+ append_log(f"FAILED: {type(exc).__name__}: {exc}")
174
+ raise
175
+ finally:
176
+ if os.environ.get("KEEP_WORKROOT", "").lower() not in {"1", "true", "yes"}:
177
+ shutil.rmtree(workroot, ignore_errors=True)
178
+
179
+
180
+ class Handler(BaseHTTPRequestHandler):
181
+ def do_GET(self) -> None: # noqa: N802
182
+ with LOCK:
183
+ state = STATUS["state"]
184
+ summary = STATUS["summary"]
185
+ log_text = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else ""
186
+ body = f"""<!doctype html>
187
+ <html>
188
+ <head>
189
+ <meta charset="utf-8">
190
+ <title>Trenches Training Space</title>
191
+ <style>
192
+ body {{ background: #111; color: #eee; font-family: monospace; padding: 24px; }}
193
+ .running {{ color: #ffd166; }}
194
+ .completed {{ color: #06d6a0; }}
195
+ .failed {{ color: #ef476f; }}
196
+ pre {{ white-space: pre-wrap; word-break: break-word; background: #181818; padding: 16px; border-radius: 8px; }}
197
+ </style>
198
+ </head>
199
+ <body>
200
+ <h1>Trenches Training Space</h1>
201
+ <p>Status: <span class="{html.escape(state)}">{html.escape(state)}</span></p>
202
+ <p>{html.escape(summary)}</p>
203
+ <pre>{html.escape(log_text[-30000:])}</pre>
204
+ </body>
205
+ </html>"""
206
+ payload = body.encode("utf-8")
207
+ self.send_response(200)
208
+ self.send_header("Content-Type", "text/html; charset=utf-8")
209
+ self.send_header("Content-Length", str(len(payload)))
210
+ self.end_headers()
211
+ self.wfile.write(payload)
212
+
213
+
214
+ def main() -> None:
215
+ LOG_PATH.write_text("", encoding="utf-8")
216
+ thread = threading.Thread(target=train, daemon=True)
217
+ thread.start()
218
+ server = ThreadingHTTPServer(("0.0.0.0", PORT), Handler)
219
+ server.serve_forever()
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()