File size: 5,131 Bytes
b49d66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b468e
b49d66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b468e
b49d66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
API Key Manager - Round-robin management for multiple Gemini API keys.

Selects the least-used key to avoid rate limiting.
Tracks usage per key index (not the actual key for security).
"""
import os
import logging
from datetime import datetime
from typing import Optional, List, Tuple
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

logger = logging.getLogger(__name__)

# Cache for API keys (loaded once from env)
_api_keys: Optional[List[str]] = None


def get_api_keys() -> List[str]:
    """Load API keys from GEMINI_API_KEYS environment variable."""
    global _api_keys
    if _api_keys is None:
        keys_str = os.getenv("GEMINI_API_KEYS", "")
        if not keys_str:
            # Fallback to single key
            single_key = os.getenv("GEMINI_API_KEY", "")
            if single_key:
                _api_keys = [single_key]
            else:
                _api_keys = []  # Return empty list if no keys configured
        else:
            _api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
        
        if _api_keys:
            logger.info(f"Loaded {len(_api_keys)} API key(s)")
        else:
            logger.warning("Unable to authenticate GEMINI.")
    
    return _api_keys


def get_key_count() -> int:
    """Get the number of available API keys."""
    return len(get_api_keys())


async def get_least_used_key(db: AsyncSession) -> Tuple[int, str]:
    """
    Get the API key with least requests (round-robin style).
    
    Returns:
        Tuple of (key_index, api_key)
    
    Raises:
        ValueError: If no API keys are configured
    """
    from core.models import ApiKeyUsage
    
    keys = get_api_keys()
    if not keys:
        raise ValueError("No API keys configured. Set GEMINI_API_KEYS or GEMINI_API_KEY in environment.")
    
    # Get all usage stats
    query = select(ApiKeyUsage).order_by(ApiKeyUsage.total_requests)
    result = await db.execute(query)
    usages = {u.key_index: u for u in result.scalars().all()}
    
    # Find the key with least usage
    min_requests = float('inf')
    selected_index = 0
    
    for i in range(len(keys)):
        if i in usages:
            if usages[i].total_requests < min_requests:
                min_requests = usages[i].total_requests
                selected_index = i
        else:
            # Key not in DB yet - create it and use it (0 requests)
            # Note: Don't commit here - caller handles transaction
            new_usage = ApiKeyUsage(key_index=i, total_requests=0, success_count=0, failure_count=0)
            db.add(new_usage)
            selected_index = i
            break
    
    logger.debug(f"Selected API key index {selected_index} (least used)")
    return selected_index, keys[selected_index]


async def record_usage(db: AsyncSession, key_index: int, success: bool, error_message: Optional[str] = None):
    """
    Record API key usage after a request.
    
    Args:
        db: Database session
        key_index: Index of the key used
        success: Whether the request succeeded
        error_message: Error message if request failed
    """
    from core.models import ApiKeyUsage
    
    query = select(ApiKeyUsage).where(ApiKeyUsage.key_index == key_index)
    result = await db.execute(query)
    usage = result.scalar_one_or_none()
    
    if not usage:
        usage = ApiKeyUsage(key_index=key_index, total_requests=0, success_count=0, failure_count=0)
        db.add(usage)
    
    usage.total_requests += 1
    if success:
        usage.success_count += 1
    else:
        usage.failure_count += 1
        if error_message:
            usage.last_error = error_message[:1000]  # Truncate to 1000 chars
    usage.last_used_at = datetime.utcnow()
    
    # Note: Don't commit here - caller handles transaction atomically
    logger.debug(f"Recorded {'success' if success else 'failure'} for key index {key_index}")


async def get_all_usage_stats(db: AsyncSession) -> List[dict]:
    """
    Get usage stats for all API keys.
    
    Returns:
        List of dicts with key_index, total_requests, success_count, failure_count
    """
    from core.models import ApiKeyUsage
    
    keys = get_api_keys()
    query = select(ApiKeyUsage).order_by(ApiKeyUsage.key_index)
    result = await db.execute(query)
    usages = {u.key_index: u for u in result.scalars().all()}
    
    stats = []
    for i in range(len(keys)):
        if i in usages:
            u = usages[i]
            stats.append({
                "key_index": i,
                "total_requests": u.total_requests,
                "success_count": u.success_count,
                "failure_count": u.failure_count,
                "last_error": u.last_error,
                "last_used_at": u.last_used_at.isoformat() if u.last_used_at else None
            })
        else:
            stats.append({
                "key_index": i,
                "total_requests": 0,
                "success_count": 0,
                "failure_count": 0,
                "last_error": None,
                "last_used_at": None
            })
    
    return stats