Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
576a3ec
1
Parent(s): 5fb237f
feat: Safety settings addition(only gemini for now)
Browse files
src/rotator_library/client.py
CHANGED
|
@@ -104,7 +104,27 @@ class RotatingClient:
|
|
| 104 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 105 |
|
| 106 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if "gemma-3" in model and "messages" in litellm_kwargs:
|
| 109 |
new_messages = [
|
| 110 |
{"role": "user", "content": m["content"]} if m.get("role") == "system" else m
|
|
|
|
| 104 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 105 |
|
| 106 |
litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
|
| 107 |
+
|
| 108 |
+
if provider in self._provider_instances:
|
| 109 |
+
provider_instance = self._provider_instances[provider]
|
| 110 |
+
|
| 111 |
+
# Ensure safety_settings are present, defaulting to lowest if not provided
|
| 112 |
+
if "safety_settings" not in litellm_kwargs:
|
| 113 |
+
litellm_kwargs["safety_settings"] = {
|
| 114 |
+
"harassment": "BLOCK_NONE",
|
| 115 |
+
"hate_speech": "BLOCK_NONE",
|
| 116 |
+
"sexually_explicit": "BLOCK_NONE",
|
| 117 |
+
"dangerous_content": "BLOCK_NONE",
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
|
| 121 |
+
|
| 122 |
+
if converted_settings is not None:
|
| 123 |
+
litellm_kwargs["safety_settings"] = converted_settings
|
| 124 |
+
else:
|
| 125 |
+
# If conversion returns None, remove it to avoid sending empty settings
|
| 126 |
+
del litellm_kwargs["safety_settings"]
|
| 127 |
+
|
| 128 |
if "gemma-3" in model and "messages" in litellm_kwargs:
|
| 129 |
new_messages = [
|
| 130 |
{"role": "user", "content": m["content"]} if m.get("role") == "system" else m
|
src/rotator_library/providers/gemini_provider.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import httpx
|
| 2 |
import logging
|
| 3 |
-
from typing import List
|
| 4 |
from .provider_interface import ProviderInterface
|
| 5 |
|
| 6 |
lib_logger = logging.getLogger('rotator_library')
|
|
@@ -26,3 +26,27 @@ class GeminiProvider(ProviderInterface):
|
|
| 26 |
except httpx.RequestError as e:
|
| 27 |
lib_logger.error(f"Failed to fetch Gemini models: {e}")
|
| 28 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import httpx
|
| 2 |
import logging
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
from .provider_interface import ProviderInterface
|
| 5 |
|
| 6 |
lib_logger = logging.getLogger('rotator_library')
|
|
|
|
| 26 |
except httpx.RequestError as e:
|
| 27 |
lib_logger.error(f"Failed to fetch Gemini models: {e}")
|
| 28 |
return []
|
| 29 |
+
|
| 30 |
+
def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]:
|
| 31 |
+
"""
|
| 32 |
+
Converts generic safety settings to the Gemini-specific format.
|
| 33 |
+
"""
|
| 34 |
+
if not settings:
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
gemini_settings = []
|
| 38 |
+
category_map = {
|
| 39 |
+
"harassment": "HARM_CATEGORY_HARASSMENT",
|
| 40 |
+
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
| 41 |
+
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 42 |
+
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
for generic_category, threshold in settings.items():
|
| 46 |
+
if generic_category in category_map:
|
| 47 |
+
gemini_settings.append({
|
| 48 |
+
"category": category_map[generic_category],
|
| 49 |
+
"threshold": threshold.upper()
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
return gemini_settings
|
src/rotator_library/providers/provider_interface.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import List
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
class ProviderInterface(ABC):
|
|
@@ -21,3 +21,15 @@ class ProviderInterface(ABC):
|
|
| 21 |
A list of model name strings.
|
| 22 |
"""
|
| 23 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
class ProviderInterface(ABC):
|
|
|
|
| 21 |
A list of model name strings.
|
| 22 |
"""
|
| 23 |
pass
|
| 24 |
+
|
| 25 |
+
def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]:
|
| 26 |
+
"""
|
| 27 |
+
Converts a generic safety settings dictionary to the provider-specific format.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
settings: A dictionary with generic harm categories and thresholds.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A list of provider-specific safety setting objects or None.
|
| 34 |
+
"""
|
| 35 |
+
return None
|