Dmitry Beresnev commited on
Commit
6381e7f
·
1 Parent(s): 62a5a49

add simple compacting

Browse files
Files changed (1) hide show
  1. app.py +69 -1
app.py CHANGED
@@ -114,6 +114,70 @@ LOG_REQUEST_BODY = os.getenv("LOG_REQUEST_BODY", "1") == "1"
114
  LOG_REQUEST_BODY_MAX_CHARS = int(os.getenv("LOG_REQUEST_BODY_MAX_CHARS", "2000"))
115
  CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048"))
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @dataclass
119
  class CachedModel:
@@ -901,8 +965,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
901
  raise HTTPException(status_code=500, detail="Current model not loaded")
902
 
903
  # Forward to llama-server using aiohttp
 
904
  payload = {
905
- "messages": request.messages,
906
  "max_tokens": request.max_tokens,
907
  "temperature": request.temperature,
908
  }
@@ -1043,6 +1108,9 @@ Always cite sources when using information from the search results."""
1043
  if not http_session or http_session.closed:
1044
  raise HTTPException(status_code=500, detail="HTTP session not initialized")
1045
 
 
 
 
1046
  # Get current model from cache
1047
  cached_model = model_cache.get(current_model)
1048
  if not cached_model:
 
114
  LOG_REQUEST_BODY_MAX_CHARS = int(os.getenv("LOG_REQUEST_BODY_MAX_CHARS", "2000"))
115
  CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048"))
116
 
117
+ PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256"))
118
+ CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0"))
119
+
120
+
121
+ def _estimate_tokens(text: str) -> int:
122
+ """Rough token estimate based on character count."""
123
+ if not text:
124
+ return 0
125
+ return int(len(text) / CHARS_PER_TOKEN_EST) + 1
126
+
127
+
128
+ def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
129
+ """Truncate text to an approximate token budget."""
130
+ if not text or max_tokens <= 0:
131
+ return ""
132
+ max_chars = int(max_tokens * CHARS_PER_TOKEN_EST)
133
+ if len(text) <= max_chars:
134
+ return text
135
+ return text[:max_chars] + "...[truncated]"
136
+
137
+
138
+ def _compact_messages(messages: list[dict], max_tokens: int) -> list[dict]:
139
+ """
140
+ Compact messages to fit within the prompt budget.
141
+ Strategy:
142
+ - Cap system message content size.
143
+ - Drop oldest non-system messages until within budget.
144
+ - As a last resort, truncate the oldest remaining non-system message.
145
+ """
146
+ if not messages:
147
+ return messages
148
+
149
+ prompt_budget = CONTEXT_SIZE - max_tokens - PROMPT_MARGIN_TOKENS
150
+ if prompt_budget <= 0:
151
+ return messages
152
+
153
+ # Work on a copy to avoid mutating caller input
154
+ compacted = [dict(m) for m in messages]
155
+
156
+ # Cap system messages
157
+ system_cap = min(1024, max(256, prompt_budget // 3))
158
+ for msg in compacted:
159
+ if msg.get("role") == "system" and "content" in msg:
160
+ msg["content"] = _truncate_text_to_tokens(str(msg["content"]), system_cap)
161
+
162
+ def total_tokens(msgs: list[dict]) -> int:
163
+ return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs)
164
+
165
+ # Drop oldest non-system messages until under budget
166
+ while total_tokens(compacted) > prompt_budget:
167
+ idx = next((i for i, m in enumerate(compacted) if m.get("role") != "system"), None)
168
+ if idx is None:
169
+ break
170
+ compacted.pop(idx)
171
+
172
+ # Last resort: truncate oldest non-system content
173
+ if total_tokens(compacted) > prompt_budget:
174
+ idx = next((i for i, m in enumerate(compacted) if m.get("role") != "system"), None)
175
+ if idx is not None:
176
+ remaining_budget = max(1, prompt_budget - (total_tokens(compacted) - _estimate_tokens(str(compacted[idx].get("content", "")))))
177
+ compacted[idx]["content"] = _truncate_text_to_tokens(str(compacted[idx].get("content", "")), remaining_budget)
178
+
179
+ return compacted
180
+
181
 
182
  @dataclass
183
  class CachedModel:
 
965
  raise HTTPException(status_code=500, detail="Current model not loaded")
966
 
967
  # Forward to llama-server using aiohttp
968
+ compacted_messages = _compact_messages(request.messages, request.max_tokens)
969
  payload = {
970
+ "messages": compacted_messages,
971
  "max_tokens": request.max_tokens,
972
  "temperature": request.temperature,
973
  }
 
1108
  if not http_session or http_session.closed:
1109
  raise HTTPException(status_code=500, detail="HTTP session not initialized")
1110
 
1111
+ # Compact messages to fit within context
1112
+ augmented_messages = _compact_messages(augmented_messages, request.max_tokens)
1113
+
1114
  # Get current model from cache
1115
  cached_model = model_cache.get(current_model)
1116
  if not cached_model: