File size: 11,723 Bytes
a82c744
 
 
 
 
 
 
 
 
 
 
 
 
 
a9aa4ae
a82c744
 
a9aa4ae
 
a82c744
 
 
a9aa4ae
a82c744
a9aa4ae
a82c744
a9aa4ae
a82c744
 
 
 
 
 
 
a9aa4ae
 
 
a82c744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9aa4ae
 
a82c744
 
 
 
 
a9aa4ae
 
 
a82c744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9aa4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a82c744
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""FastAPI server for GPU Goblin.

One audit endpoint plus a health probe. Streams the agent loop's `SSEEvent`s
to the UI via Server-Sent Events. CORS is wide open because Streamlit runs on
a different port β€” fine for a hackathon.

The agent runs on Qwen via Hugging Face Inference Providers. HF_TOKEN is
read at startup; if it's missing the server still starts (so the offline-
replay UI lane keeps working) but `/audit` yields a single error event.
We never crash on missing keys.
"""

from __future__ import annotations

import asyncio
import json
import os
import subprocess
import sys
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
from typing import Any

from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse

from agent.backends import active_backend_name
from agent.loop import run_audit
from agent.schemas import SSEEvent
from agent.tools import ALL_TOOLS

_REPO_ROOT = Path(__file__).resolve().parent.parent
_AUTO_TUNE_SCRIPT = _REPO_ROOT / "scripts" / "auto_tune.py"

app = FastAPI(title="GPU Goblin Agent", version="0.1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)


def _has_hf_token() -> bool:
    return bool(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN"))


@app.get("/healthz")
async def healthz() -> dict:
    """Liveness + tool inventory + active backend. UI uses this to confirm
    the agent is reachable and configured."""
    name = active_backend_name()
    base = {
        "ok": True,
        "tools": [t.name for t in ALL_TOOLS],
        "backend": name,
    }
    if name == "qwen-vllm":
        base.update(
            {
                "model": os.environ.get(
                    "GOBLIN_QWEN_VLLM_MODEL", "Qwen/Qwen2.5-7B-Instruct"
                ),
                "vllm_url": os.environ.get(
                    "GOBLIN_QWEN_VLLM_URL", "http://localhost:8000/v1"
                ),
                "has_api_key": True,  # vLLM doesn't require one by default
            }
        )
    else:
        base.update(
            {
                "model": os.environ.get(
                    "GOBLIN_QWEN_MODEL", "Qwen/Qwen2.5-7B-Instruct"
                ),
                "provider": os.environ.get("GOBLIN_QWEN_PROVIDER", "auto"),
                "has_api_key": _has_hf_token(),
            }
        )
    return base


async def _stream_audit(file_path: str) -> AsyncIterator[dict]:
    """Bridge `run_audit`'s SSEEvent generator into the dict shape that
    sse-starlette expects. Each yielded dict becomes one `data: ...\\n\\n`
    SSE message.
    """
    if not _has_hf_token():
        # Surface a clean error instead of letting the loop crash on missing key.
        yield {
            "data": SSEEvent(
                type="error",
                data={
                    "message": (
                        "HF_TOKEN not set on the server β€” Qwen agent loop is "
                        "unavailable. Set HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN) "
                        "or use the offline-replay UI lane."
                    )
                },
            ).model_dump_json()
        }
        return

    try:
        async for event in run_audit(file_path):
            yield {"data": event.model_dump_json()}
    except Exception as exc:  # defence in depth β€” run_audit already wraps itself
        yield {
            "data": SSEEvent(
                type="error", data={"message": f"server: {exc}"}
            ).model_dump_json()
        }


@app.post("/audit")
async def audit(file: UploadFile = File(...)) -> EventSourceResponse:
    """Accept a multipart file upload and stream the agent's audit events.

    The uploaded file is saved to a tempfile (preserving the extension so
    `parse_config`'s extension-dispatched parser picks the right path) and
    handed to `run_audit`. We don't delete the temp file here β€” the audit
    might still be reading it; the OS reaps it eventually and `bench_cache/`
    is gitignored.
    """
    suffix = Path(file.filename or "").suffix or ".bin"
    fd, tmp_path = tempfile.mkstemp(prefix="goblin_upload_", suffix=suffix)
    try:
        with os.fdopen(fd, "wb") as f:
            f.write(await file.read())
    except Exception:
        # If we couldn't even land the upload, surface that immediately.
        async def _err() -> AsyncIterator[dict]:
            yield {
                "data": SSEEvent(
                    type="error",
                    data={"message": "Failed to save uploaded file."},
                ).model_dump_json()
            }

        return EventSourceResponse(_err())

    return EventSourceResponse(_stream_audit(tmp_path))


# ---------------------------------------------------------------------------
# Auto-tune endpoint β€” lets a UI on a CPU-only host (e.g. an HF Space) drive
# scripts/auto_tune.py running on a remote MI300X server. The endpoint
# spawns the CLI, tails its --events NDJSON stream, and re-emits each line
# as an SSE message. Subprocess output is discarded; everything the UI
# needs is in the structured events.
# ---------------------------------------------------------------------------


class AutoTuneRequest(BaseModel):
    """JSON shape the /auto-tune endpoint accepts. Mirrors the auto_tune.py
    CLI surface so the UI just sends what the user picked in the form."""

    model: str | None = Field(
        default=None,
        description="HuggingFace model id (e.g. Qwen/Qwen2.5-7B-Instruct). "
        "Mutually exclusive with `workload`.",
    )
    workload: str | None = Field(
        default=None,
        description="Path to a workload script ON THE SERVER's filesystem. "
        "Mutually exclusive with `model`.",
    )
    mode: str = Field(default="hardcoded", pattern="^(hardcoded|llm|llm-explore)$")
    candidates_per_iteration: int = Field(default=3, ge=2, le=10)
    steps: int = Field(default=20, ge=1, le=500)
    max_iterations: int = Field(default=10, ge=1, le=50)
    early_stop_after: int = Field(default=3, ge=1, le=20)
    max_crashes: int = Field(default=4, ge=1, le=20)
    improvement_threshold: float = Field(default=0.0, ge=0.0, le=20.0)


def _build_auto_tune_cmd(req: AutoTuneRequest, events_file: Path) -> list[str]:
    cmd: list[str] = [sys.executable, "-u", str(_AUTO_TUNE_SCRIPT)]
    if req.model:
        cmd.extend(["--model", req.model])
    elif req.workload:
        cmd.append(req.workload)
    cmd.extend([
        "--mode", req.mode,
        "--steps", str(req.steps),
        "--max-iterations", str(req.max_iterations),
        "--early-stop-after", str(req.early_stop_after),
        "--max-crashes", str(req.max_crashes),
        "--improvement-threshold", str(req.improvement_threshold),
        "--events", str(events_file),
    ])
    if req.mode == "llm-explore":
        cmd.extend(["--candidates-per-iteration", str(req.candidates_per_iteration)])
    return cmd


async def _stream_auto_tune(req: AutoTuneRequest) -> AsyncIterator[dict]:
    """Spawn auto_tune.py and forward its NDJSON --events stream as SSE.

    Each event is forwarded verbatim β€” the UI gets the same structured
    payload it would see when running auto_tune.py locally. We discard
    the subprocess's stdout/stderr; any errors are surfaced via the
    `summary` event's absence at process exit.
    """
    events_file = Path(tempfile.mktemp(prefix="auto_tune_events_", suffix=".ndjson"))
    events_file.write_text("")

    cmd = _build_auto_tune_cmd(req, events_file)

    # Validate at least one of model/workload was provided. (Pydantic
    # can't express "exactly one of A or B" cleanly, so we check here.)
    if not req.model and not req.workload:
        yield {"data": json.dumps({
            "type": "error",
            "message": "Pass either `model` or `workload`, not neither."
        })}
        return
    if req.model and req.workload:
        yield {"data": json.dumps({
            "type": "error",
            "message": "Pass either `model` or `workload`, not both."
        })}
        return

    proc = subprocess.Popen(
        cmd,
        cwd=str(_REPO_ROOT),
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
        env={**os.environ},
    )

    seen_bytes = 0
    try:
        while True:
            # Poll the events file for new lines
            try:
                with events_file.open("r") as f:
                    f.seek(seen_bytes)
                    chunk = f.read()
                    new_seen = f.tell()
            except OSError:
                chunk = ""
                new_seen = seen_bytes

            if chunk:
                # Drop a trailing partial line β€” re-read it next tick once
                # the writer has flushed the rest.
                lines = chunk.splitlines(keepends=True)
                if lines and not lines[-1].endswith("\n"):
                    partial = lines.pop()
                    new_seen -= len(partial.encode("utf-8"))
                for line in lines:
                    line = line.strip()
                    if line:
                        yield {"data": line}
            seen_bytes = new_seen

            if proc.poll() is not None:
                # Subprocess exited. Drain whatever's left on disk.
                try:
                    with events_file.open("r") as f:
                        f.seek(seen_bytes)
                        tail = f.read()
                except OSError:
                    tail = ""
                for line in tail.splitlines():
                    line = line.strip()
                    if line:
                        yield {"data": line}
                if proc.returncode != 0:
                    yield {"data": json.dumps({
                        "type": "process_exit",
                        "returncode": proc.returncode,
                        "message": (
                            f"auto_tune.py exited with code {proc.returncode}. "
                            "Check the server's stderr or check `last_runner_failure_*` "
                            "in `bench_cache/` for goblin_runner.sh failure logs."
                        ),
                    })}
                break

            await asyncio.sleep(0.5)
    finally:
        if proc.poll() is None:
            proc.terminate()
            try:
                proc.wait(timeout=3)
            except subprocess.TimeoutExpired:
                proc.kill()
        try:
            events_file.unlink()
        except OSError:
            pass


@app.post("/auto-tune")
async def auto_tune_endpoint(req: AutoTuneRequest) -> EventSourceResponse:
    """Stream auto_tune.py events back to the caller as SSE.

    Run a UI on any host (HF Spaces, local laptop), point it at this
    endpoint, and the actual GPU work happens on the server hosting the
    FastAPI app. Subprocess output is discarded β€” only the --events
    NDJSON stream crosses the wire, one structured event per SSE message.
    """
    if not _AUTO_TUNE_SCRIPT.exists():
        raise HTTPException(
            status_code=500,
            detail=f"auto_tune.py not found at {_AUTO_TUNE_SCRIPT}",
        )
    return EventSourceResponse(_stream_auto_tune(req))


# Convenience: support `python -m uvicorn agent.server:app --reload`.
__all__ = ["app"]


def _decode_event(raw: str) -> dict:
    """Helper for the CLI driver β€” parse a serialized SSEEvent JSON payload.

    Lives here so __main__.py and tests can share one parser.
    """
    return json.loads(raw)