headroom / tests /test_proxy /test_proxy_gemini_native_integration.py
tudragon154203
fix: route count_tokens to api.anthropic.com, not proxy base_url
0adb431
"""Integration tests for Gemini native API endpoint with real API calls.
These tests require a valid GEMINI_API_KEY environment variable.
They test the /v1beta/models/{model}:generateContent endpoint with compression.
Run with:
GEMINI_API_KEY=your-key pytest tests/test_proxy_gemini_native_integration.py -v
"""
import json
import os
import pytest
# Skip entire module if no API key
pytestmark = pytest.mark.skipif(
not os.environ.get("GEMINI_API_KEY"), reason="GEMINI_API_KEY not set"
)
pytest.importorskip("fastapi")
pytest.importorskip("httpx")
from fastapi.testclient import TestClient # noqa: E402
from headroom.proxy.server import ProxyConfig, create_app # noqa: E402
@pytest.fixture
def gemini_native_client():
"""Create test client for Gemini native API with optimization enabled."""
config = ProxyConfig(
optimize=True, # Enable compression
cache_enabled=False,
rate_limit_enabled=False,
cost_tracking_enabled=False,
)
app = create_app(config)
with TestClient(app) as client:
yield client
@pytest.fixture
def api_key():
"""Get Gemini API key from environment."""
return os.environ.get("GEMINI_API_KEY")
class TestGeminiNativeGenerateContent:
"""Test /v1beta/models/{model}:generateContent endpoint."""
def test_basic_generation(self, gemini_native_client, api_key):
"""Basic text generation works."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={"contents": [{"parts": [{"text": "What is 2+2? Reply with just the number."}]}]},
)
assert response.status_code == 200
data = response.json()
# Verify Gemini native response format
assert "candidates" in data
assert len(data["candidates"]) > 0
assert "content" in data["candidates"][0]
assert "parts" in data["candidates"][0]["content"]
text = data["candidates"][0]["content"]["parts"][0]["text"]
assert "4" in text
# Verify usage metadata
assert "usageMetadata" in data
assert "promptTokenCount" in data["usageMetadata"]
def test_with_system_instruction(self, gemini_native_client, api_key):
"""System instruction works correctly."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [{"parts": [{"text": "Hello"}]}],
"systemInstruction": {"parts": [{"text": "Always respond with exactly one word."}]},
},
)
assert response.status_code == 200
data = response.json()
text = data["candidates"][0]["content"]["parts"][0]["text"]
# Should be a short response due to system instruction
assert len(text.split()) <= 3
def test_multi_turn_conversation(self, gemini_native_client, api_key):
"""Multi-turn conversations maintain context."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [
{"role": "user", "parts": [{"text": "My name is TestUser456."}]},
{"role": "model", "parts": [{"text": "Nice to meet you, TestUser456!"}]},
{"role": "user", "parts": [{"text": "What is my name?"}]},
]
},
)
assert response.status_code == 200
data = response.json()
text = data["candidates"][0]["content"]["parts"][0]["text"].lower()
assert "testuser456" in text
def test_function_calling(self, gemini_native_client, api_key):
"""Function calling / tools work correctly."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [{"parts": [{"text": "What is the weather in Tokyo?"}]}],
"tools": [
{
"functionDeclarations": [
{
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
},
}
]
}
],
},
)
assert response.status_code == 200
data = response.json()
# Verify function call response
parts = data["candidates"][0]["content"]["parts"]
function_call = None
for part in parts:
if "functionCall" in part:
function_call = part["functionCall"]
break
assert function_call is not None
assert function_call["name"] == "get_weather"
assert "tokyo" in function_call["args"]["location"].lower()
def test_generation_config(self, gemini_native_client, api_key):
"""Generation config parameters are respected."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [{"parts": [{"text": "Write a very short poem about AI."}]}],
"generationConfig": {"maxOutputTokens": 50, "temperature": 0.1},
},
)
assert response.status_code == 200
data = response.json()
# Response should be limited by maxOutputTokens
assert data["usageMetadata"]["candidatesTokenCount"] <= 60 # Some buffer
class TestGeminiNativeCompression:
"""Test that compression works with Gemini native API."""
def test_compression_on_model_message(self, gemini_native_client, api_key):
"""Large data in model message gets compressed."""
# Create large JSON data (simulating tool output)
items = [
{"id": i, "name": f"Item {i}", "desc": f"Description for item {i}"} for i in range(100)
]
tool_output = json.dumps(items)
# Send as model message (like tool returning data)
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [
{"role": "user", "parts": [{"text": "Get items from database"}]},
{"role": "model", "parts": [{"text": f"Here are the results:\n{tool_output}"}]},
{"role": "user", "parts": [{"text": "How many items are there?"}]},
]
},
)
assert response.status_code == 200
data = response.json()
text = data["candidates"][0]["content"]["parts"][0]["text"]
# Model should correctly count the items
assert "100" in text
# Check that compression happened via stats
stats = gemini_native_client.get("/stats").json()
# At least some tokens should have been saved
assert stats["tokens"]["saved"] >= 0 # May or may not compress depending on size
def test_user_messages_protected(self, gemini_native_client, api_key):
"""User messages are not compressed (by design)."""
# Large data in user message
items = [{"id": i} for i in range(50)]
user_data = json.dumps(items)
# First request with data in user message
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={
"contents": [
{"role": "user", "parts": [{"text": f"Analyze this data: {user_data}"}]}
]
},
)
assert response.status_code == 200
# The request should succeed - user messages are protected from compression
class TestGeminiNativeStats:
"""Test that proxy stats track Gemini native requests correctly."""
def test_stats_track_gemini_provider(self, gemini_native_client, api_key):
"""Stats show requests under 'gemini' provider."""
# Make a request
gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={"contents": [{"parts": [{"text": "Hi"}]}]},
)
stats = gemini_native_client.get("/stats").json()
assert "gemini" in stats["requests"]["by_provider"]
assert stats["requests"]["by_provider"]["gemini"] >= 1
def test_stats_track_model(self, gemini_native_client, api_key):
"""Stats track the specific model used."""
gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}",
json={"contents": [{"parts": [{"text": "Hi"}]}]},
)
stats = gemini_native_client.get("/stats").json()
assert "gemini-2.0-flash" in stats["requests"]["by_model"]
class TestGeminiNativeErrorHandling:
"""Test error handling for Gemini native API."""
def test_invalid_api_key(self, gemini_native_client):
"""Invalid API key returns appropriate error."""
response = gemini_native_client.post(
"/v1beta/models/gemini-2.0-flash:generateContent?key=invalid-key-123",
json={"contents": [{"parts": [{"text": "Hi"}]}]},
)
assert response.status_code >= 400
def test_invalid_model(self, gemini_native_client, api_key):
"""Invalid model returns appropriate error."""
response = gemini_native_client.post(
f"/v1beta/models/nonexistent-model-xyz:generateContent?key={api_key}",
json={"contents": [{"parts": [{"text": "Hi"}]}]},
)
assert response.status_code >= 400
def test_empty_contents(self, gemini_native_client, api_key):
"""Empty contents handled gracefully."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}", json={"contents": []}
)
# Should either return error or handle gracefully
assert response.status_code in [200, 400]
class TestGeminiNativeHeaderAuth:
"""Test authentication via x-goog-api-key header."""
def test_header_auth(self, gemini_native_client, api_key):
"""API key in header works."""
response = gemini_native_client.post(
"/v1beta/models/gemini-2.0-flash:generateContent",
headers={"x-goog-api-key": api_key},
json={"contents": [{"parts": [{"text": "Hi"}]}]},
)
assert response.status_code == 200
class TestGeminiNativeCountTokens:
"""Test /v1beta/models/{model}:countTokens endpoint with compression."""
def test_count_tokens_basic(self, gemini_native_client, api_key):
"""Basic token counting works."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}",
json={"contents": [{"parts": [{"text": "Hello, world!"}]}]},
)
assert response.status_code == 200
data = response.json()
# Verify response format
assert "totalTokens" in data
assert isinstance(data["totalTokens"], int)
assert data["totalTokens"] > 0
def test_count_tokens_with_system_instruction(self, gemini_native_client, api_key):
"""Token counting includes system instruction."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}",
json={
"contents": [{"parts": [{"text": "Hello"}]}],
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
},
)
# Note: systemInstruction may not be supported by countTokens in all versions
assert response.status_code in [200, 400]
if response.status_code == 200:
data = response.json()
assert "totalTokens" in data
assert data["totalTokens"] > 0
def test_count_tokens_reflects_compression(self, gemini_native_client, api_key):
"""Token count reflects compressed content size."""
# Create large repetitive JSON data that should compress
items = [
{
"id": i,
"name": f"Item {i}",
"description": f"This is the description for item number {i}",
}
for i in range(100)
]
tool_output = json.dumps(items)
# Count tokens with large data in model message (which gets compressed)
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}",
json={
"contents": [
{"role": "user", "parts": [{"text": "Get items from database"}]},
{"role": "model", "parts": [{"text": f"Here are the results:\n{tool_output}"}]},
{"role": "user", "parts": [{"text": "Summarize these items"}]},
]
},
)
assert response.status_code == 200
data = response.json()
# Verify we got a token count
assert "totalTokens" in data
compressed_tokens = data["totalTokens"]
assert compressed_tokens > 0
# Check stats to verify compression was applied
stats = gemini_native_client.get("/stats").json()
# The request should have been tracked
assert stats["requests"]["by_provider"].get("gemini", 0) >= 1
def test_count_tokens_multi_turn(self, gemini_native_client, api_key):
"""Token counting works for multi-turn conversations."""
response = gemini_native_client.post(
f"/v1beta/models/gemini-2.0-flash:countTokens?key={api_key}",
json={
"contents": [
{"role": "user", "parts": [{"text": "My name is Alice."}]},
{"role": "model", "parts": [{"text": "Nice to meet you, Alice!"}]},
{"role": "user", "parts": [{"text": "What is my name?"}]},
]
},
)
assert response.status_code == 200
data = response.json()
assert "totalTokens" in data
assert data["totalTokens"] > 0
def test_count_tokens_header_auth(self, gemini_native_client, api_key):
"""API key in header works for countTokens."""
response = gemini_native_client.post(
"/v1beta/models/gemini-2.0-flash:countTokens",
headers={"x-goog-api-key": api_key},
json={"contents": [{"parts": [{"text": "Hello"}]}]},
)
assert response.status_code == 200
data = response.json()
assert "totalTokens" in data