Spaces:
Runtime error
Runtime error
Mark-Lasfar commited on
Commit ·
01237fb
1
Parent(s): 7f3503f
Update Model
Browse files- api/endpoints.py +17 -15
- utils/generation.py +47 -26
api/endpoints.py
CHANGED
|
@@ -28,10 +28,11 @@ BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
|
|
| 28 |
if not BACKUP_HF_TOKEN:
|
| 29 |
logger.warning("BACKUP_HF_TOKEN is not set. Fallback to secondary model will not work if primary token fails.")
|
| 30 |
|
|
|
|
| 31 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 32 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 33 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
| 34 |
-
SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/
|
| 35 |
TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
|
| 36 |
CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
|
| 37 |
CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
|
|
@@ -116,7 +117,7 @@ async def model_info():
|
|
| 116 |
{"alias": "audio", "description": "Audio transcription model (default)"},
|
| 117 |
{"alias": "tts", "description": "Text-to-speech model (default)"}
|
| 118 |
],
|
| 119 |
-
"api_base":
|
| 120 |
"fallback_api_base": FALLBACK_API_ENDPOINT,
|
| 121 |
"status": "online"
|
| 122 |
}
|
|
@@ -182,13 +183,13 @@ async def chat_endpoint(
|
|
| 182 |
)
|
| 183 |
if req.output_format == "audio":
|
| 184 |
audio_chunks = []
|
| 185 |
-
for chunk in stream:
|
| 186 |
if isinstance(chunk, bytes):
|
| 187 |
audio_chunks.append(chunk)
|
| 188 |
audio_data = b"".join(audio_chunks)
|
| 189 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 190 |
response_chunks = []
|
| 191 |
-
for chunk in stream:
|
| 192 |
if isinstance(chunk, str):
|
| 193 |
response_chunks.append(chunk)
|
| 194 |
response = "".join(response_chunks)
|
|
@@ -255,7 +256,7 @@ async def audio_transcription_endpoint(
|
|
| 255 |
output_format="text"
|
| 256 |
)
|
| 257 |
response_chunks = []
|
| 258 |
-
for chunk in stream:
|
| 259 |
if isinstance(chunk, str):
|
| 260 |
response_chunks.append(chunk)
|
| 261 |
response = "".join(response_chunks)
|
|
@@ -300,7 +301,7 @@ async def text_to_speech_endpoint(
|
|
| 300 |
output_format="audio"
|
| 301 |
)
|
| 302 |
audio_chunks = []
|
| 303 |
-
for chunk in stream:
|
| 304 |
if isinstance(chunk, bytes):
|
| 305 |
audio_chunks.append(chunk)
|
| 306 |
audio_data = b"".join(audio_chunks)
|
|
@@ -340,13 +341,13 @@ async def code_endpoint(
|
|
| 340 |
)
|
| 341 |
if output_format == "audio":
|
| 342 |
audio_chunks = []
|
| 343 |
-
for chunk in stream:
|
| 344 |
if isinstance(chunk, bytes):
|
| 345 |
audio_chunks.append(chunk)
|
| 346 |
audio_data = b"".join(audio_chunks)
|
| 347 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 348 |
response_chunks = []
|
| 349 |
-
for chunk in stream:
|
| 350 |
if isinstance(chunk, str):
|
| 351 |
response_chunks.append(chunk)
|
| 352 |
response = "".join(response_chunks)
|
|
@@ -383,13 +384,13 @@ async def analysis_endpoint(
|
|
| 383 |
)
|
| 384 |
if output_format == "audio":
|
| 385 |
audio_chunks = []
|
| 386 |
-
for chunk in stream:
|
| 387 |
if isinstance(chunk, bytes):
|
| 388 |
audio_chunks.append(chunk)
|
| 389 |
audio_data = b"".join(audio_chunks)
|
| 390 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 391 |
response_chunks = []
|
| 392 |
-
for chunk in stream:
|
| 393 |
if isinstance(chunk, str):
|
| 394 |
response_chunks.append(chunk)
|
| 395 |
response = "".join(response_chunks)
|
|
@@ -446,13 +447,13 @@ async def image_analysis_endpoint(
|
|
| 446 |
)
|
| 447 |
if output_format == "audio":
|
| 448 |
audio_chunks = []
|
| 449 |
-
for chunk in stream:
|
| 450 |
if isinstance(chunk, bytes):
|
| 451 |
audio_chunks.append(chunk)
|
| 452 |
audio_data = b"".join(audio_chunks)
|
| 453 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 454 |
response_chunks = []
|
| 455 |
-
for chunk in stream:
|
| 456 |
if isinstance(chunk, str):
|
| 457 |
response_chunks.append(chunk)
|
| 458 |
response = "".join(response_chunks)
|
|
@@ -473,9 +474,10 @@ async def image_analysis_endpoint(
|
|
| 473 |
return {"image_analysis": response}
|
| 474 |
|
| 475 |
@router.get("/api/test-model")
|
| 476 |
-
async def test_model(model: str = MODEL_NAME, endpoint: str =
|
| 477 |
try:
|
| 478 |
-
|
|
|
|
| 479 |
response = client.chat.completions.create(
|
| 480 |
model=model,
|
| 481 |
messages=[{"role": "user", "content": "Test"}],
|
|
|
|
| 28 |
if not BACKUP_HF_TOKEN:
|
| 29 |
logger.warning("BACKUP_HF_TOKEN is not set. Fallback to secondary model will not work if primary token fails.")
|
| 30 |
|
| 31 |
+
ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
|
| 32 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 33 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 34 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Updated to target model
|
| 35 |
+
SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
|
| 36 |
TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
|
| 37 |
CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
|
| 38 |
CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
|
|
|
|
| 117 |
{"alias": "audio", "description": "Audio transcription model (default)"},
|
| 118 |
{"alias": "tts", "description": "Text-to-speech model (default)"}
|
| 119 |
],
|
| 120 |
+
"api_base": ROUTER_API_URL,
|
| 121 |
"fallback_api_base": FALLBACK_API_ENDPOINT,
|
| 122 |
"status": "online"
|
| 123 |
}
|
|
|
|
| 183 |
)
|
| 184 |
if req.output_format == "audio":
|
| 185 |
audio_chunks = []
|
| 186 |
+
for chunk in stream:
|
| 187 |
if isinstance(chunk, bytes):
|
| 188 |
audio_chunks.append(chunk)
|
| 189 |
audio_data = b"".join(audio_chunks)
|
| 190 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 191 |
response_chunks = []
|
| 192 |
+
for chunk in stream:
|
| 193 |
if isinstance(chunk, str):
|
| 194 |
response_chunks.append(chunk)
|
| 195 |
response = "".join(response_chunks)
|
|
|
|
| 256 |
output_format="text"
|
| 257 |
)
|
| 258 |
response_chunks = []
|
| 259 |
+
for chunk in stream:
|
| 260 |
if isinstance(chunk, str):
|
| 261 |
response_chunks.append(chunk)
|
| 262 |
response = "".join(response_chunks)
|
|
|
|
| 301 |
output_format="audio"
|
| 302 |
)
|
| 303 |
audio_chunks = []
|
| 304 |
+
for chunk in stream:
|
| 305 |
if isinstance(chunk, bytes):
|
| 306 |
audio_chunks.append(chunk)
|
| 307 |
audio_data = b"".join(audio_chunks)
|
|
|
|
| 341 |
)
|
| 342 |
if output_format == "audio":
|
| 343 |
audio_chunks = []
|
| 344 |
+
for chunk in stream:
|
| 345 |
if isinstance(chunk, bytes):
|
| 346 |
audio_chunks.append(chunk)
|
| 347 |
audio_data = b"".join(audio_chunks)
|
| 348 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 349 |
response_chunks = []
|
| 350 |
+
for chunk in stream:
|
| 351 |
if isinstance(chunk, str):
|
| 352 |
response_chunks.append(chunk)
|
| 353 |
response = "".join(response_chunks)
|
|
|
|
| 384 |
)
|
| 385 |
if output_format == "audio":
|
| 386 |
audio_chunks = []
|
| 387 |
+
for chunk in stream:
|
| 388 |
if isinstance(chunk, bytes):
|
| 389 |
audio_chunks.append(chunk)
|
| 390 |
audio_data = b"".join(audio_chunks)
|
| 391 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 392 |
response_chunks = []
|
| 393 |
+
for chunk in stream:
|
| 394 |
if isinstance(chunk, str):
|
| 395 |
response_chunks.append(chunk)
|
| 396 |
response = "".join(response_chunks)
|
|
|
|
| 447 |
)
|
| 448 |
if output_format == "audio":
|
| 449 |
audio_chunks = []
|
| 450 |
+
for chunk in stream:
|
| 451 |
if isinstance(chunk, bytes):
|
| 452 |
audio_chunks.append(chunk)
|
| 453 |
audio_data = b"".join(audio_chunks)
|
| 454 |
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav")
|
| 455 |
response_chunks = []
|
| 456 |
+
for chunk in stream:
|
| 457 |
if isinstance(chunk, str):
|
| 458 |
response_chunks.append(chunk)
|
| 459 |
response = "".join(response_chunks)
|
|
|
|
| 474 |
return {"image_analysis": response}
|
| 475 |
|
| 476 |
@router.get("/api/test-model")
|
| 477 |
+
async def test_model(model: str = MODEL_NAME, endpoint: str = ROUTER_API_URL):
|
| 478 |
try:
|
| 479 |
+
_, api_key, selected_endpoint = check_model_availability(model, HF_TOKEN)
|
| 480 |
+
client = OpenAI(api_key=api_key, base_url=selected_endpoint, timeout=60.0)
|
| 481 |
response = client.chat.completions.create(
|
| 482 |
model=model,
|
| 483 |
messages=[{"role": "user", "content": "Test"}],
|
utils/generation.py
CHANGED
|
@@ -30,19 +30,32 @@ LATEX_DELIMS = [
|
|
| 30 |
{"left": "\\(", "right": "\\)", "display": False},
|
| 31 |
]
|
| 32 |
|
| 33 |
-
# إعداد العميل لـ Hugging Face
|
| 34 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 35 |
BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
|
|
|
|
| 36 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 37 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 38 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
| 39 |
-
SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/
|
| 40 |
TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
|
| 41 |
CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
|
| 42 |
CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
|
| 43 |
ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
|
| 44 |
TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# Model alias mapping
|
| 47 |
MODEL_ALIASES = {
|
| 48 |
"advanced": MODEL_NAME,
|
|
@@ -54,37 +67,45 @@ MODEL_ALIASES = {
|
|
| 54 |
"tts": TTS_MODEL
|
| 55 |
}
|
| 56 |
|
| 57 |
-
def check_model_availability(model_name: str,
|
| 58 |
try:
|
| 59 |
response = requests.get(
|
| 60 |
-
f"{
|
| 61 |
headers={"Authorization": f"Bearer {api_key}"},
|
| 62 |
-
timeout=30
|
| 63 |
)
|
| 64 |
if response.status_code == 200:
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
elif response.status_code == 429 and BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
|
| 68 |
logger.warning(f"Rate limit reached for token {api_key}. Switching to backup token.")
|
| 69 |
-
return check_model_availability(model_name,
|
| 70 |
logger.error(f"Model {model_name} not available: {response.status_code} - {response.text}")
|
| 71 |
-
return False, api_key
|
| 72 |
except Exception as e:
|
| 73 |
logger.error(f"Failed to check model availability for {model_name}: {e}")
|
| 74 |
if BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
|
| 75 |
logger.warning(f"Retrying with backup token for {model_name}")
|
| 76 |
-
return check_model_availability(model_name,
|
| 77 |
-
return False, api_key
|
| 78 |
|
| 79 |
def select_model(query: str, input_type: str = "text", preferred_model: Optional[str] = None) -> tuple[str, str]:
|
| 80 |
# If user has a preferred model, use it unless the input type requires a specific model
|
| 81 |
if preferred_model and preferred_model in MODEL_ALIASES:
|
| 82 |
model_name = MODEL_ALIASES[preferred_model]
|
| 83 |
-
|
| 84 |
-
is_available, _ = check_model_availability(model_name, api_endpoint, HF_TOKEN)
|
| 85 |
if is_available:
|
| 86 |
-
logger.info(f"Selected preferred model {model_name} with endpoint {
|
| 87 |
-
return model_name,
|
| 88 |
|
| 89 |
query_lower = query.lower()
|
| 90 |
# دعم الصوت
|
|
@@ -111,10 +132,10 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
|
|
| 111 |
(TERTIARY_MODEL_NAME, API_ENDPOINT)
|
| 112 |
]
|
| 113 |
for model_name, api_endpoint in available_models:
|
| 114 |
-
is_available, _ = check_model_availability(model_name,
|
| 115 |
if is_available:
|
| 116 |
-
logger.info(f"Selected {model_name} with endpoint {
|
| 117 |
-
return model_name,
|
| 118 |
logger.error("No models available. Falling back to default.")
|
| 119 |
return MODEL_NAME, API_ENDPOINT
|
| 120 |
|
|
@@ -137,7 +158,7 @@ def request_generation(
|
|
| 137 |
image_data: Optional[bytes] = None,
|
| 138 |
output_format: str = "text"
|
| 139 |
) -> Generator[bytes | str, None, None]:
|
| 140 |
-
is_available, selected_api_key = check_model_availability(model_name,
|
| 141 |
if not is_available:
|
| 142 |
yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
|
| 143 |
return
|
|
@@ -158,7 +179,7 @@ def request_generation(
|
|
| 158 |
yield chunk
|
| 159 |
return
|
| 160 |
|
| 161 |
-
client = OpenAI(api_key=selected_api_key, base_url=
|
| 162 |
task_type = "general"
|
| 163 |
enhanced_system_prompt = system_prompt
|
| 164 |
|
|
@@ -391,7 +412,7 @@ def request_generation(
|
|
| 391 |
logger.warning(f"Retrying with backup token for model {model_name}")
|
| 392 |
for chunk in request_generation(
|
| 393 |
api_key=BACKUP_HF_TOKEN,
|
| 394 |
-
api_base=
|
| 395 |
message=message,
|
| 396 |
system_prompt=system_prompt,
|
| 397 |
model_name=model_name,
|
|
@@ -414,11 +435,11 @@ def request_generation(
|
|
| 414 |
fallback_endpoint = FALLBACK_API_ENDPOINT
|
| 415 |
logger.info(f"Retrying with fallback model: {fallback_model} on {fallback_endpoint}")
|
| 416 |
try:
|
| 417 |
-
is_available, selected_api_key = check_model_availability(fallback_model,
|
| 418 |
if not is_available:
|
| 419 |
yield f"Error: Fallback model {fallback_model} is not available."
|
| 420 |
return
|
| 421 |
-
client = OpenAI(api_key=selected_api_key, base_url=
|
| 422 |
stream = client.chat.completions.create(
|
| 423 |
model=fallback_model,
|
| 424 |
messages=input_messages,
|
|
@@ -496,11 +517,11 @@ def request_generation(
|
|
| 496 |
except Exception as e2:
|
| 497 |
logger.exception(f"[Gateway] Streaming failed for fallback model {fallback_model}: {e2}")
|
| 498 |
try:
|
| 499 |
-
is_available, selected_api_key = check_model_availability(TERTIARY_MODEL_NAME,
|
| 500 |
if not is_available:
|
| 501 |
yield f"Error: Tertiary model {TERTIARY_MODEL_NAME} is not available."
|
| 502 |
return
|
| 503 |
-
client = OpenAI(api_key=selected_api_key, base_url=
|
| 504 |
stream = client.chat.completions.create(
|
| 505 |
model=TERTIARY_MODEL_NAME,
|
| 506 |
messages=input_messages,
|
|
|
|
| 30 |
{"left": "\\(", "right": "\\)", "display": False},
|
| 31 |
]
|
| 32 |
|
| 33 |
+
# إعداد العميل لـ Hugging Face Router API
|
| 34 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 35 |
BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
|
| 36 |
+
ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
|
| 37 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 38 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co")
|
| 39 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Updated to target model
|
| 40 |
+
SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
|
| 41 |
TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "gpt2")
|
| 42 |
CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
|
| 43 |
CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
|
| 44 |
ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
|
| 45 |
TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
|
| 46 |
|
| 47 |
+
# Provider endpoints (based on Router API providers)
|
| 48 |
+
PROVIDER_ENDPOINTS = {
|
| 49 |
+
"together": "https://api.together.xyz/v1",
|
| 50 |
+
"fireworks-ai": "https://api.fireworks.ai/inference/v1",
|
| 51 |
+
"nebius": "https://api.nebius.ai/v1",
|
| 52 |
+
"novita": "https://api.novita.ai/v1",
|
| 53 |
+
"groq": "https://api.groq.com/openai/v1",
|
| 54 |
+
"cerebras": "https://api.cerebras.ai/v1",
|
| 55 |
+
"hyperbolic": "https://api.hyperbolic.xyz/v1",
|
| 56 |
+
"nscale": "https://api.nscale.ai/v1"
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
# Model alias mapping
|
| 60 |
MODEL_ALIASES = {
|
| 61 |
"advanced": MODEL_NAME,
|
|
|
|
| 67 |
"tts": TTS_MODEL
|
| 68 |
}
|
| 69 |
|
| 70 |
+
def check_model_availability(model_name: str, api_key: str) -> tuple[bool, str, str]:
|
| 71 |
try:
|
| 72 |
response = requests.get(
|
| 73 |
+
f"{ROUTER_API_URL}/v1/models/{model_name}",
|
| 74 |
headers={"Authorization": f"Bearer {api_key}"},
|
| 75 |
+
timeout=30
|
| 76 |
)
|
| 77 |
if response.status_code == 200:
|
| 78 |
+
data = response.json().get("data", {})
|
| 79 |
+
providers = data.get("providers", [])
|
| 80 |
+
# Select the first available provider (e.g., 'together')
|
| 81 |
+
for provider in providers:
|
| 82 |
+
if provider.get("status") == "live":
|
| 83 |
+
provider_name = provider.get("provider")
|
| 84 |
+
endpoint = PROVIDER_ENDPOINTS.get(provider_name, API_ENDPOINT)
|
| 85 |
+
logger.info(f"Model {model_name} is available via provider {provider_name} at {endpoint}")
|
| 86 |
+
return True, api_key, endpoint
|
| 87 |
+
logger.error(f"No live providers found for model {model_name}")
|
| 88 |
+
return False, api_key, API_ENDPOINT
|
| 89 |
elif response.status_code == 429 and BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
|
| 90 |
logger.warning(f"Rate limit reached for token {api_key}. Switching to backup token.")
|
| 91 |
+
return check_model_availability(model_name, BACKUP_HF_TOKEN)
|
| 92 |
logger.error(f"Model {model_name} not available: {response.status_code} - {response.text}")
|
| 93 |
+
return False, api_key, API_ENDPOINT
|
| 94 |
except Exception as e:
|
| 95 |
logger.error(f"Failed to check model availability for {model_name}: {e}")
|
| 96 |
if BACKUP_HF_TOKEN and api_key != BACKUP_HF_TOKEN:
|
| 97 |
logger.warning(f"Retrying with backup token for {model_name}")
|
| 98 |
+
return check_model_availability(model_name, BACKUP_HF_TOKEN)
|
| 99 |
+
return False, api_key, API_ENDPOINT
|
| 100 |
|
| 101 |
def select_model(query: str, input_type: str = "text", preferred_model: Optional[str] = None) -> tuple[str, str]:
|
| 102 |
# If user has a preferred model, use it unless the input type requires a specific model
|
| 103 |
if preferred_model and preferred_model in MODEL_ALIASES:
|
| 104 |
model_name = MODEL_ALIASES[preferred_model]
|
| 105 |
+
is_available, _, endpoint = check_model_availability(model_name, HF_TOKEN)
|
|
|
|
| 106 |
if is_available:
|
| 107 |
+
logger.info(f"Selected preferred model {model_name} with endpoint {endpoint} for query: {query}")
|
| 108 |
+
return model_name, endpoint
|
| 109 |
|
| 110 |
query_lower = query.lower()
|
| 111 |
# دعم الصوت
|
|
|
|
| 132 |
(TERTIARY_MODEL_NAME, API_ENDPOINT)
|
| 133 |
]
|
| 134 |
for model_name, api_endpoint in available_models:
|
| 135 |
+
is_available, _, endpoint = check_model_availability(model_name, HF_TOKEN)
|
| 136 |
if is_available:
|
| 137 |
+
logger.info(f"Selected {model_name} with endpoint {endpoint} for query: {query}")
|
| 138 |
+
return model_name, endpoint
|
| 139 |
logger.error("No models available. Falling back to default.")
|
| 140 |
return MODEL_NAME, API_ENDPOINT
|
| 141 |
|
|
|
|
| 158 |
image_data: Optional[bytes] = None,
|
| 159 |
output_format: str = "text"
|
| 160 |
) -> Generator[bytes | str, None, None]:
|
| 161 |
+
is_available, selected_api_key, selected_endpoint = check_model_availability(model_name, api_key)
|
| 162 |
if not is_available:
|
| 163 |
yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
|
| 164 |
return
|
|
|
|
| 179 |
yield chunk
|
| 180 |
return
|
| 181 |
|
| 182 |
+
client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
|
| 183 |
task_type = "general"
|
| 184 |
enhanced_system_prompt = system_prompt
|
| 185 |
|
|
|
|
| 412 |
logger.warning(f"Retrying with backup token for model {model_name}")
|
| 413 |
for chunk in request_generation(
|
| 414 |
api_key=BACKUP_HF_TOKEN,
|
| 415 |
+
api_base=selected_endpoint,
|
| 416 |
message=message,
|
| 417 |
system_prompt=system_prompt,
|
| 418 |
model_name=model_name,
|
|
|
|
| 435 |
fallback_endpoint = FALLBACK_API_ENDPOINT
|
| 436 |
logger.info(f"Retrying with fallback model: {fallback_model} on {fallback_endpoint}")
|
| 437 |
try:
|
| 438 |
+
is_available, selected_api_key, selected_endpoint = check_model_availability(fallback_model, selected_api_key)
|
| 439 |
if not is_available:
|
| 440 |
yield f"Error: Fallback model {fallback_model} is not available."
|
| 441 |
return
|
| 442 |
+
client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
|
| 443 |
stream = client.chat.completions.create(
|
| 444 |
model=fallback_model,
|
| 445 |
messages=input_messages,
|
|
|
|
| 517 |
except Exception as e2:
|
| 518 |
logger.exception(f"[Gateway] Streaming failed for fallback model {fallback_model}: {e2}")
|
| 519 |
try:
|
| 520 |
+
is_available, selected_api_key, selected_endpoint = check_model_availability(TERTIARY_MODEL_NAME, selected_api_key)
|
| 521 |
if not is_available:
|
| 522 |
yield f"Error: Tertiary model {TERTIARY_MODEL_NAME} is not available."
|
| 523 |
return
|
| 524 |
+
client = OpenAI(api_key=selected_api_key, base_url=selected_endpoint, timeout=120.0)
|
| 525 |
stream = client.chat.completions.create(
|
| 526 |
model=TERTIARY_MODEL_NAME,
|
| 527 |
messages=input_messages,
|