File size: 1,570 Bytes
d7d1d4e
bbde124
60aa09c
a25d048
 
 
d7d1d4e
bbde124
 
 
d7d1d4e
60aa09c
d7d1d4e
 
60aa09c
3a74ace
60aa09c
 
d7d1d4e
 
 
60aa09c
 
 
d7d1d4e
60aa09c
 
bbde124
60aa09c
bbde124
60aa09c
 
bbde124
60aa09c
3a74ace
d7d1d4e
60aa09c
 
bbde124
60aa09c
 
bbde124
60aa09c
 
 
 
 
d7d1d4e
3a74ace
60aa09c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import logging
from typing import List, Optional, Dict
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class InstanceProvider:
    """Manages multiple Cerebras API keys with simple rotation"""
    
    def __init__(self):
        self.api_keys: List[str] = []
        self.current_index = 0
        self.base_url = os.getenv("CEREBRAS_BASE_URL")
        self.model_name = os.getenv("CEREBRAS_MODEL", "llama3.1-70b")
        self._initialize_instances()
    
    def _initialize_instances(self):
        """Load all API keys into a list"""
        keys_str = os.getenv("CEREBRAS_API_KEYS", "")
        self.api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
        
        if not self.api_keys:
            logger.error("No API keys found in CEREBRAS_API_KEYS")

    def get_next_instance(self) -> Optional[Dict[str, str]]:
        """
        Returns a dictionary with the credentials for the next instance.
        Returns: {'api_key': str, 'base_url': str, 'model': str}
        """
        if not self.api_keys:
            return None
        
        # Get current key
        key = self.api_keys[self.current_index]
        
        # Rotate index for the next call
        self.current_index = (self.current_index + 1) % len(self.api_keys)
        
        return {
            "api_key": key,
            "base_url": self.base_url,
            "model": self.model_name
        }
    
    def get_total_instances(self) -> int:
        return len(self.api_keys)