File size: 9,319 Bytes
5e40307
 
 
 
 
8c5e6cc
5e40307
 
 
 
 
 
 
 
 
 
8c5e6cc
5e40307
 
8c5e6cc
5e40307
 
 
 
 
 
 
 
 
 
8c5e6cc
 
 
 
 
 
 
5e40307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be80524
5e40307
 
 
 
 
 
 
 
 
 
 
 
8c5e6cc
 
 
 
 
 
 
 
 
 
 
5e40307
8c5e6cc
5e40307
 
 
8c5e6cc
5e40307
 
be80524
 
5e40307
 
 
 
 
 
 
 
 
 
 
fad52c2
5e40307
 
 
fad52c2
5e40307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be80524
5e40307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c5e6cc
5e40307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be80524
5e40307
 
 
 
 
 
be80524
5e40307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#!/usr/bin/env python3
"""
FastAPI server for Trace Model inference.

Usage:
    python eval_server.py --model-id mihirgrao/trace-model --port 8000

Endpoints:
    POST /predict         - Single image + instruction
    POST /predict_batch   - Batch of (image, instruction) pairs
    GET  /health          - Health check
    GET  /model_info      - Model information
"""

import argparse
import base64
import io
import logging
import os
import re
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import Any, Dict, List, Optional

import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware

from trace_inference import (
    DEFAULT_MODEL_ID,
    build_prompt,
    load_model,
    run_inference,
)
from trace_inference import _model_state as _trace_model_state
from trajectory_viz import extract_trajectory_from_text

logger = logging.getLogger(__name__)

# --- Trace Eval Server ---


class TraceEvalServer:
    """Inference server for the trace model."""

    def __init__(
        self,
        model_id: str = DEFAULT_MODEL_ID,
        max_workers: int = 1,
    ):
        self.model_id = model_id
        self.max_workers = max_workers
        self._job_counter = 0
        self._completed_jobs = 0
        self._lock = Lock()
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

        logger.info(f"Loading trace model: {model_id}")
        success, msg = load_model(model_id)
        if not success:
            raise RuntimeError(f"Failed to load model: {msg}")
        logger.info(msg)

    def predict_one(
        self,
        image_path: Optional[str] = None,
        image_base64: Optional[str] = None,
        instruction: str = "",
        is_oxe: bool = False,
    ) -> Dict[str, Any]:
        """
        Run inference on a single image.

        Provide either image_path (file path) or image_base64 (base64-encoded image).
        """
        if image_path is None and image_base64 is None:
            return {"error": "Provide image_path or image_base64"}

        temp_file_path = None
        if image_path is None:
            try:
                # Strip data URL prefix if present (e.g. "data:image/png;base64,")
                b64_str = image_base64.strip()
                if b64_str.startswith("data:"):
                    match = re.match(r"data:image/[^;]+;base64,(.+)", b64_str, re.DOTALL)
                    if match:
                        b64_str = match.group(1)
                image_bytes = base64.b64decode(b64_str, validate=False)
                # Load via BytesIO to validate and get proper format, then save
                from PIL import Image

                img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
                    img.save(f.name, format="PNG")
                    image_path = f.name
                    temp_file_path = image_path
            except Exception as e:
                return {"error": f"Invalid image data: {e}"}

        try:
            prompt = build_prompt(instruction, is_oxe=is_oxe)
            prediction, _, _ = run_inference(image_path, prompt, self.model_id)
        finally:
            if temp_file_path and os.path.exists(temp_file_path):
                try:
                    os.unlink(temp_file_path)
                except Exception:
                    pass

        if prediction.startswith("Error:") or prediction.startswith("Please "):
            return {"error": prediction}

        trajectory = extract_trajectory_from_text(prediction)
        result: Dict[str, Any] = {
            "prediction": prediction,
            "trajectory": trajectory,
        }
        return result

    def predict_batch(
        self,
        samples: List[Dict[str, Any]],
    ) -> Dict[str, Any]:
        """Process a batch of (image_path or image_base64, instruction) samples."""
        results = []
        for sample in samples:
            with self._lock:
                self._job_counter += 1
                job_id = self._job_counter

            start = time.time()
            result = self.predict_one(
                image_path=sample.get("image_path"),
                image_base64=sample.get("image_base64"),
                instruction=sample.get("instruction", ""),
                is_oxe=sample.get("is_oxe", False),
            )
            elapsed = time.time() - start

            with self._lock:
                self._completed_jobs += 1

            logger.debug(f"[job {job_id}] completed in {elapsed:.3f}s")
            results.append(result)

        return {"results": results}

    def get_status(self) -> Dict[str, Any]:
        """Get server status."""
        return {
            "model_id": self.model_id,
            "max_workers": self.max_workers,
            "completed_jobs": self._completed_jobs,
            "job_counter": self._job_counter,
        }

    def get_model_info(self) -> Dict[str, Any]:
        """Get model information."""
        try:
            model = _trace_model_state.get("model")
            if model is None:
                return {"model_id": self.model_id, "status": "not_loaded"}

            all_params = sum(p.numel() for p in model.parameters())
            return {
                "model_id": self.model_id,
                "model_class": model.__class__.__name__,
                "total_parameters": all_params,
            }
        except Exception as e:
            return {"model_id": self.model_id, "error": str(e)}

    def shutdown(self):
        """Shutdown the executor."""
        self.executor.shutdown(wait=True)


def create_app(
    model_id: str = DEFAULT_MODEL_ID,
    max_workers: int = 1,
    server: Optional[TraceEvalServer] = None,
) -> FastAPI:
    app = FastAPI(title="Trace Model Evaluation Server")

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    trace_server = server or TraceEvalServer(model_id=model_id, max_workers=max_workers)

    @app.post("/predict")
    async def predict(request: Request) -> Dict[str, Any]:
        """
        Predict trace for a single image.

        JSON body:
            - image_path: (optional) path to image file
            - image_base64: (optional) base64-encoded image
            - instruction: natural language task description
            - is_oxe: (optional) if true, use OXE prompt format
        """
        body = await request.json()
        return trace_server.predict_one(
            image_path=body.get("image_path"),
            image_base64=body.get("image_base64"),
            instruction=body.get("instruction", ""),
            is_oxe=body.get("is_oxe", False),
        )

    @app.post("/predict_batch")
    async def predict_batch(request: Request) -> Dict[str, Any]:
        """
        Predict trace for a batch of images.

        JSON body:
            - samples: list of {image_path?, image_base64?, instruction}
        """
        body = await request.json()
        samples = body.get("samples", [])
        if not samples:
            return {"error": "samples list is required", "results": []}
        return trace_server.predict_batch(samples)

    @app.post("/evaluate_batch")
    async def evaluate_batch(request: Request) -> Dict[str, Any]:
        """
        Alias for /predict_batch for compatibility with RFM-style clients.
        Accepts same format as /predict_batch.
        """
        return await predict_batch(request)

    @app.get("/health")
    def health() -> Dict[str, Any]:
        """Health check."""
        status = trace_server.get_status()
        return {
            "status": "healthy",
            "model_id": status["model_id"],
        }

    @app.get("/model_info")
    def model_info() -> Dict[str, Any]:
        """Get model information."""
        return trace_server.get_model_info()

    @app.get("/gpu_status")
    def gpu_status() -> Dict[str, Any]:
        """Get server status (RFM-compatible endpoint name)."""
        return trace_server.get_status()

    @app.on_event("shutdown")
    async def shutdown_event():
        trace_server.shutdown()

    return app


def main():
    parser = argparse.ArgumentParser(description="Trace Model Evaluation Server")
    parser.add_argument(
        "--model-id",
        type=str,
        default=DEFAULT_MODEL_ID,
        help=f"Model ID (default: {DEFAULT_MODEL_ID})",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Server host",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8001,
        help="Server port",
    )
    parser.add_argument(
        "--max-workers",
        type=int,
        default=1,
        help="Max worker threads for batch processing",
    )
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    app = create_app(model_id=args.model_id, max_workers=args.max_workers)
    print(f"Trace eval server starting on {args.host}:{args.port}")
    print(f"Model: {args.model_id}")
    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()