"""HF Sandbox 入口(docker SDK,监听 7860)。 启动后: 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log) 2. 主进程开 HTTP server on :7860,返回最新日志 阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。 阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。 """ import os import subprocess import sys import threading from http.server import BaseHTTPRequestHandler, HTTPServer LOG_PATH = "/tmp/wjad.log" PORT = 7860 # 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。 _MODE = os.environ.get("SANDBOX_MODE", "smoke") if _MODE == "real_data": LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"] else: LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"] def _print_env(f): f.write("=" * 72 + "\n") f.write(" WJAD HF Sandbox\n") f.write("=" * 72 + "\n") f.write(f"Python: {sys.version}\n") try: import torch f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\n") if torch.cuda.is_available(): p = torch.cuda.get_device_properties(0) f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\n") except Exception as e: f.write(f"torch import failed: {e}\n") f.flush() def run_training(): with open(LOG_PATH, "w", buffering=1) as f: _print_env(f) f.write(f"$ {' '.join(LAUNCH_CMD)}\n") f.flush() p = subprocess.Popen( LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app" ) rc = p.wait() f.write(f"\n[exit code = {rc}]\n") class Handler(BaseHTTPRequestHandler): def do_GET(self): try: with open(LOG_PATH, "r") as f: body = f.read() except FileNotFoundError: body = "starting..." self.send_response(200) self.send_header("Content-Type", "text/plain; charset=utf-8") self.end_headers() self.wfile.write(body.encode("utf-8")) def log_message(self, fmt, *args): return if __name__ == "__main__": threading.Thread(target=run_training, daemon=True).start() HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()