uparekh01151 commited on
Commit
682dc03
·
1 Parent(s): 2ec7e68

Fix Nebius provider to use chat.completions.create for conversational models

Browse files
Files changed (1) hide show
  1. src/models_registry.py +30 -11
src/models_registry.py CHANGED
@@ -85,17 +85,34 @@ class HuggingFaceInference:
85
  api_key=os.environ.get("HF_TOKEN")
86
  )
87
 
88
- # Use text_generation for all providers (simplified approach)
89
- result = client.text_generation(
90
- prompt=prompt,
91
- model=model_id,
92
- max_new_tokens=params.get('max_new_tokens', 128),
93
- temperature=params.get('temperature', 0.1),
94
- top_p=params.get('top_p', 0.9),
95
- return_full_text=False # Only return the generated part
96
- )
97
-
98
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
  # Improved error handling with detailed error messages
@@ -112,6 +129,8 @@ class HuggingFaceInference:
112
  raise Exception(f"Request timeout - model may be loading on {provider}")
113
  elif "not supported for task" in error_msg:
114
  raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}")
 
 
115
  else:
116
  raise Exception(f"{provider} API error: {error_msg}")
117
 
 
85
  api_key=os.environ.get("HF_TOKEN")
86
  )
87
 
88
+ # Use different methods based on provider capabilities
89
+ if provider == "nebius":
90
+ # Nebius provider only supports conversational tasks, use chat completion
91
+ completion = client.chat.completions.create(
92
+ model=model_id,
93
+ messages=[
94
+ {
95
+ "role": "user",
96
+ "content": prompt
97
+ }
98
+ ],
99
+ max_tokens=params.get('max_new_tokens', 128),
100
+ temperature=params.get('temperature', 0.1),
101
+ top_p=params.get('top_p', 0.9)
102
+ )
103
+ # Extract the content from the response
104
+ return completion.choices[0].message.content
105
+ else:
106
+ # Other providers use text_generation
107
+ result = client.text_generation(
108
+ prompt=prompt,
109
+ model=model_id,
110
+ max_new_tokens=params.get('max_new_tokens', 128),
111
+ temperature=params.get('temperature', 0.1),
112
+ top_p=params.get('top_p', 0.9),
113
+ return_full_text=False # Only return the generated part
114
+ )
115
+ return result
116
 
117
  except Exception as e:
118
  # Improved error handling with detailed error messages
 
129
  raise Exception(f"Request timeout - model may be loading on {provider}")
130
  elif "not supported for task" in error_msg:
131
  raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}")
132
+ elif "not supported by provider" in error_msg:
133
+ raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}")
134
  else:
135
  raise Exception(f"{provider} API error: {error_msg}")
136