uparekh01151 commited on
Commit
e65f7af
·
1 Parent(s): 710113d

Add Qwen3-Next-80B model with Together provider and update InferenceClient to support multiple providers

Browse files
Files changed (2) hide show
  1. config/models.yaml +6 -24
  2. src/models_registry.py +19 -16
config/models.yaml CHANGED
@@ -1,28 +1,10 @@
1
  models:
2
- # Working Models - Using Hugging Face Inference API
3
- - name: "BlenderBot-400M"
4
- provider: "huggingface"
5
- model_id: "facebook/blenderbot-400M-distill"
6
  params:
7
- max_new_tokens: 128
8
  temperature: 0.1
9
  top_p: 0.9
10
- description: "BlenderBot 400M - Conversational AI model"
11
-
12
- - name: "DialoGPT-Medium"
13
- provider: "huggingface"
14
- model_id: "microsoft/DialoGPT-medium"
15
- params:
16
- max_new_tokens: 128
17
- temperature: 0.1
18
- top_p: 0.9
19
- description: "DialoGPT Medium - Conversational model"
20
-
21
- - name: "GPT-2"
22
- provider: "huggingface"
23
- model_id: "gpt2"
24
- params:
25
- max_new_tokens: 128
26
- temperature: 0.1
27
- top_p: 0.9
28
- description: "GPT-2 model - Original transformer model"
 
1
  models:
2
+ # Qwen Model with Together Provider
3
+ - name: "Qwen3-Next-80B"
4
+ provider: "together"
5
+ model_id: "Qwen/Qwen3-Next-80B-A3B-Instruct"
6
  params:
7
+ max_new_tokens: 256
8
  temperature: 0.1
9
  top_p: 0.9
10
+ description: "Qwen3-Next-80B - Advanced instruction-following model via Together AI"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models_registry.py CHANGED
@@ -74,17 +74,19 @@ class HuggingFaceInference:
74
 
75
  def __init__(self, api_token: Optional[str] = None):
76
  self.api_token = api_token or os.getenv("HF_TOKEN")
77
- # Initialize InferenceClient with correct provider
78
- self.client = InferenceClient(
79
- provider="hf-inference", # Fixed: was "huggingface", now "hf-inference"
80
- api_key=os.environ.get("HF_TOKEN")
81
- )
82
 
83
- def generate(self, model_id: str, prompt: str, params: Dict[str, Any]) -> str:
84
- """Generate text using Hugging Face Inference API."""
85
  try:
 
 
 
 
 
 
86
  # Use text_generation method with correct parameters
87
- result = self.client.text_generation(
88
  prompt=prompt,
89
  model=model_id,
90
  max_new_tokens=params.get('max_new_tokens', 128),
@@ -101,15 +103,15 @@ class HuggingFaceInference:
101
  print(f"🔍 Debug - Full error: {error_msg}")
102
 
103
  if "404" in error_msg or "Not Found" in error_msg:
104
- raise Exception(f"Model not found: {model_id} - Model may not be available via Inference API")
105
  elif "401" in error_msg or "Unauthorized" in error_msg:
106
- raise Exception(f"Authentication failed - check HF_TOKEN")
107
  elif "503" in error_msg or "Service Unavailable" in error_msg:
108
- raise Exception(f"Model {model_id} is loading, please try again in a moment")
109
  elif "timeout" in error_msg.lower():
110
- raise Exception(f"Request timeout - model may be loading")
111
  else:
112
- raise Exception(f"Hugging Face API error: {error_msg}")
113
 
114
 
115
  class ModelInterface:
@@ -168,12 +170,13 @@ class ModelInterface:
168
  return self._generate_mock_sql(model_config, prompt)
169
 
170
  try:
171
- if model_config.provider == "huggingface":
172
- print(f"🤗 Using Hugging Face Inference API for {model_config.name}")
173
  return self.hf_interface.generate(
174
  model_config.model_id,
175
  prompt,
176
- model_config.params
 
177
  )
178
  else:
179
  raise ValueError(f"Unsupported provider: {model_config.provider}")
 
74
 
75
  def __init__(self, api_token: Optional[str] = None):
76
  self.api_token = api_token or os.getenv("HF_TOKEN")
77
+ # We'll create clients dynamically based on provider
 
 
 
 
78
 
79
+ def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str:
80
+ """Generate text using Hugging Face Inference API with specified provider."""
81
  try:
82
+ # Create InferenceClient with the specified provider
83
+ client = InferenceClient(
84
+ provider=provider,
85
+ api_key=os.environ.get("HF_TOKEN")
86
+ )
87
+
88
  # Use text_generation method with correct parameters
89
+ result = client.text_generation(
90
  prompt=prompt,
91
  model=model_id,
92
  max_new_tokens=params.get('max_new_tokens', 128),
 
103
  print(f"🔍 Debug - Full error: {error_msg}")
104
 
105
  if "404" in error_msg or "Not Found" in error_msg:
106
+ raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider")
107
  elif "401" in error_msg or "Unauthorized" in error_msg:
108
+ raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider")
109
  elif "503" in error_msg or "Service Unavailable" in error_msg:
110
+ raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment")
111
  elif "timeout" in error_msg.lower():
112
+ raise Exception(f"Request timeout - model may be loading on {provider}")
113
  else:
114
+ raise Exception(f"{provider} API error: {error_msg}")
115
 
116
 
117
  class ModelInterface:
 
170
  return self._generate_mock_sql(model_config, prompt)
171
 
172
  try:
173
+ if model_config.provider in ["huggingface", "hf-inference", "together"]:
174
+ print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
175
  return self.hf_interface.generate(
176
  model_config.model_id,
177
  prompt,
178
+ model_config.params,
179
+ model_config.provider
180
  )
181
  else:
182
  raise ValueError(f"Unsupported provider: {model_config.provider}")