github-actions[bot] commited on
Commit
1af7678
·
1 Parent(s): 4c52012

🚀 Auto-deploy backend from GitHub (3d8979f)

Browse files
Files changed (2) hide show
  1. services/ai_client.py +13 -1
  2. services/inference_client.py +185 -3
services/ai_client.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
- from openai import OpenAI, APIError, RateLimitError, APITimeoutError
3
  from functools import lru_cache
4
 
5
  __all__ = [
6
  "get_deepseek_client",
 
7
  "CHAT_MODEL",
8
  "REASONER_MODEL",
9
  "DEEPSEEK_BASE_URL",
@@ -25,4 +26,15 @@ def get_deepseek_client() -> OpenAI:
25
  return OpenAI(
26
  api_key=api_key,
27
  base_url=DEEPSEEK_BASE_URL,
 
 
 
 
 
 
 
 
 
 
 
28
  )
 
1
  import os
2
+ from openai import OpenAI, AsyncOpenAI, APIError, RateLimitError, APITimeoutError
3
  from functools import lru_cache
4
 
5
  __all__ = [
6
  "get_deepseek_client",
7
+ "get_async_deepseek_client",
8
  "CHAT_MODEL",
9
  "REASONER_MODEL",
10
  "DEEPSEEK_BASE_URL",
 
26
  return OpenAI(
27
  api_key=api_key,
28
  base_url=DEEPSEEK_BASE_URL,
29
+ )
30
+
31
+
32
+ @lru_cache(maxsize=1)
33
+ def get_async_deepseek_client() -> AsyncOpenAI:
34
+ api_key = os.getenv("DEEPSEEK_API_KEY")
35
+ if not api_key:
36
+ raise ValueError("DEEPSEEK_API_KEY environment variable not set")
37
+ return AsyncOpenAI(
38
+ api_key=api_key,
39
+ base_url=DEEPSEEK_BASE_URL,
40
  )
services/inference_client.py CHANGED
@@ -6,13 +6,13 @@ import random
6
  from threading import Lock
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Optional, Tuple
10
 
11
  import requests
12
  import yaml
13
- from openai import OpenAI, APIError, RateLimitError, APITimeoutError
14
 
15
- from .ai_client import get_deepseek_client, CHAT_MODEL, REASONER_MODEL, DEEPSEEK_BASE_URL
16
  from .logging_utils import configure_structured_logging, log_model_call
17
 
18
  LOGGER = configure_structured_logging("mathpulse.inference")
@@ -271,6 +271,13 @@ class InferenceClient:
271
  primary = primary_cfg
272
 
273
  self.provider = "deepseek"
 
 
 
 
 
 
 
274
  self.ds_api_key = os.getenv("DEEPSEEK_API_KEY", "")
275
  self.ds_base_url = os.getenv("DEEPSEEK_BASE_URL", DEEPSEEK_BASE_URL)
276
  self.ds_chat_model = os.getenv("DEEPSEEK_MODEL", CHAT_MODEL)
@@ -674,6 +681,10 @@ class InferenceClient:
674
  return self.interactive_timeout_sec
675
  return self.background_timeout_sec
676
 
 
 
 
 
677
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
678
  parts: List[str] = []
679
  for msg in messages:
@@ -880,6 +891,177 @@ class InferenceClient:
880
 
881
  raise RuntimeError(f"DeepSeek call failed after {max_retries} attempts")
882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
884
  target_model = req.model or self.default_model
885
  url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"
 
6
  from threading import Lock
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
+ from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple
10
 
11
  import requests
12
  import yaml
13
+ from openai import OpenAI, AsyncOpenAI, APIError, RateLimitError, APITimeoutError
14
 
15
+ from .ai_client import get_deepseek_client, get_async_deepseek_client, CHAT_MODEL, REASONER_MODEL, DEEPSEEK_BASE_URL
16
  from .logging_utils import configure_structured_logging, log_model_call
17
 
18
  LOGGER = configure_structured_logging("mathpulse.inference")
 
271
  primary = primary_cfg
272
 
273
  self.provider = "deepseek"
274
+ self.cpu_provider = os.getenv("INFERENCE_CPU_PROVIDER", "deepseek").strip().lower() or self.provider
275
+ self.gpu_provider = os.getenv("INFERENCE_GPU_PROVIDER", "deepseek").strip().lower() or self.provider
276
+ # Pro provider not used in current setup
277
+ self.pro_enabled = False
278
+ self.pro_provider = self.provider
279
+ self.pro_priority_tasks: Set[str] = set()
280
+ self.enable_provider_fallback = False
281
  self.ds_api_key = os.getenv("DEEPSEEK_API_KEY", "")
282
  self.ds_base_url = os.getenv("DEEPSEEK_BASE_URL", DEEPSEEK_BASE_URL)
283
  self.ds_chat_model = os.getenv("DEEPSEEK_MODEL", CHAT_MODEL)
 
681
  return self.interactive_timeout_sec
682
  return self.background_timeout_sec
683
 
684
+ def _provider_chain_for_task(self, task_type: str) -> List[str]:
685
+ """Return provider chain for task. All inference uses deepseek."""
686
+ return ["deepseek"]
687
+
688
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
689
  parts: List[str] = []
690
  for msg in messages:
 
891
 
892
  raise RuntimeError(f"DeepSeek call failed after {max_retries} attempts")
893
 
894
+ async def _call_deepseek_stream(self, req: InferenceRequest, fallback_depth: int) -> AsyncIterator[str]:
895
+ """Stream DeepSeek API with OpenAI-compatible chat completions. Yields content chunks."""
896
+ if not self.ds_api_key:
897
+ raise RuntimeError("DEEPSEEK_API_KEY is not set")
898
+
899
+ target_model = req.model or self.default_model
900
+ route = "deepseek"
901
+ task_type = req.task_type or "default"
902
+
903
+ LOGGER.debug(
904
+ f"📞 Streaming DeepSeek: task={task_type} model={target_model} route={route}"
905
+ )
906
+
907
+ timeout = self._timeout_for(req, "deepseek")
908
+ max_retries, backoff_sec = self._retry_profile(task_type)
909
+
910
+ client = get_async_deepseek_client()
911
+
912
+ params: Dict[str, Any] = {
913
+ "model": target_model,
914
+ "messages": req.messages,
915
+ "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
916
+ "stream": True,
917
+ }
918
+
919
+ if target_model == REASONER_MODEL:
920
+ params["max_tokens"] = req.max_new_tokens or 1024
921
+ else:
922
+ params["temperature"] = req.temperature
923
+ params["top_p"] = req.top_p
924
+
925
+ last_error: Optional[Exception] = None
926
+ for attempt in range(max_retries):
927
+ self._record_attempt(
928
+ task_type=task_type,
929
+ provider="deepseek",
930
+ route=route,
931
+ fallback_depth=fallback_depth,
932
+ )
933
+ start = time.perf_counter()
934
+ try:
935
+ async with client.chat.completions.stream(**params, timeout=timeout) as stream:
936
+ async for event in stream:
937
+ if event.type == "content.delta" and event.content:
938
+ yield event.content
939
+
940
+ latency_ms = (time.perf_counter() - start) * 1000
941
+ log_model_call(
942
+ LOGGER,
943
+ provider="deepseek",
944
+ model=target_model,
945
+ endpoint=self.ds_base_url,
946
+ latency_ms=latency_ms,
947
+ input_tokens=None,
948
+ output_tokens=None,
949
+ status="ok",
950
+ task_type=task_type,
951
+ request_tag=req.request_tag,
952
+ retry_attempt=attempt + 1,
953
+ fallback_depth=fallback_depth,
954
+ route=route,
955
+ )
956
+ self._bump_metric("requests_ok", 1)
957
+ return
958
+
959
+ except RateLimitError:
960
+ latency_ms = (time.perf_counter() - start) * 1000
961
+ if attempt < max_retries - 1:
962
+ log_model_call(
963
+ LOGGER,
964
+ provider="deepseek",
965
+ model=target_model,
966
+ endpoint=self.ds_base_url,
967
+ latency_ms=latency_ms,
968
+ input_tokens=None,
969
+ output_tokens=None,
970
+ status="error",
971
+ error_class="RateLimitError",
972
+ error_message="rate limited",
973
+ task_type=task_type,
974
+ request_tag=req.request_tag,
975
+ retry_attempt=attempt + 1,
976
+ fallback_depth=fallback_depth,
977
+ route=route,
978
+ )
979
+ self._bump_metric("retries_total", 1)
980
+ await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
981
+ continue
982
+ self._bump_metric("requests_error", 1)
983
+ raise RuntimeError("DeepSeek API rate limit reached. Please try again shortly.")
984
+
985
+ except APITimeoutError:
986
+ latency_ms = (time.perf_counter() - start) * 1000
987
+ if attempt < max_retries - 1:
988
+ log_model_call(
989
+ LOGGER,
990
+ provider="deepseek",
991
+ model=target_model,
992
+ endpoint=self.ds_base_url,
993
+ latency_ms=latency_ms,
994
+ input_tokens=None,
995
+ output_tokens=None,
996
+ status="error",
997
+ error_class="APITimeoutError",
998
+ error_message="timeout",
999
+ task_type=task_type,
1000
+ request_tag=req.request_tag,
1001
+ retry_attempt=attempt + 1,
1002
+ fallback_depth=fallback_depth,
1003
+ route=route,
1004
+ )
1005
+ self._bump_metric("retries_total", 1)
1006
+ await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
1007
+ continue
1008
+ self._bump_metric("requests_error", 1)
1009
+ raise RuntimeError("DeepSeek API timed out. Please retry.")
1010
+
1011
+ except APIError as e:
1012
+ latency_ms = (time.perf_counter() - start) * 1000
1013
+ if attempt < max_retries - 1:
1014
+ log_model_call(
1015
+ LOGGER,
1016
+ provider="deepseek",
1017
+ model=target_model,
1018
+ endpoint=self.ds_base_url,
1019
+ latency_ms=latency_ms,
1020
+ input_tokens=None,
1021
+ output_tokens=None,
1022
+ status="error",
1023
+ error_class="APIError",
1024
+ error_message=str(e)[:200],
1025
+ task_type=task_type,
1026
+ request_tag=req.request_tag,
1027
+ retry_attempt=attempt + 1,
1028
+ fallback_depth=fallback_depth,
1029
+ route=route,
1030
+ )
1031
+ self._bump_metric("retries_total", 1)
1032
+ await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
1033
+ continue
1034
+ self._bump_metric("requests_error", 1)
1035
+ raise RuntimeError(f"DeepSeek API error: {str(e)}")
1036
+
1037
+ except Exception as exc:
1038
+ latency_ms = (time.perf_counter() - start) * 1000
1039
+ self._bump_metric("requests_error", 1)
1040
+ last_error = exc
1041
+ log_model_call(
1042
+ LOGGER,
1043
+ provider="deepseek",
1044
+ model=target_model,
1045
+ endpoint=self.ds_base_url,
1046
+ latency_ms=latency_ms,
1047
+ input_tokens=None,
1048
+ output_tokens=None,
1049
+ status="error",
1050
+ error_class=exc.__class__.__name__,
1051
+ error_message=str(exc)[:200],
1052
+ task_type=task_type,
1053
+ request_tag=req.request_tag,
1054
+ retry_attempt=attempt + 1,
1055
+ fallback_depth=fallback_depth,
1056
+ route=route,
1057
+ )
1058
+ if attempt < max_retries - 1:
1059
+ await asyncio.sleep(backoff_sec * (attempt + 1) * random.uniform(0.9, 1.2))
1060
+ continue
1061
+ raise
1062
+
1063
+ raise last_error or RuntimeError(f"DeepSeek stream failed after {max_retries} attempts")
1064
+
1065
  def _call_local_space(self, req: InferenceRequest, *, provider: str, route: str, fallback_depth: int) -> str:
1066
  target_model = req.model or self.default_model
1067
  url = f"{self.local_space_url.rstrip('/')}{self.local_generate_path}"