bluewinliang commited on
Commit
83eadcd
·
verified ·
1 Parent(s): 41a07df

Upload proxy_handler.py

Browse files
Files changed (1) hide show
  1. proxy_handler.py +62 -34
proxy_handler.py CHANGED
@@ -22,6 +22,7 @@ class ProxyHandler:
22
  limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
23
  http2=True,
24
  )
 
25
  self.primary_secret = "junjie".encode('utf-8')
26
 
27
  async def aclose(self):
@@ -32,46 +33,54 @@ class ProxyHandler:
32
  return int(time.time() * 1000)
33
 
34
  def _parse_jwt_token(self, token: str) -> Dict[str, str]:
 
35
  try:
36
  parts = token.split('.')
37
  if len(parts) != 3: return {"user_id": ""}
38
  payload_b64 = parts[1]
39
- payload_b64 += '=' * (-len(payload_b64) % 4)
40
  payload_json = base64.urlsafe_b64decode(payload_b64).decode('utf-8')
41
  payload = json.loads(payload_json)
42
  return {"user_id": payload.get("sub", "")}
43
  except Exception:
 
44
  return {"user_id": ""}
45
 
46
- def _generate_signature(self, e_payload: str, t_payload: str, timestamp_ms: int) -> Dict[str, Any]:
47
  """
48
  Generates the signature based on the logic from the reference JS code.
49
- This version corrects the HMAC chaining issue by using .digest() for the intermediate key.
50
 
51
  Args:
52
  e_payload (str): The simplified payload string (e.g., "requestId,...,timestamp,...").
53
  t_payload (str): The last message content.
54
- timestamp_ms (int): The consistent timestamp generated for the request.
55
 
56
  Returns:
57
  A dictionary with 'signature' and 'timestamp'.
58
  """
59
- b64_encoded_t = base64.b64encode(t_payload.encode("utf-8")).decode("utf-8")
60
- message_string = f"{e_payload}|{b64_encoded_t}|{timestamp_ms}"
61
- n = timestamp_ms // (5 * 60 * 1000)
 
 
 
 
 
 
 
 
 
62
 
63
- # --- MODIFICATION START: Correct HMAC Chaining ---
 
64
 
65
- # 1. First HMAC: Calculate intermediate key as RAW BYTES using .digest()
66
  msg1 = str(n).encode("utf-8")
67
- intermediate_key_bytes = hmac.new(self.primary_secret, msg1, hashlib.sha256).digest()
68
 
69
- # 2. Second HMAC: Use the raw bytes of the intermediate key directly.
70
- # The final result is converted to a hex string for the header.
71
  msg2 = message_string.encode("utf-8")
72
- final_signature = hmac.new(intermediate_key_bytes, msg2, hashlib.sha256).hexdigest()
73
-
74
- # --- MODIFICATION END ---
75
 
76
  return {"signature": final_signature, "timestamp": timestamp_ms}
77
 
@@ -86,14 +95,13 @@ class ProxyHandler:
86
 
87
  def _clean_answer_content(self, text: str) -> str:
88
  if not text: return ""
89
- cleaned_text = re.sub(r'<details[^>]*>.*?</details>', '', text, flags=re.DOTALL)
90
- cleaned_text = re.sub(r'<glm_block.*?</glm_block>|<summary>.*?</summary>', '', text, flags=re.DOTALL)
91
- cleaned_text = re.sub(r'\s*duration="\d+"[^>]*>', '', cleaned_text)
92
  return cleaned_text
93
 
94
  def _serialize_msgs(self, msgs) -> list:
95
  out = []
96
  for m in msgs:
 
97
  if hasattr(m, "dict"): out.append(m.dict())
98
  elif hasattr(m, "model_dump"): out.append(m.model_dump())
99
  elif isinstance(m, dict): out.append(m)
@@ -101,30 +109,39 @@ class ProxyHandler:
101
  return out
102
 
103
  async def _prep_upstream(self, req: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str], str, str]:
 
104
  ck = await cookie_manager.get_next_cookie()
105
  if not ck: raise HTTPException(503, "No available cookies")
106
 
107
  model = settings.UPSTREAM_MODEL if req.model == settings.MODEL_NAME else req.model
 
 
108
 
109
- timestamp_ms = self._get_timestamp_millis()
110
- payload_timestamp = str(timestamp_ms)
111
-
 
 
112
  payload_user_id = str(uuid.uuid4())
113
  payload_request_id = str(uuid.uuid4())
 
114
 
 
115
  e_payload = f"requestId,{payload_request_id},timestamp,{payload_timestamp},user_id,{payload_user_id}"
116
 
 
117
  t_payload = ""
118
  if req.messages:
119
  last_message = req.messages[-1]
120
  if isinstance(last_message.content, str):
121
  t_payload = last_message.content
122
 
123
- signature_data = self._generate_signature(e_payload, t_payload, timestamp_ms)
124
-
125
  signature = signature_data["signature"]
126
  signature_timestamp = signature_data["timestamp"]
127
 
 
128
  url_params = {
129
  "requestId": payload_request_id,
130
  "timestamp": payload_timestamp,
@@ -132,14 +149,16 @@ class ProxyHandler:
132
  "signature_timestamp": str(signature_timestamp)
133
  }
134
 
 
 
135
  final_url = httpx.URL(settings.UPSTREAM_URL).copy_with(params=url_params)
136
 
137
  body = {
138
  "stream": True,
139
  "model": model,
140
  "messages": self._serialize_msgs(req.messages),
141
- "chat_id": str(uuid.uuid4()),
142
- "id": str(uuid.uuid4()),
143
  "features": {
144
  "image_generation": False,
145
  "web_search": False,
@@ -209,6 +228,7 @@ class ProxyHandler:
209
  line = line.strip()
210
  if not line.startswith('data: '): continue
211
  payload_str = line[6:]
 
212
  if payload_str == '[DONE]':
213
  if think_open:
214
  yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': '</think>'}, 'finish_reason': None}]})}\n\n"
@@ -222,7 +242,11 @@ class ProxyHandler:
222
  phase = dat.get("phase")
223
  content_chunk = dat.get("delta_content") or dat.get("edit_content")
224
  if not content_chunk:
225
- continue
 
 
 
 
226
 
227
  if phase == "thinking":
228
  current_raw_thinking = content_chunk if dat.get("edit_content") is not None else current_raw_thinking + content_chunk
@@ -231,12 +255,10 @@ class ProxyHandler:
231
  elif phase == "answer":
232
  content_to_process = content_chunk
233
  if is_first_answer_chunk:
234
- last_bracket_pos = content_to_process.rfind('>')
235
- if last_bracket_pos != -1:
236
- content_to_process = content_to_process[last_bracket_pos + 1:]
237
- content_to_process = content_to_process.lstrip()
238
  is_first_answer_chunk = False
239
-
240
  if content_to_process:
241
  async for item in yield_delta("answer", content_to_process):
242
  yield item
@@ -244,9 +266,12 @@ class ProxyHandler:
244
  logger.exception("Stream error"); raise
245
 
246
  async def non_stream_proxy_response(self, req: ChatCompletionRequest) -> ChatCompletionResponse:
 
 
247
  ck = None
248
  try:
249
  body, headers, ck, url = await self._prep_upstream(req)
 
250
  body["stream"] = False
251
 
252
  async with self.client.post(url, json=body, headers=headers) as resp:
@@ -256,9 +281,13 @@ class ProxyHandler:
256
  raise HTTPException(resp.status_code, f"Upstream error: {error_detail}")
257
 
258
  await cookie_manager.mark_cookie_success(ck)
 
 
259
  response_data = resp.json()
 
 
260
  final_content = ""
261
- finish_reason = "stop"
262
 
263
  if "choices" in response_data and response_data["choices"]:
264
  first_choice = response_data["choices"][0]
@@ -267,8 +296,6 @@ class ProxyHandler:
267
  if "finish_reason" in first_choice:
268
  finish_reason = first_choice["finish_reason"]
269
 
270
- final_content = self._clean_answer_content(final_content)
271
-
272
  return ChatCompletionResponse(
273
  id=response_data.get("id", f"chatcmpl-{uuid.uuid4().hex[:29]}"),
274
  created=int(time.time()),
@@ -279,6 +306,7 @@ class ProxyHandler:
279
  logger.exception("Non-stream processing failed"); raise
280
 
281
  async def handle_chat_completion(self, req: ChatCompletionRequest):
 
282
  stream = bool(req.stream) if req.stream is not None else settings.DEFAULT_STREAM
283
  if stream:
284
  return StreamingResponse(self.stream_proxy_response(req), media_type="text/event-stream",
 
22
  limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
23
  http2=True,
24
  )
25
+ # The primary secret key from the reference code.
26
  self.primary_secret = "junjie".encode('utf-8')
27
 
28
  async def aclose(self):
 
33
  return int(time.time() * 1000)
34
 
35
  def _parse_jwt_token(self, token: str) -> Dict[str, str]:
36
+ """A simple JWT payload decoder to get user ID ('sub' claim)."""
37
  try:
38
  parts = token.split('.')
39
  if len(parts) != 3: return {"user_id": ""}
40
  payload_b64 = parts[1]
41
+ payload_b64 += '=' * (-len(payload_b64) % 4) # Add padding if needed
42
  payload_json = base64.urlsafe_b64decode(payload_b64).decode('utf-8')
43
  payload = json.loads(payload_json)
44
  return {"user_id": payload.get("sub", "")}
45
  except Exception:
46
+ # It's okay if this fails; we'll proceed with an empty user_id.
47
  return {"user_id": ""}
48
 
49
+ def _generate_signature(self, e_payload: str, t_payload: str) -> Dict[str, Any]:
50
  """
51
  Generates the signature based on the logic from the reference JS code.
52
+ This is a two-level HMAC-SHA256 process.
53
 
54
  Args:
55
  e_payload (str): The simplified payload string (e.g., "requestId,...,timestamp,...").
56
  t_payload (str): The last message content.
 
57
 
58
  Returns:
59
  A dictionary with 'signature' and 'timestamp'.
60
  """
61
+ # The provided reference code uses a different logic for the key derivation.
62
+ # It's based on a timestamp bucket. Let's re-implement that one.
63
+ # However, the OTHER reference code `signature_generator.py` uses a different method.
64
+ # Let's stick to the one from the new `utils.py` and `signature_generator.py` for now.
65
+ # The provided python snippet in the prompt is actually different from the JS.
66
+ # The python snippet is: `n = timestamp_ms // (5 * 60 * 1000)`
67
+ # The JS snippet is: `minuteBucket = Math.floor(timestampMs / 60000)`
68
+ # Let's trust the JS one as it's more complete. Let's try the python one first as it's provided.
69
+
70
+ # --- Let's use the Python snippet logic from the prompt first ---
71
+ timestamp_ms = self._get_timestamp_millis()
72
+ message_string = f"{e_payload}|{t_payload}|{timestamp_ms}"
73
 
74
+ # Per the Python snippet: n is a 5-minute bucket
75
+ n = timestamp_ms // (5 * 60 * 1000)
76
 
77
+ # Intermediate key derivation
78
  msg1 = str(n).encode("utf-8")
79
+ intermediate_key = hmac.new(self.primary_secret, msg1, hashlib.sha256).hexdigest()
80
 
81
+ # Final signature
 
82
  msg2 = message_string.encode("utf-8")
83
+ final_signature = hmac.new(intermediate_key.encode("utf-8"), msg2, hashlib.sha256).hexdigest()
 
 
84
 
85
  return {"signature": final_signature, "timestamp": timestamp_ms}
86
 
 
95
 
96
  def _clean_answer_content(self, text: str) -> str:
97
  if not text: return ""
98
+ cleaned_text = re.sub(r'<glm_block.*?</glm_block>|<details[^>]*>.*?</details>|<summary>.*?</summary>', '', text, flags=re.DOTALL)
 
 
99
  return cleaned_text
100
 
101
  def _serialize_msgs(self, msgs) -> list:
102
  out = []
103
  for m in msgs:
104
+ # Adapting to Pydantic v1/v2 and dicts
105
  if hasattr(m, "dict"): out.append(m.dict())
106
  elif hasattr(m, "model_dump"): out.append(m.model_dump())
107
  elif isinstance(m, dict): out.append(m)
 
109
  return out
110
 
111
  async def _prep_upstream(self, req: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str], str, str]:
112
+ """Prepares the request body, headers, cookie, and URL for the upstream API."""
113
  ck = await cookie_manager.get_next_cookie()
114
  if not ck: raise HTTPException(503, "No available cookies")
115
 
116
  model = settings.UPSTREAM_MODEL if req.model == settings.MODEL_NAME else req.model
117
+ chat_id = str(uuid.uuid4())
118
+ request_id = str(uuid.uuid4())
119
 
120
+ # --- NEW Simplified Signature Payload Logic ---
121
+ user_info = self._parse_jwt_token(ck)
122
+ user_id = user_info.get("user_id", "")
123
+ # The reference code uses a separate UUID for user_id in payload, let's follow that.
124
+ # This seems strange, but let's replicate the reference code exactly.
125
  payload_user_id = str(uuid.uuid4())
126
  payload_request_id = str(uuid.uuid4())
127
+ payload_timestamp = str(self._get_timestamp_millis())
128
 
129
+ # e: The simplified payload for the signature
130
  e_payload = f"requestId,{payload_request_id},timestamp,{payload_timestamp},user_id,{payload_user_id}"
131
 
132
+ # t: The last message content
133
  t_payload = ""
134
  if req.messages:
135
  last_message = req.messages[-1]
136
  if isinstance(last_message.content, str):
137
  t_payload = last_message.content
138
 
139
+ # Generate the signature
140
+ signature_data = self._generate_signature(e_payload, t_payload)
141
  signature = signature_data["signature"]
142
  signature_timestamp = signature_data["timestamp"]
143
 
144
+ # The reference code sends these as URL parameters, not in the body.
145
  url_params = {
146
  "requestId": payload_request_id,
147
  "timestamp": payload_timestamp,
 
149
  "signature_timestamp": str(signature_timestamp)
150
  }
151
 
152
+ # Construct URL with query parameters
153
+ # Note: The reference code has a typo `f"{BASE_URL}/api/chat/completions"`, it should be `z.ai`
154
  final_url = httpx.URL(settings.UPSTREAM_URL).copy_with(params=url_params)
155
 
156
  body = {
157
  "stream": True,
158
  "model": model,
159
  "messages": self._serialize_msgs(req.messages),
160
+ "chat_id": chat_id,
161
+ "id": request_id,
162
  "features": {
163
  "image_generation": False,
164
  "web_search": False,
 
228
  line = line.strip()
229
  if not line.startswith('data: '): continue
230
  payload_str = line[6:]
231
+ # The reference code has a special 'done' phase, but the original Z.AI uses [DONE]
232
  if payload_str == '[DONE]':
233
  if think_open:
234
  yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': '</think>'}, 'finish_reason': None}]})}\n\n"
 
242
  phase = dat.get("phase")
243
  content_chunk = dat.get("delta_content") or dat.get("edit_content")
244
  if not content_chunk:
245
+ # Handle case where chunk is just usage info, etc.
246
+ if phase == 'other' and dat.get('usage'):
247
+ pass # In streaming, usage might come with the final chunk
248
+ else:
249
+ continue
250
 
251
  if phase == "thinking":
252
  current_raw_thinking = content_chunk if dat.get("edit_content") is not None else current_raw_thinking + content_chunk
 
255
  elif phase == "answer":
256
  content_to_process = content_chunk
257
  if is_first_answer_chunk:
258
+ if '</details>' in content_to_process:
259
+ parts = content_to_process.split('</details>', 1)
260
+ content_to_process = parts[1] if len(parts) > 1 else ""
 
261
  is_first_answer_chunk = False
 
262
  if content_to_process:
263
  async for item in yield_delta("answer", content_to_process):
264
  yield item
 
266
  logger.exception("Stream error"); raise
267
 
268
  async def non_stream_proxy_response(self, req: ChatCompletionRequest) -> ChatCompletionResponse:
269
+ # This part of the code can be simplified as well, but let's focus on fixing the streaming first.
270
+ # The logic will be almost identical to the streaming one.
271
  ck = None
272
  try:
273
  body, headers, ck, url = await self._prep_upstream(req)
274
+ # For non-stream, set stream to False in the body
275
  body["stream"] = False
276
 
277
  async with self.client.post(url, json=body, headers=headers) as resp:
 
281
  raise HTTPException(resp.status_code, f"Upstream error: {error_detail}")
282
 
283
  await cookie_manager.mark_cookie_success(ck)
284
+
285
+ # Z.AI non-stream response is a single JSON object
286
  response_data = resp.json()
287
+
288
+ # We need to adapt Z.AI's response format to OpenAI's format
289
  final_content = ""
290
+ finish_reason = "stop" # Default
291
 
292
  if "choices" in response_data and response_data["choices"]:
293
  first_choice = response_data["choices"][0]
 
296
  if "finish_reason" in first_choice:
297
  finish_reason = first_choice["finish_reason"]
298
 
 
 
299
  return ChatCompletionResponse(
300
  id=response_data.get("id", f"chatcmpl-{uuid.uuid4().hex[:29]}"),
301
  created=int(time.time()),
 
306
  logger.exception("Non-stream processing failed"); raise
307
 
308
  async def handle_chat_completion(self, req: ChatCompletionRequest):
309
+ """Determines whether to stream or not and handles the request."""
310
  stream = bool(req.stream) if req.stream is not None else settings.DEFAULT_STREAM
311
  if stream:
312
  return StreamingResponse(self.stream_proxy_response(req), media_type="text/event-stream",