Spaces:
Running
Running
| from __future__ import annotations | |
| import gzip | |
| import json | |
| import sys | |
| import zlib | |
| from types import SimpleNamespace | |
| import pytest | |
| from headroom.proxy.handlers import batch as batch_module | |
| from headroom.proxy import helpers as proxy_helpers | |
| class FakeResponse: | |
| def __init__( | |
| self, | |
| *, | |
| status_code: int = 200, | |
| content: bytes = b"{}", | |
| headers: dict[str, str] | None = None, | |
| text: str | None = None, | |
| json_data=None, # noqa: ANN001 | |
| ) -> None: | |
| self.status_code = status_code | |
| self.content = content | |
| self.headers = headers or {} | |
| self.text = text if text is not None else content.decode("utf-8", errors="ignore") | |
| self._json_data = json_data | |
| def json(self): # noqa: ANN201 | |
| if self._json_data is not None: | |
| return self._json_data | |
| return json.loads(self.text) | |
| class FakeHttpClient: | |
| def __init__(self) -> None: | |
| self.posts: list[dict[str, object]] = [] | |
| self.gets: list[dict[str, object]] = [] | |
| self.requests: list[dict[str, object]] = [] | |
| self.post_response = FakeResponse() | |
| self.get_response = FakeResponse() | |
| self.raise_post: Exception | None = None | |
| self.raise_get: Exception | None = None | |
| async def post(self, url: str, **kwargs): # noqa: ANN003, ANN201 | |
| self.posts.append({"url": url, **kwargs}) | |
| if self.raise_post is not None: | |
| raise self.raise_post | |
| return self.post_response | |
| async def get(self, url: str, **kwargs): # noqa: ANN003, ANN201 | |
| self.gets.append({"url": url, **kwargs}) | |
| if self.raise_get is not None: | |
| raise self.raise_get | |
| return self.get_response | |
| async def request(self, method: str, url: str, **kwargs): # noqa: ANN003, ANN201 | |
| self.requests.append({"method": method, "url": url, **kwargs}) | |
| if self.raise_get is not None: | |
| raise self.raise_get | |
| return self.get_response | |
| class FakeMetrics: | |
| def __init__(self) -> None: | |
| self.record_calls: list[dict[str, object]] = [] | |
| self.failed_calls: list[dict[str, object]] = [] | |
| async def record_request(self, **kwargs) -> None: # noqa: ANN003 | |
| self.record_calls.append(kwargs) | |
| async def record_failed(self, **kwargs) -> None: # noqa: ANN003 | |
| self.failed_calls.append(kwargs) | |
| class DummyBatchHandler(batch_module.BatchHandlerMixin): | |
| OPENAI_API_URL = "https://openai.example" | |
| GEMINI_API_URL = "https://gemini.example" | |
| def __init__(self) -> None: | |
| self.http_client = FakeHttpClient() | |
| self.metrics = FakeMetrics() | |
| self.config = SimpleNamespace( | |
| optimize=False, | |
| ccr_inject_tool=False, | |
| ccr_inject_system_instructions=False, | |
| ) | |
| self.openai_provider = SimpleNamespace(get_context_limit=lambda model: 8192) | |
| self.openai_pipeline = SimpleNamespace(apply=lambda **kwargs: None) | |
| self._request_counter = 0 | |
| self._retry_response = FakeResponse() | |
| async def _next_request_id(self) -> str: | |
| self._request_counter += 1 | |
| return f"req-{self._request_counter}" | |
| async def handle_passthrough(self, request, base_url): # noqa: ANN001, ANN201 | |
| return {"request": request, "base_url": base_url} | |
| async def _retry_request(self, method, url, headers, body): # noqa: ANN001, ANN201 | |
| return self._retry_response | |
| def _gemini_contents_to_messages(self, contents, system_instruction): # noqa: ANN001, ANN201 | |
| messages = [{"role": "user", "content": part["parts"][0]["text"]} for part in contents] | |
| return messages, [] | |
| def _messages_to_gemini_contents(self, messages): # noqa: ANN001, ANN201 | |
| return ([{"parts": [{"text": message["content"]}]} for message in messages], None) | |
| class FakeRequest: | |
| def __init__( | |
| self, | |
| body: bytes | str, | |
| *, | |
| headers: dict[str, str] | None = None, | |
| method: str = "POST", | |
| path: str = "/v1/batches", | |
| query: str = "", | |
| ) -> None: | |
| self._body = body.encode("utf-8") if isinstance(body, str) else body | |
| self.headers = headers or {} | |
| self.method = method | |
| self.url = SimpleNamespace(path=path, query=query) | |
| async def body(self) -> bytes: | |
| return self._body | |
| async def test_read_request_json_enforces_body_size_limit() -> None: | |
| body = b"{" + b'"x":"' + (b"a" * (proxy_helpers.MAX_REQUEST_BODY_SIZE + 1)) + b'"}' | |
| request = FakeRequest(body) | |
| with pytest.raises(ValueError, match="Request body too large"): | |
| await proxy_helpers._read_request_json(request) | |
| async def test_read_request_json_handles_gzip_and_deflate_with_size_cap() -> None: | |
| payload = {"messages": [{"role": "user", "content": "hello"}], "model": "m"} | |
| raw = json.dumps(payload).encode("utf-8") | |
| gzip_req = FakeRequest(gzip.compress(raw), headers={"content-encoding": "gzip"}) | |
| assert await proxy_helpers._read_request_json(gzip_req) == payload | |
| deflate_req = FakeRequest(zlib.compress(raw), headers={"content-encoding": "deflate"}) | |
| assert await proxy_helpers._read_request_json(deflate_req) == payload | |
| big = b"{" + b'"x":"' + (b"a" * (proxy_helpers.MAX_REQUEST_BODY_SIZE + 1)) + b'"}' | |
| gzip_big = FakeRequest(gzip.compress(big), headers={"content-encoding": "gzip"}) | |
| with pytest.raises(ValueError, match="Decompressed request body too large"): | |
| await proxy_helpers._read_request_json(gzip_big) | |
| deflate_big = FakeRequest(zlib.compress(big), headers={"content-encoding": "deflate"}) | |
| with pytest.raises(ValueError, match="Decompressed request body too large"): | |
| await proxy_helpers._read_request_json(deflate_big) | |
| def install_batch_support_modules( | |
| monkeypatch: pytest.MonkeyPatch, | |
| *, | |
| injector_result=None, # noqa: ANN001 | |
| tokenizer_count: int = 10, | |
| ) -> None: | |
| class FakeInjector: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| self.kwargs = kwargs | |
| def process_request(self, messages, tools): # noqa: ANN001, ANN201 | |
| if injector_result is not None: | |
| return injector_result | |
| return messages, tools, False | |
| class FakeTokenizer: | |
| def count_messages(self, messages) -> int: # noqa: ANN001 | |
| return tokenizer_count | |
| monkeypatch.setitem(sys.modules, "headroom.ccr", SimpleNamespace(CCRToolInjector=FakeInjector)) | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.tokenizers", | |
| SimpleNamespace(get_tokenizer=lambda model: FakeTokenizer()), | |
| ) | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.utils", | |
| SimpleNamespace(extract_user_query=lambda messages: "query"), | |
| ) | |
| async def test_compress_batch_jsonl_without_optimization_handles_invalid_lines( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules(monkeypatch, tokenizer_count=12) | |
| handler = DummyBatchHandler() | |
| content = "\n".join( | |
| [ | |
| json.dumps( | |
| {"body": {"model": "gpt-4o", "messages": [{"role": "user", "content": "hi"}]}} | |
| ), | |
| json.dumps({"body": {"model": "gpt-4o", "messages": []}}), | |
| "not-json", | |
| ] | |
| ) | |
| lines, stats = await handler._compress_batch_jsonl(content, "req-1") | |
| assert len(lines) == 3 | |
| assert json.loads(lines[0])["body"]["messages"][0]["content"] == "hi" | |
| assert lines[2] == "not-json" | |
| assert stats == { | |
| "total_requests": 3, | |
| "total_original_tokens": 12, | |
| "total_compressed_tokens": 12, | |
| "total_tokens_saved": 0, | |
| "savings_percent": 0.0, | |
| "errors": 1, | |
| } | |
| async def test_compress_batch_jsonl_uses_pipeline_and_ccr_injection( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules( | |
| monkeypatch, | |
| injector_result=( | |
| [{"role": "system", "content": "compressed"}], | |
| [{"name": "retrieval"}], | |
| True, | |
| ), | |
| ) | |
| handler = DummyBatchHandler() | |
| handler.config.optimize = True | |
| handler.config.ccr_inject_tool = True | |
| handler.openai_pipeline = SimpleNamespace( | |
| apply=lambda **kwargs: SimpleNamespace( | |
| messages=[{"role": "assistant", "content": "short"}], | |
| tokens_before=100, | |
| tokens_after=40, | |
| ) | |
| ) | |
| lines, stats = await handler._compress_batch_jsonl( | |
| json.dumps( | |
| { | |
| "body": { | |
| "model": "gpt-4o-mini", | |
| "messages": [{"role": "user", "content": "hello"}], | |
| "tools": [{"name": "existing"}], | |
| } | |
| } | |
| ), | |
| "req-2", | |
| ) | |
| body = json.loads(lines[0])["body"] | |
| assert body["messages"] == [{"role": "system", "content": "compressed"}] | |
| assert body["tools"] == [{"name": "retrieval"}] | |
| assert stats["total_tokens_saved"] == 60 | |
| assert stats["savings_percent"] == 60.0 | |
| async def test_compress_batch_jsonl_falls_back_when_pipeline_raises( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules(monkeypatch, tokenizer_count=33) | |
| handler = DummyBatchHandler() | |
| handler.config.optimize = True | |
| handler.openai_pipeline = SimpleNamespace( | |
| apply=lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")) | |
| ) | |
| lines, stats = await handler._compress_batch_jsonl( | |
| json.dumps({"body": {"messages": [{"role": "user", "content": "hello"}]}}), | |
| "req-3", | |
| ) | |
| assert json.loads(lines[0])["body"]["messages"][0]["content"] == "hello" | |
| assert stats["total_original_tokens"] == 33 | |
| assert stats["total_compressed_tokens"] == 33 | |
| async def test_batch_passthrough_forwards_request_and_strips_response_headers() -> None: | |
| handler = DummyBatchHandler() | |
| handler.http_client.post_response = FakeResponse( | |
| content=b'{"ok":true}', | |
| headers={"content-encoding": "gzip", "content-length": "20", "x-kept": "1"}, | |
| ) | |
| response = await handler._batch_passthrough( | |
| FakeRequest( | |
| '{"input_file_id":"file-1"}', headers={"host": "example", "content-length": "10"} | |
| ), | |
| {"input_file_id": "file-1"}, | |
| ) | |
| assert response.status_code == 200 | |
| assert dict(response.headers)["x-kept"] == "1" | |
| assert "content-encoding" not in dict(response.headers) | |
| assert handler.http_client.posts[0]["url"] == "https://openai.example/v1/batches" | |
| async def test_handle_batch_create_validates_json_and_required_fields( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| handler = DummyBatchHandler() | |
| async def raise_bad_json(request): # noqa: ANN001 | |
| raise ValueError("bad json") | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", raise_bad_json) | |
| bad = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert bad.status_code == 400 | |
| assert bad.body.decode().find("invalid_json") > 0 | |
| async def missing_file_payload(request): # noqa: ANN001 | |
| return {"endpoint": "/v1/chat/completions"} | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", missing_file_payload) | |
| missing_file = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert missing_file.status_code == 400 | |
| assert missing_file.body.decode().find("input_file_id is required") > 0 | |
| async def missing_endpoint_payload(request): # noqa: ANN001 | |
| return {"input_file_id": "file-1"} | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", missing_endpoint_payload) | |
| missing_endpoint = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert missing_endpoint.status_code == 400 | |
| assert missing_endpoint.body.decode().find("endpoint is required") > 0 | |
| async def test_handle_batch_create_passthrough_and_download_failure( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| handler = DummyBatchHandler() | |
| passthrough_response = SimpleNamespace(marker="passthrough") | |
| async def fake_passthrough(request, body): # noqa: ANN001 | |
| return passthrough_response | |
| monkeypatch.setattr(handler, "_batch_passthrough", fake_passthrough) | |
| async def passthrough_payload(request): # noqa: ANN001 | |
| return {"input_file_id": "file-1", "endpoint": "/v1/responses"} | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", passthrough_payload) | |
| assert await handler.handle_batch_create(FakeRequest("{}")) is passthrough_response | |
| async def download_missing_payload(request): # noqa: ANN001 | |
| return {"input_file_id": "file-1", "endpoint": "/v1/chat/completions"} | |
| async def missing_download(file_id, headers): # noqa: ANN001 | |
| return None | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", download_missing_payload) | |
| monkeypatch.setattr(handler, "_download_openai_file", missing_download) | |
| missing = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert missing.status_code == 404 | |
| assert missing.body.decode().find("file_not_found") > 0 | |
| async def test_handle_batch_create_handles_empty_upload_failure_and_success( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| handler = DummyBatchHandler() | |
| async def request_payload(request): # noqa: ANN001 | |
| return { | |
| "input_file_id": "file-1", | |
| "endpoint": "/v1/chat/completions", | |
| "completion_window": "12h", | |
| "metadata": {"source": "test"}, | |
| } | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", request_payload) | |
| async def fake_download(file_id, headers): # noqa: ANN001 | |
| return "downloaded" | |
| monkeypatch.setattr(handler, "_download_openai_file", fake_download) | |
| async def empty_compress(content, request_id): # noqa: ANN001 | |
| return [], { | |
| "total_requests": 0, | |
| "total_original_tokens": 0, | |
| "total_compressed_tokens": 0, | |
| "total_tokens_saved": 0, | |
| "savings_percent": 0.0, | |
| "errors": 0, | |
| } | |
| monkeypatch.setattr(handler, "_compress_batch_jsonl", empty_compress) | |
| empty = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert empty.status_code == 400 | |
| assert empty.body.decode().find("empty_file") > 0 | |
| async def compressed(content, request_id): # noqa: ANN001 | |
| return ['{"body":{}}'], { | |
| "total_requests": 1, | |
| "total_original_tokens": 20, | |
| "total_compressed_tokens": 10, | |
| "total_tokens_saved": 10, | |
| "savings_percent": 50.0, | |
| "errors": 0, | |
| } | |
| monkeypatch.setattr(handler, "_compress_batch_jsonl", compressed) | |
| async def upload_failed_file(content, filename, headers): # noqa: ANN001 | |
| return None | |
| monkeypatch.setattr(handler, "_upload_openai_file", upload_failed_file) | |
| upload_failed = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert upload_failed.status_code == 500 | |
| assert upload_failed.body.decode().find("upload_failed") > 0 | |
| handler.http_client.post_response = FakeResponse( | |
| content=b'{"id":"batch_123","object":"batch"}', | |
| headers={"content-encoding": "gzip", "content-length": "12", "x-openai": "1"}, | |
| ) | |
| async def upload_success(content, filename, headers): # noqa: ANN001 | |
| return "file-compressed" | |
| monkeypatch.setattr(handler, "_upload_openai_file", upload_success) | |
| success = await handler.handle_batch_create( | |
| FakeRequest( | |
| "{}", headers={"host": "proxy", "content-length": "4", "authorization": "Bearer test"} | |
| ) | |
| ) | |
| assert success.status_code == 200 | |
| success_headers = dict(success.headers) | |
| assert success_headers["x-headroom-tokens-saved"] == "10" | |
| assert success_headers["x-headroom-savings-percent"] == "50.0" | |
| assert success_headers["x-openai"] == "1" | |
| sent_body = handler.http_client.posts[-1]["json"] | |
| assert sent_body["metadata"]["headroom_compressed"] == "true" | |
| assert sent_body["metadata"]["headroom_original_file_id"] == "file-1" | |
| assert handler.metrics.record_calls[-1]["provider"] == "openai" | |
| async def test_handle_batch_create_records_failure_on_exception( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| handler = DummyBatchHandler() | |
| async def request_payload(request): # noqa: ANN001 | |
| return {"input_file_id": "file-1", "endpoint": "/v1/chat/completions"} | |
| async def boom(file_id, headers): # noqa: ANN001 | |
| raise RuntimeError("boom") | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", request_payload) | |
| monkeypatch.setattr(handler, "_download_openai_file", boom) | |
| response = await handler.handle_batch_create(FakeRequest("{}")) | |
| assert response.status_code == 500 | |
| assert handler.metrics.failed_calls == [{"provider": "batch"}] | |
| async def test_download_and_upload_openai_file_helpers() -> None: | |
| handler = DummyBatchHandler() | |
| handler.http_client.get_response = FakeResponse(status_code=200, text="jsonl-content") | |
| downloaded = await handler._download_openai_file("file-1", {"authorization": "Bearer token"}) | |
| assert downloaded == "jsonl-content" | |
| assert handler.http_client.gets[0]["url"] == "https://openai.example/v1/files/file-1/content" | |
| handler.http_client.get_response = FakeResponse(status_code=404, text="missing") | |
| assert await handler._download_openai_file("file-2", {}) is None | |
| handler.http_client.post_response = FakeResponse( | |
| status_code=200, | |
| json_data={"id": "file-uploaded"}, | |
| headers={"content-type": "application/json"}, | |
| ) | |
| file_id = await handler._upload_openai_file( | |
| '{"body":{}}', | |
| "compressed.jsonl", | |
| {"authorization": "Bearer token", "content-type": "application/json"}, | |
| ) | |
| assert file_id == "file-uploaded" | |
| post_call = handler.http_client.posts[-1] | |
| assert post_call["headers"] == {"authorization": "Bearer token"} | |
| assert post_call["files"]["file"][0] == "compressed.jsonl" | |
| handler.http_client.post_response = FakeResponse(status_code=500, text="fail") | |
| assert await handler._upload_openai_file("{}", "bad.jsonl", {}) is None | |
| handler.http_client.raise_post = RuntimeError("network") | |
| assert await handler._upload_openai_file("{}", "bad.jsonl", {}) is None | |
| async def test_store_google_batch_context_persists_transformed_requests( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| stored_contexts: list[object] = [] | |
| class FakeBatchContext: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| self.kwargs = kwargs | |
| self.requests: list[object] = [] | |
| def add_request(self, request) -> None: # noqa: ANN001 | |
| self.requests.append(request) | |
| class FakeBatchRequestContext: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| self.kwargs = kwargs | |
| class FakeStore: | |
| async def store(self, context) -> None: # noqa: ANN001 | |
| stored_contexts.append(context) | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.ccr", | |
| SimpleNamespace( | |
| BatchContext=FakeBatchContext, | |
| BatchRequestContext=FakeBatchRequestContext, | |
| get_batch_context_store=lambda: FakeStore(), | |
| ), | |
| ) | |
| handler = DummyBatchHandler() | |
| await handler._store_google_batch_context( | |
| "batches/123", | |
| [ | |
| { | |
| "metadata": {"key": "req-1"}, | |
| "request": { | |
| "contents": [{"parts": [{"text": "hello"}]}], | |
| "systemInstruction": {"parts": [{"text": "system"}]}, | |
| "tools": [{"name": "tool"}], | |
| }, | |
| } | |
| ], | |
| "gemini-2.0", | |
| "api-key", | |
| ) | |
| context = stored_contexts[0] | |
| assert context.kwargs["batch_id"] == "batches/123" | |
| assert context.requests[0].kwargs["custom_id"] == "req-1" | |
| assert context.requests[0].kwargs["messages"] == [{"role": "user", "content": "hello"}] | |
| assert context.requests[0].kwargs["system_instruction"] == "system" | |
| async def test_handle_google_batch_results_passes_through_early_exit_cases( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| class FakeStore: | |
| async def get(self, batch_name): # noqa: ANN001 | |
| return None | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.ccr", | |
| SimpleNamespace( | |
| BatchResultProcessor=lambda http_client: None, | |
| get_batch_context_store=lambda: FakeStore(), | |
| ), | |
| ) | |
| handler = DummyBatchHandler() | |
| request = FakeRequest( | |
| "{}", headers={"x-goog-api-key": "secret"}, method="GET", path="/v1beta/batches/b1" | |
| ) | |
| handler.http_client.get_response = FakeResponse( | |
| status_code=500, content=b"bad", headers={"x-upstream": "1"} | |
| ) | |
| error_response = await handler.handle_google_batch_results(request, "batches/b1") | |
| assert error_response.status_code == 500 | |
| assert dict(error_response.headers)["x-upstream"] == "1" | |
| class BadJsonResponse(FakeResponse): | |
| def json(self): # noqa: ANN201 | |
| raise json.JSONDecodeError("bad", "x", 0) | |
| handler.http_client.get_response = BadJsonResponse( | |
| status_code=200, content=b"plain", headers={"x-upstream": "2"} | |
| ) | |
| non_json = await handler.handle_google_batch_results(request, "batches/b1") | |
| assert non_json.status_code == 200 | |
| assert dict(non_json.headers)["x-upstream"] == "2" | |
| handler.http_client.get_response = FakeResponse( | |
| status_code=200, | |
| content=b"{}", | |
| json_data={"metadata": {"state": "RUNNING"}}, | |
| ) | |
| running = await handler.handle_google_batch_results(request, "batches/b1") | |
| assert running.status_code == 200 | |
| handler.http_client.get_response = FakeResponse( | |
| status_code=200, | |
| content=b"{}", | |
| json_data={"metadata": {"state": "SUCCEEDED"}, "response": {"responses": []}}, | |
| ) | |
| no_results = await handler.handle_google_batch_results(request, "batches/b1") | |
| assert no_results.status_code == 200 | |
| handler.http_client.get_response = FakeResponse( | |
| status_code=200, | |
| content=b"{}", | |
| json_data={"metadata": {"state": "SUCCEEDED"}, "response": {"responses": [{"id": 1}]}}, | |
| ) | |
| handler.config.ccr_inject_tool = False | |
| no_ccr = await handler.handle_google_batch_results(request, "batches/b1") | |
| assert no_ccr.status_code == 200 | |
| assert "key=secret" in handler.http_client.gets[-1]["url"] | |
| async def test_handle_google_batch_results_processes_completed_results( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| processed_calls: list[tuple[str, list[object], str]] = [] | |
| class FakeProcessed: | |
| def __init__( | |
| self, result, custom_id: str, was_processed: bool, continuation_rounds: int | |
| ) -> None: # noqa: ANN001 | |
| self.result = result | |
| self.custom_id = custom_id | |
| self.was_processed = was_processed | |
| self.continuation_rounds = continuation_rounds | |
| class FakeProcessor: | |
| def __init__(self, http_client) -> None: # noqa: ANN001 | |
| self.http_client = http_client | |
| async def process_results(self, batch_name, results, provider): # noqa: ANN001 | |
| processed_calls.append((batch_name, results, provider)) | |
| return [ | |
| FakeProcessed({"id": "processed"}, "req-1", True, 2), | |
| FakeProcessed({"id": "unchanged"}, "req-2", False, 0), | |
| ] | |
| class FakeStore: | |
| async def get(self, batch_name): # noqa: ANN001 | |
| return SimpleNamespace(batch_name=batch_name) | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.ccr", | |
| SimpleNamespace( | |
| BatchResultProcessor=FakeProcessor, | |
| get_batch_context_store=lambda: FakeStore(), | |
| ), | |
| ) | |
| handler = DummyBatchHandler() | |
| handler.config.ccr_inject_tool = True | |
| handler.http_client.get_response = FakeResponse( | |
| status_code=200, | |
| content=b"{}", | |
| json_data={ | |
| "metadata": {"state": "SUCCEEDED"}, | |
| "response": {"responses": [{"id": "raw-1"}, {"id": "raw-2"}]}, | |
| }, | |
| ) | |
| response = await handler.handle_google_batch_results( | |
| FakeRequest("{}", method="GET", path="/v1beta/batches/b1"), | |
| "batches/b1", | |
| ) | |
| payload = json.loads(response.body) | |
| assert payload["response"]["responses"] == [{"id": "processed"}, {"id": "unchanged"}] | |
| assert processed_calls == [("batches/b1", [{"id": "raw-1"}, {"id": "raw-2"}], "google")] | |
| assert handler.metrics.record_calls[-1]["model"] == "batch:ccr-processed" | |
| async def test_google_batch_passthrough_helpers_forward_and_track_metrics() -> None: | |
| handler = DummyBatchHandler() | |
| handler.http_client.post_response = FakeResponse( | |
| content=b'{"ok":true}', | |
| headers={"content-encoding": "gzip", "content-length": "10", "x-kept": "1"}, | |
| ) | |
| handler.http_client.post_response = FakeResponse( | |
| content=b'{"ok":true}', | |
| headers={"content-encoding": "gzip", "content-length": "10", "x-kept": "1"}, | |
| ) | |
| passthrough = await handler._google_batch_passthrough( | |
| FakeRequest( | |
| "body", headers={"host": "proxy", "content-length": "4", "x-goog-api-key": "secret"} | |
| ), | |
| "gemini-pro", | |
| {"batch": {}}, | |
| ) | |
| assert passthrough.status_code == 200 | |
| assert dict(passthrough.headers)["x-kept"] == "1" | |
| assert "key=secret" in handler.http_client.posts[-1]["url"] | |
| assert handler.metrics.record_calls[-1]["model"] == "passthrough:batch:gemini-pro" | |
| handler.http_client.get_response = FakeResponse( | |
| content=b'{"state":"ok"}', | |
| headers={"content-encoding": "gzip", "content-length": "10", "x-kept": "2"}, | |
| ) | |
| response = await handler.handle_google_batch_passthrough( | |
| FakeRequest( | |
| "ping", | |
| headers={"host": "proxy", "x-goog-api-key": "secret"}, | |
| method="DELETE", | |
| path="/v1beta/batches/b1", | |
| query="alt=json", | |
| ), | |
| "b1", | |
| ) | |
| assert response.status_code == 200 | |
| assert dict(response.headers)["x-kept"] == "2" | |
| get_call = handler.http_client.requests[-1] | |
| assert get_call["url"] == "https://gemini.example/v1beta/batches/b1?alt=json&key=secret" | |
| assert handler.metrics.record_calls[-1]["model"] == "passthrough:batches" | |
| async def test_handle_google_batch_create_validates_and_passthroughs( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules(monkeypatch) | |
| handler = DummyBatchHandler() | |
| too_large = await handler.handle_google_batch_create( | |
| FakeRequest("{}", headers={"content-length": str(200 * 1024 * 1024)}), | |
| "gemini-pro", | |
| ) | |
| assert too_large.status_code == 413 | |
| async def bad_json(request): # noqa: ANN001 | |
| raise ValueError("bad json") | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", bad_json) | |
| invalid = await handler.handle_google_batch_create(FakeRequest("{}"), "gemini-pro") | |
| assert invalid.status_code == 400 | |
| passthrough_response = SimpleNamespace(kind="passthrough") | |
| async def fake_google_passthrough(request, model, body=None): # noqa: ANN001 | |
| return passthrough_response | |
| async def no_inline(request): # noqa: ANN001 | |
| return {"batch": {"input_config": {"requests": {"requests": []}}}} | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", no_inline) | |
| monkeypatch.setattr(handler, "_google_batch_passthrough", fake_google_passthrough) | |
| assert ( | |
| await handler.handle_google_batch_create(FakeRequest("{}"), "gemini-pro") | |
| is passthrough_response | |
| ) | |
| async def test_handle_google_batch_create_success_and_failure_paths( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules(monkeypatch) | |
| handler = DummyBatchHandler() | |
| handler.config.optimize = True | |
| handler.config.ccr_inject_tool = True | |
| handler.openai_pipeline = SimpleNamespace( | |
| apply=lambda **kwargs: SimpleNamespace( | |
| messages=[{"role": "user", "content": "compressed"}], | |
| timing={"compress": 1.2}, | |
| tokens_before=100, | |
| tokens_after=40, | |
| ) | |
| ) | |
| class FakeInjector: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| pass | |
| def process_request(self, messages, tools): # noqa: ANN001, ANN201 | |
| return ( | |
| messages + [{"role": "system", "content": "retrieval"}], | |
| [{"name": "retrieval"}], | |
| True, | |
| ) | |
| monkeypatch.setitem(sys.modules, "headroom.ccr", SimpleNamespace(CCRToolInjector=FakeInjector)) | |
| stored: list[tuple[str, list[dict[str, object]], str, str | None]] = [] | |
| async def fake_store(batch_name, requests_list, model, api_key): # noqa: ANN001 | |
| stored.append((batch_name, requests_list, model, api_key)) | |
| async def fake_retry(method, url, headers, body): # noqa: ANN001 | |
| return FakeResponse( | |
| status_code=200, | |
| content=b'{"name":"batches/123"}', | |
| headers={"content-encoding": "gzip", "content-length": "10", "x-upstream": "1"}, | |
| json_data={"name": "batches/123"}, | |
| ) | |
| async def good_payload(request): # noqa: ANN001 | |
| return { | |
| "batch": { | |
| "input_config": { | |
| "requests": { | |
| "requests": [ | |
| { | |
| "request": { | |
| "contents": [{"parts": [{"text": "hello"}]}], | |
| "tools": [{"functionDeclarations": [{"name": "existing"}]}], | |
| }, | |
| "metadata": {"key": "req-1"}, | |
| } | |
| ] | |
| } | |
| } | |
| } | |
| } | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", good_payload) | |
| monkeypatch.setattr(handler, "_retry_request", fake_retry) | |
| monkeypatch.setattr(handler, "_store_google_batch_context", fake_store) | |
| response = await handler.handle_google_batch_create( | |
| FakeRequest("{}", headers={"x-goog-api-key": "secret"}), | |
| "gemini-pro", | |
| ) | |
| assert response.status_code == 200 | |
| assert dict(response.headers)["x-upstream"] == "1" | |
| assert handler.metrics.record_calls[-1]["provider"] == "google" | |
| assert handler.metrics.record_calls[-1]["tokens_saved"] == 60 | |
| assert stored[0][0] == "batches/123" | |
| assert stored[0][2:] == ("gemini-pro", "secret") | |
| assert stored[0][1][0]["metadata"] == {"key": "req-1"} | |
| async def broken_retry(method, url, headers, body): # noqa: ANN001 | |
| raise RuntimeError("forward failed") | |
| monkeypatch.setattr(handler, "_retry_request", broken_retry) | |
| failed = await handler.handle_google_batch_create(FakeRequest("{}"), "gemini-pro") | |
| assert failed.status_code == 500 | |
| async def test_handle_google_batch_create_covers_passthrough_revert_and_store_failures( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules( | |
| monkeypatch, injector_result=([{"role": "user", "content": "kept"}], None, False) | |
| ) | |
| handler = DummyBatchHandler() | |
| handler.config.optimize = True | |
| handler.config.ccr_inject_tool = True | |
| pipeline_calls: list[dict[str, object]] = [] | |
| handler.openai_pipeline = SimpleNamespace( | |
| apply=lambda **kwargs: ( | |
| pipeline_calls.append(kwargs) | |
| or SimpleNamespace( | |
| messages=[{"role": "user", "content": "inflated"}], | |
| timing={}, | |
| tokens_before=40, | |
| tokens_after=80, | |
| ) | |
| ) | |
| ) | |
| def fake_to_messages(contents, system_instruction): # noqa: ANN001, ANN201 | |
| if contents and "inlineData" in contents[0]["parts"][0]: | |
| return ([{"role": "user", "content": "binary"}], [0]) | |
| return ([{"role": "user", "content": "compress"}], []) | |
| def fake_to_gemini(messages): # noqa: ANN001, ANN201 | |
| return ([{"parts": [{"text": "new"}]}], {"parts": [{"text": "sys"}]}) | |
| async def payload(request): # noqa: ANN001 | |
| return { | |
| "batch": { | |
| "input_config": { | |
| "requests": { | |
| "requests": [ | |
| {"request": {"contents": []}, "metadata": {"key": "empty"}}, | |
| { | |
| "request": {"contents": [{"parts": [{"inlineData": "x"}]}]}, | |
| "metadata": {"key": "preserved"}, | |
| }, | |
| { | |
| "request": { | |
| "contents": [{"parts": [{"text": "hello"}]}], | |
| "tools": [ | |
| {"other": True}, | |
| {"functionDeclarations": [{"name": "existing"}]}, | |
| ], | |
| }, | |
| "metadata": {"key": "optimized"}, | |
| }, | |
| ] | |
| } | |
| } | |
| } | |
| } | |
| seen_bodies: list[dict[str, object]] = [] | |
| async def retry(method, url, headers, body): # noqa: ANN001 | |
| seen_bodies.append(body) | |
| return FakeResponse(status_code=200, content=b"{}", json_data={"name": "batches/123"}) | |
| async def broken_store(batch_name, requests_list, model, api_key): # noqa: ANN001 | |
| raise RuntimeError("store failed") | |
| monkeypatch.setattr("headroom.proxy.helpers._read_request_json", payload) | |
| monkeypatch.setattr(handler, "_gemini_contents_to_messages", fake_to_messages) | |
| monkeypatch.setattr(handler, "_messages_to_gemini_contents", fake_to_gemini) | |
| monkeypatch.setattr(handler, "_retry_request", retry) | |
| monkeypatch.setattr(handler, "_store_google_batch_context", broken_store) | |
| response = await handler.handle_google_batch_create(FakeRequest("{}"), "gemini-pro") | |
| assert response.status_code == 200 | |
| assert len(pipeline_calls) == 1 | |
| assert handler.metrics.record_calls[-1]["tokens_saved"] == 0 | |
| assert ( | |
| seen_bodies[0]["batch"]["input_config"]["requests"]["requests"][0]["metadata"]["key"] | |
| == "empty" | |
| ) | |
| optimized = seen_bodies[0]["batch"]["input_config"]["requests"]["requests"][2]["request"] | |
| assert optimized["contents"][0] == {"parts": [{"text": "new"}]} | |
| assert optimized["systemInstruction"] == {"parts": [{"text": "sys"}]} | |
| async def test_google_batch_passthrough_without_body_and_query_variants() -> None: | |
| handler = DummyBatchHandler() | |
| handler.http_client.post_response = FakeResponse(content=b"ok", headers={"x-upstream": "1"}) | |
| response = await handler._google_batch_passthrough( | |
| FakeRequest("raw-body", headers={"host": "proxy"}, method="POST"), | |
| "gemini-pro", | |
| ) | |
| assert response.status_code == 200 | |
| assert handler.http_client.posts[-1]["content"] == b"raw-body" | |
| handler.http_client.get_response = FakeResponse(content=b"{}", headers={"x-upstream": "2"}) | |
| passthrough = await handler.handle_google_batch_passthrough( | |
| FakeRequest( | |
| "{}", | |
| headers={"host": "proxy", "x-goog-api-key": "secret"}, | |
| method="GET", | |
| path="/v1beta/batches/b1", | |
| ), | |
| "b1", | |
| ) | |
| assert passthrough.status_code == 200 | |
| assert ( | |
| handler.http_client.requests[-1]["url"] | |
| == "https://gemini.example/v1beta/batches/b1?key=secret" | |
| ) | |
| async def test_batch_helper_methods_and_openai_file_error_branches() -> None: | |
| handler = DummyBatchHandler() | |
| marker = object() | |
| async def fake_passthrough(request, base_url): # noqa: ANN001 | |
| return marker | |
| handler.handle_passthrough = fake_passthrough | |
| request = FakeRequest("{}") | |
| assert await handler.handle_batch_list(request) is marker | |
| assert await handler.handle_batch_get(request, "b1") is marker | |
| assert await handler.handle_batch_cancel(request, "b1") is marker | |
| handler.http_client.raise_get = RuntimeError("download boom") | |
| assert await handler._download_openai_file("file-1", {}) is None | |
| handler.http_client.raise_get = None | |
| handler.http_client.post_response = FakeResponse(status_code=200, json_data={}) | |
| assert await handler._upload_openai_file("{}", "missing-id.jsonl", {}) is None | |
| async def test_store_google_batch_context_without_system_text( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| stored_contexts: list[object] = [] | |
| class FakeBatchContext: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| self.kwargs = kwargs | |
| self.requests: list[object] = [] | |
| def add_request(self, request) -> None: # noqa: ANN001 | |
| self.requests.append(request) | |
| class FakeBatchRequestContext: | |
| def __init__(self, **kwargs) -> None: # noqa: ANN003 | |
| self.kwargs = kwargs | |
| class FakeStore: | |
| async def store(self, context) -> None: # noqa: ANN001 | |
| stored_contexts.append(context) | |
| handler = DummyBatchHandler() | |
| monkeypatch.setitem( | |
| sys.modules, | |
| "headroom.ccr", | |
| SimpleNamespace( | |
| BatchContext=FakeBatchContext, | |
| BatchRequestContext=FakeBatchRequestContext, | |
| get_batch_context_store=lambda: FakeStore(), | |
| ), | |
| ) | |
| await handler._store_google_batch_context( | |
| "batches/456", | |
| [ | |
| { | |
| "request": { | |
| "contents": [{"parts": [{"text": "hello"}]}], | |
| "systemInstruction": {"parts": ["bad"]}, | |
| } | |
| } | |
| ], | |
| "gemini-2.0", | |
| None, | |
| ) | |
| context = stored_contexts[0] | |
| assert context.kwargs["api_key"] is None | |
| assert context.requests[0].kwargs["custom_id"] == "" | |
| assert context.requests[0].kwargs["system_instruction"] is None | |
| async def test_compress_batch_jsonl_skips_blank_lines_and_preserves_tools_when_not_injected( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| install_batch_support_modules( | |
| monkeypatch, | |
| injector_result=([{"role": "assistant", "content": "short"}], [{"name": "orig"}], False), | |
| ) | |
| handler = DummyBatchHandler() | |
| handler.config.optimize = True | |
| handler.config.ccr_inject_tool = True | |
| handler.openai_pipeline = SimpleNamespace( | |
| apply=lambda **kwargs: SimpleNamespace( | |
| messages=[{"role": "assistant", "content": "short"}], | |
| tokens_before=50, | |
| tokens_after=10, | |
| ) | |
| ) | |
| lines, stats = await handler._compress_batch_jsonl( | |
| "\n" | |
| + json.dumps( | |
| { | |
| "body": { | |
| "model": "gpt-4o", | |
| "messages": [{"role": "user", "content": "hello"}], | |
| "tools": [{"name": "orig"}], | |
| } | |
| } | |
| ) | |
| + "\n", | |
| "req-extra", | |
| ) | |
| assert len(lines) == 1 | |
| body = json.loads(lines[0])["body"] | |
| assert body["tools"] == [{"name": "orig"}] | |
| assert stats["total_requests"] == 1 | |
| assert stats["errors"] == 0 | |