|
|
""" |
|
|
Agent cache management for per-session agents. |
|
|
|
|
|
This module handles storing and retrieving agents for differnet users/sessions. |
|
|
Each agent is cached by (session_id, provider, model, api_key_hash) to avoid recreating them. |
|
|
""" |
|
|
|
|
|
from datetime import datetime, timedelta |
|
|
from typing import Dict, Tuple, Any |
|
|
import hashlib |
|
|
|
|
|
|
|
|
agent_cache: Dict[Tuple[str, str, str, str, str], Any] = {} |
|
|
|
|
|
|
|
|
agent_last_used: Dict[Tuple[str, str, str, str, str], datetime] = {} |
|
|
|
|
|
async def get_or_create_agent( |
|
|
session_id: str, |
|
|
provider: str, |
|
|
api_key: str, |
|
|
model: str, |
|
|
mode: str, |
|
|
agent_factory_method |
|
|
): |
|
|
""" |
|
|
Get existing agent from cache or create new one. |
|
|
|
|
|
Args: |
|
|
session_id: Unique identifier for user session (from gr.Request) |
|
|
provider: "huggingface" or "openai" |
|
|
api_key: User's API key or OAuth token |
|
|
model: Model name/repo ID |
|
|
mode: Agent mode (e.g., "Single Agent (All Tools)", "Specialized Subagents (3 Specialists)") |
|
|
agent_factory_method: Async function to create agent if not cached |
|
|
|
|
|
Returns: |
|
|
Cached or newly created agent |
|
|
|
|
|
Example: |
|
|
agent = await get_or_create_agent( |
|
|
session_id="abc123", |
|
|
provider="openai", |
|
|
api_key="sk-...", |
|
|
model="gpt-4o-mini", |
|
|
mode="Single Agent (All Tools)", |
|
|
agent_factory_method=lambda: AgentFactory.create_streaming_agent_with_openai(...) |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] |
|
|
|
|
|
|
|
|
cache_key = (session_id, provider, model, api_key_hash, mode) |
|
|
|
|
|
|
|
|
if cache_key in agent_cache: |
|
|
print(f"[CACHE HIT] Reusing agent for session {session_id[:8]}...") |
|
|
agent_last_used[cache_key] = datetime.now() |
|
|
return agent_cache[cache_key] |
|
|
|
|
|
|
|
|
print(f"[CACHE MISS] Creating new {provider} agent for session {session_id[:8]}...") |
|
|
|
|
|
|
|
|
agent = await agent_factory_method() |
|
|
|
|
|
|
|
|
agent_cache[cache_key] = agent |
|
|
agent_last_used[cache_key] = datetime.now() |
|
|
|
|
|
print(f"[CACHE] Stored agent. Total agents in cache: {len(agent_cache)}") |
|
|
|
|
|
return agent |
|
|
|
|
|
def cleanup_old_agents(max_age_hours: int = 1): |
|
|
""" |
|
|
Remove agents that haven't been used in max_age_hours. |
|
|
|
|
|
Call this periodically to prevent memory leaks. |
|
|
|
|
|
Args: |
|
|
max_age_hours: Remove agents older than this many hours |
|
|
|
|
|
Returns: |
|
|
Number of agents removed |
|
|
""" |
|
|
now = datetime.now() |
|
|
to_remove = [] |
|
|
|
|
|
for cache_key, last_used in agent_last_used.items(): |
|
|
age = now - last_used |
|
|
if age > timedelta(hours=max_age_hours): |
|
|
to_remove.append(cache_key) |
|
|
|
|
|
|
|
|
for cache_key in to_remove: |
|
|
print(f"[CLEANUP] Removing stale agent: {cache_key}") |
|
|
del agent_cache[cache_key] |
|
|
del agent_last_used[cache_key] |
|
|
|
|
|
return len(to_remove) |
|
|
|
|
|
def get_cache_stats(): |
|
|
"""Get statistics about the agent cache.""" |
|
|
return { |
|
|
"total_agents": len(agent_cache), |
|
|
"cache_keys": list(agent_cache.keys()) |
|
|
} |