Spaces:
Sleeping
Sleeping
Dmitry Beresnev commited on
Commit ·
a8f6b6b
1
Parent(s): 130d9e3
Log elapsed time and token rate when the response arrives.
Browse filesAdd a hard timeout (e.g., 300s) and return a friendly 504.
Add a rule to drop system prompts above a cap unless explicitly allowed.
app.py
CHANGED
|
@@ -116,6 +116,9 @@ CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048"))
|
|
| 116 |
|
| 117 |
PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256"))
|
| 118 |
CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0"))
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def _estimate_tokens(text: str) -> int:
|
|
@@ -157,7 +160,10 @@ def _compact_messages(messages: list[dict], max_tokens: int) -> list[dict]:
|
|
| 157 |
system_cap = min(1024, max(256, prompt_budget // 3))
|
| 158 |
for msg in compacted:
|
| 159 |
if msg.get("role") == "system" and "content" in msg:
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def total_tokens(msgs: list[dict]) -> int:
|
| 163 |
return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs)
|
|
@@ -979,6 +985,16 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
|
|
| 979 |
f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
|
| 980 |
)
|
| 981 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
compacted_messages = _compact_messages(request.messages, request.max_tokens)
|
| 983 |
|
| 984 |
compacted_tokens = _estimate_messages_tokens(compacted_messages)
|
|
@@ -1000,18 +1016,21 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
|
|
| 1000 |
"max_tokens": request.max_tokens,
|
| 1001 |
"temperature": request.temperature,
|
| 1002 |
}
|
| 1003 |
-
async
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
|
|
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
# Update metrics
|
| 1017 |
request_latency = time.time() - request_start
|
|
@@ -1019,6 +1038,18 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
|
|
| 1019 |
cached_model.total_latency += request_latency
|
| 1020 |
metrics.record_request(current_model, request_latency)
|
| 1021 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
return result
|
| 1023 |
except aiohttp.ClientResponseError as e:
|
| 1024 |
logger.exception(f"request_id={request_id} llama-server error")
|
|
@@ -1026,6 +1057,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
|
|
| 1026 |
except aiohttp.ClientError as e:
|
| 1027 |
logger.exception(f"request_id={request_id} llama-server error")
|
| 1028 |
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 1029 |
except Exception:
|
| 1030 |
logger.exception(f"request_id={request_id} chat_completions error")
|
| 1031 |
raise
|
|
@@ -1150,6 +1184,16 @@ Always cite sources when using information from the search results."""
|
|
| 1150 |
f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
|
| 1151 |
)
|
| 1152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
augmented_messages = _compact_messages(augmented_messages, request.max_tokens)
|
| 1154 |
|
| 1155 |
compacted_tokens = _estimate_messages_tokens(augmented_messages)
|
|
@@ -1170,17 +1214,20 @@ Always cite sources when using information from the search results."""
|
|
| 1170 |
if not cached_model:
|
| 1171 |
raise HTTPException(status_code=500, detail="Current model not loaded")
|
| 1172 |
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
|
|
|
|
|
|
|
|
|
| 1184 |
|
| 1185 |
# Add metadata about search results
|
| 1186 |
result["web_search"] = {
|
|
@@ -1195,6 +1242,9 @@ Always cite sources when using information from the search results."""
|
|
| 1195 |
except aiohttp.ClientError as e:
|
| 1196 |
logger.exception(f"request_id={request_id} llama-server error")
|
| 1197 |
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 1198 |
except HTTPException:
|
| 1199 |
raise
|
| 1200 |
except Exception as e:
|
|
|
|
| 116 |
|
| 117 |
PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256"))
|
| 118 |
CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0"))
|
| 119 |
+
SYSTEM_PROMPT_MAX_TOKENS = int(os.getenv("SYSTEM_PROMPT_MAX_TOKENS", "512"))
|
| 120 |
+
ALLOW_LONG_SYSTEM_PROMPT = os.getenv("ALLOW_LONG_SYSTEM_PROMPT", "0") == "1"
|
| 121 |
+
HARD_REQUEST_TIMEOUT = int(os.getenv("HARD_REQUEST_TIMEOUT", "300"))
|
| 122 |
|
| 123 |
|
| 124 |
def _estimate_tokens(text: str) -> int:
|
|
|
|
| 160 |
system_cap = min(1024, max(256, prompt_budget // 3))
|
| 161 |
for msg in compacted:
|
| 162 |
if msg.get("role") == "system" and "content" in msg:
|
| 163 |
+
if not ALLOW_LONG_SYSTEM_PROMPT and _estimate_tokens(str(msg["content"])) > SYSTEM_PROMPT_MAX_TOKENS:
|
| 164 |
+
msg["content"] = ""
|
| 165 |
+
else:
|
| 166 |
+
msg["content"] = _truncate_text_to_tokens(str(msg["content"]), system_cap)
|
| 167 |
|
| 168 |
def total_tokens(msgs: list[dict]) -> int:
|
| 169 |
return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs)
|
|
|
|
| 985 |
f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
|
| 986 |
)
|
| 987 |
|
| 988 |
+
# Drop system prompts above cap unless allowed
|
| 989 |
+
if not ALLOW_LONG_SYSTEM_PROMPT:
|
| 990 |
+
for msg in request.messages:
|
| 991 |
+
if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS:
|
| 992 |
+
logger.warning(
|
| 993 |
+
f"request_id={request_id} system_prompt_dropped "
|
| 994 |
+
f"tokens≈{_estimate_tokens(str(msg.get('content', '')))} cap≈{SYSTEM_PROMPT_MAX_TOKENS}"
|
| 995 |
+
)
|
| 996 |
+
break
|
| 997 |
+
|
| 998 |
compacted_messages = _compact_messages(request.messages, request.max_tokens)
|
| 999 |
|
| 1000 |
compacted_tokens = _estimate_messages_tokens(compacted_messages)
|
|
|
|
| 1016 |
"max_tokens": request.max_tokens,
|
| 1017 |
"temperature": request.temperature,
|
| 1018 |
}
|
| 1019 |
+
async def _do_request():
|
| 1020 |
+
async with http_session.post(
|
| 1021 |
+
f"{cached_model.url}/v1/chat/completions",
|
| 1022 |
+
json=payload
|
| 1023 |
+
) as response:
|
| 1024 |
+
if response.status >= 400:
|
| 1025 |
+
error_text = await response.text()
|
| 1026 |
+
logger.error(
|
| 1027 |
+
f"request_id={request_id} llama-server {response.status} "
|
| 1028 |
+
f"error_body={error_text[:1000]}"
|
| 1029 |
+
)
|
| 1030 |
+
response.raise_for_status()
|
| 1031 |
+
return await response.json()
|
| 1032 |
+
|
| 1033 |
+
result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT)
|
| 1034 |
|
| 1035 |
# Update metrics
|
| 1036 |
request_latency = time.time() - request_start
|
|
|
|
| 1038 |
cached_model.total_latency += request_latency
|
| 1039 |
metrics.record_request(current_model, request_latency)
|
| 1040 |
|
| 1041 |
+
# Log elapsed time and token rate (if usage available)
|
| 1042 |
+
usage = result.get("usage") if isinstance(result, dict) else None
|
| 1043 |
+
if usage and usage.get("completion_tokens"):
|
| 1044 |
+
completion_tokens = usage.get("completion_tokens", 0)
|
| 1045 |
+
tok_per_sec = completion_tokens / max(request_latency, 1e-6)
|
| 1046 |
+
logger.info(
|
| 1047 |
+
f"request_id={request_id} done "
|
| 1048 |
+
f"time={request_latency:.2f}s tokens={completion_tokens} tok/s={tok_per_sec:.1f}"
|
| 1049 |
+
)
|
| 1050 |
+
else:
|
| 1051 |
+
logger.info(f"request_id={request_id} done time={request_latency:.2f}s")
|
| 1052 |
+
|
| 1053 |
return result
|
| 1054 |
except aiohttp.ClientResponseError as e:
|
| 1055 |
logger.exception(f"request_id={request_id} llama-server error")
|
|
|
|
| 1057 |
except aiohttp.ClientError as e:
|
| 1058 |
logger.exception(f"request_id={request_id} llama-server error")
|
| 1059 |
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
| 1060 |
+
except asyncio.TimeoutError:
|
| 1061 |
+
logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s")
|
| 1062 |
+
raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.")
|
| 1063 |
except Exception:
|
| 1064 |
logger.exception(f"request_id={request_id} chat_completions error")
|
| 1065 |
raise
|
|
|
|
| 1184 |
f"original_tokens≈{original_tokens} budget≈{prompt_budget}"
|
| 1185 |
)
|
| 1186 |
|
| 1187 |
+
# Drop system prompts above cap unless allowed
|
| 1188 |
+
if not ALLOW_LONG_SYSTEM_PROMPT:
|
| 1189 |
+
for msg in augmented_messages:
|
| 1190 |
+
if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS:
|
| 1191 |
+
logger.warning(
|
| 1192 |
+
f"request_id={request_id} system_prompt_dropped "
|
| 1193 |
+
f"tokens≈{_estimate_tokens(str(msg.get('content', '')))} cap≈{SYSTEM_PROMPT_MAX_TOKENS}"
|
| 1194 |
+
)
|
| 1195 |
+
break
|
| 1196 |
+
|
| 1197 |
augmented_messages = _compact_messages(augmented_messages, request.max_tokens)
|
| 1198 |
|
| 1199 |
compacted_tokens = _estimate_messages_tokens(augmented_messages)
|
|
|
|
| 1214 |
if not cached_model:
|
| 1215 |
raise HTTPException(status_code=500, detail="Current model not loaded")
|
| 1216 |
|
| 1217 |
+
async def _do_request():
|
| 1218 |
+
# Forward to llama-server with augmented context
|
| 1219 |
+
async with http_session.post(
|
| 1220 |
+
f"{cached_model.url}/v1/chat/completions",
|
| 1221 |
+
json={
|
| 1222 |
+
"messages": augmented_messages,
|
| 1223 |
+
"max_tokens": request.max_tokens,
|
| 1224 |
+
"temperature": request.temperature,
|
| 1225 |
+
}
|
| 1226 |
+
) as response:
|
| 1227 |
+
response.raise_for_status()
|
| 1228 |
+
return await response.json()
|
| 1229 |
+
|
| 1230 |
+
result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT)
|
| 1231 |
|
| 1232 |
# Add metadata about search results
|
| 1233 |
result["web_search"] = {
|
|
|
|
| 1242 |
except aiohttp.ClientError as e:
|
| 1243 |
logger.exception(f"request_id={request_id} llama-server error")
|
| 1244 |
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
| 1245 |
+
except asyncio.TimeoutError:
|
| 1246 |
+
logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s")
|
| 1247 |
+
raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.")
|
| 1248 |
except HTTPException:
|
| 1249 |
raise
|
| 1250 |
except Exception as e:
|