Mirrowel commited on
Commit
adfcd18
·
1 Parent(s): ea15798

feat(gemini): implement default safety settings and image data URL support

Browse files

Default safety settings are now enforced across all Gemini calls to prevent unintentional content filtering by the API's aggressive defaults.

- The RotatingClient applies a standard set of safety thresholds (HARM_CATEGORY_* set to OFF or BLOCK_NONE) if the user has not provided explicit settings or if their settings are incomplete.
- The `GeminiProvider` conversion logic is enhanced to merge missing categories when converting generic safety settings (dict) or when receiving direct Gemini-style lists.
- The `GeminiCliProvider` now supports multimodal inputs by parsing image content provided as data URLs (`data:image/...`) and converting them to `inlineData` parts.
- Usage metadata reporting in the CLI provider is updated to include `thoughtsTokenCount` within the prompt tokens and optionally detail reasoning tokens.
- Function tool schema translation for the CLI provider is corrected to use `parametersJsonSchema` and enforce default schemas.
- Function call IDs during streaming are now generated with nanosecond precision for guaranteed uniqueness.

src/rotator_library/client.py CHANGED
@@ -332,6 +332,53 @@ class RotatingClient:
332
 
333
  return kwargs
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
336
  return self.oauth_credentials
337
 
@@ -666,6 +713,13 @@ class RotatingClient:
666
 
667
  provider_instance = self._get_provider_instance(provider)
668
  if provider_instance:
 
 
 
 
 
 
 
669
  if "safety_settings" in litellm_kwargs:
670
  converted_settings = (
671
  provider_instance.convert_safety_settings(
@@ -1138,6 +1192,12 @@ class RotatingClient:
1138
 
1139
  provider_instance = self._get_provider_instance(provider)
1140
  if provider_instance:
 
 
 
 
 
 
1141
  if "safety_settings" in litellm_kwargs:
1142
  converted_settings = (
1143
  provider_instance.convert_safety_settings(
 
332
 
333
  return kwargs
334
 
335
+ def _apply_default_safety_settings(self, litellm_kwargs: Dict[str, Any], provider: str):
336
+ """
337
+ Ensure default Gemini safety settings are present when calling the Gemini provider.
338
+ This will not override any explicit settings provided by the request. It accepts
339
+ either OpenAI-compatible generic `safety_settings` (dict) or direct Gemini-style
340
+ `safetySettings` (list of dicts). Missing categories will be added with safe defaults.
341
+ """
342
+ if provider != "gemini":
343
+ return
344
+
345
+ # Generic defaults (openai-compatible style)
346
+ default_generic = {
347
+ "harassment": "OFF",
348
+ "hate_speech": "OFF",
349
+ "sexually_explicit": "OFF",
350
+ "dangerous_content": "OFF",
351
+ "civic_integrity": "BLOCK_NONE",
352
+ }
353
+
354
+ # Gemini defaults (direct Gemini format)
355
+ default_gemini = [
356
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
357
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
358
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
359
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
360
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
361
+ ]
362
+
363
+ # If generic form is present, ensure missing generic keys are filled in
364
+ if "safety_settings" in litellm_kwargs and isinstance(litellm_kwargs["safety_settings"], dict):
365
+ for k, v in default_generic.items():
366
+ if k not in litellm_kwargs["safety_settings"]:
367
+ litellm_kwargs["safety_settings"][k] = v
368
+ return
369
+
370
+ # If Gemini form is present, ensure missing gemini categories are appended
371
+ if "safetySettings" in litellm_kwargs and isinstance(litellm_kwargs["safetySettings"], list):
372
+ present = {item.get("category") for item in litellm_kwargs["safetySettings"] if isinstance(item, dict)}
373
+ for d in default_gemini:
374
+ if d["category"] not in present:
375
+ litellm_kwargs["safetySettings"].append(d)
376
+ return
377
+
378
+ # Neither present: set generic defaults so provider conversion will translate them
379
+ if "safety_settings" not in litellm_kwargs and "safetySettings" not in litellm_kwargs:
380
+ litellm_kwargs["safety_settings"] = default_generic.copy()
381
+
382
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
383
  return self.oauth_credentials
384
 
 
713
 
714
  provider_instance = self._get_provider_instance(provider)
715
  if provider_instance:
716
+ # Ensure default Gemini safety settings are present (without overriding request)
717
+ try:
718
+ self._apply_default_safety_settings(litellm_kwargs, provider)
719
+ except Exception:
720
+ # If anything goes wrong here, avoid breaking the request flow.
721
+ lib_logger.debug("Could not apply default safety settings; continuing.")
722
+
723
  if "safety_settings" in litellm_kwargs:
724
  converted_settings = (
725
  provider_instance.convert_safety_settings(
 
1192
 
1193
  provider_instance = self._get_provider_instance(provider)
1194
  if provider_instance:
1195
+ # Ensure default Gemini safety settings are present (without overriding request)
1196
+ try:
1197
+ self._apply_default_safety_settings(litellm_kwargs, provider)
1198
+ except Exception:
1199
+ lib_logger.debug("Could not apply default safety settings for streaming path; continuing.")
1200
+
1201
  if "safety_settings" in litellm_kwargs:
1202
  converted_settings = (
1203
  provider_instance.convert_safety_settings(
src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -201,13 +201,35 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
201
  gemini_role = "model" if role == "assistant" else "tool" if role == "tool" else "user"
202
 
203
  if role == "user":
204
- text_content = ""
205
  if isinstance(content, str):
206
- text_content = content
 
 
207
  elif isinstance(content, list):
208
- text_content = "\n".join(p.get("text", "") for p in content if p.get("type") == "text")
209
- if text_content:
210
- parts.append({"text": text_content})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  elif role == "assistant":
213
  if isinstance(content, str):
@@ -292,12 +314,15 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
292
 
293
  if 'functionCall' in part:
294
  function_call = part['functionCall']
 
 
 
295
  delta['tool_calls'] = [{
296
  "index": 0,
297
- "id": f"tool-call-{time.time()}",
298
  "type": "function",
299
  "function": {
300
- "name": function_call.get('name'),
301
  "arguments": json.dumps(function_call.get('args', {}))
302
  }
303
  }]
@@ -326,11 +351,21 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
326
 
327
  if 'usageMetadata' in response_data:
328
  usage = response_data['usageMetadata']
 
 
 
 
329
  openai_chunk["usage"] = {
330
- "prompt_tokens": usage.get("promptTokenCount", 0),
331
- "completion_tokens": usage.get("candidatesTokenCount", 0),
332
  "total_tokens": usage.get("totalTokenCount", 0),
333
  }
 
 
 
 
 
 
334
 
335
  yield openai_chunk
336
 
@@ -482,9 +517,15 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
482
  # The Gemini CLI API does not support the 'strict' property.
483
  new_function.pop("strict", None)
484
 
485
- if "parameters" in new_function and isinstance(new_function["parameters"], dict):
486
- new_function["parameters"] = self._gemini_cli_transform_schema(new_function["parameters"])
487
-
 
 
 
 
 
 
488
  transformed_declarations.append(new_function)
489
 
490
  return transformed_declarations
@@ -548,8 +589,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
548
  }
549
  if "temperature" in kwargs:
550
  gen_config["temperature"] = kwargs["temperature"]
551
- else:
552
- gen_config["temperature"] = 0.7
553
  if "top_k" in kwargs:
554
  gen_config["topK"] = kwargs["top_k"]
555
  if "top_p" in kwargs:
@@ -583,7 +623,17 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
583
  tool_config = self._translate_tool_choice(kwargs["tool_choice"])
584
  if tool_config:
585
  request_payload["request"]["toolConfig"] = tool_config
586
-
 
 
 
 
 
 
 
 
 
 
587
  # Log the final payload for debugging and to the dedicated file
588
  #lib_logger.debug(f"Gemini CLI Request Payload: {json.dumps(request_payload, indent=2)}")
589
  file_logger.log_request(request_payload)
 
201
  gemini_role = "model" if role == "assistant" else "tool" if role == "tool" else "user"
202
 
203
  if role == "user":
 
204
  if isinstance(content, str):
205
+ # Simple text content
206
+ if content:
207
+ parts.append({"text": content})
208
  elif isinstance(content, list):
209
+ # Multi-part content (text, images, etc.)
210
+ for item in content:
211
+ if item.get("type") == "text":
212
+ text = item.get("text", "")
213
+ if text:
214
+ parts.append({"text": text})
215
+ elif item.get("type") == "image_url":
216
+ # Handle image data URLs
217
+ image_url = item.get("image_url", {}).get("url", "")
218
+ if image_url.startswith("data:"):
219
+ try:
220
+ # Parse: data:image/png;base64,iVBORw0KG...
221
+ header, data = image_url.split(",", 1)
222
+ mime_type = header.split(":")[1].split(";")[0]
223
+ parts.append({
224
+ "inlineData": {
225
+ "mimeType": mime_type,
226
+ "data": data
227
+ }
228
+ })
229
+ except Exception as e:
230
+ lib_logger.warning(f"Failed to parse image data URL: {e}")
231
+ else:
232
+ lib_logger.warning(f"Non-data-URL images not supported: {image_url[:50]}...")
233
 
234
  elif role == "assistant":
235
  if isinstance(content, str):
 
314
 
315
  if 'functionCall' in part:
316
  function_call = part['functionCall']
317
+ function_name = function_call.get('name', 'unknown')
318
+ # Generate unique ID with nanosecond precision (matching Go implementation)
319
+ tool_call_id = f"call_{function_name}_{int(time.time() * 1_000_000_000)}"
320
  delta['tool_calls'] = [{
321
  "index": 0,
322
+ "id": tool_call_id,
323
  "type": "function",
324
  "function": {
325
+ "name": function_name,
326
  "arguments": json.dumps(function_call.get('args', {}))
327
  }
328
  }]
 
351
 
352
  if 'usageMetadata' in response_data:
353
  usage = response_data['usageMetadata']
354
+ prompt_tokens = usage.get("promptTokenCount", 0)
355
+ thoughts_tokens = usage.get("thoughtsTokenCount", 0)
356
+ candidate_tokens = usage.get("candidatesTokenCount", 0)
357
+
358
  openai_chunk["usage"] = {
359
+ "prompt_tokens": prompt_tokens + thoughts_tokens, # Include thoughts in prompt tokens
360
+ "completion_tokens": candidate_tokens,
361
  "total_tokens": usage.get("totalTokenCount", 0),
362
  }
363
+
364
+ # Add reasoning tokens details if present (OpenAI o1 format)
365
+ if thoughts_tokens > 0:
366
+ if "completion_tokens_details" not in openai_chunk["usage"]:
367
+ openai_chunk["usage"]["completion_tokens_details"] = {}
368
+ openai_chunk["usage"]["completion_tokens_details"]["reasoning_tokens"] = thoughts_tokens
369
 
370
  yield openai_chunk
371
 
 
517
  # The Gemini CLI API does not support the 'strict' property.
518
  new_function.pop("strict", None)
519
 
520
+ # Gemini CLI expects 'parametersJsonSchema' instead of 'parameters'
521
+ if "parameters" in new_function:
522
+ schema = self._gemini_cli_transform_schema(new_function["parameters"])
523
+ new_function["parametersJsonSchema"] = schema
524
+ del new_function["parameters"]
525
+ elif "parametersJsonSchema" not in new_function:
526
+ # Set default empty schema if neither exists
527
+ new_function["parametersJsonSchema"] = {"type": "object", "properties": {}}
528
+
529
  transformed_declarations.append(new_function)
530
 
531
  return transformed_declarations
 
589
  }
590
  if "temperature" in kwargs:
591
  gen_config["temperature"] = kwargs["temperature"]
592
+ # No else - let Gemini use its default temperature (matches OpenAI behavior)
 
593
  if "top_k" in kwargs:
594
  gen_config["topK"] = kwargs["top_k"]
595
  if "top_p" in kwargs:
 
623
  tool_config = self._translate_tool_choice(kwargs["tool_choice"])
624
  if tool_config:
625
  request_payload["request"]["toolConfig"] = tool_config
626
+
627
+ # Add default safety settings to prevent content filtering
628
+ if "safetySettings" not in request_payload["request"]:
629
+ request_payload["request"]["safetySettings"] = [
630
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
631
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
632
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
633
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
634
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
635
+ ]
636
+
637
  # Log the final payload for debugging and to the dedicated file
638
  #lib_logger.debug(f"Gemini CLI Request Payload: {json.dumps(request_payload, indent=2)}")
639
  file_logger.log_request(request_payload)
src/rotator_library/providers/gemini_provider.py CHANGED
@@ -32,23 +32,57 @@ class GeminiProvider(ProviderInterface):
32
  Converts generic safety settings to the Gemini-specific format.
33
  """
34
  if not settings:
35
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
37
  gemini_settings = []
38
  category_map = {
39
  "harassment": "HARM_CATEGORY_HARASSMENT",
40
  "hate_speech": "HARM_CATEGORY_HATE_SPEECH",
41
  "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
42
  "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
 
43
  }
44
 
45
  for generic_category, threshold in settings.items():
46
  if generic_category in category_map:
 
47
  gemini_settings.append({
48
  "category": category_map[generic_category],
49
- "threshold": threshold.upper()
50
  })
51
-
 
 
 
 
 
 
52
  return gemini_settings
53
 
54
  def handle_thinking_parameter(self, payload: Dict[str, Any], model: str):
 
32
  Converts generic safety settings to the Gemini-specific format.
33
  """
34
  if not settings:
35
+ # Return full defaults if nothing provided
36
+ return [
37
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
38
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
39
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
40
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
41
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
42
+ ]
43
+
44
+ # Default gemini-format settings for merging
45
+ default_gemini = {
46
+ "HARM_CATEGORY_HARASSMENT": "OFF",
47
+ "HARM_CATEGORY_HATE_SPEECH": "OFF",
48
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": "OFF",
49
+ "HARM_CATEGORY_DANGEROUS_CONTENT": "OFF",
50
+ "HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
51
+ }
52
 
53
+ # If the caller already provided Gemini-style list, merge defaults without overwriting
54
+ if isinstance(settings, list):
55
+ existing = {item.get("category"): item for item in settings if isinstance(item, dict) and item.get("category")}
56
+ merged = list(settings)
57
+ for cat, thr in default_gemini.items():
58
+ if cat not in existing:
59
+ merged.append({"category": cat, "threshold": thr})
60
+ return merged
61
+
62
+ # Otherwise assume a generic mapping (dict) and convert
63
  gemini_settings = []
64
  category_map = {
65
  "harassment": "HARM_CATEGORY_HARASSMENT",
66
  "hate_speech": "HARM_CATEGORY_HATE_SPEECH",
67
  "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
68
  "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
69
+ "civic_integrity": "HARM_CATEGORY_CIVIC_INTEGRITY",
70
  }
71
 
72
  for generic_category, threshold in settings.items():
73
  if generic_category in category_map:
74
+ thr = (threshold or "").upper()
75
  gemini_settings.append({
76
  "category": category_map[generic_category],
77
+ "threshold": thr if thr else default_gemini[category_map[generic_category]]
78
  })
79
+
80
+ # Add any missing defaults
81
+ present = {s["category"] for s in gemini_settings}
82
+ for cat, thr in default_gemini.items():
83
+ if cat not in present:
84
+ gemini_settings.append({"category": cat, "threshold": thr})
85
+
86
  return gemini_settings
87
 
88
  def handle_thinking_parameter(self, payload: Dict[str, Any], model: str):