Spaces:
Running
Running
| # DEPENDENCIES | |
| import tiktoken | |
| from typing import List | |
| from typing import Dict | |
| from typing import Optional | |
| from config.settings import get_settings | |
| from config.logging_config import get_logger | |
| from utils.error_handler import handle_errors | |
| from utils.error_handler import TokenManagementError | |
| # Setup Settings and Logging | |
| settings = get_settings() | |
| logger = get_logger(__name__) | |
| class TokenManager: | |
| """ | |
| Token management for LLM context windows: Handles token counting, context window management, and optimization for different LLM providers (Ollama, OpenAI) | |
| """ | |
| def __init__(self, model_name: str = None): | |
| """ | |
| Initialize token manager | |
| Arguments: | |
| ---------- | |
| model_name { str } : Model name for tokenizer selection | |
| """ | |
| self.logger = logger | |
| self.settings = get_settings() | |
| self.model_name = model_name or self.settings.OLLAMA_MODEL | |
| self.encoding = None | |
| self.context_window = self._get_context_window() | |
| self._initialize_tokenizer() | |
| def _initialize_tokenizer(self): | |
| """ | |
| Initialize appropriate tokenizer based on model | |
| """ | |
| try: | |
| # Determine tokenizer based on model | |
| if self.model_name.startswith(('gpt-3.5', 'gpt-4')): | |
| # OpenAI models | |
| self.encoding = tiktoken.encoding_for_model(self.model_name) | |
| self.logger.debug(f"Initialized tiktoken for {self.model_name}") | |
| else: | |
| # Default for Ollama/local models | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| self.logger.debug(f"Using cl100k_base tokenizer for local model {self.model_name}") | |
| except Exception as e: | |
| self.logger.warning(f"Failed to initialize specific tokenizer: {repr(e)}, using approximation") | |
| self.encoding = None | |
| def _get_context_window(self) -> int: | |
| """ | |
| Get context window size based on model | |
| Returns: | |
| -------- | |
| { int } : Context window size in tokens | |
| """ | |
| model_contexts = {"gpt-3.5-turbo" : 4096, | |
| "mistral:7b" : 8192, | |
| } | |
| # Find matching model | |
| for model_pattern, context_size in model_contexts.items(): | |
| if model_pattern in self.model_name.lower(): | |
| return context_size | |
| # Default context window | |
| default_context = self.settings.CONTEXT_WINDOW | |
| self.logger.info(f"Using default context window {default_context} for model {self.model_name}") | |
| return default_context | |
| def count_tokens(self, text: str) -> int: | |
| """ | |
| Count tokens in text | |
| Arguments: | |
| ---------- | |
| text { str } : Input text | |
| Returns: | |
| -------- | |
| { int } : Number of tokens | |
| """ | |
| if not text: | |
| return 0 | |
| if self.encoding is not None: | |
| try: | |
| return len(self.encoding.encode(text)) | |
| except Exception as e: | |
| self.logger.warning(f"Tokenizer failed, using approximation: {repr(e)}") | |
| # Fallback approximation | |
| return self._approximate_token_count(text = text) | |
| def _approximate_token_count(self, text: str) -> int: | |
| """ | |
| Approximate token count when tokenizer is unavailable | |
| Arguments: | |
| ---------- | |
| text { str } : Input text | |
| Returns: | |
| -------- | |
| { int } : Approximate token count | |
| """ | |
| if not text: | |
| return 0 | |
| # Use word-based approximation (more reliable than char-based) | |
| words = text.split() | |
| # English text averages ~1.3 tokens per word : (accounting for punctuation and subword tokenization) | |
| estimated_tokens = int(len(words) * 1.3) | |
| # Add 5% buffer for punctuation and special tokens | |
| estimated_tokens = int(estimated_tokens * 1.05) | |
| return estimated_tokens | |
| def count_message_tokens(self, messages: List[Dict]) -> int: | |
| """ | |
| Count tokens in chat messages | |
| Arguments: | |
| ---------- | |
| messages { list } : List of message dictionaries | |
| Returns: | |
| -------- | |
| { int } : Total tokens in messages | |
| """ | |
| if not messages: | |
| return 0 | |
| total_tokens = 0 | |
| for message in messages: | |
| # Count content tokens | |
| content = message.get('content', '') | |
| total_tokens += self.count_tokens(text = content) | |
| # Count role tokens (approximate) | |
| role = message.get('role', '') | |
| total_tokens += self.count_tokens(text = role) | |
| # Add overhead for message structure: Approximate overhead per message | |
| total_tokens += 5 | |
| return total_tokens | |
| def fits_in_context(self, prompt: str, max_completion_tokens: int = 1000) -> bool: | |
| """ | |
| Check if prompt fits in context window with room for completion | |
| Arguments: | |
| ---------- | |
| prompt { str } : Prompt text | |
| max_completion_tokens { int } : Tokens to reserve for completion | |
| Returns: | |
| -------- | |
| { bool } : True if prompt fits | |
| """ | |
| prompt_tokens = self.count_tokens(text = prompt) | |
| total_required = prompt_tokens + max_completion_tokens | |
| return (total_required <= self.context_window) | |
| def truncate_to_fit(self, text: str, max_tokens: int, strategy: str = "end") -> str: | |
| """ | |
| Truncate text to fit within token limit | |
| Arguments: | |
| ---------- | |
| text { str } : Text to truncate | |
| max_tokens { int } : Maximum tokens allowed | |
| strategy { str } : Truncation strategy ('end', 'start', 'middle') | |
| Returns: | |
| -------- | |
| { str } : Truncated text | |
| """ | |
| current_tokens = self.count_tokens(text = text) | |
| if (current_tokens <= max_tokens): | |
| return text | |
| if (strategy == "end"): | |
| return self._truncate_from_end(text = text, | |
| max_tokens = max_tokens, | |
| ) | |
| elif (strategy == "start"): | |
| return self._truncate_from_start(text = text, | |
| max_tokens = max_tokens, | |
| ) | |
| elif (strategy == "middle"): | |
| return self._truncate_from_middle(text = text, | |
| max_tokens = max_tokens, | |
| ) | |
| else: | |
| self.logger.warning(f"Unknown truncation strategy: {strategy}, using 'end'") | |
| return self._truncate_from_end(text = text, | |
| max_tokens = max_tokens, | |
| ) | |
| def _truncate_from_end(self, text: str, max_tokens: int) -> str: | |
| """ | |
| Truncate from the end of the text | |
| """ | |
| if self.encoding is not None: | |
| tokens = self.encoding.encode(text) | |
| truncated_tokens = tokens[:max_tokens] | |
| return self.encoding.decode(truncated_tokens) | |
| # Approximate truncation | |
| words = text.split() | |
| # Conservative estimate | |
| target_words = int(max_tokens * 0.75) | |
| truncated_words = words[:target_words] | |
| return " ".join(truncated_words) | |
| def _truncate_from_start(self, text: str, max_tokens: int) -> str: | |
| """ | |
| Truncate from the start of the text | |
| """ | |
| if self.encoding is not None: | |
| tokens = self.encoding.encode(text) | |
| # Take from end | |
| truncated_tokens = tokens[-max_tokens:] | |
| return self.encoding.decode(truncated_tokens) | |
| # Approximate truncation | |
| words = text.split() | |
| target_words = int(max_tokens * 0.75) | |
| # Take from end | |
| truncated_words = words[-target_words:] | |
| return " ".join(truncated_words) | |
| def _truncate_from_middle(self, text: str, max_tokens: int) -> str: | |
| """ | |
| Truncate from the middle of the text | |
| """ | |
| if self.encoding is not None: | |
| tokens = self.encoding.encode(text) | |
| if (len(tokens) <= max_tokens): | |
| return text | |
| # Keep beginning and end, remove middle | |
| keep_start = max_tokens // 3 | |
| keep_end = max_tokens - keep_start | |
| start_tokens = tokens[:keep_start] | |
| end_tokens = tokens[-keep_end:] | |
| return self.encoding.decode(start_tokens) + " [...] " + self.encoding.decode(end_tokens) | |
| # Approximate truncation | |
| words = text.split() | |
| if (len(words) <= max_tokens): | |
| return text | |
| keep_start = max_tokens // 3 | |
| keep_end = max_tokens - keep_start | |
| start_words = words[:keep_start] | |
| end_words = words[-keep_end:] | |
| return " ".join(start_words) + " [...] " + " ".join(end_words) | |
| def calculate_max_completion_tokens(self, prompt: str, reserve_tokens: int = 100) -> int: | |
| """ | |
| Calculate maximum completion tokens given prompt length | |
| Arguments: | |
| ---------- | |
| prompt { str } : Prompt text | |
| reserve_tokens { int } : Tokens to reserve for safety | |
| Returns: | |
| -------- | |
| { int } : Maximum completion tokens | |
| """ | |
| prompt_tokens = self.count_tokens(text = prompt) | |
| available_tokens = self.context_window - prompt_tokens - reserve_tokens | |
| return max(0, available_tokens) | |
| def optimize_context_usage(self, context: str, prompt: str, max_completion_tokens: int = 1000) -> str: | |
| """ | |
| Optimize context to fit within context window | |
| Arguments: | |
| ---------- | |
| context { str } : Context text | |
| prompt { str } : Prompt template | |
| max_completion_tokens { int } : Tokens needed for completion | |
| Returns: | |
| -------- | |
| { str } : Optimized context | |
| """ | |
| total_prompt_tokens = self.count_tokens(text = prompt.format(context="")) | |
| available_for_context = self.context_window - total_prompt_tokens - max_completion_tokens | |
| if (available_for_context <= 0): | |
| self.logger.warning("Prompt too large for context window") | |
| return "" | |
| context_tokens = self.count_tokens(text = context) | |
| if (context_tokens <= available_for_context): | |
| return context | |
| # Truncate context to fit | |
| optimized_context = self.truncate_to_fit(context, available_for_context, strategy="end") | |
| reduction_pct = ((context_tokens - self.count_tokens(text = optimized_context)) / context_tokens) * 100 | |
| self.logger.info(f"Context reduced by {reduction_pct:.1f}% to fit context window") | |
| return optimized_context | |
| def get_token_stats(self, text: str) -> Dict: | |
| """ | |
| Get detailed token statistics | |
| Arguments: | |
| ---------- | |
| text { str } : Input text | |
| Returns: | |
| -------- | |
| { dict } : Token statistics | |
| """ | |
| tokens = self.count_tokens(text = text) | |
| chars = len(text) | |
| words = len(text.split()) | |
| return {"tokens" : tokens, | |
| "characters" : chars, | |
| "words" : words, | |
| "chars_per_token" : chars / tokens if tokens > 0 else 0, | |
| "tokens_per_word" : tokens / words if words > 0 else 0, | |
| "context_window" : self.context_window, | |
| "model" : self.model_name, | |
| } | |
| # Global token manager instance | |
| _token_manager = None | |
| def get_token_manager(model_name: str = None) -> TokenManager: | |
| """ | |
| Get global token manager instance | |
| Arguments: | |
| ---------- | |
| model_name { str } : Model name for tokenizer selection | |
| Returns: | |
| -------- | |
| { TokenManager } : TokenManager instance | |
| """ | |
| global _token_manager | |
| if _token_manager is None or (model_name and _token_manager.model_name != model_name): | |
| _token_manager = TokenManager(model_name) | |
| return _token_manager | |
| def count_tokens_safe(text: str, model_name: str = None) -> int: | |
| """ | |
| Safe token counting with error handling | |
| Arguments: | |
| ---------- | |
| text { str } : Text to count tokens for | |
| model_name { str } : Model name for tokenizer | |
| Returns: | |
| -------- | |
| { int } : Token count (0 on error) | |
| """ | |
| try: | |
| manager = get_token_manager(model_name = model_name) | |
| return manager.count_tokens(text = text) | |
| except Exception: | |
| return 0 |