StanDataCamp's picture
Restore test.py
4770b76
"""
Podcast Assistant Test Runner - Debug mode with message flow and metrics.
Usage:
uv run python test.py
"""
import os
import time
from datetime import datetime
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.chat_models import init_chat_model
from search_podcasts import search_podcasts
# Load environment variables
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
load_dotenv(dotenv_path=os.path.join(SCRIPT_DIR, '..', '.env'))
# OpenAI pricing (per 1M tokens) - update as needed
PRICING = {
"gpt-5-nano": {
"input": 0.05 / 1_000_000,
"input_cached": 0.005 / 1_000_000,
"output": 0.40 / 1_000_000,
},
"gpt-5-mini": {
"input": 0.25 / 1_000_000,
"input_cached": 0.025 / 1_000_000,
"output": 2.00 / 1_000_000,
},
"gpt-4.1-nano": {
"input": 0.10 / 1_000_000,
"input_cached": 0.025 / 1_000_000,
"output": 0.40 / 1_000_000,
},
"gpt-4.1-mini": {
"input": 0.40 / 1_000_000,
"input_cached": 0.10 / 1_000_000,
"output": 1.60 / 1_000_000,
},
"gpt-4o-mini": {
"input": 0.15 / 1_000_000,
"input_cached": 0.075 / 1_000_000,
"output": 0.60 / 1_000_000,
},
"gpt-4o": {
"input": 2.50 / 1_000_000,
"input_cached": 1.25 / 1_000_000,
"output": 10.00 / 1_000_000,
},
}
# Model to use for testing
MODEL = "gpt-4.1-mini"
def compare_models(conversation_index: int = 0, models: list = None):
"""Compare multiple models on the same conversation."""
if models is None:
models = ["gpt-5-nano", "gpt-4.1-nano", "gpt-4o-mini"]
results = []
for model in models:
print(f"\n{'#' * 80}")
print(f" TESTING: {model} ".center(80, "#"))
print("#" * 80)
result = run_test(conversation_index=conversation_index, model=model, return_metrics=True)
results.append({"model": model, **result})
# Print comparison summary
print("\n" + "=" * 80)
print(" MODEL COMPARISON ".center(80, "="))
print("=" * 80)
print("\n📊 Summary:")
print(f" {'Model':<15} {'Input':>10} {'Output':>10} {'Cost':>12} {'Time':>10}")
print(f" {'-'*15} {'-'*10} {'-'*10} {'-'*12} {'-'*10}")
for r in results:
print(f" {r['model']:<15} {r['total_api_input']:>10,} {r['total_api_output']:>10,} ${r['total_cost']:>11.6f} {r['total_time']:>9.2f}s")
# Find cheapest and fastest
cheapest = min(results, key=lambda x: x['total_cost'])
fastest = min(results, key=lambda x: x['total_time'])
print(f"\n 💰 Cheapest: {cheapest['model']} (${cheapest['total_cost']:.6f})")
print(f" ⚡ Fastest: {fastest['model']} ({fastest['total_time']:.2f}s)")
def get_usage_from_response(response) -> dict:
"""Extract token usage from LangChain response (from OpenAI API)."""
usage = {"input": 0, "output": 0, "cache_read": 0}
if hasattr(response, 'usage_metadata') and response.usage_metadata:
usage["input"] = response.usage_metadata.get("input_tokens", 0)
usage["output"] = response.usage_metadata.get("output_tokens", 0)
# Get cached tokens
input_details = response.usage_metadata.get("input_token_details", {})
usage["cache_read"] = input_details.get("cache_read", 0)
return usage
def calculate_cost(input_tokens: int, output_tokens: int, cache_read: int = 0, model: str = MODEL) -> float:
"""Calculate cost in USD. Cached tokens get 50% discount."""
pricing = PRICING.get(model, PRICING["gpt-4o-mini"])
# Non-cached input tokens
regular_input = input_tokens - cache_read
cost = (
(regular_input * pricing["input"]) +
(cache_read * pricing["input_cached"]) +
(output_tokens * pricing["output"])
)
return cost
def truncate(text: str, max_chars: int = 150) -> str:
"""Truncate text with ellipsis."""
if len(text) <= max_chars:
return text
return text[:max_chars] + "..."
# =============================================================================
# Test Conversations (multi-turn)
# =============================================================================
TEST_CONVERSATIONS = [
# 0: Multi-turn with tool use → follow-up may trigger another search → context-based recommendation
[
"What is the future of AI and AGI according to experts?",
"What benefits and risks did they mention?",
"Which episode should I watch first?",
],
# 1: Tool use → context-based summary (tests if AI reuses previous results instead of searching again)
[
"How should young people approach their career and education?",
"Can you summarize that in 3 actionable points?"
],
# 2: Multiple tool calls across turns (each question may require fresh search)
[
"What habits and routines do high performers follow?",
"What about sleep habits?",
"Any book recommendations from them?",
],
# 3: Off-topic question (tests if AI correctly skips tool use)
[
"What is 2 + 2?",
],
# 4: Single turn with one tool call (baseline for model comparison)
[
"What are the best books or films that influenced successful people?",
],
]
# =============================================================================
# Message Formatting
# =============================================================================
def print_header(title: str, width: int = 80):
"""Print a centered header."""
padding = (width - len(title) - 2) // 2
print("=" * padding + f" {title} " + "=" * (width - padding - len(title) - 2))
def print_message(role: str, content: str, tool_calls: list = None, full: bool = False):
"""Print a message."""
headers = {
"system": "System Message",
"human": "Human Message",
"ai": "Ai Message",
"tool": "Tool Message",
}
header = headers.get(role, role)
print_header(header)
if content:
display = content if full else truncate(content)
print(display)
if tool_calls:
print("Tool Calls:")
for tc in tool_calls:
print(f" {tc['name']}({', '.join(f'{k}={repr(v)}' for k, v in tc['args'].items())})")
def print_metrics_summary(metrics: dict, model: str = MODEL):
"""Print a summary of all metrics."""
print("\n" + "=" * 80)
print(" METRICS SUMMARY ".center(80, "="))
print("=" * 80)
print(f"\n📊 Token Usage & Cost ({model}):")
for tm in metrics['turns']:
tool_str = " [tool]" if tm.get('used_tool') else ""
cache_str = f" ({tm['cache_read']:,} cached)" if tm['cache_read'] > 0 else ""
print(f" Turn {tm['turn']}: {tm['api_input']:,} in{cache_str}{tm['api_output']:,} out = ${tm['cost']:.6f}{tool_str}")
print(f"\n ─────────────────────────────")
print(f" Total input tokens: {metrics['total_api_input']:,}")
print(f" Total cached tokens: {metrics['total_cache_read']:,} (50% discount)")
print(f" Total output tokens: {metrics['total_api_output']:,}")
print(f" Total cost: ${metrics['total_cost']:.6f}")
print("\n⏱️ Timing:")
total_llm = 0
total_tool = 0
for tm in metrics['turns']:
llm_time = tm.get('llm_time', 0)
tool_time = tm.get('tool_time', 0)
total_llm += llm_time
total_tool += tool_time
tool_str = f" + tool {tool_time:.2f}s" if tool_time > 0 else ""
print(f" Turn {tm['turn']}: LLM {llm_time:.2f}s{tool_str}")
print(f" ─────────────────────────────")
print(f" Total LLM time: {total_llm:.2f}s")
print(f" Total tool time: {total_tool:.2f}s")
print(f" Total: {total_llm + total_tool:.2f}s")
# =============================================================================
# Main Test Runner
# =============================================================================
def run_test(conversation_index: int = 0, model: str = None, return_metrics: bool = False):
"""Run a multi-turn conversation test with metrics."""
model = model or MODEL
queries = TEST_CONVERSATIONS[conversation_index]
metrics = {
'turns': [],
'total_api_input': 0,
'total_api_output': 0,
'total_cache_read': 0,
'total_cost': 0,
'total_time': 0,
}
# Load system prompt
with open(os.path.join(SCRIPT_DIR, "prompt.md"), "r") as f:
prompt_template = f.read()
today = datetime.now().strftime("%A, %B %d, %Y")
system_prompt = prompt_template.replace("{today_date}", today)
# Initialize LLM
llm = init_chat_model(model=model)
tools = [search_podcasts]
llm_with_tools = llm.bind_tools(tools)
tools_dict = {t.name: t for t in tools}
# Build initial messages
messages = [SystemMessage(content=system_prompt)]
print("\n" + "=" * 80)
print(" PODCAST ASSISTANT TEST ".center(80, "="))
print("=" * 80)
print(f"Model: {model}")
print(f"Conversation {conversation_index + 1}: {len(queries)} turns")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 80)
# Print System Message (truncated)
print_message("system", system_prompt)
# Process each query in the conversation
for turn, query in enumerate(queries, 1):
turn_metrics = {
'turn': turn,
'api_input': 0,
'api_output': 0,
'cache_read': 0,
'cost': 0,
'llm_time': 0,
'tool_time': 0,
'used_tool': False
}
print(f"\n{'─' * 80}")
print(f" TURN {turn}/{len(queries)} ".center(80, "─"))
print("─" * 80)
# Add user message
messages.append(HumanMessage(content=query))
print_message("human", query)
# Get AI response
t0 = time.perf_counter()
response = llm_with_tools.invoke(messages)
turn_metrics['llm_time'] = time.perf_counter() - t0
# Track API usage
usage = get_usage_from_response(response)
turn_metrics['api_input'] += usage['input']
turn_metrics['api_output'] += usage['output']
turn_metrics['cache_read'] += usage['cache_read']
messages.append(response)
# Print AI Message with tool calls
if response.tool_calls:
turn_metrics['used_tool'] = True
print_message("ai", "", tool_calls=response.tool_calls)
# Process tool calls
for tool_call in response.tool_calls:
tool_name = tool_call["name"]
if tool_name in tools_dict:
t0 = time.perf_counter()
tool_result = tools_dict[tool_name].invoke(tool_call)
turn_metrics['tool_time'] += time.perf_counter() - t0
messages.append(tool_result)
# Show FULL tool results (RAG context)
print_message("tool", tool_result.content, full=True)
# Get final AI response
t0 = time.perf_counter()
final_response = llm_with_tools.invoke(messages)
turn_metrics['llm_time'] += time.perf_counter() - t0
# Track API usage for second call
usage2 = get_usage_from_response(final_response)
turn_metrics['api_input'] += usage2['input']
turn_metrics['api_output'] += usage2['output']
turn_metrics['cache_read'] += usage2['cache_read']
messages.append(final_response)
# Full AI response
print_message("ai", final_response.content, full=True)
else:
# No tool call, just AI response
print_message("ai", response.content, full=True)
# Calculate cost for this turn (with cache discount)
turn_metrics['cost'] = calculate_cost(
turn_metrics['api_input'],
turn_metrics['api_output'],
turn_metrics['cache_read'],
model=model
)
# Update totals
metrics['total_api_input'] += turn_metrics['api_input']
metrics['total_api_output'] += turn_metrics['api_output']
metrics['total_cache_read'] += turn_metrics['cache_read']
metrics['total_cost'] += turn_metrics['cost']
metrics['turns'].append(turn_metrics)
# Calculate total time
metrics['total_time'] = sum(t['llm_time'] + t['tool_time'] for t in metrics['turns'])
# Print summary
print_metrics_summary(metrics, model=model)
if return_metrics:
return metrics
if __name__ == "__main__":
run_test(conversation_index=0)