QuerySphere / generation /token_manager.py
satyakimitra's picture
first commit
0a4529c
# DEPENDENCIES
import tiktoken
from typing import List
from typing import Dict
from typing import Optional
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.error_handler import TokenManagementError
# Setup Settings and Logging
settings = get_settings()
logger = get_logger(__name__)
class TokenManager:
"""
Token management for LLM context windows: Handles token counting, context window management, and optimization for different LLM providers (Ollama, OpenAI)
"""
def __init__(self, model_name: str = None):
"""
Initialize token manager
Arguments:
----------
model_name { str } : Model name for tokenizer selection
"""
self.logger = logger
self.settings = get_settings()
self.model_name = model_name or self.settings.OLLAMA_MODEL
self.encoding = None
self.context_window = self._get_context_window()
self._initialize_tokenizer()
def _initialize_tokenizer(self):
"""
Initialize appropriate tokenizer based on model
"""
try:
# Determine tokenizer based on model
if self.model_name.startswith(('gpt-3.5', 'gpt-4')):
# OpenAI models
self.encoding = tiktoken.encoding_for_model(self.model_name)
self.logger.debug(f"Initialized tiktoken for {self.model_name}")
else:
# Default for Ollama/local models
self.encoding = tiktoken.get_encoding("cl100k_base")
self.logger.debug(f"Using cl100k_base tokenizer for local model {self.model_name}")
except Exception as e:
self.logger.warning(f"Failed to initialize specific tokenizer: {repr(e)}, using approximation")
self.encoding = None
def _get_context_window(self) -> int:
"""
Get context window size based on model
Returns:
--------
{ int } : Context window size in tokens
"""
model_contexts = {"gpt-3.5-turbo" : 4096,
"mistral:7b" : 8192,
}
# Find matching model
for model_pattern, context_size in model_contexts.items():
if model_pattern in self.model_name.lower():
return context_size
# Default context window
default_context = self.settings.CONTEXT_WINDOW
self.logger.info(f"Using default context window {default_context} for model {self.model_name}")
return default_context
def count_tokens(self, text: str) -> int:
"""
Count tokens in text
Arguments:
----------
text { str } : Input text
Returns:
--------
{ int } : Number of tokens
"""
if not text:
return 0
if self.encoding is not None:
try:
return len(self.encoding.encode(text))
except Exception as e:
self.logger.warning(f"Tokenizer failed, using approximation: {repr(e)}")
# Fallback approximation
return self._approximate_token_count(text = text)
def _approximate_token_count(self, text: str) -> int:
"""
Approximate token count when tokenizer is unavailable
Arguments:
----------
text { str } : Input text
Returns:
--------
{ int } : Approximate token count
"""
if not text:
return 0
# Use word-based approximation (more reliable than char-based)
words = text.split()
# English text averages ~1.3 tokens per word : (accounting for punctuation and subword tokenization)
estimated_tokens = int(len(words) * 1.3)
# Add 5% buffer for punctuation and special tokens
estimated_tokens = int(estimated_tokens * 1.05)
return estimated_tokens
def count_message_tokens(self, messages: List[Dict]) -> int:
"""
Count tokens in chat messages
Arguments:
----------
messages { list } : List of message dictionaries
Returns:
--------
{ int } : Total tokens in messages
"""
if not messages:
return 0
total_tokens = 0
for message in messages:
# Count content tokens
content = message.get('content', '')
total_tokens += self.count_tokens(text = content)
# Count role tokens (approximate)
role = message.get('role', '')
total_tokens += self.count_tokens(text = role)
# Add overhead for message structure: Approximate overhead per message
total_tokens += 5
return total_tokens
def fits_in_context(self, prompt: str, max_completion_tokens: int = 1000) -> bool:
"""
Check if prompt fits in context window with room for completion
Arguments:
----------
prompt { str } : Prompt text
max_completion_tokens { int } : Tokens to reserve for completion
Returns:
--------
{ bool } : True if prompt fits
"""
prompt_tokens = self.count_tokens(text = prompt)
total_required = prompt_tokens + max_completion_tokens
return (total_required <= self.context_window)
def truncate_to_fit(self, text: str, max_tokens: int, strategy: str = "end") -> str:
"""
Truncate text to fit within token limit
Arguments:
----------
text { str } : Text to truncate
max_tokens { int } : Maximum tokens allowed
strategy { str } : Truncation strategy ('end', 'start', 'middle')
Returns:
--------
{ str } : Truncated text
"""
current_tokens = self.count_tokens(text = text)
if (current_tokens <= max_tokens):
return text
if (strategy == "end"):
return self._truncate_from_end(text = text,
max_tokens = max_tokens,
)
elif (strategy == "start"):
return self._truncate_from_start(text = text,
max_tokens = max_tokens,
)
elif (strategy == "middle"):
return self._truncate_from_middle(text = text,
max_tokens = max_tokens,
)
else:
self.logger.warning(f"Unknown truncation strategy: {strategy}, using 'end'")
return self._truncate_from_end(text = text,
max_tokens = max_tokens,
)
def _truncate_from_end(self, text: str, max_tokens: int) -> str:
"""
Truncate from the end of the text
"""
if self.encoding is not None:
tokens = self.encoding.encode(text)
truncated_tokens = tokens[:max_tokens]
return self.encoding.decode(truncated_tokens)
# Approximate truncation
words = text.split()
# Conservative estimate
target_words = int(max_tokens * 0.75)
truncated_words = words[:target_words]
return " ".join(truncated_words)
def _truncate_from_start(self, text: str, max_tokens: int) -> str:
"""
Truncate from the start of the text
"""
if self.encoding is not None:
tokens = self.encoding.encode(text)
# Take from end
truncated_tokens = tokens[-max_tokens:]
return self.encoding.decode(truncated_tokens)
# Approximate truncation
words = text.split()
target_words = int(max_tokens * 0.75)
# Take from end
truncated_words = words[-target_words:]
return " ".join(truncated_words)
def _truncate_from_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate from the middle of the text
"""
if self.encoding is not None:
tokens = self.encoding.encode(text)
if (len(tokens) <= max_tokens):
return text
# Keep beginning and end, remove middle
keep_start = max_tokens // 3
keep_end = max_tokens - keep_start
start_tokens = tokens[:keep_start]
end_tokens = tokens[-keep_end:]
return self.encoding.decode(start_tokens) + " [...] " + self.encoding.decode(end_tokens)
# Approximate truncation
words = text.split()
if (len(words) <= max_tokens):
return text
keep_start = max_tokens // 3
keep_end = max_tokens - keep_start
start_words = words[:keep_start]
end_words = words[-keep_end:]
return " ".join(start_words) + " [...] " + " ".join(end_words)
def calculate_max_completion_tokens(self, prompt: str, reserve_tokens: int = 100) -> int:
"""
Calculate maximum completion tokens given prompt length
Arguments:
----------
prompt { str } : Prompt text
reserve_tokens { int } : Tokens to reserve for safety
Returns:
--------
{ int } : Maximum completion tokens
"""
prompt_tokens = self.count_tokens(text = prompt)
available_tokens = self.context_window - prompt_tokens - reserve_tokens
return max(0, available_tokens)
def optimize_context_usage(self, context: str, prompt: str, max_completion_tokens: int = 1000) -> str:
"""
Optimize context to fit within context window
Arguments:
----------
context { str } : Context text
prompt { str } : Prompt template
max_completion_tokens { int } : Tokens needed for completion
Returns:
--------
{ str } : Optimized context
"""
total_prompt_tokens = self.count_tokens(text = prompt.format(context=""))
available_for_context = self.context_window - total_prompt_tokens - max_completion_tokens
if (available_for_context <= 0):
self.logger.warning("Prompt too large for context window")
return ""
context_tokens = self.count_tokens(text = context)
if (context_tokens <= available_for_context):
return context
# Truncate context to fit
optimized_context = self.truncate_to_fit(context, available_for_context, strategy="end")
reduction_pct = ((context_tokens - self.count_tokens(text = optimized_context)) / context_tokens) * 100
self.logger.info(f"Context reduced by {reduction_pct:.1f}% to fit context window")
return optimized_context
def get_token_stats(self, text: str) -> Dict:
"""
Get detailed token statistics
Arguments:
----------
text { str } : Input text
Returns:
--------
{ dict } : Token statistics
"""
tokens = self.count_tokens(text = text)
chars = len(text)
words = len(text.split())
return {"tokens" : tokens,
"characters" : chars,
"words" : words,
"chars_per_token" : chars / tokens if tokens > 0 else 0,
"tokens_per_word" : tokens / words if words > 0 else 0,
"context_window" : self.context_window,
"model" : self.model_name,
}
# Global token manager instance
_token_manager = None
def get_token_manager(model_name: str = None) -> TokenManager:
"""
Get global token manager instance
Arguments:
----------
model_name { str } : Model name for tokenizer selection
Returns:
--------
{ TokenManager } : TokenManager instance
"""
global _token_manager
if _token_manager is None or (model_name and _token_manager.model_name != model_name):
_token_manager = TokenManager(model_name)
return _token_manager
@handle_errors(error_type = TokenManagementError, log_error = True, reraise = False)
def count_tokens_safe(text: str, model_name: str = None) -> int:
"""
Safe token counting with error handling
Arguments:
----------
text { str } : Text to count tokens for
model_name { str } : Model name for tokenizer
Returns:
--------
{ int } : Token count (0 on error)
"""
try:
manager = get_token_manager(model_name = model_name)
return manager.count_tokens(text = text)
except Exception:
return 0