Dmitry Beresnev commited on
Commit
a8f6b6b
·
1 Parent(s): 130d9e3

Log elapsed time and token rate when the response arrives.

Browse files

Add a hard timeout (e.g., 300s) and return a friendly 504.
Add a rule to drop system prompts above a cap unless explicitly allowed.

Files changed (1) hide show
  1. app.py +74 -24
app.py CHANGED
@@ -116,6 +116,9 @@ CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048"))
116
 
117
  PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256"))
118
  CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0"))
 
 
 
119
 
120
 
121
  def _estimate_tokens(text: str) -> int:
@@ -157,7 +160,10 @@ def _compact_messages(messages: list[dict], max_tokens: int) -> list[dict]:
157
  system_cap = min(1024, max(256, prompt_budget // 3))
158
  for msg in compacted:
159
  if msg.get("role") == "system" and "content" in msg:
160
- msg["content"] = _truncate_text_to_tokens(str(msg["content"]), system_cap)
 
 
 
161
 
162
  def total_tokens(msgs: list[dict]) -> int:
163
  return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs)
@@ -979,6 +985,16 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
979
  f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
980
  )
981
 
 
 
 
 
 
 
 
 
 
 
982
  compacted_messages = _compact_messages(request.messages, request.max_tokens)
983
 
984
  compacted_tokens = _estimate_messages_tokens(compacted_messages)
@@ -1000,18 +1016,21 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
1000
  "max_tokens": request.max_tokens,
1001
  "temperature": request.temperature,
1002
  }
1003
- async with http_session.post(
1004
- f"{cached_model.url}/v1/chat/completions",
1005
- json=payload
1006
- ) as response:
1007
- if response.status >= 400:
1008
- error_text = await response.text()
1009
- logger.error(
1010
- f"request_id={request_id} llama-server {response.status} "
1011
- f"error_body={error_text[:1000]}"
1012
- )
1013
- response.raise_for_status()
1014
- result = await response.json()
 
 
 
1015
 
1016
  # Update metrics
1017
  request_latency = time.time() - request_start
@@ -1019,6 +1038,18 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
1019
  cached_model.total_latency += request_latency
1020
  metrics.record_request(current_model, request_latency)
1021
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
  return result
1023
  except aiohttp.ClientResponseError as e:
1024
  logger.exception(f"request_id={request_id} llama-server error")
@@ -1026,6 +1057,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
1026
  except aiohttp.ClientError as e:
1027
  logger.exception(f"request_id={request_id} llama-server error")
1028
  raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
 
 
 
1029
  except Exception:
1030
  logger.exception(f"request_id={request_id} chat_completions error")
1031
  raise
@@ -1150,6 +1184,16 @@ Always cite sources when using information from the search results."""
1150
  f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
1151
  )
1152
 
 
 
 
 
 
 
 
 
 
 
1153
  augmented_messages = _compact_messages(augmented_messages, request.max_tokens)
1154
 
1155
  compacted_tokens = _estimate_messages_tokens(augmented_messages)
@@ -1170,17 +1214,20 @@ Always cite sources when using information from the search results."""
1170
  if not cached_model:
1171
  raise HTTPException(status_code=500, detail="Current model not loaded")
1172
 
1173
- # Forward to llama-server with augmented context
1174
- async with http_session.post(
1175
- f"{cached_model.url}/v1/chat/completions",
1176
- json={
1177
- "messages": augmented_messages,
1178
- "max_tokens": request.max_tokens,
1179
- "temperature": request.temperature,
1180
- }
1181
- ) as response:
1182
- response.raise_for_status()
1183
- result = await response.json()
 
 
 
1184
 
1185
  # Add metadata about search results
1186
  result["web_search"] = {
@@ -1195,6 +1242,9 @@ Always cite sources when using information from the search results."""
1195
  except aiohttp.ClientError as e:
1196
  logger.exception(f"request_id={request_id} llama-server error")
1197
  raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
 
 
 
1198
  except HTTPException:
1199
  raise
1200
  except Exception as e:
 
116
 
117
  PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256"))
118
  CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0"))
119
+ SYSTEM_PROMPT_MAX_TOKENS = int(os.getenv("SYSTEM_PROMPT_MAX_TOKENS", "512"))
120
+ ALLOW_LONG_SYSTEM_PROMPT = os.getenv("ALLOW_LONG_SYSTEM_PROMPT", "0") == "1"
121
+ HARD_REQUEST_TIMEOUT = int(os.getenv("HARD_REQUEST_TIMEOUT", "300"))
122
 
123
 
124
  def _estimate_tokens(text: str) -> int:
 
160
  system_cap = min(1024, max(256, prompt_budget // 3))
161
  for msg in compacted:
162
  if msg.get("role") == "system" and "content" in msg:
163
+ if not ALLOW_LONG_SYSTEM_PROMPT and _estimate_tokens(str(msg["content"])) > SYSTEM_PROMPT_MAX_TOKENS:
164
+ msg["content"] = ""
165
+ else:
166
+ msg["content"] = _truncate_text_to_tokens(str(msg["content"]), system_cap)
167
 
168
  def total_tokens(msgs: list[dict]) -> int:
169
  return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs)
 
985
  f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
986
  )
987
 
988
+ # Drop system prompts above cap unless allowed
989
+ if not ALLOW_LONG_SYSTEM_PROMPT:
990
+ for msg in request.messages:
991
+ if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS:
992
+ logger.warning(
993
+ f"request_id={request_id} system_prompt_dropped "
994
+ f"tokens≈{_estimate_tokens(str(msg.get('content', '')))} cap≈{SYSTEM_PROMPT_MAX_TOKENS}"
995
+ )
996
+ break
997
+
998
  compacted_messages = _compact_messages(request.messages, request.max_tokens)
999
 
1000
  compacted_tokens = _estimate_messages_tokens(compacted_messages)
 
1016
  "max_tokens": request.max_tokens,
1017
  "temperature": request.temperature,
1018
  }
1019
+ async def _do_request():
1020
+ async with http_session.post(
1021
+ f"{cached_model.url}/v1/chat/completions",
1022
+ json=payload
1023
+ ) as response:
1024
+ if response.status >= 400:
1025
+ error_text = await response.text()
1026
+ logger.error(
1027
+ f"request_id={request_id} llama-server {response.status} "
1028
+ f"error_body={error_text[:1000]}"
1029
+ )
1030
+ response.raise_for_status()
1031
+ return await response.json()
1032
+
1033
+ result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT)
1034
 
1035
  # Update metrics
1036
  request_latency = time.time() - request_start
 
1038
  cached_model.total_latency += request_latency
1039
  metrics.record_request(current_model, request_latency)
1040
 
1041
+ # Log elapsed time and token rate (if usage available)
1042
+ usage = result.get("usage") if isinstance(result, dict) else None
1043
+ if usage and usage.get("completion_tokens"):
1044
+ completion_tokens = usage.get("completion_tokens", 0)
1045
+ tok_per_sec = completion_tokens / max(request_latency, 1e-6)
1046
+ logger.info(
1047
+ f"request_id={request_id} done "
1048
+ f"time={request_latency:.2f}s tokens={completion_tokens} tok/s={tok_per_sec:.1f}"
1049
+ )
1050
+ else:
1051
+ logger.info(f"request_id={request_id} done time={request_latency:.2f}s")
1052
+
1053
  return result
1054
  except aiohttp.ClientResponseError as e:
1055
  logger.exception(f"request_id={request_id} llama-server error")
 
1057
  except aiohttp.ClientError as e:
1058
  logger.exception(f"request_id={request_id} llama-server error")
1059
  raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
1060
+ except asyncio.TimeoutError:
1061
+ logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s")
1062
+ raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.")
1063
  except Exception:
1064
  logger.exception(f"request_id={request_id} chat_completions error")
1065
  raise
 
1184
  f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
1185
  )
1186
 
1187
+ # Drop system prompts above cap unless allowed
1188
+ if not ALLOW_LONG_SYSTEM_PROMPT:
1189
+ for msg in augmented_messages:
1190
+ if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS:
1191
+ logger.warning(
1192
+ f"request_id={request_id} system_prompt_dropped "
1193
+ f"tokens≈{_estimate_tokens(str(msg.get('content', '')))} cap≈{SYSTEM_PROMPT_MAX_TOKENS}"
1194
+ )
1195
+ break
1196
+
1197
  augmented_messages = _compact_messages(augmented_messages, request.max_tokens)
1198
 
1199
  compacted_tokens = _estimate_messages_tokens(augmented_messages)
 
1214
  if not cached_model:
1215
  raise HTTPException(status_code=500, detail="Current model not loaded")
1216
 
1217
+ async def _do_request():
1218
+ # Forward to llama-server with augmented context
1219
+ async with http_session.post(
1220
+ f"{cached_model.url}/v1/chat/completions",
1221
+ json={
1222
+ "messages": augmented_messages,
1223
+ "max_tokens": request.max_tokens,
1224
+ "temperature": request.temperature,
1225
+ }
1226
+ ) as response:
1227
+ response.raise_for_status()
1228
+ return await response.json()
1229
+
1230
+ result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT)
1231
 
1232
  # Add metadata about search results
1233
  result["web_search"] = {
 
1242
  except aiohttp.ClientError as e:
1243
  logger.exception(f"request_id={request_id} llama-server error")
1244
  raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
1245
+ except asyncio.TimeoutError:
1246
+ logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s")
1247
+ raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.")
1248
  except HTTPException:
1249
  raise
1250
  except Exception as e: