Mirrowel commited on
Commit
576a3ec
·
1 Parent(s): 5fb237f

feat: Safety settings addition(only gemini for now)

Browse files
src/rotator_library/client.py CHANGED
@@ -104,7 +104,27 @@ class RotatingClient:
104
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
105
 
106
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
107
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if "gemma-3" in model and "messages" in litellm_kwargs:
109
  new_messages = [
110
  {"role": "user", "content": m["content"]} if m.get("role") == "system" else m
 
104
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
105
 
106
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
107
+
108
+ if provider in self._provider_instances:
109
+ provider_instance = self._provider_instances[provider]
110
+
111
+ # Ensure safety_settings are present, defaulting to lowest if not provided
112
+ if "safety_settings" not in litellm_kwargs:
113
+ litellm_kwargs["safety_settings"] = {
114
+ "harassment": "BLOCK_NONE",
115
+ "hate_speech": "BLOCK_NONE",
116
+ "sexually_explicit": "BLOCK_NONE",
117
+ "dangerous_content": "BLOCK_NONE",
118
+ }
119
+
120
+ converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
121
+
122
+ if converted_settings is not None:
123
+ litellm_kwargs["safety_settings"] = converted_settings
124
+ else:
125
+ # If conversion returns None, remove it to avoid sending empty settings
126
+ del litellm_kwargs["safety_settings"]
127
+
128
  if "gemma-3" in model and "messages" in litellm_kwargs:
129
  new_messages = [
130
  {"role": "user", "content": m["content"]} if m.get("role") == "system" else m
src/rotator_library/providers/gemini_provider.py CHANGED
@@ -1,6 +1,6 @@
1
  import httpx
2
  import logging
3
- from typing import List
4
  from .provider_interface import ProviderInterface
5
 
6
  lib_logger = logging.getLogger('rotator_library')
@@ -26,3 +26,27 @@ class GeminiProvider(ProviderInterface):
26
  except httpx.RequestError as e:
27
  lib_logger.error(f"Failed to fetch Gemini models: {e}")
28
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import httpx
2
  import logging
3
+ from typing import List, Dict, Any
4
  from .provider_interface import ProviderInterface
5
 
6
  lib_logger = logging.getLogger('rotator_library')
 
26
  except httpx.RequestError as e:
27
  lib_logger.error(f"Failed to fetch Gemini models: {e}")
28
  return []
29
+
30
+ def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]:
31
+ """
32
+ Converts generic safety settings to the Gemini-specific format.
33
+ """
34
+ if not settings:
35
+ return []
36
+
37
+ gemini_settings = []
38
+ category_map = {
39
+ "harassment": "HARM_CATEGORY_HARASSMENT",
40
+ "hate_speech": "HARM_CATEGORY_HATE_SPEECH",
41
+ "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
42
+ "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
43
+ }
44
+
45
+ for generic_category, threshold in settings.items():
46
+ if generic_category in category_map:
47
+ gemini_settings.append({
48
+ "category": category_map[generic_category],
49
+ "threshold": threshold.upper()
50
+ })
51
+
52
+ return gemini_settings
src/rotator_library/providers/provider_interface.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from typing import List
3
  import httpx
4
 
5
  class ProviderInterface(ABC):
@@ -21,3 +21,15 @@ class ProviderInterface(ABC):
21
  A list of model name strings.
22
  """
23
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any
3
  import httpx
4
 
5
  class ProviderInterface(ABC):
 
21
  A list of model name strings.
22
  """
23
  pass
24
+
25
+ def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]:
26
+ """
27
+ Converts a generic safety settings dictionary to the provider-specific format.
28
+
29
+ Args:
30
+ settings: A dictionary with generic harm categories and thresholds.
31
+
32
+ Returns:
33
+ A list of provider-specific safety setting objects or None.
34
+ """
35
+ return None