Spaces:
Paused
Paused
Commit Β·
812dbcc
1
Parent(s): b71b25c
perf: add encode+step benchmark to logs, validate CUDA speed
Browse files
app.py
CHANGED
|
@@ -164,9 +164,32 @@ def _training_thread():
|
|
| 164 |
)
|
| 165 |
obs, info = env.reset()
|
| 166 |
env.step(env.action_space.sample())
|
| 167 |
-
env.close()
|
| 168 |
_log(f"Smoke test OK β obs shape {obs.shape}")
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# ββ Training ββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
import torch, yaml
|
| 172 |
from sb3_contrib import RecurrentPPO
|
|
|
|
| 164 |
)
|
| 165 |
obs, info = env.reset()
|
| 166 |
env.step(env.action_space.sample())
|
|
|
|
| 167 |
_log(f"Smoke test OK β obs shape {obs.shape}")
|
| 168 |
|
| 169 |
+
# ββ Benchmark: encode speed and full step speed ββββββ
|
| 170 |
+
_log("Benchmarking SentenceTransformer encode speed...")
|
| 171 |
+
_N_enc = 50
|
| 172 |
+
_t0 = time.perf_counter()
|
| 173 |
+
for _ in range(_N_enc):
|
| 174 |
+
env.registry.embed_query("Software engineering task requiring specialist delegation")
|
| 175 |
+
_enc_ms = (time.perf_counter() - _t0) / _N_enc * 1000
|
| 176 |
+
_enc_device = "CUDA β fast" if _enc_ms < 50 else "CPU β slow, patch may have failed"
|
| 177 |
+
_log(f"Encode speed : {_enc_ms:.1f} ms/call [{_enc_device}]")
|
| 178 |
+
|
| 179 |
+
_log("Benchmarking full env.step() speed...")
|
| 180 |
+
_N_steps = 30
|
| 181 |
+
obs_b, _ = env.reset()
|
| 182 |
+
_t0 = time.perf_counter()
|
| 183 |
+
for _ in range(_N_steps):
|
| 184 |
+
obs_b, _, _d, _, _ = env.step(env.action_space.sample())
|
| 185 |
+
if _d:
|
| 186 |
+
obs_b, _ = env.reset()
|
| 187 |
+
_step_ms = (time.perf_counter() - _t0) / _N_steps * 1000
|
| 188 |
+
_step_ok = "fast β" if _step_ms < 100 else "slow β check logs"
|
| 189 |
+
_log(f"Step speed : {_step_ms:.1f} ms/step [{_step_ok}]")
|
| 190 |
+
_log(f"Projected 100k steps: {100_000 * _step_ms / 1000 / 60:.0f} min")
|
| 191 |
+
env.close()
|
| 192 |
+
|
| 193 |
# ββ Training ββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
import torch, yaml
|
| 195 |
from sb3_contrib import RecurrentPPO
|