Spaces:
Running
Running
github-actions[bot] commited on
Commit ·
1af7678
1
Parent(s): 4c52012
🚀 Auto-deploy backend from GitHub (3d8979f)
Browse files- services/ai_client.py +13 -1
- services/inference_client.py +185 -3
services/ai_client.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
-
from openai import OpenAI, APIError, RateLimitError, APITimeoutError
|
| 3 |
from functools import lru_cache
|
| 4 |
|
| 5 |
__all__ = [
|
| 6 |
"get_deepseek_client",
|
|
|
|
| 7 |
"CHAT_MODEL",
|
| 8 |
"REASONER_MODEL",
|
| 9 |
"DEEPSEEK_BASE_URL",
|
|
@@ -25,4 +26,15 @@ def get_deepseek_client() -> OpenAI:
|
|
| 25 |
return OpenAI(
|
| 26 |
api_key=api_key,
|
| 27 |
base_url=DEEPSEEK_BASE_URL,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
+
from openai import OpenAI, AsyncOpenAI, APIError, RateLimitError, APITimeoutError
|
| 3 |
from functools import lru_cache
|
| 4 |
|
| 5 |
__all__ = [
|
| 6 |
"get_deepseek_client",
|
| 7 |
+
"get_async_deepseek_client",
|
| 8 |
"CHAT_MODEL",
|
| 9 |
"REASONER_MODEL",
|
| 10 |
"DEEPSEEK_BASE_URL",
|
|
|
|
| 26 |
return OpenAI(
|
| 27 |
api_key=api_key,
|
| 28 |
base_url=DEEPSEEK_BASE_URL,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@lru_cache(maxsize=1)
|
| 33 |
+
def get_async_deepseek_client() -> AsyncOpenAI:
|
| 34 |
+
api_key = os.getenv("DEEPSEEK_API_KEY")
|
| 35 |
+
if not api_key:
|
| 36 |
+
raise ValueError("DEEPSEEK_API_KEY environment variable not set")
|
| 37 |
+
return AsyncOpenAI(
|
| 38 |
+
api_key=api_key,
|
| 39 |
+
base_url=DEEPSEEK_BASE_URL,
|
| 40 |
)
|
services/inference_client.py
CHANGED
|
@@ -6,13 +6,13 @@ import random
|
|
| 6 |
from threading import Lock
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
|
| 11 |
import requests
|
| 12 |
import yaml
|
| 13 |
-
from openai import OpenAI, APIError, RateLimitError, APITimeoutError
|
| 14 |
|
| 15 |
-
from .ai_client import get_deepseek_client, CHAT_MODEL, REASONER_MODEL, DEEPSEEK_BASE_URL
|
| 16 |
from .logging_utils import configure_structured_logging, log_model_call
|
| 17 |
|
| 18 |
LOGGER = configure_structured_logging("mathpulse.inference")
|
|
@@ -271,6 +271,13 @@ class InferenceClient:
|
|
| 271 |
primary = primary_cfg
|
| 272 |
|
| 273 |
self.provider = "deepseek"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
self.ds_api_key = os.getenv("DEEPSEEK_API_KEY", "")
|
| 275 |
self.ds_base_url = os.getenv("DEEPSEEK_BASE_URL", DEEPSEEK_BASE_URL)
|
| 276 |
self.ds_chat_model = os.getenv("DEEPSEEK_MODEL", CHAT_MODEL)
|
|
@@ -674,6 +681,10 @@ class InferenceClient:
|
|
| 674 |
return self.interactive_timeout_sec
|
| 675 |
return self.background_timeout_sec
|
| 676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 678 |
parts: List[str] = []
|
| 679 |
for msg in messages:
|
|
@@ -880,6 +891,177 @@ class InferenceClient:
|
|
| 880 |
|
| 881 |
raise RuntimeError(f"DeepSeek call failed after {max_retries} attempts")
|
| 882 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
|
| 884 |
target_model = req.model or self.default_model
|
| 885 |
url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"
|
|
|
|
| 6 |
from threading import Lock
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple
|
| 10 |
|
| 11 |
import requests
|
| 12 |
import yaml
|
| 13 |
+
from openai import OpenAI, AsyncOpenAI, APIError, RateLimitError, APITimeoutError
|
| 14 |
|
| 15 |
+
from .ai_client import get_deepseek_client, get_async_deepseek_client, CHAT_MODEL, REASONER_MODEL, DEEPSEEK_BASE_URL
|
| 16 |
from .logging_utils import configure_structured_logging, log_model_call
|
| 17 |
|
| 18 |
LOGGER = configure_structured_logging("mathpulse.inference")
|
|
|
|
| 271 |
primary = primary_cfg
|
| 272 |
|
| 273 |
self.provider = "deepseek"
|
| 274 |
+
self.cpu_provider = os.getenv("INFERENCE_CPU_PROVIDER", "deepseek").strip().lower() or self.provider
|
| 275 |
+
self.gpu_provider = os.getenv("INFERENCE_GPU_PROVIDER", "deepseek").strip().lower() or self.provider
|
| 276 |
+
# Pro provider not used in current setup
|
| 277 |
+
self.pro_enabled = False
|
| 278 |
+
self.pro_provider = self.provider
|
| 279 |
+
self.pro_priority_tasks: Set[str] = set()
|
| 280 |
+
self.enable_provider_fallback = False
|
| 281 |
self.ds_api_key = os.getenv("DEEPSEEK_API_KEY", "")
|
| 282 |
self.ds_base_url = os.getenv("DEEPSEEK_BASE_URL", DEEPSEEK_BASE_URL)
|
| 283 |
self.ds_chat_model = os.getenv("DEEPSEEK_MODEL", CHAT_MODEL)
|
|
|
|
| 681 |
return self.interactive_timeout_sec
|
| 682 |
return self.background_timeout_sec
|
| 683 |
|
| 684 |
+
def _provider_chain_for_task(self, task_type: str) -> List[str]:
|
| 685 |
+
"""Return provider chain for task. All inference uses deepseek."""
|
| 686 |
+
return ["deepseek"]
|
| 687 |
+
|
| 688 |
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 689 |
parts: List[str] = []
|
| 690 |
for msg in messages:
|
|
|
|
| 891 |
|
| 892 |
raise RuntimeError(f"DeepSeek call failed after {max_retries} attempts")
|
| 893 |
|
| 894 |
+
async def _call_deepseek_stream(self, req: InferenceRequest, fallback_depth: int) -> AsyncIterator[str]:
|
| 895 |
+
"""Stream DeepSeek API with OpenAI-compatible chat completions. Yields content chunks."""
|
| 896 |
+
if not self.ds_api_key:
|
| 897 |
+
raise RuntimeError("DEEPSEEK_API_KEY is not set")
|
| 898 |
+
|
| 899 |
+
target_model = req.model or self.default_model
|
| 900 |
+
route = "deepseek"
|
| 901 |
+
task_type = req.task_type or "default"
|
| 902 |
+
|
| 903 |
+
LOGGER.debug(
|
| 904 |
+
f"📞 Streaming DeepSeek: task={task_type} model={target_model} route={route}"
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
timeout = self._timeout_for(req, "deepseek")
|
| 908 |
+
max_retries, backoff_sec = self._retry_profile(task_type)
|
| 909 |
+
|
| 910 |
+
client = get_async_deepseek_client()
|
| 911 |
+
|
| 912 |
+
params: Dict[str, Any] = {
|
| 913 |
+
"model": target_model,
|
| 914 |
+
"messages": req.messages,
|
| 915 |
+
"max_tokens": req.max_new_tokens or self.default_max_new_tokens,
|
| 916 |
+
"stream": True,
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
if target_model == REASONER_MODEL:
|
| 920 |
+
params["max_tokens"] = req.max_new_tokens or 1024
|
| 921 |
+
else:
|
| 922 |
+
params["temperature"] = req.temperature
|
| 923 |
+
params["top_p"] = req.top_p
|
| 924 |
+
|
| 925 |
+
last_error: Optional[Exception] = None
|
| 926 |
+
for attempt in range(max_retries):
|
| 927 |
+
self._record_attempt(
|
| 928 |
+
task_type=task_type,
|
| 929 |
+
provider="deepseek",
|
| 930 |
+
route=route,
|
| 931 |
+
fallback_depth=fallback_depth,
|
| 932 |
+
)
|
| 933 |
+
start = time.perf_counter()
|
| 934 |
+
try:
|
| 935 |
+
async with client.chat.completions.stream(**params, timeout=timeout) as stream:
|
| 936 |
+
async for event in stream:
|
| 937 |
+
if event.type == "content.delta" and event.content:
|
| 938 |
+
yield event.content
|
| 939 |
+
|
| 940 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 941 |
+
log_model_call(
|
| 942 |
+
LOGGER,
|
| 943 |
+
provider="deepseek",
|
| 944 |
+
model=target_model,
|
| 945 |
+
endpoint=self.ds_base_url,
|
| 946 |
+
latency_ms=latency_ms,
|
| 947 |
+
input_tokens=None,
|
| 948 |
+
output_tokens=None,
|
| 949 |
+
status="ok",
|
| 950 |
+
task_type=task_type,
|
| 951 |
+
request_tag=req.request_tag,
|
| 952 |
+
retry_attempt=attempt + 1,
|
| 953 |
+
fallback_depth=fallback_depth,
|
| 954 |
+
route=route,
|
| 955 |
+
)
|
| 956 |
+
self._bump_metric("requests_ok", 1)
|
| 957 |
+
return
|
| 958 |
+
|
| 959 |
+
except RateLimitError:
|
| 960 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 961 |
+
if attempt < max_retries - 1:
|
| 962 |
+
log_model_call(
|
| 963 |
+
LOGGER,
|
| 964 |
+
provider="deepseek",
|
| 965 |
+
model=target_model,
|
| 966 |
+
endpoint=self.ds_base_url,
|
| 967 |
+
latency_ms=latency_ms,
|
| 968 |
+
input_tokens=None,
|
| 969 |
+
output_tokens=None,
|
| 970 |
+
status="error",
|
| 971 |
+
error_class="RateLimitError",
|
| 972 |
+
error_message="rate limited",
|
| 973 |
+
task_type=task_type,
|
| 974 |
+
request_tag=req.request_tag,
|
| 975 |
+
retry_attempt=attempt + 1,
|
| 976 |
+
fallback_depth=fallback_depth,
|
| 977 |
+
route=route,
|
| 978 |
+
)
|
| 979 |
+
self._bump_metric("retries_total", 1)
|
| 980 |
+
await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
|
| 981 |
+
continue
|
| 982 |
+
self._bump_metric("requests_error", 1)
|
| 983 |
+
raise RuntimeError("DeepSeek API rate limit reached. Please try again shortly.")
|
| 984 |
+
|
| 985 |
+
except APITimeoutError:
|
| 986 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 987 |
+
if attempt < max_retries - 1:
|
| 988 |
+
log_model_call(
|
| 989 |
+
LOGGER,
|
| 990 |
+
provider="deepseek",
|
| 991 |
+
model=target_model,
|
| 992 |
+
endpoint=self.ds_base_url,
|
| 993 |
+
latency_ms=latency_ms,
|
| 994 |
+
input_tokens=None,
|
| 995 |
+
output_tokens=None,
|
| 996 |
+
status="error",
|
| 997 |
+
error_class="APITimeoutError",
|
| 998 |
+
error_message="timeout",
|
| 999 |
+
task_type=task_type,
|
| 1000 |
+
request_tag=req.request_tag,
|
| 1001 |
+
retry_attempt=attempt + 1,
|
| 1002 |
+
fallback_depth=fallback_depth,
|
| 1003 |
+
route=route,
|
| 1004 |
+
)
|
| 1005 |
+
self._bump_metric("retries_total", 1)
|
| 1006 |
+
await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
|
| 1007 |
+
continue
|
| 1008 |
+
self._bump_metric("requests_error", 1)
|
| 1009 |
+
raise RuntimeError("DeepSeek API timed out. Please retry.")
|
| 1010 |
+
|
| 1011 |
+
except APIError as e:
|
| 1012 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 1013 |
+
if attempt < max_retries - 1:
|
| 1014 |
+
log_model_call(
|
| 1015 |
+
LOGGER,
|
| 1016 |
+
provider="deepseek",
|
| 1017 |
+
model=target_model,
|
| 1018 |
+
endpoint=self.ds_base_url,
|
| 1019 |
+
latency_ms=latency_ms,
|
| 1020 |
+
input_tokens=None,
|
| 1021 |
+
output_tokens=None,
|
| 1022 |
+
status="error",
|
| 1023 |
+
error_class="APIError",
|
| 1024 |
+
error_message=str(e)[:200],
|
| 1025 |
+
task_type=task_type,
|
| 1026 |
+
request_tag=req.request_tag,
|
| 1027 |
+
retry_attempt=attempt + 1,
|
| 1028 |
+
fallback_depth=fallback_depth,
|
| 1029 |
+
route=route,
|
| 1030 |
+
)
|
| 1031 |
+
self._bump_metric("retries_total", 1)
|
| 1032 |
+
await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
|
| 1033 |
+
continue
|
| 1034 |
+
self._bump_metric("requests_error", 1)
|
| 1035 |
+
raise RuntimeError(f"DeepSeek API error: {str(e)}")
|
| 1036 |
+
|
| 1037 |
+
except Exception as exc:
|
| 1038 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 1039 |
+
self._bump_metric("requests_error", 1)
|
| 1040 |
+
last_error = exc
|
| 1041 |
+
log_model_call(
|
| 1042 |
+
LOGGER,
|
| 1043 |
+
provider="deepseek",
|
| 1044 |
+
model=target_model,
|
| 1045 |
+
endpoint=self.ds_base_url,
|
| 1046 |
+
latency_ms=latency_ms,
|
| 1047 |
+
input_tokens=None,
|
| 1048 |
+
output_tokens=None,
|
| 1049 |
+
status="error",
|
| 1050 |
+
error_class=exc.__class__.__name__,
|
| 1051 |
+
error_message=str(exc)[:200],
|
| 1052 |
+
task_type=task_type,
|
| 1053 |
+
request_tag=req.request_tag,
|
| 1054 |
+
retry_attempt=attempt + 1,
|
| 1055 |
+
fallback_depth=fallback_depth,
|
| 1056 |
+
route=route,
|
| 1057 |
+
)
|
| 1058 |
+
if attempt < max_retries - 1:
|
| 1059 |
+
await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
|
| 1060 |
+
continue
|
| 1061 |
+
raise
|
| 1062 |
+
|
| 1063 |
+
raise last_error or RuntimeError(f"DeepSeek stream failed after {max_retries} attempts")
|
| 1064 |
+
|
| 1065 |
def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
|
| 1066 |
target_model = req.model or self.default_model
|
| 1067 |
url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"
|