| """HF Sandbox 入口(docker SDK,监听 7860)。
|
|
|
| 启动后:
|
| 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)
|
| 2. 主进程开 HTTP server on :7860,返回最新日志
|
|
|
| 阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。
|
| 阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。
|
| """
|
| import os
|
| import subprocess
|
| import sys
|
| import threading
|
| from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
| LOG_PATH = "/tmp/wjad.log"
|
| PORT = 7860
|
|
|
| _MODE = os.environ.get("SANDBOX_MODE", "smoke")
|
| if _MODE == "real_data":
|
| LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
|
| else:
|
| LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]
|
|
|
|
|
| def _print_env(f):
|
| f.write("=" * 72 + "\n")
|
| f.write(" WJAD HF Sandbox\n")
|
| f.write("=" * 72 + "\n")
|
| f.write(f"Python: {sys.version}\n")
|
| try:
|
| import torch
|
| f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\n")
|
| if torch.cuda.is_available():
|
| p = torch.cuda.get_device_properties(0)
|
| f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\n")
|
| except Exception as e:
|
| f.write(f"torch import failed: {e}\n")
|
| f.flush()
|
|
|
|
|
| def run_training():
|
| with open(LOG_PATH, "w", buffering=1) as f:
|
| _print_env(f)
|
| f.write(f"$ {' '.join(LAUNCH_CMD)}\n")
|
| f.flush()
|
| p = subprocess.Popen(
|
| LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
|
| )
|
| rc = p.wait()
|
| f.write(f"\n[exit code = {rc}]\n")
|
|
|
|
|
| class Handler(BaseHTTPRequestHandler):
|
| def do_GET(self):
|
| try:
|
| with open(LOG_PATH, "r") as f:
|
| body = f.read()
|
| except FileNotFoundError:
|
| body = "starting..."
|
| self.send_response(200)
|
| self.send_header("Content-Type", "text/plain; charset=utf-8")
|
| self.end_headers()
|
| self.wfile.write(body.encode("utf-8"))
|
|
|
| def log_message(self, fmt, *args):
|
| return
|
|
|
|
|
| if __name__ == "__main__":
|
| threading.Thread(target=run_training, daemon=True).start()
|
| HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()
|
|
|