Spaces:
Running
Running
| """Integration tests for the proxy with real Gemini API calls. | |
| These tests require a valid GEMINI_API_KEY environment variable. | |
| They test the actual /v1/chat/completions endpoint with real API calls. | |
| Run with: | |
| GEMINI_API_KEY=your-key pytest tests/test_proxy_gemini_integration.py -v | |
| """ | |
| import json | |
| import os | |
| import pytest | |
| # Skip entire module if no API key | |
| pytestmark = pytest.mark.skipif( | |
| not os.environ.get("GEMINI_API_KEY"), reason="GEMINI_API_KEY not set" | |
| ) | |
| pytest.importorskip("fastapi") | |
| pytest.importorskip("httpx") | |
| from fastapi.testclient import TestClient # noqa: E402 | |
| from headroom.proxy.server import ProxyConfig, create_app # noqa: E402 | |
| GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai" | |
| def gemini_client(): | |
| """Create test client configured to forward to Gemini.""" | |
| config = ProxyConfig( | |
| optimize=False, | |
| cache_enabled=False, | |
| rate_limit_enabled=False, | |
| cost_tracking_enabled=False, | |
| openai_api_url=GEMINI_BASE_URL, | |
| ) | |
| app = create_app(config) | |
| with TestClient(app) as client: | |
| yield client | |
| def api_key(): | |
| """Get Gemini API key from environment.""" | |
| return os.environ.get("GEMINI_API_KEY") | |
| class TestGeminiChatCompletions: | |
| """Test /v1/chat/completions with real Gemini API.""" | |
| def test_basic_completion(self, gemini_client, api_key): | |
| """Basic chat completion works.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "gemini-2.0-flash", | |
| "messages": [ | |
| {"role": "user", "content": "What is 2+2? Reply with just the number."} | |
| ], | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify OpenAI-compatible response format | |
| assert "choices" in data | |
| assert len(data["choices"]) > 0 | |
| assert "message" in data["choices"][0] | |
| assert "content" in data["choices"][0]["message"] | |
| assert "4" in data["choices"][0]["message"]["content"] | |
| # Verify usage stats | |
| assert "usage" in data | |
| assert "prompt_tokens" in data["usage"] | |
| assert "completion_tokens" in data["usage"] | |
| def test_multi_turn_conversation(self, gemini_client, api_key): | |
| """Multi-turn conversations maintain context.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "gemini-2.0-flash", | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant. Be concise."}, | |
| {"role": "user", "content": "My name is TestUser123."}, | |
| {"role": "assistant", "content": "Nice to meet you, TestUser123!"}, | |
| {"role": "user", "content": "What is my name?"}, | |
| ], | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| content = data["choices"][0]["message"]["content"].lower() | |
| assert "testuser123" in content | |
| def test_streaming(self, gemini_client, api_key): | |
| """Streaming responses work correctly.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "gemini-2.0-flash", | |
| "stream": True, | |
| "messages": [{"role": "user", "content": "Count from 1 to 3."}], | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| # Parse SSE stream | |
| chunks = [] | |
| for line in response.text.strip().split("\n"): | |
| if line.startswith("data: ") and line != "data: [DONE]": | |
| chunk = json.loads(line[6:]) | |
| chunks.append(chunk) | |
| assert len(chunks) > 0 | |
| # Verify chunk format | |
| for chunk in chunks: | |
| assert "choices" in chunk | |
| assert "delta" in chunk["choices"][0] | |
| assert chunk["object"] == "chat.completion.chunk" | |
| def test_function_calling(self, gemini_client, api_key): | |
| """Function calling / tools work correctly.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "gemini-2.0-flash", | |
| "messages": [{"role": "user", "content": "What is the weather in Paris?"}], | |
| "tools": [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_weather", | |
| "description": "Get the weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string", "description": "City name"} | |
| }, | |
| "required": ["location"], | |
| }, | |
| }, | |
| } | |
| ], | |
| "tool_choice": "auto", | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify tool call response | |
| message = data["choices"][0]["message"] | |
| assert "tool_calls" in message | |
| assert len(message["tool_calls"]) > 0 | |
| tool_call = message["tool_calls"][0] | |
| assert tool_call["function"]["name"] == "get_weather" | |
| assert "paris" in tool_call["function"]["arguments"].lower() | |
| def test_json_mode(self, gemini_client, api_key): | |
| """JSON response format works.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "gemini-2.0-flash", | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": "Return a JSON object with keys 'name' and 'age' for a person named Alice who is 30 years old.", | |
| } | |
| ], | |
| "response_format": {"type": "json_object"}, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| content = data["choices"][0]["message"]["content"] | |
| # Parse the response as JSON | |
| parsed = json.loads(content) | |
| assert "name" in parsed or "Name" in parsed | |
| assert "age" in parsed or "Age" in parsed | |
| class TestGeminiModels: | |
| """Test /v1/models endpoint with Gemini.""" | |
| def test_list_models(self, gemini_client, api_key): | |
| """Can list available models.""" | |
| response = gemini_client.get("/v1/models", headers={"Authorization": f"Bearer {api_key}"}) | |
| # This goes through passthrough handler | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "data" in data or "object" in data | |
| class TestProxyStats: | |
| """Test that proxy stats track Gemini requests correctly.""" | |
| def test_stats_track_requests(self, gemini_client, api_key): | |
| """Proxy stats track Gemini requests.""" | |
| # Make a request | |
| gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={"model": "gemini-2.0-flash", "messages": [{"role": "user", "content": "Hi"}]}, | |
| ) | |
| # Check stats | |
| stats_response = gemini_client.get("/stats") | |
| assert stats_response.status_code == 200 | |
| stats = stats_response.json() | |
| assert stats["requests"]["total"] >= 1 | |
| assert stats["requests"]["by_provider"]["openai"] >= 1 | |
| assert "gemini" in str(stats["requests"]["by_model"]).lower() | |
| class TestErrorHandling: | |
| """Test error handling with Gemini.""" | |
| def test_invalid_api_key(self, gemini_client): | |
| """Invalid API key returns appropriate error.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": "Bearer invalid-key-123"}, | |
| json={"model": "gemini-2.0-flash", "messages": [{"role": "user", "content": "Hi"}]}, | |
| ) | |
| # Should return 4xx error | |
| assert response.status_code >= 400 | |
| def test_invalid_model(self, gemini_client, api_key): | |
| """Invalid model returns appropriate error.""" | |
| response = gemini_client.post( | |
| "/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={ | |
| "model": "nonexistent-model-xyz", | |
| "messages": [{"role": "user", "content": "Hi"}], | |
| }, | |
| ) | |
| # Should return 4xx error | |
| assert response.status_code >= 400 | |