pythermalcomfort_Chat / scripts /test_gemini_live.py
sadickam's picture
Prepare for HF Space deployment
d01a7e3
#!/usr/bin/env python3
"""Live integration test for Gemini LLM provider.
This script tests the GeminiLLM implementation against the real Gemini API.
It validates all functionality including:
- Provider availability check
- Health check connectivity
- Complete response generation
- Streaming response generation
Usage:
# Option 1: Add GEMINI_API_KEY to .env file
echo 'GEMINI_API_KEY="your-api-key"' >> .env
# Option 2: Set in environment
export GEMINI_API_KEY="your-api-key"
# Run the test
poetry run python scripts/test_gemini_live.py
Requirements:
- Valid GEMINI_API_KEY in .env file or environment variable
- google-generativeai package installed
Note:
This script makes real API calls and may consume API quota.
"""
from __future__ import annotations
import asyncio
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
def load_env() -> None:
"""Load environment variables from .env file."""
# Try to load from .env file in project root
env_path = Path(__file__).parent.parent / ".env"
if env_path.exists():
try:
from dotenv import load_dotenv
load_dotenv(env_path)
print(f"[OK] Loaded environment from {env_path}")
except ImportError:
# dotenv not installed, try manual parsing
print("[INFO] python-dotenv not installed, parsing .env manually")
with open(env_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and value:
os.environ[key] = value
print(f"[OK] Manually loaded environment from {env_path}")
else:
print(f"[INFO] No .env file found at {env_path}")
async def run_live_test() -> bool:
"""Run live integration tests against Gemini API.
Returns:
True if all tests pass, False otherwise.
"""
print("=" * 60)
print("GEMINI LLM LIVE INTEGRATION TEST")
print("=" * 60)
print()
# Check for API key
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
print("[FAIL] GEMINI_API_KEY environment variable not set")
print()
print("Please set your API key:")
print(" export GEMINI_API_KEY='your-api-key'")
return False
print(f"[OK] API key found (length: {len(api_key)} chars)")
print()
# Import the provider (lazy loading test)
print("-" * 60)
print("TEST 1: Import and Initialization")
print("-" * 60)
try:
from rag_chatbot.llm.gemini import GeminiLLM, RateLimitError
from rag_chatbot.llm.base import LLMRequest
print("[OK] Imports successful (lazy loading verified)")
except ImportError as e:
print(f"[FAIL] Import error: {e}")
return False
# Initialize provider with gemma-3-27b-it model
try:
llm = GeminiLLM(api_key=api_key, model="gemma-3-27b-it")
print(f"[OK] GeminiLLM initialized")
print(f" Provider: {llm.provider_name}")
print(f" Model: {llm.model_name}")
print(f" Timeout: {llm.timeout_ms}ms")
except Exception as e:
print(f"[FAIL] Initialization error: {e}")
return False
print()
# Test is_available
print("-" * 60)
print("TEST 2: is_available Property")
print("-" * 60)
if llm.is_available:
print("[OK] Provider reports available")
else:
print("[FAIL] Provider reports unavailable")
return False
print()
# Test check_health
print("-" * 60)
print("TEST 3: check_health() Method")
print("-" * 60)
print("Making lightweight API call to verify connectivity...")
try:
is_healthy = await llm.check_health()
if is_healthy:
print("[OK] Health check passed - API is reachable")
else:
print("[WARN] Health check returned False")
print(" This may indicate API issues or invalid key")
except Exception as e:
print(f"[FAIL] Health check error: {e}")
return False
print()
# Test generate
print("-" * 60)
print("TEST 4: generate() Method")
print("-" * 60)
request = LLMRequest(
query="What is thermal comfort in buildings? Answer in 2-3 sentences.",
context=[
"Thermal comfort is the condition of mind that expresses satisfaction with the thermal environment.",
"ASHRAE Standard 55 defines thermal comfort conditions for building occupants.",
],
max_tokens=256,
temperature=0.7,
)
print(f"Query: {request.query}")
print(f"Context chunks: {len(request.context)}")
print(f"Max tokens: {request.max_tokens}")
print(f"Temperature: {request.temperature}")
print()
print("Generating response...")
try:
response = await llm.generate(request)
print()
print("[OK] Response received!")
print(f" Provider: {response.provider}")
print(f" Model: {response.model}")
print(f" Tokens used: {response.tokens_used}")
print(f" Latency: {response.latency_ms}ms")
print()
print("Response content:")
print("-" * 40)
print(response.content)
print("-" * 40)
except RateLimitError as e:
print(f"[WARN] Rate limit hit: {e}")
print(f" Retry after: {e.retry_after} seconds")
except TimeoutError as e:
print(f"[FAIL] Timeout: {e}")
return False
except RuntimeError as e:
print(f"[FAIL] API error: {e}")
return False
except Exception as e:
print(f"[FAIL] Unexpected error: {type(e).__name__}: {e}")
return False
print()
# Test stream
print("-" * 60)
print("TEST 5: stream() Method")
print("-" * 60)
stream_request = LLMRequest(
query="What is PMV? Answer in one sentence.",
context=["PMV stands for Predicted Mean Vote, a thermal comfort index."],
max_tokens=128,
temperature=0.5,
)
print(f"Query: {stream_request.query}")
print()
print("Streaming response:")
print("-" * 40)
try:
chunks_received = 0
full_response = ""
async for chunk in llm.stream(stream_request):
print(chunk, end="", flush=True)
full_response += chunk
chunks_received += 1
print()
print("-" * 40)
print()
print(f"[OK] Streaming complete!")
print(f" Chunks received: {chunks_received}")
print(f" Total length: {len(full_response)} chars")
except RateLimitError as e:
print()
print(f"[WARN] Rate limit hit: {e}")
print(f" Retry after: {e.retry_after} seconds")
except TimeoutError as e:
print()
print(f"[FAIL] Timeout: {e}")
return False
except RuntimeError as e:
print()
print(f"[FAIL] API error: {e}")
return False
except Exception as e:
print()
print(f"[FAIL] Unexpected error: {type(e).__name__}: {e}")
return False
print()
# Summary
print("=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print()
print("[OK] All live integration tests PASSED!")
print()
print("The GeminiLLM provider is working correctly with:")
print(" - Valid API key authentication")
print(" - Health check connectivity")
print(" - Complete response generation")
print(" - Streaming response generation")
print()
return True
def main() -> int:
"""Main entry point."""
# Load environment variables from .env file
load_env()
print()
try:
success = asyncio.run(run_live_test())
return 0 if success else 1
except KeyboardInterrupt:
print("\n\nTest interrupted by user")
return 130
if __name__ == "__main__":
sys.exit(main())