Spaces:
Running
Running
| """Integration tests for Gemini native API endpoint with real API calls. | |
| These tests require a valid GEMINI_API_KEY environment variable. | |
| They test the /v1beta/models/{model}:generateContent endpoint with compression. | |
| Run with: | |
| GEMINI_API_KEY=your-key pytest tests/test_proxy_gemini_native_integration.py -v | |
| """ | |
| import json | |
| import os | |
| import pytest | |
| # Skip entire module if no API key | |
| pytestmark = pytest.mark.skipif( | |
| not os.environ.get("GEMINI_API_KEY"), reason="GEMINI_API_KEY not set" | |
| ) | |
| pytest.importorskip("fastapi") | |
| pytest.importorskip("httpx") | |
| from fastapi.testclient import TestClient # noqa: E402 | |
| from headroom.proxy.server import ProxyConfig, create_app # noqa: E402 | |
| def gemini_native_client(): | |
| """Create test client for Gemini native API with optimization enabled.""" | |
| config = ProxyConfig( | |
| optimize=True, # Enable compression | |
| cache_enabled=False, | |
| rate_limit_enabled=False, | |
| cost_tracking_enabled=False, | |
| ) | |
| app = create_app(config) | |
| with TestClient(app) as client: | |
| yield client | |
| def api_key(): | |
| """Get Gemini API key from environment.""" | |
| return os.environ.get("GEMINI_API_KEY") | |
| class TestGeminiNativeGenerateContent: | |
| """Test /v1beta/models/{model}:generateContent endpoint.""" | |
| def test_basic_generation(self, gemini_native_client, api_key): | |
| """Basic text generation works.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={"contents": [{"parts": [{"text": "What is 2+2? Reply with just the number."}]}]}, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify Gemini native response format | |
| assert "candidates" in data | |
| assert len(data["candidates"]) > 0 | |
| assert "content" in data["candidates"][0] | |
| assert "parts" in data["candidates"][0]["content"] | |
| text = data["candidates"][0]["content"]["parts"][0]["text"] | |
| assert "4" in text | |
| # Verify usage metadata | |
| assert "usageMetadata" in data | |
| assert "promptTokenCount" in data["usageMetadata"] | |
| def test_with_system_instruction(self, gemini_native_client, api_key): | |
| """System instruction works correctly.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [{"parts": [{"text": "Hello"}]}], | |
| "systemInstruction": {"parts": [{"text": "Always respond with exactly one word."}]}, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| text = data["candidates"][0]["content"]["parts"][0]["text"] | |
| # Should be a short response due to system instruction | |
| assert len(text.split()) <= 3 | |
| def test_multi_turn_conversation(self, gemini_native_client, api_key): | |
| """Multi-turn conversations maintain context.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "My name is TestUser456."}]}, | |
| {"role": "model", "parts": [{"text": "Nice to meet you, TestUser456!"}]}, | |
| {"role": "user", "parts": [{"text": "What is my name?"}]}, | |
| ] | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| text = data["candidates"][0]["content"]["parts"][0]["text"].lower() | |
| assert "testuser456" in text | |
| def test_function_calling(self, gemini_native_client, api_key): | |
| """Function calling / tools work correctly.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [{"parts": [{"text": "What is the weather in Tokyo?"}]}], | |
| "tools": [ | |
| { | |
| "functionDeclarations": [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string", "description": "City name"} | |
| }, | |
| "required": ["location"], | |
| }, | |
| } | |
| ] | |
| } | |
| ], | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify function call response | |
| parts = data["candidates"][0]["content"]["parts"] | |
| function_call = None | |
| for part in parts: | |
| if "functionCall" in part: | |
| function_call = part["functionCall"] | |
| break | |
| assert function_call is not None | |
| assert function_call["name"] == "get_weather" | |
| assert "tokyo" in function_call["args"]["location"].lower() | |
| def test_generation_config(self, gemini_native_client, api_key): | |
| """Generation config parameters are respected.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [{"parts": [{"text": "Write a very short poem about AI."}]}], | |
| "generationConfig": {"maxOutputTokens": 50, "temperature": 0.1}, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Response should be limited by maxOutputTokens | |
| assert data["usageMetadata"]["candidatesTokenCount"] <= 60 # Some buffer | |
| class TestGeminiNativeCompression: | |
| """Test that compression works with Gemini native API.""" | |
| def test_compression_on_model_message(self, gemini_native_client, api_key): | |
| """Large data in model message gets compressed.""" | |
| # Create large JSON data (simulating tool output) | |
| items = [ | |
| {"id": i, "name": f"Item {i}", "desc": f"Description for item {i}"} for i in range(100) | |
| ] | |
| tool_output = json.dumps(items) | |
| # Send as model message (like tool returning data) | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "Get items from database"}]}, | |
| {"role": "model", "parts": [{"text": f"Here are the results:\n{tool_output}"}]}, | |
| {"role": "user", "parts": [{"text": "How many items are there?"}]}, | |
| ] | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| text = data["candidates"][0]["content"]["parts"][0]["text"] | |
| # Model should correctly count the items | |
| assert "100" in text | |
| # Check that compression happened via stats | |
| stats = gemini_native_client.get("/stats").json() | |
| # At least some tokens should have been saved | |
| assert stats["tokens"]["saved"] >= 0 # May or may not compress depending on size | |
| def test_user_messages_protected(self, gemini_native_client, api_key): | |
| """User messages are not compressed (by design).""" | |
| # Large data in user message | |
| items = [{"id": i} for i in range(50)] | |
| user_data = json.dumps(items) | |
| # First request with data in user message | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": f"Analyze this data: {user_data}"}]} | |
| ] | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| # The request should succeed - user messages are protected from compression | |
| class TestGeminiNativeStats: | |
| """Test that proxy stats track Gemini native requests correctly.""" | |
| def test_stats_track_gemini_provider(self, gemini_native_client, api_key): | |
| """Stats show requests under 'gemini' provider.""" | |
| # Make a request | |
| gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={"contents": [{"parts": [{"text": "Hi"}]}]}, | |
| ) | |
| stats = gemini_native_client.get("/stats").json() | |
| assert "gemini" in stats["requests"]["by_provider"] | |
| assert stats["requests"]["by_provider"]["gemini"] >= 1 | |
| def test_stats_track_model(self, gemini_native_client, api_key): | |
| """Stats track the specific model used.""" | |
| gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", | |
| json={"contents": [{"parts": [{"text": "Hi"}]}]}, | |
| ) | |
| stats = gemini_native_client.get("/stats").json() | |
| assert "gemini-2.0-flash" in stats["requests"]["by_model"] | |
| class TestGeminiNativeErrorHandling: | |
| """Test error handling for Gemini native API.""" | |
| def test_invalid_api_key(self, gemini_native_client): | |
| """Invalid API key returns appropriate error.""" | |
| response = gemini_native_client.post( | |
| "/v1beta/models/gemini-2.0-flash:generateContent?key=invalid-key-123", | |
| json={"contents": [{"parts": [{"text": "Hi"}]}]}, | |
| ) | |
| assert response.status_code >= 400 | |
| def test_invalid_model(self, gemini_native_client, api_key): | |
| """Invalid model returns appropriate error.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/nonexistent-model-xyz:generateContent?key={api_key}", | |
| json={"contents": [{"parts": [{"text": "Hi"}]}]}, | |
| ) | |
| assert response.status_code >= 400 | |
| def test_empty_contents(self, gemini_native_client, api_key): | |
| """Empty contents handled gracefully.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", json={"contents": []} | |
| ) | |
| # Should either return error or handle gracefully | |
| assert response.status_code in [200, 400] | |
| class TestGeminiNativeHeaderAuth: | |
| """Test authentication via x-goog-api-key header.""" | |
| def test_header_auth(self, gemini_native_client, api_key): | |
| """API key in header works.""" | |
| response = gemini_native_client.post( | |
| "/v1beta/models/gemini-2.0-flash:generateContent", | |
| headers={"x-goog-api-key": api_key}, | |
| json={"contents": [{"parts": [{"text": "Hi"}]}]}, | |
| ) | |
| assert response.status_code == 200 | |
| class TestGeminiNativeCountTokens: | |
| """Test /v1beta/models/{model}:countTokens endpoint with compression.""" | |
| def test_count_tokens_basic(self, gemini_native_client, api_key): | |
| """Basic token counting works.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}", | |
| json={"contents": [{"parts": [{"text": "Hello, world!"}]}]}, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify response format | |
| assert "totalTokens" in data | |
| assert isinstance(data["totalTokens"], int) | |
| assert data["totalTokens"] > 0 | |
| def test_count_tokens_with_system_instruction(self, gemini_native_client, api_key): | |
| """Token counting includes system instruction.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}", | |
| json={ | |
| "contents": [{"parts": [{"text": "Hello"}]}], | |
| "systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]}, | |
| }, | |
| ) | |
| # Note: systemInstruction may not be supported by countTokens in all versions | |
| assert response.status_code in [200, 400] | |
| if response.status_code == 200: | |
| data = response.json() | |
| assert "totalTokens" in data | |
| assert data["totalTokens"] > 0 | |
| def test_count_tokens_reflects_compression(self, gemini_native_client, api_key): | |
| """Token count reflects compressed content size.""" | |
| # Create large repetitive JSON data that should compress | |
| items = [ | |
| { | |
| "id": i, | |
| "name": f"Item {i}", | |
| "description": f"This is the description for item number {i}", | |
| } | |
| for i in range(100) | |
| ] | |
| tool_output = json.dumps(items) | |
| # Count tokens with large data in model message (which gets compressed) | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}", | |
| json={ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "Get items from database"}]}, | |
| {"role": "model", "parts": [{"text": f"Here are the results:\n{tool_output}"}]}, | |
| {"role": "user", "parts": [{"text": "Summarize these items"}]}, | |
| ] | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Verify we got a token count | |
| assert "totalTokens" in data | |
| compressed_tokens = data["totalTokens"] | |
| assert compressed_tokens > 0 | |
| # Check stats to verify compression was applied | |
| stats = gemini_native_client.get("/stats").json() | |
| # The request should have been tracked | |
| assert stats["requests"]["by_provider"].get("gemini", 0) >= 1 | |
| def test_count_tokens_multi_turn(self, gemini_native_client, api_key): | |
| """Token counting works for multi-turn conversations.""" | |
| response = gemini_native_client.post( | |
| f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}", | |
| json={ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "My name is Alice."}]}, | |
| {"role": "model", "parts": [{"text": "Nice to meet you, Alice!"}]}, | |
| {"role": "user", "parts": [{"text": "What is my name?"}]}, | |
| ] | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "totalTokens" in data | |
| assert data["totalTokens"] > 0 | |
| def test_count_tokens_header_auth(self, gemini_native_client, api_key): | |
| """API key in header works for countTokens.""" | |
| response = gemini_native_client.post( | |
| "/v1beta/models/gemini-2.0-flash:countTokens", | |
| headers={"x-goog-api-key": api_key}, | |
| json={"contents": [{"parts": [{"text": "Hello"}]}]}, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "totalTokens" in data | |