from typing import Literal import tiktoken import anthropic from typing import List # Gemini requires the Vertex AI SDK try: from vertexai.preview import tokenization as vertex_tokenization except ImportError: vertex_tokenization = None # Mistral requires the SentencePiece tokenizer try: import sentencepiece as spm except ImportError: spm = None # --------------------------- # Individual Token Counters # --------------------------- def count_tokens_openai(text: str, model_name: str) -> int: try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") # fallback return len(encoding.encode(text)) def count_tokens_anthropic(text: str, model_name: str) -> int: try: client = anthropic.Anthropic() response = client.messages.count_tokens( model=model_name, messages=[{"role": "user", "content": text}] ) return response['input_tokens'] except Exception as e: raise RuntimeError(f"Anthropic token counting failed: {e}") def count_tokens_gemini(text: str, model_name: str) -> int: if vertex_tokenization is None: raise ImportError("Please install vertexai: pip install google-cloud-aiplatform[tokenization]") try: tokenizer = vertex_tokenization.get_tokenizer_for_model("gemini-1.5-flash-002") result = tokenizer.count_tokens(text) return result.total_tokens except Exception as e: raise RuntimeError(f"Gemini token counting failed: {e}") def count_tokens_mistral(text: str) -> int: if spm is None: raise ImportError("Please install sentencepiece: pip install sentencepiece") try: sp = spm.SentencePieceProcessor() # IMPORTANT: You must provide the correct path to the tokenizer model file sp.load("mistral_tokenizer.model") tokens = sp.encode(text, out_type=str) return len(tokens) except Exception as e: raise RuntimeError(f"Mistral token counting failed: {e}") # --------------------------- # Unified Token Counter # --------------------------- def count_tokens(text: str, model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]) -> int: if provider == "OpenAI": return count_tokens_openai(text, model_name) elif provider == "Anthropic": return count_tokens_anthropic(text, model_name) elif provider == "Gemini": return count_tokens_gemini(text, model_name) elif provider == "Mistral": return count_tokens_mistral(text) else: raise ValueError(f"Unsupported provider: {provider}") def get_token_limit_for_model(model_name, provider): # Example values; update as needed for your providers if provider == "openai": if "gpt-4.1-nano" in model_name: return 1047576 # Based on search results elif "gpt-4o-mini" in model_name: return 128000 # Based on search results elif provider == "anthropic": if "claude-3-opus" in model_name: return 200000 # Based on search results elif "claude-3-sonnet" in model_name: return 200000 # Based on search results elif provider == "gemini": if "gemini-2.0-flash-lite" in model_name: return 1048576 # Based on search results elif "gemini-1.5-flash" in model_name: return 1048576 # Based on search results elif provider == "mistral": if "mistral-small" in model_name: return 32000 # Based on search results elif "mistral-medium" in model_name: return 32000 # Based on search results return 8000 # default fallback def estimate_avg_tokens_per_doc( docs: List[str], model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"] ) -> float: """ Estimate the average number of tokens per document for the given model. Args: docs (List[str]): List of documents. model_name (str): Model name. provider (Literal): LLM provider. Returns: float: Average number of tokens per document. """ if not docs: return 0.0 token_counts = [count_tokens(doc, model_name, provider) for doc in docs] return sum(token_counts) / len(token_counts) def estimate_max_k( docs: List[str], model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"], margin_ratio: float = 0.1, ) -> int: """ Estimate the maximum number of documents that can fit in the context window. Returns: int: Estimated K. """ if not docs: return 0 max_tokens = get_token_limit_for_model(model_name, provider) margin = int(max_tokens * margin_ratio) available_tokens = max_tokens - margin avg_tokens_per_doc = estimate_avg_tokens_per_doc(docs, model_name, provider) if avg_tokens_per_doc == 0: return 0 return min(len(docs), int(available_tokens // avg_tokens_per_doc)) def estimate_max_k_fast(docs, margin_ratio=0.1, max_tokens=8000, model_name="gpt-3.5-turbo"): enc = tiktoken.encoding_for_model(model_name) avg_len = sum(len(enc.encode(doc)) for doc in docs[:20]) / min(20, len(docs)) margin = int(max_tokens * margin_ratio) available = max_tokens - margin return min(len(docs), int(available // avg_len)) def estimate_k_max_from_word_stats( avg_words_per_doc: float, margin_ratio: float = 0.1, avg_tokens_per_word: float = 1.3, model_name=None, provider=None ) -> int: model_token_limit = get_token_limit_for_model(model_name, provider) effective_limit = int(model_token_limit * (1 - margin_ratio)) est_tokens_per_doc = avg_words_per_doc * avg_tokens_per_word return int(effective_limit // est_tokens_per_doc)