arubaDev commited on
Commit
c0c9f2c
·
verified ·
1 Parent(s): 41c5c08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -123,8 +123,19 @@ def build_api_messages(session_id: int, system_message: str):
123
  msgs.extend(get_messages(session_id))
124
  return msgs
125
 
126
- def get_client(model_choice: str):
127
- return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  def load_dataset_by_name(name: str):
130
  if name == "The Stack": return load_dataset("bigcode/the-stack", split="train")
 
123
  msgs.extend(get_messages(session_id))
124
  return msgs
125
 
126
+ # def get_client(model_choice: str):
127
+ # return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN)
128
+ def get_client(model_choice):
129
+ # Normalize model_choice -> must be a string
130
+ if isinstance(model_choice, list):
131
+ model_choice = model_choice[0] if model_choice else None
132
+
133
+ if not model_choice:
134
+ model_choice = list(MODELS.keys())[0]
135
+
136
+ model_id = MODELS.get(model_choice, list(MODELS.values())[0])
137
+ return InferenceClient(model_id, token=HF_TOKEN)
138
+
139
 
140
  def load_dataset_by_name(name: str):
141
  if name == "The Stack": return load_dataset("bigcode/the-stack", split="train")