Soumik Bose commited on
Commit
bbde124
·
1 Parent(s): 03e9433
Files changed (1) hide show
  1. cerebras_instance_provider.py +61 -24
cerebras_instance_provider.py CHANGED
@@ -1,48 +1,85 @@
1
  # instance_provider.py
2
  import os
3
- from typing import List, Optional
4
- from pydantic_ai.models.openai import OpenAIModel
5
- from pydantic_ai.providers.openai import OpenAIProvider
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
 
 
 
 
10
  class InstanceProvider:
11
- """Manages multiple Cerebras API instances with simple rotation"""
12
 
13
  def __init__(self):
14
- self.instances: List[OpenAIModel] = []
15
  self.current_index = 0
 
16
  self._initialize_instances()
17
 
18
  def _initialize_instances(self):
19
- """Load all API keys and create instances"""
 
20
  api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
21
  base_url = os.getenv("CEREBRAS_BASE_URL")
22
- model_name = os.getenv("CEREBRAS_MODEL")
23
 
24
  for key in api_keys:
25
  key = key.strip()
26
  if key:
27
- self.instances.append(
28
- OpenAIModel(
29
- model_name,
30
- provider=OpenAIProvider(
31
- base_url=base_url,
32
- api_key=key
33
- )
34
  )
35
- )
36
-
37
- def get_next_instance(self) -> Optional[OpenAIModel]:
38
- """Get next instance in rotation"""
39
- if not self.instances:
 
 
 
 
 
40
  return None
41
 
42
- instance = self.instances[self.current_index]
43
- self.current_index = (self.current_index + 1) % len(self.instances)
44
- return instance
 
 
 
 
45
 
46
  def get_total_instances(self) -> int:
47
- """Return total number of instances available"""
48
- return len(self.instances)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # instance_provider.py
2
  import os
3
+ import logging
4
+ from typing import List, Optional, Tuple
5
+ from openai import OpenAI
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
10
+ # Setup basic logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
  class InstanceProvider:
15
+ """Manages multiple Cerebras/OpenAI clients with simple rotation"""
16
 
17
  def __init__(self):
18
+ self.clients: List[OpenAI] = []
19
  self.current_index = 0
20
+ self.model_name = os.getenv("CEREBRAS_MODEL") or "llama3.1-70b"
21
  self._initialize_instances()
22
 
23
  def _initialize_instances(self):
24
+ """Load all API keys and create OpenAI clients"""
25
+ # Split keys by comma
26
  api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
27
  base_url = os.getenv("CEREBRAS_BASE_URL")
 
28
 
29
  for key in api_keys:
30
  key = key.strip()
31
  if key:
32
+ try:
33
+ # Create a standard OpenAI client for this key
34
+ client = OpenAI(
35
+ base_url=base_url,
36
+ api_key=key
 
 
37
  )
38
+ self.clients.append(client)
39
+ except Exception as e:
40
+ logger.error(f"Failed to initialize key {key[:4]}...: {e}")
41
+
42
+ def get_next_instance(self) -> Optional[Tuple[OpenAI, str]]:
43
+ """
44
+ Get next client in rotation.
45
+ Returns: Tuple (OpenAI_Client, Model_Name)
46
+ """
47
+ if not self.clients:
48
  return None
49
 
50
+ # Get current client
51
+ client = self.clients[self.current_index]
52
+
53
+ # Rotate index for the next call (Round Robin)
54
+ self.current_index = (self.current_index + 1) % len(self.clients)
55
+
56
+ return client, self.model_name
57
 
58
  def get_total_instances(self) -> int:
59
+ """Return total number of active clients available"""
60
+ return len(self.clients)
61
+
62
+ def chat_completion_with_retry(self, messages: list, **kwargs):
63
+ """
64
+ Helper function that automatically retries across all instances
65
+ if one fails.
66
+ """
67
+ total_attempts = self.get_total_instances()
68
+
69
+ for attempt in range(total_attempts):
70
+ client, model = self.get_next_instance()
71
+
72
+ try:
73
+ # Execute the API call
74
+ response = client.chat.completions.create(
75
+ model=model,
76
+ messages=messages,
77
+ **kwargs
78
+ )
79
+ return response
80
+ except Exception as e:
81
+ logger.warning(f"Instance failed (Attempt {attempt+1}/{total_attempts}): {e}")
82
+ # Loop continues to next instance automatically
83
+ continue
84
+
85
+ raise RuntimeError(f"All {total_attempts} instances failed.")