Mirrowel commited on
Commit
8cf8bab
·
1 Parent(s): 2ccd2a1

feat(providers): ✨ inject NVIDIA 'thinking' flag for DeepSeek v3.1+ models

Browse files

Add automatic injection of a chat-template "thinking" toggle for NVIDIA NIM requests when DeepSeek v3.1+ variants are used and a reasoning budget is provided.

- Implement NvidiaProvider.handle_thinking_parameter(payload, model):
- Recognizes DeepSeek v3.1+ model names (v3.1, v3.1-terminus, v3.2)
- Checks payload's reasoning_effort for "low", "medium", or "high"
- Ensures extra_body.chat_template_kwargs exists and sets thinking = True
- Logs when the flag is enabled
- Call the new handler from RotatingClient for provider == "nvidia_nim" at the same payload-prep points used for gemini

This change ensures outgoing NVIDIA NIM payloads include the required internal flag so DeepSeek models can enable their optimized chat template behavior based on the requested reasoning budget

src/rotator_library/client.py CHANGED
@@ -490,6 +490,8 @@ class RotatingClient:
490
 
491
  if provider == "gemini" and provider_instance:
492
  provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
493
 
494
  if "gemma-3" in model and "messages" in litellm_kwargs:
495
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
@@ -757,6 +759,8 @@ class RotatingClient:
757
 
758
  if provider == "gemini" and provider_instance:
759
  provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
760
 
761
  if "gemma-3" in model and "messages" in litellm_kwargs:
762
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
 
490
 
491
  if provider == "gemini" and provider_instance:
492
  provider_instance.handle_thinking_parameter(litellm_kwargs, model)
493
+ if provider == "nvidia_nim" and provider_instance:
494
+ provider_instance.handle_thinking_parameter(litellm_kwargs, model)
495
 
496
  if "gemma-3" in model and "messages" in litellm_kwargs:
497
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
 
759
 
760
  if provider == "gemini" and provider_instance:
761
  provider_instance.handle_thinking_parameter(litellm_kwargs, model)
762
+ if provider == "nvidia_nim" and provider_instance:
763
+ provider_instance.handle_thinking_parameter(litellm_kwargs, model)
764
 
765
  if "gemma-3" in model and "messages" in litellm_kwargs:
766
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
src/rotator_library/providers/nvidia_provider.py CHANGED
@@ -1,6 +1,6 @@
1
  import httpx
2
  import logging
3
- from typing import List
4
  import litellm
5
  from .provider_interface import ProviderInterface
6
 
@@ -29,3 +29,27 @@ class NvidiaProvider(ProviderInterface):
29
  except httpx.RequestError as e:
30
  lib_logger.error(f"Failed to fetch NVIDIA models: {e}")
31
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import httpx
2
  import logging
3
+ from typing import List, Dict, Any
4
  import litellm
5
  from .provider_interface import ProviderInterface
6
 
 
29
  except httpx.RequestError as e:
30
  lib_logger.error(f"Failed to fetch NVIDIA models: {e}")
31
  return []
32
+
33
+ def handle_thinking_parameter(self, payload: Dict[str, Any], model: str):
34
+ """
35
+ Adds the 'thinking' parameter for specific DeepSeek models on the NVIDIA provider,
36
+ only if reasoning_effort is set to low, medium, or high.
37
+ """
38
+ deepseek_models = [
39
+ "deepseek-ai/deepseek-v3.1",
40
+ "deepseek-ai/deepseek-v3.1-terminus",
41
+ "deepseek-ai/deepseek-v3.2"
42
+ ]
43
+
44
+ # The model name in the payload is prefixed with 'nvidia_nim/'
45
+ model_name = model.split('/', 1)[1] if '/' in model else model
46
+ reasoning_effort = payload.get("reasoning_effort")
47
+
48
+ if model_name in deepseek_models and reasoning_effort in ["low", "medium", "high"]:
49
+ if "extra_body" not in payload:
50
+ payload["extra_body"] = {}
51
+ if "chat_template_kwargs" not in payload["extra_body"]:
52
+ payload["extra_body"]["chat_template_kwargs"] = {}
53
+
54
+ payload["extra_body"]["chat_template_kwargs"]["thinking"] = True
55
+ lib_logger.info(f"Enabled 'thinking' parameter for model: {model_name} due to reasoning_effort: '{reasoning_effort}'")