File size: 3,996 Bytes
e65a128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239f219
e65a128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239f219
e65a128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import annotations

import json
import os
import socket
import subprocess
import sys
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path

import httpx


ROOT = Path(__file__).resolve().parents[1]
PYTHON = ROOT / ".venv" / "Scripts" / "python.exe"


def _free_port() -> int:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.bind(("127.0.0.1", 0))
        return int(sock.getsockname()[1])


def test_inference_uses_proxy_api_key():
    app_port = _free_port()
    proxy_port = _free_port()
    requests_seen: list[dict[str, str | None]] = []

    class ProxyHandler(BaseHTTPRequestHandler):
        def do_POST(self):
            length = int(self.headers.get("Content-Length", "0"))
            body = self.rfile.read(length).decode("utf-8")
            requests_seen.append(
                {
                    "path": self.path,
                    "authorization": self.headers.get("Authorization"),
                    "body": body,
                }
            )
            payload = {
                "id": "chatcmpl-test",
                "object": "chat.completion",
                "created": int(time.time()),
                "model": "proxy-test-model",
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": json.dumps(
                                {
                                    "action_type": "submit_report",
                                    "answer": "Proxy verified [support_003]",
                                }
                            ),
                        },
                        "finish_reason": "stop",
                    }
                ],
            }
            encoded = json.dumps(payload).encode("utf-8")
            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Content-Length", str(len(encoded)))
            self.end_headers()
            self.wfile.write(encoded)

        def log_message(self, format: str, *args):
            return

    proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler)
    proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True)
    proxy_thread.start()

    app_process = subprocess.Popen(
        [str(PYTHON), "-m", "uvicorn", "app:app", "--host", "127.0.0.1", "--port", str(app_port)],
        cwd=ROOT,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )

    try:
        deadline = time.time() + 20
        while time.time() < deadline:
            try:
                if httpx.get(f"http://127.0.0.1:{app_port}/health", timeout=2).status_code == 200:
                    break
            except Exception:
                time.sleep(0.5)

        env = os.environ.copy()
        env["RAG_ENV_URL"] = f"http://127.0.0.1:{app_port}"
        env["RAG_ENV_TASK"] = "refund_triage_easy"
        env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1"
        env["API_KEY"] = "proxy-check-token"
        env["HF_TOKEN"] = "legacy-should-not-win"
        result = subprocess.run(
            [str(PYTHON), "inference.py"],
            cwd=ROOT,
            env=env,
            capture_output=True,
            text=True,
            timeout=60,
        )
        assert result.returncode == 0
        assert requests_seen
        assert requests_seen[0]["path"] == "/v1/chat/completions"
        assert requests_seen[0]["authorization"] == "Bearer proxy-check-token"
        assert any(line.startswith("[END]") and "score=" in line for line in result.stdout.splitlines())
    finally:
        proxy_server.shutdown()
        proxy_server.server_close()
        app_process.terminate()
        try:
            app_process.wait(timeout=5)
        except Exception:
            app_process.kill()