File size: 3,710 Bytes
c1ae554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import json
import time
import uuid
from typing import Any, Dict, Optional

import requests
from .logging import logger

from .config import (
    BRIDGE_BASE_URL,
    FALLBACK_BRIDGE_URLS,
    WARMUP_INIT_RETRIES,
    WARMUP_INIT_DELAY_S,
    WARMUP_REQUEST_RETRIES,
    WARMUP_REQUEST_DELAY_S,
)
from .packets import packet_template
from .state import STATE, ensure_tool_ids


def bridge_send_stream(packet: Dict[str, Any]) -> Dict[str, Any]:
    last_exc: Optional[Exception] = None
    for base in FALLBACK_BRIDGE_URLS:
        url = f"{base}/api/warp/send_stream"
        try:
            wrapped_packet = {"json_data": packet, "message_type": "warp.multi_agent.v1.Request"}
            try:
                logger.info("[OpenAI Compat] Bridge request URL: %s", url)
                logger.info("[OpenAI Compat] Bridge request payload: %s", json.dumps(wrapped_packet, ensure_ascii=False))
            except Exception:
                logger.info("[OpenAI Compat] Bridge request payload serialization failed for URL %s", url)
            r = requests.post(url, json=wrapped_packet, timeout=(5.0, 180.0))
            if r.status_code == 200:
                try:
                    logger.info("[OpenAI Compat] Bridge response (raw text): %s", r.text)
                except Exception:
                    pass
                return r.json()
            else:
                txt = r.text
                last_exc = Exception(f"bridge_error: HTTP {r.status_code} {txt}")
        except Exception as e:
            last_exc = e
            continue
    if last_exc:
        raise last_exc
    raise Exception("bridge_unreachable")


def initialize_once() -> None:
    if STATE.conversation_id:
        return

    ensure_tool_ids()

    first_task_id = STATE.baseline_task_id or str(uuid.uuid4())
    STATE.baseline_task_id = first_task_id

    health_urls = [f"{base}/healthz" for base in FALLBACK_BRIDGE_URLS]
    last_err: Optional[str] = None
    for _ in range(WARMUP_INIT_RETRIES):
        try:
            ok = False
            last_err = None
            for h in health_urls:
                try:
                    resp = requests.get(h, timeout=5.0)
                    if resp.status_code == 200:
                        ok = True
                        break
                    else:
                        last_err = f"HTTP {resp.status_code} at {h}"
                except Exception as he:
                    last_err = f"{type(he).__name__}: {he} at {h}"
            if ok:
                break
        except Exception as e:
            last_err = str(e)
        time.sleep(WARMUP_INIT_DELAY_S)
    else:
        raise RuntimeError(f"Bridge server not ready: {last_err}")

    pkt = packet_template()
    pkt["task_context"]["active_task_id"] = first_task_id
    pkt["input"]["user_inputs"]["inputs"].append({"user_query": {"query": "warmup"}})

    last_exc: Optional[Exception] = None
    for attempt in range(1, WARMUP_REQUEST_RETRIES + 1):
        try:
            resp = bridge_send_stream(pkt)
            break
        except Exception as e:
            last_exc = e
            logger.warning(f"[OpenAI Compat] Warmup attempt {attempt}/{WARMUP_REQUEST_RETRIES} failed: {e}")
            if attempt < WARMUP_REQUEST_RETRIES:
                time.sleep(WARMUP_REQUEST_DELAY_S)
            else:
                raise

    STATE.conversation_id = resp.get("conversation_id") or STATE.conversation_id
    ret_task_id = resp.get("task_id")
    if isinstance(ret_task_id, str) and ret_task_id:
        STATE.baseline_task_id = ret_task_id