ankigen / ankigen_core /agents /token_tracker.py
brickfrog's picture
Upload folder using huggingface_hub
c467eef verified
"""
Token usage tracking for OpenAI API calls using tiktoken.
Provides accurate token counting and cost estimation.
"""
import tiktoken
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from ankigen_core.logging import logger
@dataclass
class TokenUsage:
"""Track token usage for a single request"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
estimated_cost: Optional[float]
model: str
timestamp: datetime = field(default_factory=datetime.now)
class TokenTracker:
"""Track token usage across multiple requests"""
def __init__(self):
self.usage_history: List[TokenUsage] = []
self.total_cost = 0.0
self.total_tokens = 0
def count_tokens_for_messages(
self, messages: List[Dict[str, str]], model: str
) -> int:
"""
Count total tokens for a list of chat messages using tiktoken.
Implements OpenAI's token counting algorithm for chat completions:
- Each message adds 3 tokens for role/content/structure overhead
- Message names add an additional token
- The entire message list adds 3 tokens for conversation wrapper
The encoding is selected based on the model:
- Attempts to use model-specific encoding via tiktoken
- Falls back to 'o200k_base' (GPT-4 Turbo encoding) for unknown models
Args:
messages: List of message dicts (each with 'role', 'content', optional 'name')
model: OpenAI model identifier (e.g., 'gpt-5.2', 'gpt-4o')
Returns:
Total tokens required to send these messages to the model
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("o200k_base")
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def count_tokens_for_text(self, text: str, model: str) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("o200k_base")
return len(encoding.encode(text))
def track_usage_from_response(
self, response_data, model: str
) -> Optional[TokenUsage]:
try:
if hasattr(response_data, "usage"):
usage = response_data.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
actual_cost = None
if hasattr(usage, "total_cost"):
actual_cost = usage.total_cost
elif hasattr(usage, "cost"):
actual_cost = usage.cost
return self.track_usage(
prompt_tokens, completion_tokens, model, actual_cost
)
return None
except Exception as e:
logger.error(f"Failed to track usage from response: {e}")
return None
def track_usage(
self,
prompt_tokens: int,
completion_tokens: int,
model: str,
actual_cost: Optional[float] = None,
) -> TokenUsage:
total_tokens = prompt_tokens + completion_tokens
final_cost = actual_cost # Cost estimation removed - rely on API-provided costs
usage = TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
estimated_cost=final_cost,
model=model,
)
self.usage_history.append(usage)
if final_cost:
self.total_cost += final_cost
self.total_tokens += total_tokens
logger.info(
f"πŸ’° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}, Cost: ${final_cost:.4f}"
if final_cost
else f"πŸ’° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}"
)
return usage
def get_session_summary(self) -> Dict[str, Any]:
if not self.usage_history:
return {
"total_requests": 0,
"total_tokens": 0,
"total_cost": 0.0,
"by_model": {},
}
by_model = {}
for usage in self.usage_history:
if usage.model not in by_model:
by_model[usage.model] = {"requests": 0, "tokens": 0, "cost": 0.0}
by_model[usage.model]["requests"] += 1
by_model[usage.model]["tokens"] += usage.total_tokens
if usage.estimated_cost:
by_model[usage.model]["cost"] += usage.estimated_cost
return {
"total_requests": len(self.usage_history),
"total_tokens": self.total_tokens,
"total_cost": self.total_cost,
"by_model": by_model,
}
def get_session_usage(self) -> Dict[str, Any]:
return self.get_session_summary()
def reset_session(self):
self.usage_history.clear()
self.total_cost = 0.0
self.total_tokens = 0
logger.info("πŸ”„ Token usage tracking reset")
def track_usage_from_agents_sdk(
self, usage_dict: Dict[str, Any], model: str
) -> Optional[TokenUsage]:
"""Track usage from OpenAI Agents SDK usage format"""
try:
if not usage_dict or usage_dict.get("total_tokens", 0) == 0:
return None
prompt_tokens = usage_dict.get("input_tokens", 0)
completion_tokens = usage_dict.get("output_tokens", 0)
return self.track_usage(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"Failed to track usage from agents SDK: {e}")
return None
# Global token tracker instance
_global_tracker = TokenTracker()
def get_token_tracker() -> TokenTracker:
return _global_tracker
def track_agent_usage(
prompt_text: str,
completion_text: str,
model: str,
actual_cost: Optional[float] = None,
) -> TokenUsage:
tracker = get_token_tracker()
prompt_tokens = tracker.count_tokens_for_text(prompt_text, model)
completion_tokens = tracker.count_tokens_for_text(completion_text, model)
return tracker.track_usage(prompt_tokens, completion_tokens, model, actual_cost)
def track_usage_from_openai_response(response_data, model: str) -> Optional[TokenUsage]:
tracker = get_token_tracker()
return tracker.track_usage_from_response(response_data, model)
def track_usage_from_agents_sdk(
usage_dict: Dict[str, Any], model: str
) -> Optional[TokenUsage]:
"""Track usage from OpenAI Agents SDK usage format"""
tracker = get_token_tracker()
return tracker.track_usage_from_agents_sdk(usage_dict, model)