File size: 2,394 Bytes
2f30c49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""HF Sandbox 入口(docker SDK,监听 7860)。



启动后:

  1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)

  2. 主进程开 HTTP server on :7860,返回最新日志



阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。

阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。

"""
import os
import subprocess
import sys
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer

LOG_PATH = "/tmp/wjad.log"
PORT = 7860
# 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。
_MODE = os.environ.get("SANDBOX_MODE", "smoke")
if _MODE == "real_data":
    LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
else:
    LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]


def _print_env(f):
    f.write("=" * 72 + "\n")
    f.write(" WJAD HF Sandbox\n")
    f.write("=" * 72 + "\n")
    f.write(f"Python: {sys.version}\n")
    try:
        import torch
        f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\n")
        if torch.cuda.is_available():
            p = torch.cuda.get_device_properties(0)
            f.write(f"device: {p.name}  vram={p.total_memory / 1024**3:.2f} GB\n")
    except Exception as e:
        f.write(f"torch import failed: {e}\n")
    f.flush()


def run_training():
    with open(LOG_PATH, "w", buffering=1) as f:
        _print_env(f)
        f.write(f"$ {' '.join(LAUNCH_CMD)}\n")
        f.flush()
        p = subprocess.Popen(
            LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
        )
        rc = p.wait()
        f.write(f"\n[exit code = {rc}]\n")


class Handler(BaseHTTPRequestHandler):
    def do_GET(self):
        try:
            with open(LOG_PATH, "r") as f:
                body = f.read()
        except FileNotFoundError:
            body = "starting..."
        self.send_response(200)
        self.send_header("Content-Type", "text/plain; charset=utf-8")
        self.end_headers()
        self.wfile.write(body.encode("utf-8"))

    def log_message(self, fmt, *args):
        return


if __name__ == "__main__":
    threading.Thread(target=run_training, daemon=True).start()
    HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()