Spaces:
Running
Running
Pulastya B commited on
Commit ·
8b86ea3
1
Parent(s): 94bbef1
Fixed the Pydantic Model Errors for the Token Budget
Browse files- src/utils/token_budget.py +39 -11
src/utils/token_budget.py
CHANGED
|
@@ -72,15 +72,30 @@ class TokenBudgetManager:
|
|
| 72 |
# Fallback estimation: ~4 chars per token
|
| 73 |
return len(text) // 4
|
| 74 |
|
| 75 |
-
def count_message_tokens(self, message
|
| 76 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# Format: <|role|>content<|endofmessage|>
|
| 78 |
# Approximately 4 tokens overhead per message
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
return content_tokens + role_tokens + 4
|
| 82 |
|
| 83 |
-
def count_messages_tokens(self, messages: List
|
| 84 |
"""Count total tokens in message list."""
|
| 85 |
return sum(self.count_message_tokens(msg) for msg in messages)
|
| 86 |
|
|
@@ -303,12 +318,17 @@ class TokenBudgetManager:
|
|
| 303 |
# Still too large - truncate system prompt
|
| 304 |
print("⚠️ Truncating system prompt to fit budget")
|
| 305 |
system_msg = essential_messages[0]
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Keep first 1000 chars of system prompt
|
| 309 |
truncated_system = {
|
| 310 |
"role": "system",
|
| 311 |
-
"content": system_content[:1000] + "\n\n... (truncated due to context limit) ..."
|
| 312 |
}
|
| 313 |
|
| 314 |
return [truncated_system] + essential_messages[1:]
|
|
@@ -344,13 +364,21 @@ class TokenBudgetManager:
|
|
| 344 |
# Convert to ConversationMessage objects
|
| 345 |
conv_messages = []
|
| 346 |
for i, msg in enumerate(messages):
|
| 347 |
-
|
| 348 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
msg_type = "tool_result"
|
| 350 |
|
| 351 |
conv_msg = ConversationMessage(
|
| 352 |
-
role=
|
| 353 |
-
content=
|
| 354 |
message_type=msg_type
|
| 355 |
)
|
| 356 |
conv_messages.append(conv_msg)
|
|
|
|
| 72 |
# Fallback estimation: ~4 chars per token
|
| 73 |
return len(text) // 4
|
| 74 |
|
| 75 |
+
def count_message_tokens(self, message) -> int:
|
| 76 |
+
"""
|
| 77 |
+
Count tokens in a message (includes role overhead).
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
message: Either a dict or a Pydantic ChatMessage object
|
| 81 |
+
"""
|
| 82 |
# Format: <|role|>content<|endofmessage|>
|
| 83 |
# Approximately 4 tokens overhead per message
|
| 84 |
+
|
| 85 |
+
# Handle both dict and Pydantic object formats
|
| 86 |
+
if isinstance(message, dict):
|
| 87 |
+
content = message.get("content", "")
|
| 88 |
+
role = message.get("role", "")
|
| 89 |
+
else:
|
| 90 |
+
# Pydantic object (like ChatMessage from Mistral SDK)
|
| 91 |
+
content = getattr(message, "content", "")
|
| 92 |
+
role = getattr(message, "role", "")
|
| 93 |
+
|
| 94 |
+
content_tokens = self.count_tokens(str(content))
|
| 95 |
+
role_tokens = self.count_tokens(str(role))
|
| 96 |
return content_tokens + role_tokens + 4
|
| 97 |
|
| 98 |
+
def count_messages_tokens(self, messages: List) -> int:
|
| 99 |
"""Count total tokens in message list."""
|
| 100 |
return sum(self.count_message_tokens(msg) for msg in messages)
|
| 101 |
|
|
|
|
| 318 |
# Still too large - truncate system prompt
|
| 319 |
print("⚠️ Truncating system prompt to fit budget")
|
| 320 |
system_msg = essential_messages[0]
|
| 321 |
+
|
| 322 |
+
# Handle both dict and Pydantic object formats
|
| 323 |
+
if isinstance(system_msg, dict):
|
| 324 |
+
system_content = system_msg["content"]
|
| 325 |
+
else:
|
| 326 |
+
system_content = getattr(system_msg, "content", "")
|
| 327 |
|
| 328 |
# Keep first 1000 chars of system prompt
|
| 329 |
truncated_system = {
|
| 330 |
"role": "system",
|
| 331 |
+
"content": str(system_content)[:1000] + "\n\n... (truncated due to context limit) ..."
|
| 332 |
}
|
| 333 |
|
| 334 |
return [truncated_system] + essential_messages[1:]
|
|
|
|
| 364 |
# Convert to ConversationMessage objects
|
| 365 |
conv_messages = []
|
| 366 |
for i, msg in enumerate(messages):
|
| 367 |
+
# Handle both dict and Pydantic object formats
|
| 368 |
+
if isinstance(msg, dict):
|
| 369 |
+
role = msg.get("role", "")
|
| 370 |
+
content = msg.get("content", "")
|
| 371 |
+
else:
|
| 372 |
+
role = getattr(msg, "role", "")
|
| 373 |
+
content = getattr(msg, "content", "")
|
| 374 |
+
|
| 375 |
+
msg_type = "system" if i == 0 and role == "system" else "normal"
|
| 376 |
+
if "tool" in str(content).lower() or "function" in str(content).lower():
|
| 377 |
msg_type = "tool_result"
|
| 378 |
|
| 379 |
conv_msg = ConversationMessage(
|
| 380 |
+
role=role,
|
| 381 |
+
content=str(content),
|
| 382 |
message_type=msg_type
|
| 383 |
)
|
| 384 |
conv_messages.append(conv_msg)
|