Pulastya B commited on
Commit
8b86ea3
·
1 Parent(s): 94bbef1

Fixed the Pydantic Model Errors for the Token Budget

Browse files
Files changed (1) hide show
  1. src/utils/token_budget.py +39 -11
src/utils/token_budget.py CHANGED
@@ -72,15 +72,30 @@ class TokenBudgetManager:
72
  # Fallback estimation: ~4 chars per token
73
  return len(text) // 4
74
 
75
- def count_message_tokens(self, message: Dict[str, str]) -> int:
76
- """Count tokens in a message (includes role overhead)."""
 
 
 
 
 
77
  # Format: <|role|>content<|endofmessage|>
78
  # Approximately 4 tokens overhead per message
79
- content_tokens = self.count_tokens(message.get("content", ""))
80
- role_tokens = self.count_tokens(message.get("role", ""))
 
 
 
 
 
 
 
 
 
 
81
  return content_tokens + role_tokens + 4
82
 
83
- def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int:
84
  """Count total tokens in message list."""
85
  return sum(self.count_message_tokens(msg) for msg in messages)
86
 
@@ -303,12 +318,17 @@ class TokenBudgetManager:
303
  # Still too large - truncate system prompt
304
  print("⚠️ Truncating system prompt to fit budget")
305
  system_msg = essential_messages[0]
306
- system_content = system_msg["content"]
 
 
 
 
 
307
 
308
  # Keep first 1000 chars of system prompt
309
  truncated_system = {
310
  "role": "system",
311
- "content": system_content[:1000] + "\n\n... (truncated due to context limit) ..."
312
  }
313
 
314
  return [truncated_system] + essential_messages[1:]
@@ -344,13 +364,21 @@ class TokenBudgetManager:
344
  # Convert to ConversationMessage objects
345
  conv_messages = []
346
  for i, msg in enumerate(messages):
347
- msg_type = "system" if i == 0 and msg["role"] == "system" else "normal"
348
- if "tool" in msg.get("content", "").lower() or "function" in msg.get("content", "").lower():
 
 
 
 
 
 
 
 
349
  msg_type = "tool_result"
350
 
351
  conv_msg = ConversationMessage(
352
- role=msg["role"],
353
- content=msg["content"],
354
  message_type=msg_type
355
  )
356
  conv_messages.append(conv_msg)
 
72
  # Fallback estimation: ~4 chars per token
73
  return len(text) // 4
74
 
75
+ def count_message_tokens(self, message) -> int:
76
+ """
77
+ Count tokens in a message (includes role overhead).
78
+
79
+ Args:
80
+ message: Either a dict or a Pydantic ChatMessage object
81
+ """
82
  # Format: <|role|>content<|endofmessage|>
83
  # Approximately 4 tokens overhead per message
84
+
85
+ # Handle both dict and Pydantic object formats
86
+ if isinstance(message, dict):
87
+ content = message.get("content", "")
88
+ role = message.get("role", "")
89
+ else:
90
+ # Pydantic object (like ChatMessage from Mistral SDK)
91
+ content = getattr(message, "content", "")
92
+ role = getattr(message, "role", "")
93
+
94
+ content_tokens = self.count_tokens(str(content))
95
+ role_tokens = self.count_tokens(str(role))
96
  return content_tokens + role_tokens + 4
97
 
98
+ def count_messages_tokens(self, messages: List) -> int:
99
  """Count total tokens in message list."""
100
  return sum(self.count_message_tokens(msg) for msg in messages)
101
 
 
318
  # Still too large - truncate system prompt
319
  print("⚠️ Truncating system prompt to fit budget")
320
  system_msg = essential_messages[0]
321
+
322
+ # Handle both dict and Pydantic object formats
323
+ if isinstance(system_msg, dict):
324
+ system_content = system_msg["content"]
325
+ else:
326
+ system_content = getattr(system_msg, "content", "")
327
 
328
  # Keep first 1000 chars of system prompt
329
  truncated_system = {
330
  "role": "system",
331
+ "content": str(system_content)[:1000] + "\n\n... (truncated due to context limit) ..."
332
  }
333
 
334
  return [truncated_system] + essential_messages[1:]
 
364
  # Convert to ConversationMessage objects
365
  conv_messages = []
366
  for i, msg in enumerate(messages):
367
+ # Handle both dict and Pydantic object formats
368
+ if isinstance(msg, dict):
369
+ role = msg.get("role", "")
370
+ content = msg.get("content", "")
371
+ else:
372
+ role = getattr(msg, "role", "")
373
+ content = getattr(msg, "content", "")
374
+
375
+ msg_type = "system" if i == 0 and role == "system" else "normal"
376
+ if "tool" in str(content).lower() or "function" in str(content).lower():
377
  msg_type = "tool_result"
378
 
379
  conv_msg = ConversationMessage(
380
+ role=role,
381
+ content=str(content),
382
  message_type=msg_type
383
  )
384
  conv_messages.append(conv_msg)