| """ |
| Token usage tracking for OpenAI API calls using tiktoken. |
| Provides accurate token counting and cost estimation. |
| """ |
|
|
| import tiktoken |
| from typing import Dict, List, Any, Optional |
| from dataclasses import dataclass, field |
| from datetime import datetime |
|
|
| from ankigen_core.logging import logger |
|
|
|
|
| @dataclass |
| class TokenUsage: |
| """Track token usage for a single request""" |
|
|
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
| estimated_cost: Optional[float] |
| model: str |
| timestamp: datetime = field(default_factory=datetime.now) |
|
|
|
|
| class TokenTracker: |
| """Track token usage across multiple requests""" |
|
|
| def __init__(self): |
| self.usage_history: List[TokenUsage] = [] |
| self.total_cost = 0.0 |
| self.total_tokens = 0 |
|
|
| def count_tokens_for_messages( |
| self, messages: List[Dict[str, str]], model: str |
| ) -> int: |
| """ |
| Count total tokens for a list of chat messages using tiktoken. |
| |
| Implements OpenAI's token counting algorithm for chat completions: |
| - Each message adds 3 tokens for role/content/structure overhead |
| - Message names add an additional token |
| - The entire message list adds 3 tokens for conversation wrapper |
| |
| The encoding is selected based on the model: |
| - Attempts to use model-specific encoding via tiktoken |
| - Falls back to 'o200k_base' (GPT-4 Turbo encoding) for unknown models |
| |
| Args: |
| messages: List of message dicts (each with 'role', 'content', optional 'name') |
| model: OpenAI model identifier (e.g., 'gpt-5.2', 'gpt-4o') |
| |
| Returns: |
| Total tokens required to send these messages to the model |
| """ |
| try: |
| encoding = tiktoken.encoding_for_model(model) |
| except KeyError: |
| encoding = tiktoken.get_encoding("o200k_base") |
|
|
| tokens_per_message = 3 |
| tokens_per_name = 1 |
|
|
| num_tokens = 0 |
| for message in messages: |
| num_tokens += tokens_per_message |
| for key, value in message.items(): |
| num_tokens += len(encoding.encode(str(value))) |
| if key == "name": |
| num_tokens += tokens_per_name |
|
|
| num_tokens += 3 |
| return num_tokens |
|
|
| def count_tokens_for_text(self, text: str, model: str) -> int: |
| try: |
| encoding = tiktoken.encoding_for_model(model) |
| except KeyError: |
| encoding = tiktoken.get_encoding("o200k_base") |
|
|
| return len(encoding.encode(text)) |
|
|
| def track_usage_from_response( |
| self, response_data, model: str |
| ) -> Optional[TokenUsage]: |
| try: |
| if hasattr(response_data, "usage"): |
| usage = response_data.usage |
| prompt_tokens = usage.prompt_tokens |
| completion_tokens = usage.completion_tokens |
|
|
| actual_cost = None |
| if hasattr(usage, "total_cost"): |
| actual_cost = usage.total_cost |
| elif hasattr(usage, "cost"): |
| actual_cost = usage.cost |
|
|
| return self.track_usage( |
| prompt_tokens, completion_tokens, model, actual_cost |
| ) |
| return None |
| except Exception as e: |
| logger.error(f"Failed to track usage from response: {e}") |
| return None |
|
|
| def track_usage( |
| self, |
| prompt_tokens: int, |
| completion_tokens: int, |
| model: str, |
| actual_cost: Optional[float] = None, |
| ) -> TokenUsage: |
| total_tokens = prompt_tokens + completion_tokens |
|
|
| final_cost = actual_cost |
|
|
| usage = TokenUsage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=total_tokens, |
| estimated_cost=final_cost, |
| model=model, |
| ) |
|
|
| self.usage_history.append(usage) |
| if final_cost: |
| self.total_cost += final_cost |
| self.total_tokens += total_tokens |
|
|
| logger.info( |
| f"π° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}, Cost: ${final_cost:.4f}" |
| if final_cost |
| else f"π° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}" |
| ) |
|
|
| return usage |
|
|
| def get_session_summary(self) -> Dict[str, Any]: |
| if not self.usage_history: |
| return { |
| "total_requests": 0, |
| "total_tokens": 0, |
| "total_cost": 0.0, |
| "by_model": {}, |
| } |
|
|
| by_model = {} |
| for usage in self.usage_history: |
| if usage.model not in by_model: |
| by_model[usage.model] = {"requests": 0, "tokens": 0, "cost": 0.0} |
| by_model[usage.model]["requests"] += 1 |
| by_model[usage.model]["tokens"] += usage.total_tokens |
| if usage.estimated_cost: |
| by_model[usage.model]["cost"] += usage.estimated_cost |
|
|
| return { |
| "total_requests": len(self.usage_history), |
| "total_tokens": self.total_tokens, |
| "total_cost": self.total_cost, |
| "by_model": by_model, |
| } |
|
|
| def get_session_usage(self) -> Dict[str, Any]: |
| return self.get_session_summary() |
|
|
| def reset_session(self): |
| self.usage_history.clear() |
| self.total_cost = 0.0 |
| self.total_tokens = 0 |
| logger.info("π Token usage tracking reset") |
|
|
| def track_usage_from_agents_sdk( |
| self, usage_dict: Dict[str, Any], model: str |
| ) -> Optional[TokenUsage]: |
| """Track usage from OpenAI Agents SDK usage format""" |
| try: |
| if not usage_dict or usage_dict.get("total_tokens", 0) == 0: |
| return None |
|
|
| prompt_tokens = usage_dict.get("input_tokens", 0) |
| completion_tokens = usage_dict.get("output_tokens", 0) |
|
|
| return self.track_usage(prompt_tokens, completion_tokens, model) |
| except Exception as e: |
| logger.error(f"Failed to track usage from agents SDK: {e}") |
| return None |
|
|
|
|
| |
| _global_tracker = TokenTracker() |
|
|
|
|
| def get_token_tracker() -> TokenTracker: |
| return _global_tracker |
|
|
|
|
| def track_agent_usage( |
| prompt_text: str, |
| completion_text: str, |
| model: str, |
| actual_cost: Optional[float] = None, |
| ) -> TokenUsage: |
| tracker = get_token_tracker() |
|
|
| prompt_tokens = tracker.count_tokens_for_text(prompt_text, model) |
| completion_tokens = tracker.count_tokens_for_text(completion_text, model) |
|
|
| return tracker.track_usage(prompt_tokens, completion_tokens, model, actual_cost) |
|
|
|
|
| def track_usage_from_openai_response(response_data, model: str) -> Optional[TokenUsage]: |
| tracker = get_token_tracker() |
| return tracker.track_usage_from_response(response_data, model) |
|
|
|
|
| def track_usage_from_agents_sdk( |
| usage_dict: Dict[str, Any], model: str |
| ) -> Optional[TokenUsage]: |
| """Track usage from OpenAI Agents SDK usage format""" |
| tracker = get_token_tracker() |
| return tracker.track_usage_from_agents_sdk(usage_dict, model) |
|
|