Mark-Lasfar commited on
Commit
01237fb
·
1 Parent(s): 7f3503f

Update Model

Browse files
Files changed (2) hide show
  1. api/endpoints.py +17 -15
  2. utils/generation.py +47 -26
api/endpoints.py CHANGED
@@ -28,10 +28,11 @@ BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
28
  if not BACKUP_HF_TOKEN:
29
  logger.warning("BACKUP_HF_TOKEN is not set. Fallback to secondary model will not work if primary token fails.")
30
 
 
31
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
32
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
33
- MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1") # Changed to supported model
34
- SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
35
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
36
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
37
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
@@ -116,7 +117,7 @@ async def model_info():
116
  {"alias": "audio", "description": "Audio transcription model (default)"},
117
  {"alias": "tts", "description": "Text-to-speech model (default)"}
118
  ],
119
- "api_base": API_ENDPOINT,
120
  "fallback_api_base": FALLBACK_API_ENDPOINT,
121
  "status": "online"
122
  }
@@ -182,13 +183,13 @@ async def chat_endpoint(
182
  )
183
  if req.output_format == "audio":
184
  audio_chunks = []
185
- for chunk in stream: # Changed from async for
186
  if isinstance(chunk, bytes):
187
  audio_chunks.append(chunk)
188
  audio_data = b"".join(audio_chunks)
189
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
190
  response_chunks = []
191
- for chunk in stream: # Changed from async for
192
  if isinstance(chunk, str):
193
  response_chunks.append(chunk)
194
  response = "".join(response_chunks)
@@ -255,7 +256,7 @@ async def audio_transcription_endpoint(
255
  output_format="text"
256
  )
257
  response_chunks = []
258
- for chunk in stream: # Changed from async for
259
  if isinstance(chunk, str):
260
  response_chunks.append(chunk)
261
  response = "".join(response_chunks)
@@ -300,7 +301,7 @@ async def text_to_speech_endpoint(
300
  output_format="audio"
301
  )
302
  audio_chunks = []
303
- for chunk in stream: # Changed from async for
304
  if isinstance(chunk, bytes):
305
  audio_chunks.append(chunk)
306
  audio_data = b"".join(audio_chunks)
@@ -340,13 +341,13 @@ async def code_endpoint(
340
  )
341
  if output_format == "audio":
342
  audio_chunks = []
343
- for chunk in stream: # Changed from async for
344
  if isinstance(chunk, bytes):
345
  audio_chunks.append(chunk)
346
  audio_data = b"".join(audio_chunks)
347
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
348
  response_chunks = []
349
- for chunk in stream: # Changed from async for
350
  if isinstance(chunk, str):
351
  response_chunks.append(chunk)
352
  response = "".join(response_chunks)
@@ -383,13 +384,13 @@ async def analysis_endpoint(
383
  )
384
  if output_format == "audio":
385
  audio_chunks = []
386
- for chunk in stream: # Changed from async for
387
  if isinstance(chunk, bytes):
388
  audio_chunks.append(chunk)
389
  audio_data = b"".join(audio_chunks)
390
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
391
  response_chunks = []
392
- for chunk in stream: # Changed from async for
393
  if isinstance(chunk, str):
394
  response_chunks.append(chunk)
395
  response = "".join(response_chunks)
@@ -446,13 +447,13 @@ async def image_analysis_endpoint(
446
  )
447
  if output_format == "audio":
448
  audio_chunks = []
449
- for chunk in stream: # Changed from async for
450
  if isinstance(chunk, bytes):
451
  audio_chunks.append(chunk)
452
  audio_data = b"".join(audio_chunks)
453
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
454
  response_chunks = []
455
- for chunk in stream: # Changed from async for
456
  if isinstance(chunk, str):
457
  response_chunks.append(chunk)
458
  response = "".join(response_chunks)
@@ -473,9 +474,10 @@ async def image_analysis_endpoint(
473
  return {"image_analysis": response}
474
 
475
  @router.get("/api/test-model")
476
- async def test_model(model: str = MODEL_NAME, endpoint: str = API_ENDPOINT):
477
  try:
478
- client = OpenAI(api_key=HF_TOKEN, base_url=endpoint, timeout=60.0)
 
479
  response = client.chat.completions.create(
480
  model=model,
481
  messages=[{"role": "user", "content": "Test"}],
 
28
  if not BACKUP_HF_TOKEN:
29
  logger.warning("BACKUP_HF_TOKEN is not set. Fallback to secondary model will not work if primary token fails.")
30
 
31
+ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
32
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
33
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
34
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Updated to target model
35
+ SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
36
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
37
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
38
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
 
117
  {"alias": "audio", "description": "Audio transcription model (default)"},
118
  {"alias": "tts", "description": "Text-to-speech model (default)"}
119
  ],
120
+ "api_base": ROUTER_API_URL,
121
  "fallback_api_base": FALLBACK_API_ENDPOINT,
122
  "status": "online"
123
  }
 
183
  )
184
  if req.output_format == "audio":
185
  audio_chunks = []
186
+ for chunk in stream:
187
  if isinstance(chunk, bytes):
188
  audio_chunks.append(chunk)
189
  audio_data = b"".join(audio_chunks)
190
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
191
  response_chunks = []
192
+ for chunk in stream:
193
  if isinstance(chunk, str):
194
  response_chunks.append(chunk)
195
  response = "".join(response_chunks)
 
256
  output_format="text"
257
  )
258
  response_chunks = []
259
+ for chunk in stream:
260
  if isinstance(chunk, str):
261
  response_chunks.append(chunk)
262
  response = "".join(response_chunks)
 
301
  output_format="audio"
302
  )
303
  audio_chunks = []
304
+ for chunk in stream:
305
  if isinstance(chunk, bytes):
306
  audio_chunks.append(chunk)
307
  audio_data = b"".join(audio_chunks)
 
341
  )
342
  if output_format == "audio":
343
  audio_chunks = []
344
+ for chunk in stream:
345
  if isinstance(chunk, bytes):
346
  audio_chunks.append(chunk)
347
  audio_data = b"".join(audio_chunks)
348
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
349
  response_chunks = []
350
+ for chunk in stream:
351
  if isinstance(chunk, str):
352
  response_chunks.append(chunk)
353
  response = "".join(response_chunks)
 
384
  )
385
  if output_format == "audio":
386
  audio_chunks = []
387
+ for chunk in stream:
388
  if isinstance(chunk, bytes):
389
  audio_chunks.append(chunk)
390
  audio_data = b"".join(audio_chunks)
391
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
392
  response_chunks = []
393
+ for chunk in stream:
394
  if isinstance(chunk, str):
395
  response_chunks.append(chunk)
396
  response = "".join(response_chunks)
 
447
  )
448
  if output_format == "audio":
449
  audio_chunks = []
450
+ for chunk in stream:
451
  if isinstance(chunk, bytes):
452
  audio_chunks.append(chunk)
453
  audio_data = b"".join(audio_chunks)
454
  return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
455
  response_chunks = []
456
+ for chunk in stream:
457
  if isinstance(chunk, str):
458
  response_chunks.append(chunk)
459
  response = "".join(response_chunks)
 
474
  return {"image_analysis": response}
475
 
476
  @router.get("/api/test-model")
477
+ async def test_model(model: str = MODEL_NAME, endpoint: str = ROUTER_API_URL):
478
  try:
479
+ _, api_key, selected_endpoint = check_model_availability(model, HF_TOKEN)
480
+ client = OpenAI(api_key=api_key, base_url=selected_endpoint, timeout=60.0)
481
  response = client.chat.completions.create(
482
  model=model,
483
  messages=[{"role": "user", "content": "Test"}],
utils/generation.py CHANGED
@@ -30,19 +30,32 @@ LATEX_DELIMS = [
30
  {"left": "\\(", "right": "\\)", "display": False},
31
  ]
32
 
33
- # إعداد العميل لـ Hugging Face Inference API
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
 
36
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
37
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
38
- MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1") # Changed to supported model
39
- SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
40
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
41
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
42
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
43
  ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
44
  TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Model alias mapping
47
  MODEL_ALIASES = {
48
  "advanced": MODEL_NAME,
@@ -54,37 +67,45 @@ MODEL_ALIASES = {
54
  "tts": TTS_MODEL
55
  }
56
 
57
- def check_model_availability(model_name: str, api_base: str, api_key: str) -> tuple[bool, str]:
58
  try:
59
  response = requests.get(
60
- f"{api_base}/models/{model_name}",
61
  headers={"Authorization": f"Bearer {api_key}"},
62
- timeout=30 # Increased timeout
63
  )
64
  if response.status_code == 200:
65
- logger.info(f"Model {model_name} is available")
66
- return True, api_key
 
 
 
 
 
 
 
 
 
67
  elif response.status_code == 429 and BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
68
  logger.warning(f"Rate limit reached for token {api_key}. Switching to backup token.")
69
- return check_model_availability(model_name, api_base, BACKUP_HF_TOKEN)
70
  logger.error(f"Model {model_name} not available: {response.status_code} - {response.text}")
71
- return False, api_key
72
  except Exception as e:
73
  logger.error(f"Failed to check model availability for {model_name}: {e}")
74
  if BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
75
  logger.warning(f"Retrying with backup token for {model_name}")
76
- return check_model_availability(model_name, api_base, BACKUP_HF_TOKEN)
77
- return False, api_key
78
 
79
  def select_model(query: str, input_type: str = "text", preferred_model: Optional[str] = None) -> tuple[str, str]:
80
  # If user has a preferred model, use it unless the input type requires a specific model
81
  if preferred_model and preferred_model in MODEL_ALIASES:
82
  model_name = MODEL_ALIASES[preferred_model]
83
- api_endpoint = API_ENDPOINT if model_name in [MODEL_NAME, TERTIARY_MODEL_NAME] else FALLBACK_API_ENDPOINT
84
- is_available, _ = check_model_availability(model_name, api_endpoint, HF_TOKEN)
85
  if is_available:
86
- logger.info(f"Selected preferred model {model_name} with endpoint {api_endpoint} for query: {query}")
87
- return model_name, api_endpoint
88
 
89
  query_lower = query.lower()
90
  # دعم الصوت
@@ -111,10 +132,10 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
111
  (TERTIARY_MODEL_NAME, API_ENDPOINT)
112
  ]
113
  for model_name, api_endpoint in available_models:
114
- is_available, _ = check_model_availability(model_name, api_endpoint, HF_TOKEN)
115
  if is_available:
116
- logger.info(f"Selected {model_name} with endpoint {api_endpoint} for query: {query}")
117
- return model_name, api_endpoint
118
  logger.error("No models available. Falling back to default.")
119
  return MODEL_NAME, API_ENDPOINT
120
 
@@ -137,7 +158,7 @@ def request_generation(
137
  image_data: Optional[bytes] = None,
138
  output_format: str = "text"
139
  ) -> Generator[bytes | str, None, None]:
140
- is_available, selected_api_key = check_model_availability(model_name, api_base, api_key)
141
  if not is_available:
142
  yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
143
  return
@@ -158,7 +179,7 @@ def request_generation(
158
  yield chunk
159
  return
160
 
161
- client = OpenAI(api_key=selected_api_key, base_url=api_base, timeout=120.0)
162
  task_type = "general"
163
  enhanced_system_prompt = system_prompt
164
 
@@ -391,7 +412,7 @@ def request_generation(
391
  logger.warning(f"Retrying with backup token for model {model_name}")
392
  for chunk in request_generation(
393
  api_key=BACKUP_HF_TOKEN,
394
- api_base=api_base,
395
  message=message,
396
  system_prompt=system_prompt,
397
  model_name=model_name,
@@ -414,11 +435,11 @@ def request_generation(
414
  fallback_endpoint = FALLBACK_API_ENDPOINT
415
  logger.info(f"Retrying with fallback model: {fallback_model} on {fallback_endpoint}")
416
  try:
417
- is_available, selected_api_key = check_model_availability(fallback_model, fallback_endpoint, selected_api_key)
418
  if not is_available:
419
  yield f"Error: Fallback model {fallback_model} is not available."
420
  return
421
- client = OpenAI(api_key=selected_api_key, base_url=fallback_endpoint, timeout=120.0)
422
  stream = client.chat.completions.create(
423
  model=fallback_model,
424
  messages=input_messages,
@@ -496,11 +517,11 @@ def request_generation(
496
  except Exception as e2:
497
  logger.exception(f"[Gateway] Streaming failed for fallback model {fallback_model}: {e2}")
498
  try:
499
- is_available, selected_api_key = check_model_availability(TERTIARY_MODEL_NAME, API_ENDPOINT, selected_api_key)
500
  if not is_available:
501
  yield f"Error: Tertiary model {TERTIARY_MODEL_NAME} is not available."
502
  return
503
- client = OpenAI(api_key=selected_api_key, base_url=API_ENDPOINT, timeout=120.0)
504
  stream = client.chat.completions.create(
505
  model=TERTIARY_MODEL_NAME,
506
  messages=input_messages,
 
30
  {"left": "\\(", "right": "\\)", "display": False},
31
  ]
32
 
33
+ # إعداد العميل لـ Hugging Face Router API
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
36
+ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
37
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
38
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
39
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Updated to target model
40
+ SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
41
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
42
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
43
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
44
  ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
45
  TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
46
 
47
+ # Provider endpoints (based on Router API providers)
48
+ PROVIDER_ENDPOINTS = {
49
+ "together": "https://api.together.xyz/v1",
50
+ "fireworks-ai": "https://api.fireworks.ai/inference/v1",
51
+ "nebius": "https://api.nebius.ai/v1",
52
+ "novita": "https://api.novita.ai/v1",
53
+ "groq": "https://api.groq.com/openai/v1",
54
+ "cerebras": "https://api.cerebras.ai/v1",
55
+ "hyperbolic": "https://api.hyperbolic.xyz/v1",
56
+ "nscale": "https://api.nscale.ai/v1"
57
+ }
58
+
59
  # Model alias mapping
60
  MODEL_ALIASES = {
61
  "advanced": MODEL_NAME,
 
67
  "tts": TTS_MODEL
68
  }
69
 
70
+ def check_model_availability(model_name: str, api_key: str) -> tuple[bool, str, str]:
71
  try:
72
  response = requests.get(
73
+ f"{ROUTER_API_URL}/v1/models/{model_name}",
74
  headers={"Authorization": f"Bearer {api_key}"},
75
+ timeout=30
76
  )
77
  if response.status_code == 200:
78
+ data = response.json().get("data", {})
79
+ providers = data.get("providers", [])
80
+ # Select the first available provider (e.g., 'together')
81
+ for provider in providers:
82
+ if provider.get("status") == "live":
83
+ provider_name = provider.get("provider")
84
+ endpoint = PROVIDER_ENDPOINTS.get(provider_name, API_ENDPOINT)
85
+ logger.info(f"Model {model_name} is available via provider {provider_name} at {endpoint}")
86
+ return True, api_key, endpoint
87
+ logger.error(f"No live providers found for model {model_name}")
88
+ return False, api_key, API_ENDPOINT
89
  elif response.status_code == 429 and BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
90
  logger.warning(f"Rate limit reached for token {api_key}. Switching to backup token.")
91
+ return check_model_availability(model_name, BACKUP_HF_TOKEN)
92
  logger.error(f"Model {model_name} not available: {response.status_code} - {response.text}")
93
+ return False, api_key, API_ENDPOINT
94
  except Exception as e:
95
  logger.error(f"Failed to check model availability for {model_name}: {e}")
96
  if BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
97
  logger.warning(f"Retrying with backup token for {model_name}")
98
+ return check_model_availability(model_name, BACKUP_HF_TOKEN)
99
+ return False, api_key, API_ENDPOINT
100
 
101
  def select_model(query: str, input_type: str = "text", preferred_model: Optional[str] = None) -> tuple[str, str]:
102
  # If user has a preferred model, use it unless the input type requires a specific model
103
  if preferred_model and preferred_model in MODEL_ALIASES:
104
  model_name = MODEL_ALIASES[preferred_model]
105
+ is_available, _, endpoint = check_model_availability(model_name, HF_TOKEN)
 
106
  if is_available:
107
+ logger.info(f"Selected preferred model {model_name} with endpoint {endpoint} for query: {query}")
108
+ return model_name, endpoint
109
 
110
  query_lower = query.lower()
111
  # دعم الصوت
 
132
  (TERTIARY_MODEL_NAME, API_ENDPOINT)
133
  ]
134
  for model_name, api_endpoint in available_models:
135
+ is_available, _, endpoint = check_model_availability(model_name, HF_TOKEN)
136
  if is_available:
137
+ logger.info(f"Selected {model_name} with endpoint {endpoint} for query: {query}")
138
+ return model_name, endpoint
139
  logger.error("No models available. Falling back to default.")
140
  return MODEL_NAME, API_ENDPOINT
141
 
 
158
  image_data: Optional[bytes] = None,
159
  output_format: str = "text"
160
  ) -> Generator[bytes | str, None, None]:
161
+ is_available, selected_api_key, selected_endpoint = check_model_availability(model_name, api_key)
162
  if not is_available:
163
  yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
164
  return
 
179
  yield chunk
180
  return
181
 
182
+ client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
183
  task_type = "general"
184
  enhanced_system_prompt = system_prompt
185
 
 
412
  logger.warning(f"Retrying with backup token for model {model_name}")
413
  for chunk in request_generation(
414
  api_key=BACKUP_HF_TOKEN,
415
+ api_base=selected_endpoint,
416
  message=message,
417
  system_prompt=system_prompt,
418
  model_name=model_name,
 
435
  fallback_endpoint = FALLBACK_API_ENDPOINT
436
  logger.info(f"Retrying with fallback model: {fallback_model} on {fallback_endpoint}")
437
  try:
438
+ is_available, selected_api_key, selected_endpoint = check_model_availability(fallback_model, selected_api_key)
439
  if not is_available:
440
  yield f"Error: Fallback model {fallback_model} is not available."
441
  return
442
+ client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
443
  stream = client.chat.completions.create(
444
  model=fallback_model,
445
  messages=input_messages,
 
517
  except Exception as e2:
518
  logger.exception(f"[Gateway] Streaming failed for fallback model {fallback_model}: {e2}")
519
  try:
520
+ is_available, selected_api_key, selected_endpoint = check_model_availability(TERTIARY_MODEL_NAME, selected_api_key)
521
  if not is_available:
522
  yield f"Error: Tertiary model {TERTIARY_MODEL_NAME} is not available."
523
  return
524
+ client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
525
  stream = client.chat.completions.create(
526
  model=TERTIARY_MODEL_NAME,
527
  messages=input_messages,