File size: 4,315 Bytes
bf0c1ca
 
66baeca
 
bf0c1ca
 
 
 
 
66baeca
bf0c1ca
 
 
 
 
 
242e068
 
 
 
 
 
 
 
 
 
 
 
a82c5a0
bf0c1ca
 
 
 
 
 
 
7c430fa
 
bf0c1ca
7c430fa
 
0f2eb35
bf0c1ca
7c430fa
 
 
 
bf0c1ca
 
 
8580f79
 
bf0c1ca
8580f79
bf0c1ca
 
 
 
 
 
 
8580f79
 
bf0c1ca
 
 
8580f79
 
 
 
 
bf0c1ca
 
66baeca
 
 
 
 
 
 
 
 
bf0c1ca
 
66baeca
 
 
 
 
bf0c1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8580f79
 
 
7c430fa
8580f79
bf0c1ca
8580f79
bf0c1ca
 
 
 
 
 
 
 
0f2eb35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""CVAT webhook listener — triggers finish_review when a job is marked completed."""

import hmac
import json
import os
import subprocess
import sys
import tempfile
import threading
from hashlib import sha256
from pathlib import Path

from fastapi import FastAPI, Request, HTTPException

app = FastAPI()

def _require_env(name: str) -> str:
    val = os.environ.get(name, "").strip()
    if not val:
        raise RuntimeError(f"Missing required env var: {name}")
    return val

DATASET = _require_env("HF_DATASET")
CVAT_TOKEN = _require_env("CVAT_TOKEN")
CVAT_WEBHOOK_SECRET = _require_env("CVAT_WEBHOOK_SECRET")
GITHUB_PAT = _require_env("GITHUB_PAT")
REPO_URL = _require_env("REPO_URL")
REPO_REF = _require_env("REPO_REF")
CVAT_URL = os.environ.get("CVAT_URL", "https://app.cvat.ai").strip()


def _clone_repo(workdir: Path) -> Path:
    repo_url = REPO_URL
    if GITHUB_PAT and "github.com" in repo_url:
        repo_url = repo_url.replace("https://", f"https://{GITHUB_PAT}@")
    repo_dir = workdir / "repo"
    print(f"Cloning {REPO_REF}...")
    result = subprocess.run(
        ["git", "clone", "--depth", "1", "-b", REPO_REF, repo_url, str(repo_dir)],
        capture_output=True, text=True,
        timeout=60,
        env={**os.environ, "GIT_TERMINAL_PROMPT": "0"},
    )
    if result.returncode != 0:
        print(f"Clone failed: {result.stderr}")
        raise RuntimeError(f"git clone failed: {result.stderr[:200]}")
    print("Clone done")
    return repo_dir


def _run_finish_review(repo_dir: Path, task_id: int) -> None:
    proc = subprocess.Popen(
        [
            sys.executable, "-u", str(repo_dir / "scripts" / "finish_review.py"),
            "--task-id", str(task_id),
            "--dataset", DATASET,
            "--experiment", f"cvat_review_{task_id}",
            "--labelmap", str(repo_dir / "labelmap.txt"),
            "--cvat-url", CVAT_URL,
            "--cvat-token", CVAT_TOKEN,
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        cwd=str(repo_dir),
    )
    for line in proc.stdout:
        print(line, end="", flush=True)
    proc.wait()
    if proc.returncode != 0:
        raise RuntimeError(f"finish_review exited with code {proc.returncode}")


def _verify_signature(body: bytes, signature: str) -> bool:
    if not CVAT_WEBHOOK_SECRET:
        return True
    expected = "sha256=" + hmac.new(
        CVAT_WEBHOOK_SECRET.encode("utf-8"), body, digestmod=sha256
    ).hexdigest()
    return hmac.compare_digest(signature, expected)


@app.post("/webhook")
async def cvat_webhook(request: Request):
    raw_body = await request.body()
    signature = request.headers.get("X-Signature-256", "")
    if not _verify_signature(raw_body, signature):
        raise HTTPException(status_code=403, detail="Invalid signature")
    body = json.loads(raw_body)

    event = body.get("event", "")
    if event != "update:job":
        return {"status": "ignored", "event": event}

    job = body.get("job", {})
    state = job.get("state", "")
    before = body.get("before_update", {})
    prev_state = before.get("state", "")

    if state != "completed" or prev_state == "completed":
        return {"status": "ignored", "reason": f"state={state}, prev={prev_state}"}

    task_id = job.get("task_id")
    if not task_id:
        raise HTTPException(status_code=400, detail="No task_id in payload")

    print(f"Job completed — task_id={task_id}, running finish_review in background...")

    def _run_in_background(tid: int):
        try:
            with tempfile.TemporaryDirectory() as workdir:
                repo_dir = _clone_repo(Path(workdir))
                print(f"Running finish_review for task {tid}...", flush=True)
                _run_finish_review(repo_dir, tid)
                print(f"finish_review completed for task {tid}", flush=True)
        except subprocess.TimeoutExpired:
            print(f"finish_review timed out for task {tid}", flush=True)
        except Exception as exc:
            print(f"finish_review failed for task {tid}: {exc}", flush=True)

    threading.Thread(target=_run_in_background, args=(task_id,), daemon=True).start()

    return {"status": "accepted", "task_id": task_id}


@app.get("/health")
async def health():
    return {"status": "ok", "dataset": DATASET}