ColettoG commited on
Commit
92f2b7d
·
1 Parent(s): 21d8407
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  python-dotenv==1.0.0
2
- google-generativeai>=0.8.0
3
- langchain-google-genai>=2.0.0
4
  langchain-core>=0.2.43
5
  fastapi==0.109.2
6
  uvicorn==0.27.1
@@ -13,6 +13,9 @@ langgraph-supervisor>=0.0.1
13
  langchain>=0.2.0
14
  langchain-community>=0.2.0
15
 
 
 
 
16
 
17
  # Testing dependencies
18
  pytest==8.0.0
@@ -25,9 +28,13 @@ numpy>=1.24.0
25
  pandas>=2.0.0
26
 
27
  # Monitoring and logging
28
- structlog>=23.0.0
29
- prometheus-client>=0.17.0
30
  langsmith>=0.1.0
 
 
 
 
31
 
32
  # Database (optional for production)
33
  sqlalchemy>=2.0.0
 
1
  python-dotenv==1.0.0
2
+ google-genai>=1.0.0
3
+ langchain-google-genai>=4.1.0
4
  langchain-core>=0.2.43
5
  fastapi==0.109.2
6
  uvicorn==0.27.1
 
13
  langchain>=0.2.0
14
  langchain-community>=0.2.0
15
 
16
+ # Multi-provider LLM support
17
+ langchain-openai>=0.2.0
18
+ langchain-anthropic>=0.2.0
19
 
20
  # Testing dependencies
21
  pytest==8.0.0
 
28
  pandas>=2.0.0
29
 
30
  # Monitoring and logging
31
+ structlog>=24.0.0
32
+ prometheus-client>=0.20.0
33
  langsmith>=0.1.0
34
+ colorlog>=6.8.0
35
+
36
+ # Rate limiting
37
+ slowapi>=0.1.9
38
 
39
  # Database (optional for production)
40
  sqlalchemy>=2.0.0
src/agents/base_agent.py DELETED
@@ -1,238 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Dict, Any, List, Optional
3
- import logging
4
- from datetime import datetime
5
-
6
- from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
7
- from langchain_core.language_models import BaseLanguageModel
8
-
9
- from src.models.chatMessage import ChatMessage, AgentResponse, AgentType, MessageRole
10
- from src.agents.config import Config
11
-
12
-
13
- class BaseAgent(ABC):
14
- """Base class for all agents in the multi-agent system"""
15
-
16
- def __init__(self, name: str, agent_type: AgentType, llm: BaseLanguageModel, description: str = ""):
17
- self.name = name
18
- self.agent_type = agent_type
19
- self.llm = llm
20
- self.description = description
21
- self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
22
-
23
- # Agent state
24
- self.is_active = True
25
- self.created_at = datetime.utcnow()
26
- self.last_used = None
27
- self.usage_count = 0
28
-
29
- # Performance metrics
30
- self.response_times = []
31
- self.success_count = 0
32
- self.error_count = 0
33
-
34
- self.logger.info(f"Initialized agent {name} of type {agent_type}")
35
-
36
- @abstractmethod
37
- async def process_message(self, message: str, context: Dict[str, Any] = None) -> AgentResponse:
38
- """Process a message and return a response"""
39
- pass
40
-
41
- @abstractmethod
42
- def get_capabilities(self) -> List[str]:
43
- """Return list of agent capabilities"""
44
- pass
45
-
46
- def can_handle(self, message: str, context: Dict[str, Any] = None) -> bool:
47
- """Determine if this agent can handle the given message"""
48
- # Default implementation - can be overridden by subclasses
49
- return True
50
-
51
- def get_confidence_score(self, message: str, context: Dict[str, Any] = None) -> float:
52
- """Get confidence score for handling this message (0.0 to 1.0)"""
53
- # Default implementation - can be overridden by subclasses
54
- return 0.5
55
-
56
- def prepare_context_messages(self, context: Dict[str, Any] = None) -> List[SystemMessage]:
57
- """Prepare context messages for the LLM"""
58
- context_messages = []
59
-
60
- if context:
61
- # Add relevant context information
62
- if context.get("crypto_related"):
63
- context_messages.append(SystemMessage(
64
- content="This conversation involves cryptocurrency-related topics."
65
- ))
66
-
67
- if context.get("user_preferences"):
68
- context_messages.append(SystemMessage(
69
- content=f"User preferences: {context['user_preferences']}"
70
- ))
71
-
72
- if context.get("conversation_history"):
73
- context_messages.append(SystemMessage(
74
- content=f"Previous context: {context['conversation_history']}"
75
- ))
76
-
77
- return context_messages
78
-
79
- def create_agent_response(
80
- self,
81
- content: str,
82
- success: bool = True,
83
- error_message: Optional[str] = None,
84
- metadata: Dict[str, Any] = None,
85
- tools_used: List[str] = None,
86
- next_agent: Optional[str] = None,
87
- requires_followup: bool = False
88
- ) -> AgentResponse:
89
- """Create a standardized agent response"""
90
- return AgentResponse(
91
- content=content,
92
- agent_name=self.name,
93
- agent_type=self.agent_type,
94
- success=success,
95
- error_message=error_message,
96
- metadata=metadata or {},
97
- tools_used=tools_used or [],
98
- next_agent=next_agent,
99
- requires_followup=requires_followup,
100
- timestamp=datetime.utcnow()
101
- )
102
-
103
- def update_metrics(self, response_time: float, success: bool):
104
- """Update agent performance metrics"""
105
- self.response_times.append(response_time)
106
- self.last_used = datetime.utcnow()
107
- self.usage_count += 1
108
-
109
- if success:
110
- self.success_count += 1
111
- else:
112
- self.error_count += 1
113
-
114
- # Keep only last 100 response times
115
- if len(self.response_times) > 100:
116
- self.response_times = self.response_times[-100:]
117
-
118
- def get_performance_metrics(self) -> Dict[str, Any]:
119
- """Get agent performance metrics"""
120
- avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
121
- success_rate = self.success_count / self.usage_count if self.usage_count > 0 else 0
122
-
123
- return {
124
- "name": self.name,
125
- "agent_type": self.agent_type.value,
126
- "usage_count": self.usage_count,
127
- "success_count": self.success_count,
128
- "error_count": self.error_count,
129
- "success_rate": success_rate,
130
- "average_response_time": avg_response_time,
131
- "last_used": self.last_used.isoformat() if self.last_used else None,
132
- "is_active": self.is_active
133
- }
134
-
135
- def activate(self):
136
- """Activate the agent"""
137
- self.is_active = True
138
- self.logger.info(f"Agent {self.name} activated")
139
-
140
- def deactivate(self):
141
- """Deactivate the agent"""
142
- self.is_active = False
143
- self.logger.info(f"Agent {self.name} deactivated")
144
-
145
- def reset_metrics(self):
146
- """Reset performance metrics"""
147
- self.response_times = []
148
- self.success_count = 0
149
- self.error_count = 0
150
- self.usage_count = 0
151
- self.logger.info(f"Reset metrics for agent {self.name}")
152
-
153
- def get_agent_info(self) -> Dict[str, Any]:
154
- """Get comprehensive agent information"""
155
- return {
156
- "name": self.name,
157
- "type": self.agent_type.value,
158
- "description": self.description,
159
- "capabilities": self.get_capabilities(),
160
- "is_active": self.is_active,
161
- "created_at": self.created_at.isoformat(),
162
- "last_used": self.last_used.isoformat() if self.last_used else None,
163
- "performance_metrics": self.get_performance_metrics()
164
- }
165
-
166
-
167
- class AgentRegistry:
168
- """Registry for managing all agents in the system"""
169
-
170
- def __init__(self):
171
- self.agents: Dict[str, BaseAgent] = {}
172
- self.logger = logging.getLogger(__name__)
173
-
174
- def register_agent(self, agent: BaseAgent) -> None:
175
- """Register an agent"""
176
- if agent.name in self.agents:
177
- self.logger.warning(f"Agent {agent.name} already registered, overwriting")
178
-
179
- self.agents[agent.name] = agent
180
- self.logger.info(f"Registered agent {agent.name}")
181
-
182
- def unregister_agent(self, agent_name: str) -> bool:
183
- """Unregister an agent"""
184
- if agent_name in self.agents:
185
- del self.agents[agent_name]
186
- self.logger.info(f"Unregistered agent {agent_name}")
187
- return True
188
- return False
189
-
190
- def get_agent(self, agent_name: str) -> Optional[BaseAgent]:
191
- """Get agent by name"""
192
- return self.agents.get(agent_name)
193
-
194
- def get_active_agents(self) -> List[BaseAgent]:
195
- """Get all active agents"""
196
- return [agent for agent in self.agents.values() if agent.is_active]
197
-
198
- def get_agents_by_type(self, agent_type: AgentType) -> List[BaseAgent]:
199
- """Get agents by type"""
200
- return [agent for agent in self.agents.values() if agent.agent_type == agent_type]
201
-
202
- def find_best_agent(self, message: str, context: Dict[str, Any] = None) -> Optional[BaseAgent]:
203
- """Find the best agent to handle a message"""
204
- best_agent = None
205
- best_score = 0.0
206
-
207
- for agent in self.get_active_agents():
208
- if agent.can_handle(message, context):
209
- confidence = agent.get_confidence_score(message, context)
210
- if confidence > best_score:
211
- best_score = confidence
212
- best_agent = agent
213
-
214
- return best_agent
215
-
216
- def get_all_agents_info(self) -> List[Dict[str, Any]]:
217
- """Get information about all agents"""
218
- return [agent.get_agent_info() for agent in self.agents.values()]
219
-
220
- def get_agent_performance_summary(self) -> Dict[str, Any]:
221
- """Get performance summary for all agents"""
222
- total_agents = len(self.agents)
223
- active_agents = len(self.get_active_agents())
224
- total_usage = sum(agent.usage_count for agent in self.agents.values())
225
- total_success = sum(agent.success_count for agent in self.agents.values())
226
-
227
- return {
228
- "total_agents": total_agents,
229
- "active_agents": active_agents,
230
- "total_usage": total_usage,
231
- "total_success": total_success,
232
- "overall_success_rate": total_success / total_usage if total_usage > 0 else 0,
233
- "agents": self.get_all_agents_info()
234
- }
235
-
236
-
237
- # Global agent registry
238
- agent_registry = AgentRegistry()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agents/config.py CHANGED
@@ -1,25 +1,34 @@
1
  import os
 
 
2
  from dotenv import load_dotenv
3
- from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
4
- from typing import Optional
 
 
5
 
6
  load_dotenv()
7
 
8
- gemini_api_key = os.getenv("GEMINI_API_KEY")
9
- if not gemini_api_key:
10
- raise ValueError("GEMINI_API_KEY não encontrada nas variáveis de ambiente")
11
 
12
  class Config:
13
- # Model configuration
14
- GEMINI_MODEL = "gemini-2.5-flash"
15
- GEMINI_EMBEDDING_MODEL = "models/embedding-001"
16
- GEMINI_API_KEY = gemini_api_key
17
-
 
 
 
 
 
18
  # Application configuration
19
  MAX_UPLOAD_LENGTH = 16 * 1024 * 1024
20
- MAX_CONVERSATION_LENGTH = 100 # Maximum messages per conversation
21
- MAX_CONTEXT_MESSAGES = 10 # Maximum messages to include in context
22
-
23
  # Agent configuration
24
  AGENTS_CONFIG = {
25
  "agents": [
@@ -28,84 +37,143 @@ class Config:
28
  "description": "Handles cryptocurrency-related queries",
29
  "type": "specialized",
30
  "enabled": True,
31
- "priority": 1
32
  },
33
  {
34
  "name": "general",
35
  "description": "Handles general conversation and queries",
36
  "type": "general",
37
  "enabled": True,
38
- "priority": 2
39
- }
40
  ]
41
  }
42
-
43
  # LangGraph configuration
44
  LANGGRAPH_CONFIG = {
45
  "max_iterations": 10,
46
  "timeout": 30,
47
  "memory_window": 10,
48
- "enable_memory": True
49
  }
50
-
51
  # Conversation configuration
52
  CONVERSATION_CONFIG = {
53
  "default_user_id": "anonymous",
54
  "max_conversations_per_user": 50,
55
  "conversation_timeout_hours": 24,
56
- "enable_context_extraction": True
57
  }
58
-
59
- # LLM instances (singleton pattern)
60
- _llm_instance: Optional[ChatGoogleGenerativeAI] = None
61
- _embeddings_instance: Optional[GoogleGenerativeAIEmbeddings] = None
62
-
 
63
  @classmethod
64
- def get_llm(cls) -> ChatGoogleGenerativeAI:
65
- """Get or create LLM instance (singleton)"""
66
- if cls._llm_instance is None:
67
- cls._llm_instance = ChatGoogleGenerativeAI(
68
- model=cls.GEMINI_MODEL,
69
- temperature=0.7,
70
- google_api_key=cls.GEMINI_API_KEY
71
- )
72
- return cls._llm_instance
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  @classmethod
75
  def get_embeddings(cls) -> GoogleGenerativeAIEmbeddings:
76
- """Get or create embeddings instance (singleton)"""
77
  if cls._embeddings_instance is None:
78
  cls._embeddings_instance = GoogleGenerativeAIEmbeddings(
79
- model=cls.GEMINI_EMBEDDING_MODEL,
80
- google_api_key=cls.GEMINI_API_KEY
81
  )
82
  return cls._embeddings_instance
83
-
 
 
 
 
 
 
 
84
  @classmethod
85
- def get_agent_config(cls, agent_name: str) -> Optional[dict]:
86
- """Get configuration for a specific agent"""
87
  for agent in cls.AGENTS_CONFIG["agents"]:
88
  if agent["name"] == agent_name:
89
  return agent
90
  return None
91
-
92
  @classmethod
93
- def get_enabled_agents(cls) -> list:
94
- """Get list of enabled agents"""
95
  return [
96
- agent for agent in cls.AGENTS_CONFIG["agents"]
 
97
  if agent.get("enabled", True)
98
  ]
99
-
 
 
 
 
 
 
 
 
 
 
100
  @classmethod
101
  def validate_config(cls) -> bool:
102
- """Validate configuration"""
103
  try:
104
- # Test LLM connection
105
- llm = cls.get_llm()
106
- # Test embeddings connection
107
  embeddings = cls.get_embeddings()
108
  return True
109
  except Exception as e:
110
  print(f"Configuration validation failed: {e}")
111
- return False
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Literal
3
+
4
  from dotenv import load_dotenv
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
7
+
8
+ from src.llm import LLMFactory, CostTrackingCallback
9
 
10
  load_dotenv()
11
 
12
+ # Type alias for providers
13
+ Provider = Literal["google", "openai", "anthropic"]
14
+
15
 
16
  class Config:
17
+ """Application configuration with multi-provider LLM support."""
18
+
19
+ # Default model configuration
20
+ DEFAULT_MODEL = os.getenv("DEFAULT_LLM_MODEL", "gemini-3-pro-preview")
21
+ DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7"))
22
+ DEFAULT_PROVIDER: Provider = "google"
23
+
24
+ # Embedding configuration
25
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "models/embedding-001")
26
+
27
  # Application configuration
28
  MAX_UPLOAD_LENGTH = 16 * 1024 * 1024
29
+ MAX_CONVERSATION_LENGTH = 100
30
+ MAX_CONTEXT_MESSAGES = 10
31
+
32
  # Agent configuration
33
  AGENTS_CONFIG = {
34
  "agents": [
 
37
  "description": "Handles cryptocurrency-related queries",
38
  "type": "specialized",
39
  "enabled": True,
40
+ "priority": 1,
41
  },
42
  {
43
  "name": "general",
44
  "description": "Handles general conversation and queries",
45
  "type": "general",
46
  "enabled": True,
47
+ "priority": 2,
48
+ },
49
  ]
50
  }
51
+
52
  # LangGraph configuration
53
  LANGGRAPH_CONFIG = {
54
  "max_iterations": 10,
55
  "timeout": 30,
56
  "memory_window": 10,
57
+ "enable_memory": True,
58
  }
59
+
60
  # Conversation configuration
61
  CONVERSATION_CONFIG = {
62
  "default_user_id": "anonymous",
63
  "max_conversations_per_user": 50,
64
  "conversation_timeout_hours": 24,
65
+ "enable_context_extraction": True,
66
  }
67
+
68
+ # Instance caches
69
+ _llm_instance: BaseChatModel | None = None
70
+ _embeddings_instance: GoogleGenerativeAIEmbeddings | None = None
71
+ _cost_tracker: CostTrackingCallback | None = None
72
+
73
  @classmethod
74
+ def get_llm(
75
+ cls,
76
+ model: str | None = None,
77
+ temperature: float | None = None,
78
+ with_cost_tracking: bool = True,
79
+ ) -> BaseChatModel:
80
+ """
81
+ Get or create LLM instance using the factory.
82
+
83
+ Args:
84
+ model: Model name (defaults to DEFAULT_MODEL)
85
+ temperature: Sampling temperature (defaults to DEFAULT_TEMPERATURE)
86
+ with_cost_tracking: Whether to attach cost tracking callback
87
+
88
+ Returns:
89
+ BaseChatModel instance
90
+ """
91
+ model = model or cls.DEFAULT_MODEL
92
+ temperature = temperature if temperature is not None else cls.DEFAULT_TEMPERATURE
93
+
94
+ # Use cache for default config
95
+ use_cache = model == cls.DEFAULT_MODEL and temperature == cls.DEFAULT_TEMPERATURE
96
+
97
+ if use_cache and cls._llm_instance is not None:
98
+ return cls._llm_instance
99
+
100
+ # Build callbacks
101
+ callbacks = []
102
+ if with_cost_tracking:
103
+ callbacks.append(cls.get_cost_tracker())
104
+
105
+ llm = LLMFactory.create(
106
+ model=model,
107
+ temperature=temperature,
108
+ callbacks=callbacks if callbacks else None,
109
+ use_cache=False, # We handle caching ourselves
110
+ )
111
+
112
+ if use_cache:
113
+ cls._llm_instance = llm
114
+
115
+ return llm
116
+
117
  @classmethod
118
  def get_embeddings(cls) -> GoogleGenerativeAIEmbeddings:
119
+ """Get or create embeddings instance (singleton)."""
120
  if cls._embeddings_instance is None:
121
  cls._embeddings_instance = GoogleGenerativeAIEmbeddings(
122
+ model=cls.EMBEDDING_MODEL,
123
+ google_api_key=os.getenv("GEMINI_API_KEY"),
124
  )
125
  return cls._embeddings_instance
126
+
127
+ @classmethod
128
+ def get_cost_tracker(cls) -> CostTrackingCallback:
129
+ """Get or create cost tracker instance (singleton)."""
130
+ if cls._cost_tracker is None:
131
+ cls._cost_tracker = CostTrackingCallback(log_calls=True)
132
+ return cls._cost_tracker
133
+
134
  @classmethod
135
+ def get_agent_config(cls, agent_name: str) -> dict | None:
136
+ """Get configuration for a specific agent."""
137
  for agent in cls.AGENTS_CONFIG["agents"]:
138
  if agent["name"] == agent_name:
139
  return agent
140
  return None
141
+
142
  @classmethod
143
+ def get_enabled_agents(cls) -> list[dict]:
144
+ """Get list of enabled agents."""
145
  return [
146
+ agent
147
+ for agent in cls.AGENTS_CONFIG["agents"]
148
  if agent.get("enabled", True)
149
  ]
150
+
151
+ @classmethod
152
+ def list_available_models(cls) -> list[str]:
153
+ """List all available LLM models."""
154
+ return LLMFactory.list_models()
155
+
156
+ @classmethod
157
+ def list_available_providers(cls) -> list[str]:
158
+ """List all available LLM providers."""
159
+ return LLMFactory.list_providers()
160
+
161
  @classmethod
162
  def validate_config(cls) -> bool:
163
+ """Validate configuration by testing connections."""
164
  try:
165
+ llm = cls.get_llm(with_cost_tracking=False)
 
 
166
  embeddings = cls.get_embeddings()
167
  return True
168
  except Exception as e:
169
  print(f"Configuration validation failed: {e}")
170
+ return False
171
+
172
+ @classmethod
173
+ def reset_instances(cls) -> None:
174
+ """Reset all cached instances."""
175
+ cls._llm_instance = None
176
+ cls._embeddings_instance = None
177
+ if cls._cost_tracker:
178
+ cls._cost_tracker.reset()
179
+ LLMFactory.clear_cache()
src/agents/conversation_manager.py DELETED
@@ -1,275 +0,0 @@
1
- import logging
2
- import uuid
3
- from typing import Dict, List, Optional, Any
4
- from datetime import datetime, timedelta
5
- from dataclasses import dataclass, asdict
6
- import json
7
-
8
- from src.models.chatMessage import ConversationState, ChatMessage, MessageRole, AgentType
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- @dataclass
14
- class ConversationMetadata:
15
- """Metadata for conversation tracking"""
16
- conversation_id: str
17
- user_id: str
18
- created_at: datetime
19
- updated_at: datetime
20
- message_count: int
21
- current_agent: Optional[str]
22
- is_active: bool
23
- context_summary: Dict[str, Any]
24
-
25
-
26
- class ConversationManager:
27
- """Manages conversation state and persistence for multi-agent system"""
28
-
29
- def __init__(self):
30
- self.conversations: Dict[str, ConversationState] = {}
31
- self.metadata: Dict[str, ConversationMetadata] = {}
32
- self.user_conversations: Dict[str, List[str]] = {}
33
-
34
- def create_conversation(self, user_id: str, conversation_id: Optional[str] = None) -> str:
35
- """Create a new conversation"""
36
- if not conversation_id:
37
- conversation_id = str(uuid.uuid4())
38
-
39
- key = f"{user_id}:{conversation_id}"
40
-
41
- # Create conversation state
42
- conversation_state = ConversationState(
43
- conversation_id=conversation_id,
44
- user_id=user_id,
45
- messages=[],
46
- context={},
47
- memory={},
48
- agent_history=[],
49
- current_agent=None,
50
- last_message_id=None,
51
- created_at=datetime.utcnow(),
52
- updated_at=datetime.utcnow(),
53
- is_active=True
54
- )
55
-
56
- # Create metadata
57
- metadata = ConversationMetadata(
58
- conversation_id=conversation_id,
59
- user_id=user_id,
60
- created_at=datetime.utcnow(),
61
- updated_at=datetime.utcnow(),
62
- message_count=0,
63
- current_agent=None,
64
- is_active=True,
65
- context_summary={}
66
- )
67
-
68
- # Store conversation
69
- self.conversations[key] = conversation_state
70
- self.metadata[key] = metadata
71
-
72
- # Update user conversations
73
- if user_id not in self.user_conversations:
74
- self.user_conversations[user_id] = []
75
- self.user_conversations[user_id].append(conversation_id)
76
-
77
- logger.info(f"Created conversation {conversation_id} for user {user_id}")
78
- return conversation_id
79
-
80
- def get_conversation(self, conversation_id: str, user_id: str) -> Optional[ConversationState]:
81
- """Get conversation by ID and user"""
82
- key = f"{user_id}:{conversation_id}"
83
- return self.conversations.get(key)
84
-
85
- def get_or_create_conversation(self, conversation_id: str, user_id: str) -> ConversationState:
86
- """Get existing conversation or create new one"""
87
- conversation = self.get_conversation(conversation_id, user_id)
88
- if not conversation:
89
- self.create_conversation(user_id, conversation_id)
90
- conversation = self.get_conversation(conversation_id, user_id)
91
- return conversation
92
-
93
- def add_message(self, conversation_id: str, user_id: str, message: ChatMessage) -> None:
94
- """Add message to conversation"""
95
- conversation = self.get_or_create_conversation(conversation_id, user_id)
96
- key = f"{user_id}:{conversation_id}"
97
-
98
- # Add message
99
- conversation.messages.append(message)
100
- conversation.last_message_id = message.message_id
101
- conversation.updated_at = datetime.utcnow()
102
-
103
- # Update metadata
104
- if key in self.metadata:
105
- self.metadata[key].message_count = len(conversation.messages)
106
- self.metadata[key].updated_at = datetime.utcnow()
107
- self.metadata[key].current_agent = conversation.current_agent
108
-
109
- logger.info(f"Added message to conversation {conversation_id}")
110
-
111
- def update_conversation_context(self, conversation_id: str, user_id: str, context_updates: Dict[str, Any]) -> None:
112
- """Update conversation context"""
113
- conversation = self.get_conversation(conversation_id, user_id)
114
- if conversation:
115
- conversation.context.update(context_updates)
116
- conversation.updated_at = datetime.utcnow()
117
-
118
- # Update metadata
119
- key = f"{user_id}:{conversation_id}"
120
- if key in self.metadata:
121
- self.metadata[key].context_summary.update(context_updates)
122
- self.metadata[key].updated_at = datetime.utcnow()
123
-
124
- def update_agent_history(self, conversation_id: str, user_id: str, agent_info: Dict[str, Any]) -> None:
125
- """Update agent interaction history"""
126
- conversation = self.get_conversation(conversation_id, user_id)
127
- if conversation:
128
- conversation.agent_history.append(agent_info)
129
- conversation.updated_at = datetime.utcnow()
130
-
131
- def get_conversation_messages(self, conversation_id: str, user_id: str, limit: Optional[int] = None) -> List[ChatMessage]:
132
- """Get messages from conversation"""
133
- conversation = self.get_conversation(conversation_id, user_id)
134
- if not conversation:
135
- return []
136
-
137
- messages = conversation.messages
138
- if limit:
139
- messages = messages[-limit:]
140
-
141
- return messages
142
-
143
- def get_user_conversations(self, user_id: str) -> List[Dict[str, Any]]:
144
- """Get all conversations for a user"""
145
- user_conversations = []
146
-
147
- for conversation_id in self.user_conversations.get(user_id, []):
148
- key = f"{user_id}:{conversation_id}"
149
- metadata = self.metadata.get(key)
150
-
151
- if metadata:
152
- user_conversations.append(asdict(metadata))
153
-
154
- return user_conversations
155
-
156
- def delete_conversation(self, conversation_id: str, user_id: str) -> bool:
157
- """Delete a conversation"""
158
- key = f"{user_id}:{conversation_id}"
159
-
160
- if key in self.conversations:
161
- del self.conversations[key]
162
-
163
- if key in self.metadata:
164
- del self.metadata[key]
165
-
166
- # Remove from user conversations
167
- if user_id in self.user_conversations:
168
- if conversation_id in self.user_conversations[user_id]:
169
- self.user_conversations[user_id].remove(conversation_id)
170
-
171
- logger.info(f"Deleted conversation {conversation_id} for user {user_id}")
172
- return True
173
-
174
- def reset_conversation(self, conversation_id: str, user_id: str) -> None:
175
- """Reset conversation (clear messages but keep conversation)"""
176
- conversation = self.get_conversation(conversation_id, user_id)
177
- if conversation:
178
- conversation.messages = []
179
- conversation.context = {}
180
- conversation.agent_history = []
181
- conversation.current_agent = None
182
- conversation.last_message_id = None
183
- conversation.updated_at = datetime.utcnow()
184
-
185
- # Update metadata
186
- key = f"{user_id}:{conversation_id}"
187
- if key in self.metadata:
188
- self.metadata[key].message_count = 0
189
- self.metadata[key].current_agent = None
190
- self.metadata[key].context_summary = {}
191
- self.metadata[key].updated_at = datetime.utcnow()
192
-
193
- def cleanup_old_conversations(self, max_age_hours: int = 24) -> int:
194
- """Clean up old conversations"""
195
- cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
196
- deleted_count = 0
197
-
198
- conversations_to_delete = []
199
-
200
- for key, metadata in self.metadata.items():
201
- if metadata.updated_at < cutoff_time and not metadata.is_active:
202
- conversations_to_delete.append(key)
203
-
204
- for key in conversations_to_delete:
205
- user_id, conversation_id = key.split(":", 1)
206
- if self.delete_conversation(conversation_id, user_id):
207
- deleted_count += 1
208
-
209
- logger.info(f"Cleaned up {deleted_count} old conversations")
210
- return deleted_count
211
-
212
- def get_conversation_stats(self, user_id: str) -> Dict[str, Any]:
213
- """Get conversation statistics for a user"""
214
- user_conversations = self.get_user_conversations(user_id)
215
-
216
- total_conversations = len(user_conversations)
217
- active_conversations = sum(1 for conv in user_conversations if conv["is_active"])
218
- total_messages = sum(conv["message_count"] for conv in user_conversations)
219
-
220
- # Agent usage statistics
221
- agent_usage = {}
222
- for conv in user_conversations:
223
- conversation = self.get_conversation(conv["conversation_id"], user_id)
224
- if conversation:
225
- for agent_info in conversation.agent_history:
226
- agent_name = agent_info.get("agent", "unknown")
227
- agent_usage[agent_name] = agent_usage.get(agent_name, 0) + 1
228
-
229
- return {
230
- "total_conversations": total_conversations,
231
- "active_conversations": active_conversations,
232
- "total_messages": total_messages,
233
- "agent_usage": agent_usage,
234
- "average_messages_per_conversation": total_messages / total_conversations if total_conversations > 0 else 0
235
- }
236
-
237
- def export_conversation(self, conversation_id: str, user_id: str) -> Dict[str, Any]:
238
- """Export conversation data"""
239
- conversation = self.get_conversation(conversation_id, user_id)
240
- if not conversation:
241
- return {}
242
-
243
- return {
244
- "conversation_id": conversation_id,
245
- "user_id": user_id,
246
- "messages": [msg.dict() for msg in conversation.messages],
247
- "context": conversation.context,
248
- "agent_history": conversation.agent_history,
249
- "metadata": asdict(self.metadata.get(f"{user_id}:{conversation_id}", {}))
250
- }
251
-
252
- def import_conversation(self, conversation_data: Dict[str, Any]) -> str:
253
- """Import conversation data"""
254
- conversation_id = conversation_data.get("conversation_id", str(uuid.uuid4()))
255
- user_id = conversation_data.get("user_id", "anonymous")
256
-
257
- # Create conversation
258
- self.create_conversation(user_id, conversation_id)
259
-
260
- # Import messages
261
- for msg_data in conversation_data.get("messages", []):
262
- message = ChatMessage(**msg_data)
263
- self.add_message(conversation_id, user_id, message)
264
-
265
- # Import context and history
266
- conversation = self.get_conversation(conversation_id, user_id)
267
- if conversation:
268
- conversation.context.update(conversation_data.get("context", {}))
269
- conversation.agent_history.extend(conversation_data.get("agent_history", []))
270
-
271
- return conversation_id
272
-
273
-
274
- # Global conversation manager instance
275
- conversation_manager = ConversationManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agents/supervisor/agent.py CHANGED
@@ -1,4 +1,3 @@
1
- from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
2
  from langgraph_supervisor import create_supervisor
3
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
4
  from src.agents.config import Config
@@ -29,16 +28,8 @@ from src.agents.staking.prompt import STAKING_AGENT_SYSTEM_PROMPT
29
  from src.agents.search.agent import SearchAgent
30
  from src.agents.database.client import is_database_available
31
 
32
- llm = ChatGoogleGenerativeAI(
33
- model=Config.GEMINI_MODEL,
34
- temperature=0.7,
35
- google_api_key=Config.GEMINI_API_KEY
36
- )
37
-
38
- embeddings = GoogleGenerativeAIEmbeddings(
39
- model=Config.GEMINI_EMBEDDING_MODEL,
40
- google_api_key=Config.GEMINI_API_KEY
41
- )
42
 
43
 
44
  class ChatMessage(TypedDict):
@@ -50,7 +41,7 @@ class Supervisor:
50
  def __init__(self, llm):
51
  self.llm = llm
52
 
53
- cryptoDataAgentClass = CryptoDataAgent(llm)
54
  cryptoDataAgent = cryptoDataAgentClass.agent
55
 
56
  agents = [cryptoDataAgent]
@@ -60,7 +51,7 @@ class Supervisor:
60
 
61
  # Conditionally include database agent
62
  if is_database_available():
63
- databaseAgent = DatabaseAgent(llm)
64
  agents.append(databaseAgent)
65
  available_agents_text += (
66
  "- database_agent: Handles database queries and data analysis. Can search and analyze data from the database.\n"
@@ -68,42 +59,42 @@ class Supervisor:
68
  else:
69
  databaseAgent = None
70
 
71
- swapAgent = SwapAgent(llm)
72
  self.swap_agent = swapAgent.agent
73
  agents.append(self.swap_agent)
74
  available_agents_text += (
75
  "- swap_agent: Handles swap operations on the Avalanche network and any other swap question related.\n"
76
  )
77
 
78
- dcaAgent = DcaAgent(llm)
79
  self.dca_agent = dcaAgent.agent
80
  agents.append(self.dca_agent)
81
  available_agents_text += (
82
  "- dca_agent: Plans DCA swap workflows, consulting strategy docs, validating parameters, and confirming automation metadata.\n"
83
  )
84
 
85
- lendingAgent = LendingAgent(llm)
86
  self.lending_agent = lendingAgent.agent
87
  agents.append(self.lending_agent)
88
  available_agents_text += (
89
  "- lending_agent: Handles lending operations (supply, borrow, repay, withdraw) on DeFi protocols like Aave.\n"
90
  )
91
 
92
- stakingAgent = StakingAgent(llm)
93
  self.staking_agent = stakingAgent.agent
94
  agents.append(self.staking_agent)
95
  available_agents_text += (
96
  "- staking_agent: Handles staking operations (stake ETH, unstake stETH) via Lido on Ethereum.\n"
97
  )
98
 
99
- searchAgent = SearchAgent(llm)
100
  self.search_agent = searchAgent.agent
101
  agents.append(self.search_agent)
102
  available_agents_text += (
103
  "- search_agent: Uses web search tools for current events and factual lookups.\n"
104
  )
105
 
106
- defaultAgent = DefaultAgent(llm)
107
  self.default_agent = defaultAgent.agent
108
  agents.append(self.default_agent)
109
 
@@ -252,7 +243,7 @@ Examples of general queries to handle directly:
252
 
253
  self.supervisor = create_supervisor(
254
  agents,
255
- model=llm,
256
  prompt=system_prompt,
257
  output_mode="last_message"
258
  )
 
 
1
  from langgraph_supervisor import create_supervisor
2
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
3
  from src.agents.config import Config
 
28
  from src.agents.search.agent import SearchAgent
29
  from src.agents.database.client import is_database_available
30
 
31
+ # Embeddings singleton
32
+ embeddings = Config.get_embeddings()
 
 
 
 
 
 
 
 
33
 
34
 
35
  class ChatMessage(TypedDict):
 
41
  def __init__(self, llm):
42
  self.llm = llm
43
 
44
+ cryptoDataAgentClass = CryptoDataAgent(self.llm)
45
  cryptoDataAgent = cryptoDataAgentClass.agent
46
 
47
  agents = [cryptoDataAgent]
 
51
 
52
  # Conditionally include database agent
53
  if is_database_available():
54
+ databaseAgent = DatabaseAgent(self.llm)
55
  agents.append(databaseAgent)
56
  available_agents_text += (
57
  "- database_agent: Handles database queries and data analysis. Can search and analyze data from the database.\n"
 
59
  else:
60
  databaseAgent = None
61
 
62
+ swapAgent = SwapAgent(self.llm)
63
  self.swap_agent = swapAgent.agent
64
  agents.append(self.swap_agent)
65
  available_agents_text += (
66
  "- swap_agent: Handles swap operations on the Avalanche network and any other swap question related.\n"
67
  )
68
 
69
+ dcaAgent = DcaAgent(self.llm)
70
  self.dca_agent = dcaAgent.agent
71
  agents.append(self.dca_agent)
72
  available_agents_text += (
73
  "- dca_agent: Plans DCA swap workflows, consulting strategy docs, validating parameters, and confirming automation metadata.\n"
74
  )
75
 
76
+ lendingAgent = LendingAgent(self.llm)
77
  self.lending_agent = lendingAgent.agent
78
  agents.append(self.lending_agent)
79
  available_agents_text += (
80
  "- lending_agent: Handles lending operations (supply, borrow, repay, withdraw) on DeFi protocols like Aave.\n"
81
  )
82
 
83
+ stakingAgent = StakingAgent(self.llm)
84
  self.staking_agent = stakingAgent.agent
85
  agents.append(self.staking_agent)
86
  available_agents_text += (
87
  "- staking_agent: Handles staking operations (stake ETH, unstake stETH) via Lido on Ethereum.\n"
88
  )
89
 
90
+ searchAgent = SearchAgent(self.llm)
91
  self.search_agent = searchAgent.agent
92
  agents.append(self.search_agent)
93
  available_agents_text += (
94
  "- search_agent: Uses web search tools for current events and factual lookups.\n"
95
  )
96
 
97
+ defaultAgent = DefaultAgent(self.llm)
98
  self.default_agent = defaultAgent.agent
99
  agents.append(self.default_agent)
100
 
 
243
 
244
  self.supervisor = create_supervisor(
245
  agents,
246
+ model=self.llm,
247
  prompt=system_prompt,
248
  output_mode="last_message"
249
  )
src/app.py CHANGED
@@ -1,16 +1,15 @@
1
- import logging
2
- logging.basicConfig(
3
- level=logging.DEBUG,
4
- format="%(asctime)s %(levelname)s %(name)s: %(message)s",
5
- handlers=[logging.StreamHandler()]
6
- )
7
- logging.info("Test log from app.py startup")
8
- from fastapi import FastAPI, HTTPException, Request
9
  from fastapi.middleware.cors import CORSMiddleware
 
10
  from pydantic import BaseModel
11
- from typing import List
12
- import re
13
 
 
 
14
  from src.agents.config import Config
15
  from src.agents.supervisor.agent import Supervisor
16
  from src.models.chatMessage import ChatMessage
@@ -19,8 +18,23 @@ from src.service.chat_manager import chat_manager_instance
19
  from src.agents.crypto_data.tools import get_coingecko_id, get_tradingview_symbol
20
  from src.agents.metadata import metadata
21
 
 
 
 
 
 
 
 
 
22
  # Initialize FastAPI app
23
- app = FastAPI(title="Zico Agent API", version="1.0")
 
 
 
 
 
 
 
24
 
25
  # Enable CORS for local/frontend dev
26
  app.add_middleware(
@@ -31,9 +45,8 @@ app.add_middleware(
31
  allow_headers=["*"],
32
  )
33
 
34
- # Instantiate Supervisor agent (singleton LLM)
35
- supervisor = Supervisor(Config.get_llm())
36
- logger = logging.getLogger(__name__)
37
 
38
  class ChatRequest(BaseModel):
39
  message: ChatMessage
@@ -156,6 +169,54 @@ def _resolve_identity(request: ChatRequest) -> tuple[str, str]:
156
  def health_check():
157
  return {"status": "ok"}
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @app.get("/chat/messages")
160
  def get_messages(request: Request):
161
  params = request.query_params
@@ -214,13 +275,27 @@ def chat(request: ChatRequest):
214
  conversation_id=conversation_id,
215
  user_id=user_id
216
  )
217
-
 
 
 
 
218
  # Invoke the supervisor agent with the conversation
219
  result = supervisor.invoke(
220
  conversation_messages,
221
  conversation_id=conversation_id,
222
  user_id=user_id,
223
  )
 
 
 
 
 
 
 
 
 
 
224
  logger.debug(
225
  "Supervisor returned result for user=%s conversation=%s: %s",
226
  user_id,
@@ -353,5 +428,303 @@ def chat(request: ChatRequest):
353
  )
354
  raise HTTPException(status_code=500, detail=str(e))
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  # Include chat manager router
357
  app.include_router(chat_manager_router)
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
 
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from langchain_core.messages import HumanMessage
9
  from pydantic import BaseModel
 
 
10
 
11
+ from src.infrastructure.logging import setup_logging, get_logger
12
+ from src.infrastructure.rate_limiter import setup_rate_limiter, limiter
13
  from src.agents.config import Config
14
  from src.agents.supervisor.agent import Supervisor
15
  from src.models.chatMessage import ChatMessage
 
18
  from src.agents.crypto_data.tools import get_coingecko_id, get_tradingview_symbol
19
  from src.agents.metadata import metadata
20
 
21
+ # Setup structured logging
22
+ log_level = os.getenv("LOG_LEVEL", "INFO")
23
+ log_format = os.getenv("LOG_FORMAT", "color")
24
+ setup_logging(level=log_level, format_type=log_format)
25
+ logger = get_logger(__name__)
26
+
27
+ logger.info("Starting Zico Agent API")
28
+
29
  # Initialize FastAPI app
30
+ app = FastAPI(
31
+ title="Zico Agent API",
32
+ version="2.0",
33
+ description="Multi-agent AI assistant with streaming support",
34
+ )
35
+
36
+ # Setup rate limiting
37
+ setup_rate_limiter(app)
38
 
39
  # Enable CORS for local/frontend dev
40
  app.add_middleware(
 
45
  allow_headers=["*"],
46
  )
47
 
48
+ # Instantiate Supervisor agent (singleton LLM with cost tracking)
49
+ supervisor = Supervisor(Config.get_llm(with_cost_tracking=True))
 
50
 
51
  class ChatRequest(BaseModel):
52
  message: ChatMessage
 
169
  def health_check():
170
  return {"status": "ok"}
171
 
172
+
173
+ @app.get("/costs")
174
+ def get_costs():
175
+ """Get current LLM cost summary."""
176
+ cost_tracker = Config.get_cost_tracker()
177
+ return cost_tracker.get_summary()
178
+
179
+
180
+ @app.get("/costs/detailed")
181
+ def get_detailed_costs():
182
+ """Get detailed LLM cost report."""
183
+ cost_tracker = Config.get_cost_tracker()
184
+ return cost_tracker.get_detailed_report()
185
+
186
+
187
+ @app.get("/costs/conversation")
188
+ def get_conversation_costs(request: Request):
189
+ """Get accumulated LLM costs for a specific conversation."""
190
+ params = request.query_params
191
+ conversation_id = params.get("conversation_id")
192
+ user_id = params.get("user_id")
193
+
194
+ if not conversation_id or not user_id:
195
+ raise HTTPException(
196
+ status_code=400,
197
+ detail="Both 'conversation_id' and 'user_id' query parameters are required.",
198
+ )
199
+
200
+ costs = chat_manager_instance.get_conversation_costs(
201
+ conversation_id=conversation_id,
202
+ user_id=user_id,
203
+ )
204
+ return {
205
+ "conversation_id": conversation_id,
206
+ "user_id": user_id,
207
+ "costs": costs,
208
+ }
209
+
210
+
211
+ @app.get("/models")
212
+ def get_available_models():
213
+ """List available LLM models."""
214
+ return {
215
+ "models": Config.list_available_models(),
216
+ "providers": Config.list_available_providers(),
217
+ "default": Config.DEFAULT_MODEL,
218
+ }
219
+
220
  @app.get("/chat/messages")
221
  def get_messages(request: Request):
222
  params = request.query_params
 
275
  conversation_id=conversation_id,
276
  user_id=user_id
277
  )
278
+
279
+ # Take cost snapshot before invoking
280
+ cost_tracker = Config.get_cost_tracker()
281
+ cost_snapshot = cost_tracker.get_snapshot()
282
+
283
  # Invoke the supervisor agent with the conversation
284
  result = supervisor.invoke(
285
  conversation_messages,
286
  conversation_id=conversation_id,
287
  user_id=user_id,
288
  )
289
+
290
+ # Calculate and save cost delta for this request
291
+ cost_delta = cost_tracker.calculate_delta(cost_snapshot)
292
+ if cost_delta.get("cost", 0) > 0 or cost_delta.get("calls", 0) > 0:
293
+ chat_manager_instance.update_conversation_costs(
294
+ cost_delta,
295
+ conversation_id=conversation_id,
296
+ user_id=user_id,
297
+ )
298
+
299
  logger.debug(
300
  "Supervisor returned result for user=%s conversation=%s: %s",
301
  user_id,
 
428
  )
429
  raise HTTPException(status_code=500, detail=str(e))
430
 
431
+
432
+ # Supported audio MIME types
433
+ AUDIO_MIME_TYPES = {
434
+ ".mp3": "audio/mpeg",
435
+ ".wav": "audio/wav",
436
+ ".flac": "audio/flac",
437
+ ".ogg": "audio/ogg",
438
+ ".webm": "audio/webm",
439
+ ".m4a": "audio/mp4",
440
+ ".aac": "audio/aac",
441
+ }
442
+
443
+ # Max audio file size (20MB)
444
+ MAX_AUDIO_SIZE = 20 * 1024 * 1024
445
+
446
+
447
+ def _get_audio_mime_type(filename: str, content_type: str | None) -> str:
448
+ """Determine the MIME type for an audio file."""
449
+ # Try from filename extension first
450
+ if filename:
451
+ ext = os.path.splitext(filename.lower())[1]
452
+ if ext in AUDIO_MIME_TYPES:
453
+ return AUDIO_MIME_TYPES[ext]
454
+
455
+ # Fall back to content type from upload
456
+ if content_type and content_type.startswith("audio/"):
457
+ return content_type
458
+
459
+ # Default to mpeg
460
+ return "audio/mpeg"
461
+
462
+
463
+ @app.post("/chat/audio")
464
+ async def chat_audio(
465
+ audio: UploadFile = File(..., description="Audio file (mp3, wav, flac, ogg, webm, m4a)"),
466
+ user_id: str = Form(..., description="User ID"),
467
+ conversation_id: str = Form(..., description="Conversation ID"),
468
+ wallet_address: str = Form("default", description="Wallet address"),
469
+ ):
470
+ """
471
+ Process audio input through the agent pipeline.
472
+
473
+ The audio is first transcribed using Gemini, then the transcription
474
+ is passed to the supervisor agent for processing (just like text input).
475
+ """
476
+ request_user_id: str | None = user_id
477
+ request_conversation_id: str | None = conversation_id
478
+
479
+ try:
480
+ # Validate user_id
481
+ if not user_id or user_id.lower() == "anonymous":
482
+ wallet = (wallet_address or "").strip()
483
+ if wallet and wallet.lower() != "default":
484
+ request_user_id = f"wallet::{wallet.lower()}"
485
+ else:
486
+ raise HTTPException(
487
+ status_code=400,
488
+ detail="A stable 'user_id' or wallet_address is required.",
489
+ )
490
+
491
+ logger.debug(
492
+ "Received audio chat request user=%s conversation=%s filename=%s",
493
+ request_user_id,
494
+ request_conversation_id,
495
+ audio.filename,
496
+ )
497
+
498
+ # Validate file size
499
+ audio_content = await audio.read()
500
+ if len(audio_content) > MAX_AUDIO_SIZE:
501
+ raise HTTPException(
502
+ status_code=413,
503
+ detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024*1024)}MB.",
504
+ )
505
+
506
+ if len(audio_content) == 0:
507
+ raise HTTPException(
508
+ status_code=400,
509
+ detail="Audio file is empty.",
510
+ )
511
+
512
+ # Get MIME type
513
+ mime_type = _get_audio_mime_type(audio.filename or "", audio.content_type)
514
+ logger.debug("Audio MIME type: %s, size: %d bytes", mime_type, len(audio_content))
515
+
516
+ # Encode audio to base64
517
+ encoded_audio = base64.b64encode(audio_content).decode("utf-8")
518
+
519
+ # Ensure session exists
520
+ wallet = wallet_address.strip() if wallet_address else None
521
+ if wallet and wallet.lower() == "default":
522
+ wallet = None
523
+
524
+ chat_manager_instance.ensure_session(
525
+ request_user_id,
526
+ request_conversation_id,
527
+ wallet_address=wallet,
528
+ )
529
+
530
+ # Take cost snapshot before invoking
531
+ cost_tracker = Config.get_cost_tracker()
532
+ cost_snapshot = cost_tracker.get_snapshot()
533
+
534
+ # Step 1: Transcribe the audio using Gemini
535
+ transcription_message = HumanMessage(
536
+ content=[
537
+ {"type": "text", "text": "Transcribe exactly what is being said in this audio. Return ONLY the transcription, nothing else."},
538
+ {"type": "media", "data": encoded_audio, "mime_type": mime_type},
539
+ ]
540
+ )
541
+
542
+ llm = Config.get_llm(with_cost_tracking=True)
543
+ transcription_response = llm.invoke([transcription_message])
544
+
545
+ # Extract transcription text
546
+ transcribed_text = transcription_response.content
547
+ if isinstance(transcribed_text, list):
548
+ text_parts = []
549
+ for part in transcribed_text:
550
+ if isinstance(part, dict) and part.get("text"):
551
+ text_parts.append(part["text"])
552
+ elif isinstance(part, str):
553
+ text_parts.append(part)
554
+ transcribed_text = " ".join(text_parts).strip()
555
+
556
+ if not transcribed_text:
557
+ raise HTTPException(
558
+ status_code=400,
559
+ detail="Could not transcribe the audio. Please try again with a clearer recording.",
560
+ )
561
+
562
+ logger.info("Audio transcribed: %s", transcribed_text[:200])
563
+
564
+ # Step 2: Store the user message with the transcription
565
+ user_message = ChatMessage(
566
+ role="user",
567
+ content=transcribed_text,
568
+ metadata={
569
+ "source": "audio",
570
+ "audio_filename": audio.filename,
571
+ "audio_size": len(audio_content),
572
+ "audio_mime_type": mime_type,
573
+ },
574
+ )
575
+ chat_manager_instance.add_message(
576
+ message=user_message.dict(),
577
+ conversation_id=request_conversation_id,
578
+ user_id=request_user_id,
579
+ )
580
+
581
+ # Step 3: Get conversation history and invoke supervisor
582
+ conversation_messages = chat_manager_instance.get_messages(
583
+ conversation_id=request_conversation_id,
584
+ user_id=request_user_id,
585
+ )
586
+
587
+ result = supervisor.invoke(
588
+ conversation_messages,
589
+ conversation_id=request_conversation_id,
590
+ user_id=request_user_id,
591
+ )
592
+
593
+ # Calculate and save cost delta
594
+ cost_delta = cost_tracker.calculate_delta(cost_snapshot)
595
+ if cost_delta.get("cost", 0) > 0 or cost_delta.get("calls", 0) > 0:
596
+ chat_manager_instance.update_conversation_costs(
597
+ cost_delta,
598
+ conversation_id=request_conversation_id,
599
+ user_id=request_user_id,
600
+ )
601
+
602
+ logger.debug(
603
+ "Supervisor returned result for audio user=%s conversation=%s: %s",
604
+ request_user_id,
605
+ request_conversation_id,
606
+ result,
607
+ )
608
+
609
+ # Step 4: Process and store the agent response (same as /chat endpoint)
610
+ if result and isinstance(result, dict):
611
+ agent_name = result.get("agent", "supervisor")
612
+ agent_name = _map_agent_type(agent_name)
613
+
614
+ response_metadata = {"supervisor_result": result, "source": "audio"}
615
+ swap_meta_snapshot = None
616
+
617
+ if isinstance(result, dict) and result.get("metadata"):
618
+ response_metadata.update(result.get("metadata") or {})
619
+ elif agent_name == "token swap":
620
+ swap_meta = metadata.get_swap_agent(
621
+ user_id=request_user_id,
622
+ conversation_id=request_conversation_id,
623
+ )
624
+ if swap_meta:
625
+ response_metadata.update(swap_meta)
626
+ swap_meta_snapshot = swap_meta
627
+ elif agent_name == "lending":
628
+ lending_meta = metadata.get_lending_agent(
629
+ user_id=request_user_id,
630
+ conversation_id=request_conversation_id,
631
+ )
632
+ if lending_meta:
633
+ response_metadata.update(lending_meta)
634
+ elif agent_name == "staking":
635
+ staking_meta = metadata.get_staking_agent(
636
+ user_id=request_user_id,
637
+ conversation_id=request_conversation_id,
638
+ )
639
+ if staking_meta:
640
+ response_metadata.update(staking_meta)
641
+
642
+ response_message = ChatMessage(
643
+ role="assistant",
644
+ content=result.get("response", "No response available"),
645
+ agent_name=agent_name,
646
+ agent_type=_map_agent_type(agent_name),
647
+ metadata=result.get("metadata", {}),
648
+ conversation_id=request_conversation_id,
649
+ user_id=request_user_id,
650
+ requires_action=True if agent_name in ["token swap", "lending", "staking"] else False,
651
+ action_type="swap" if agent_name == "token swap" else "lending" if agent_name == "lending" else "staking" if agent_name == "staking" else None,
652
+ )
653
+
654
+ chat_manager_instance.add_message(
655
+ message=response_message.dict(),
656
+ conversation_id=request_conversation_id,
657
+ user_id=request_user_id,
658
+ )
659
+
660
+ # Build response payload
661
+ response_payload = {
662
+ "response": result.get("response", "No response available"),
663
+ "agentName": agent_name,
664
+ "transcription": transcribed_text,
665
+ }
666
+
667
+ response_meta = result.get("metadata") or {}
668
+ if agent_name == "token swap" and not response_meta:
669
+ if swap_meta_snapshot:
670
+ response_meta = swap_meta_snapshot
671
+ else:
672
+ swap_meta = metadata.get_swap_agent(
673
+ user_id=request_user_id,
674
+ conversation_id=request_conversation_id,
675
+ )
676
+ if swap_meta:
677
+ response_meta = swap_meta
678
+
679
+ if response_meta:
680
+ response_payload["metadata"] = response_meta
681
+
682
+ # Clear metadata after ready events (same as /chat)
683
+ if agent_name == "token swap":
684
+ should_clear = False
685
+ if response_meta:
686
+ status = response_meta.get("status") if isinstance(response_meta, dict) else None
687
+ event = response_meta.get("event") if isinstance(response_meta, dict) else None
688
+ should_clear = status == "ready" or event == "swap_intent_ready"
689
+ if should_clear:
690
+ metadata.set_swap_agent({}, user_id=request_user_id, conversation_id=request_conversation_id)
691
+
692
+ if agent_name == "lending":
693
+ should_clear = False
694
+ if response_meta:
695
+ status = response_meta.get("status") if isinstance(response_meta, dict) else None
696
+ event = response_meta.get("event") if isinstance(response_meta, dict) else None
697
+ should_clear = status == "ready" or event == "lending_intent_ready"
698
+ if should_clear:
699
+ metadata.set_lending_agent({}, user_id=request_user_id, conversation_id=request_conversation_id)
700
+
701
+ if agent_name == "staking":
702
+ should_clear = False
703
+ if response_meta:
704
+ status = response_meta.get("status") if isinstance(response_meta, dict) else None
705
+ event = response_meta.get("event") if isinstance(response_meta, dict) else None
706
+ should_clear = status == "ready" or event == "staking_intent_ready"
707
+ if should_clear:
708
+ metadata.set_staking_agent({}, user_id=request_user_id, conversation_id=request_conversation_id)
709
+
710
+ return response_payload
711
+
712
+ return {
713
+ "response": "No response available",
714
+ "agentName": "supervisor",
715
+ "transcription": transcribed_text,
716
+ }
717
+
718
+ except HTTPException:
719
+ raise
720
+ except Exception as e:
721
+ logger.exception(
722
+ "Audio chat handler failed for user=%s conversation=%s",
723
+ request_user_id,
724
+ request_conversation_id,
725
+ )
726
+ raise HTTPException(status_code=500, detail=str(e))
727
+
728
+
729
  # Include chat manager router
730
  app.include_router(chat_manager_router)
src/infrastructure/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Infrastructure Module - Cross-cutting concerns.
3
+
4
+ This module provides:
5
+ - Logging: Structured logging with color support
6
+ - Rate Limiting: API rate limiting with SlowAPI
7
+ - Metrics: Prometheus metrics for observability
8
+ - Retry: Retry utilities with exponential backoff
9
+ """
10
+
11
+ from .logging import setup_logging, get_logger
12
+ from .rate_limiter import limiter, setup_rate_limiter, limit_chat, limit_stream
13
+ from .retry import execute_with_retry, RetryConfig
14
+
15
+ __all__ = [
16
+ # Logging
17
+ "setup_logging",
18
+ "get_logger",
19
+ # Rate limiting
20
+ "limiter",
21
+ "setup_rate_limiter",
22
+ "limit_chat",
23
+ "limit_stream",
24
+ # Retry
25
+ "execute_with_retry",
26
+ "RetryConfig",
27
+ ]
src/infrastructure/logging.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Structured Logging Configuration.
3
+
4
+ Supports:
5
+ - Color output for development (colorlog)
6
+ - JSON output for production (structlog)
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ from typing import Literal
13
+
14
+ LogFormat = Literal["color", "json"]
15
+
16
+
17
+ def setup_logging(
18
+ level: int | str = logging.INFO,
19
+ format_type: LogFormat | None = None,
20
+ json_indent: int | None = None,
21
+ ) -> logging.Logger:
22
+ """
23
+ Configure application logging.
24
+
25
+ Args:
26
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
27
+ format_type: Output format ("color" for dev, "json" for prod)
28
+ If None, reads from LOG_FORMAT env var (defaults to "color")
29
+ json_indent: Indentation for JSON output (None for compact)
30
+
31
+ Returns:
32
+ Configured root logger
33
+ """
34
+ # Determine format from env if not specified
35
+ if format_type is None:
36
+ format_type = os.getenv("LOG_FORMAT", "color").lower()
37
+ if format_type not in ("color", "json"):
38
+ format_type = "color"
39
+
40
+ # Parse level if string
41
+ if isinstance(level, str):
42
+ level = getattr(logging, level.upper(), logging.INFO)
43
+
44
+ # Get root logger
45
+ root_logger = logging.getLogger()
46
+ root_logger.setLevel(level)
47
+
48
+ # Remove existing handlers
49
+ for handler in root_logger.handlers[:]:
50
+ root_logger.removeHandler(handler)
51
+
52
+ # Create handler
53
+ handler = logging.StreamHandler(sys.stdout)
54
+ handler.setLevel(level)
55
+
56
+ if format_type == "color":
57
+ formatter = _create_color_formatter()
58
+ else:
59
+ formatter = _create_json_formatter(json_indent)
60
+
61
+ handler.setFormatter(formatter)
62
+ root_logger.addHandler(handler)
63
+
64
+ # Reduce noise from third-party libraries
65
+ logging.getLogger("httpx").setLevel(logging.WARNING)
66
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
67
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
68
+ logging.getLogger("langchain").setLevel(logging.WARNING)
69
+ logging.getLogger("langsmith").setLevel(logging.WARNING)
70
+
71
+ return root_logger
72
+
73
+
74
+ def _create_color_formatter() -> logging.Formatter:
75
+ """Create colorized formatter for development."""
76
+ try:
77
+ import colorlog
78
+
79
+ return colorlog.ColoredFormatter(
80
+ fmt="%(log_color)s%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
81
+ datefmt="%Y-%m-%d %H:%M:%S",
82
+ log_colors={
83
+ "DEBUG": "cyan",
84
+ "INFO": "green",
85
+ "WARNING": "yellow",
86
+ "ERROR": "red",
87
+ "CRITICAL": "bold_red",
88
+ },
89
+ secondary_log_colors={},
90
+ style="%",
91
+ )
92
+ except ImportError:
93
+ # Fallback if colorlog not installed
94
+ return logging.Formatter(
95
+ fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
96
+ datefmt="%Y-%m-%d %H:%M:%S",
97
+ )
98
+
99
+
100
+ def _create_json_formatter(indent: int | None = None) -> logging.Formatter:
101
+ """Create JSON formatter for production."""
102
+ try:
103
+ import structlog
104
+
105
+ # Configure structlog
106
+ structlog.configure(
107
+ processors=[
108
+ structlog.stdlib.filter_by_level,
109
+ structlog.stdlib.add_logger_name,
110
+ structlog.stdlib.add_log_level,
111
+ structlog.stdlib.PositionalArgumentsFormatter(),
112
+ structlog.processors.TimeStamper(fmt="iso"),
113
+ structlog.processors.StackInfoRenderer(),
114
+ structlog.processors.format_exc_info,
115
+ structlog.processors.UnicodeDecoder(),
116
+ structlog.processors.JSONRenderer(indent=indent),
117
+ ],
118
+ wrapper_class=structlog.stdlib.BoundLogger,
119
+ context_class=dict,
120
+ logger_factory=structlog.stdlib.LoggerFactory(),
121
+ cache_logger_on_first_use=True,
122
+ )
123
+
124
+ # Return a simple formatter since structlog handles formatting
125
+ return logging.Formatter("%(message)s")
126
+
127
+ except ImportError:
128
+ # Fallback JSON formatter
129
+ import json
130
+
131
+ class JsonFormatter(logging.Formatter):
132
+ def format(self, record: logging.LogRecord) -> str:
133
+ log_data = {
134
+ "timestamp": self.formatTime(record, "%Y-%m-%dT%H:%M:%S"),
135
+ "level": record.levelname,
136
+ "logger": record.name,
137
+ "message": record.getMessage(),
138
+ }
139
+ if record.exc_info:
140
+ log_data["exception"] = self.formatException(record.exc_info)
141
+ return json.dumps(log_data)
142
+
143
+ return JsonFormatter()
144
+
145
+
146
+ def get_logger(name: str) -> logging.Logger:
147
+ """
148
+ Get a logger with the specified name.
149
+
150
+ Args:
151
+ name: Logger name (typically __name__)
152
+
153
+ Returns:
154
+ Logger instance
155
+ """
156
+ return logging.getLogger(name)
157
+
158
+
159
+ class LoggerMixin:
160
+ """Mixin class to add logging capability to any class."""
161
+
162
+ @property
163
+ def logger(self) -> logging.Logger:
164
+ """Get logger for this class."""
165
+ if not hasattr(self, "_logger"):
166
+ self._logger = logging.getLogger(
167
+ f"{self.__class__.__module__}.{self.__class__.__name__}"
168
+ )
169
+ return self._logger
src/infrastructure/rate_limiter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rate Limiting Configuration using SlowAPI.
3
+
4
+ Provides rate limiting for FastAPI endpoints to prevent abuse.
5
+ """
6
+
7
+ import os
8
+ from typing import Callable
9
+
10
+ from fastapi import FastAPI, Request, Response
11
+ from slowapi import Limiter, _rate_limit_exceeded_handler
12
+ from slowapi.errors import RateLimitExceeded
13
+ from slowapi.util import get_remote_address
14
+
15
+
16
+ def _get_identifier(request: Request) -> str:
17
+ """
18
+ Get identifier for rate limiting.
19
+
20
+ Uses X-Forwarded-For header if behind a proxy,
21
+ otherwise falls back to remote address.
22
+ Also considers user_id from query params or body if available.
23
+ """
24
+ # Try to get user_id for more granular limiting
25
+ user_id = request.query_params.get("user_id")
26
+ if user_id and user_id != "anonymous":
27
+ return f"user:{user_id}"
28
+
29
+ # Fall back to IP-based limiting
30
+ forwarded = request.headers.get("X-Forwarded-For")
31
+ if forwarded:
32
+ return forwarded.split(",")[0].strip()
33
+
34
+ return get_remote_address(request)
35
+
36
+
37
+ # Create global limiter instance
38
+ limiter = Limiter(
39
+ key_func=_get_identifier,
40
+ default_limits=[os.getenv("RATE_LIMIT_DEFAULT", "100/minute")],
41
+ )
42
+
43
+
44
+ def setup_rate_limiter(app: FastAPI) -> None:
45
+ """
46
+ Configure rate limiting on a FastAPI application.
47
+
48
+ Args:
49
+ app: FastAPI application instance
50
+ """
51
+ app.state.limiter = limiter
52
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
53
+
54
+
55
+ def limit_chat(func: Callable) -> Callable:
56
+ """
57
+ Rate limit decorator for chat endpoints.
58
+
59
+ Default: 30 requests per minute.
60
+ """
61
+ limit = os.getenv("RATE_LIMIT_CHAT", "30/minute")
62
+ return limiter.limit(limit)(func)
63
+
64
+
65
+ def limit_stream(func: Callable) -> Callable:
66
+ """
67
+ Rate limit decorator for streaming endpoints.
68
+
69
+ Default: 10 requests per minute (streaming is more resource-intensive).
70
+ """
71
+ limit = os.getenv("RATE_LIMIT_STREAM", "10/minute")
72
+ return limiter.limit(limit)(func)
73
+
74
+
75
+ def limit_health(func: Callable) -> Callable:
76
+ """
77
+ Rate limit decorator for health check endpoints.
78
+
79
+ Default: 100 requests per minute.
80
+ """
81
+ limit = os.getenv("RATE_LIMIT_HEALTH", "100/minute")
82
+ return limiter.limit(limit)(func)
83
+
84
+
85
+ def limit_custom(limit_string: str) -> Callable:
86
+ """
87
+ Create a custom rate limit decorator.
88
+
89
+ Args:
90
+ limit_string: Rate limit string (e.g., "10/minute", "100/hour")
91
+
92
+ Returns:
93
+ Decorator function
94
+ """
95
+ return limiter.limit(limit_string)
src/infrastructure/retry.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retry utilities with exponential backoff.
3
+
4
+ Provides retry logic for unreliable operations like LLM calls.
5
+ """
6
+
7
+ import asyncio
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from functools import wraps
11
+ from typing import Any, Callable, TypeVar, ParamSpec
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ T = TypeVar("T")
16
+ P = ParamSpec("P")
17
+
18
+
19
+ @dataclass
20
+ class RetryConfig:
21
+ """Configuration for retry behavior."""
22
+
23
+ max_retries: int = 3
24
+ base_delay: float = 1.0 # Base delay in seconds
25
+ max_delay: float = 30.0 # Maximum delay in seconds
26
+ exponential_base: float = 2.0 # Exponential backoff base
27
+ retryable_exceptions: tuple = field(
28
+ default_factory=lambda: (TimeoutError, ConnectionError, Exception)
29
+ )
30
+
31
+
32
+ # Default configuration
33
+ DEFAULT_RETRY_CONFIG = RetryConfig()
34
+
35
+
36
+ async def execute_with_retry(
37
+ func: Callable[P, T],
38
+ *args: P.args,
39
+ config: RetryConfig | None = None,
40
+ fallback_response: T | None = None,
41
+ on_retry: Callable[[int, Exception], None] | None = None,
42
+ **kwargs: P.kwargs,
43
+ ) -> T:
44
+ """
45
+ Execute a function with retry logic and exponential backoff.
46
+
47
+ Args:
48
+ func: Function to execute (can be sync or async)
49
+ *args: Positional arguments for the function
50
+ config: Retry configuration (uses defaults if None)
51
+ fallback_response: Value to return if all retries fail (if None, raises exception)
52
+ on_retry: Optional callback called on each retry (receives attempt number and exception)
53
+ **kwargs: Keyword arguments for the function
54
+
55
+ Returns:
56
+ The function result or fallback_response
57
+
58
+ Raises:
59
+ The last exception if all retries fail and no fallback is provided
60
+ """
61
+ config = config or DEFAULT_RETRY_CONFIG
62
+ last_exception: Exception | None = None
63
+
64
+ for attempt in range(config.max_retries):
65
+ try:
66
+ if asyncio.iscoroutinefunction(func):
67
+ return await func(*args, **kwargs)
68
+ else:
69
+ return func(*args, **kwargs)
70
+
71
+ except config.retryable_exceptions as e:
72
+ last_exception = e
73
+ is_last_attempt = attempt >= config.max_retries - 1
74
+
75
+ if is_last_attempt:
76
+ logger.error(
77
+ f"All {config.max_retries} attempts failed for {func.__name__}. "
78
+ f"Last error: {e}"
79
+ )
80
+ else:
81
+ # Calculate delay with exponential backoff
82
+ delay = min(
83
+ config.base_delay * (config.exponential_base**attempt),
84
+ config.max_delay,
85
+ )
86
+
87
+ logger.warning(
88
+ f"Attempt {attempt + 1}/{config.max_retries} failed for {func.__name__}: {e}. "
89
+ f"Retrying in {delay:.1f}s..."
90
+ )
91
+
92
+ # Call retry callback if provided
93
+ if on_retry:
94
+ on_retry(attempt + 1, e)
95
+
96
+ await asyncio.sleep(delay)
97
+
98
+ # All retries exhausted
99
+ if fallback_response is not None:
100
+ logger.info(f"Using fallback response for {func.__name__}")
101
+ return fallback_response
102
+
103
+ if last_exception:
104
+ raise last_exception
105
+
106
+ raise RuntimeError(f"Unexpected state in retry logic for {func.__name__}")
107
+
108
+
109
+ def with_retry(
110
+ config: RetryConfig | None = None,
111
+ fallback_response: Any = None,
112
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
113
+ """
114
+ Decorator to add retry logic to a function.
115
+
116
+ Args:
117
+ config: Retry configuration
118
+ fallback_response: Value to return if all retries fail
119
+
120
+ Returns:
121
+ Decorated function with retry logic
122
+ """
123
+ config = config or DEFAULT_RETRY_CONFIG
124
+
125
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
126
+ @wraps(func)
127
+ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
128
+ return await execute_with_retry(
129
+ func,
130
+ *args,
131
+ config=config,
132
+ fallback_response=fallback_response,
133
+ **kwargs,
134
+ )
135
+
136
+ @wraps(func)
137
+ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
138
+ return asyncio.run(
139
+ execute_with_retry(
140
+ func,
141
+ *args,
142
+ config=config,
143
+ fallback_response=fallback_response,
144
+ **kwargs,
145
+ )
146
+ )
147
+
148
+ if asyncio.iscoroutinefunction(func):
149
+ return async_wrapper
150
+ return sync_wrapper
151
+
152
+ return decorator
153
+
154
+
155
+ class RetryableMixin:
156
+ """
157
+ Mixin class that adds retry capability to any class.
158
+
159
+ Usage:
160
+ class MyAgent(RetryableMixin):
161
+ async def call_llm(self, prompt):
162
+ return await self.with_retry(
163
+ self._do_call_llm,
164
+ prompt,
165
+ fallback_response="Sorry, I couldn't process that."
166
+ )
167
+ """
168
+
169
+ _retry_config: RetryConfig = DEFAULT_RETRY_CONFIG
170
+
171
+ async def with_retry(
172
+ self,
173
+ func: Callable[P, T],
174
+ *args: P.args,
175
+ fallback_response: T | None = None,
176
+ **kwargs: P.kwargs,
177
+ ) -> T:
178
+ """Execute a method with retry logic."""
179
+ return await execute_with_retry(
180
+ func,
181
+ *args,
182
+ config=self._retry_config,
183
+ fallback_response=fallback_response,
184
+ **kwargs,
185
+ )
186
+
187
+ def set_retry_config(self, config: RetryConfig) -> None:
188
+ """Update retry configuration."""
189
+ self._retry_config = config
src/llm/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Module - Multi-provider LLM abstraction layer.
3
+
4
+ This module provides:
5
+ - LLMFactory: Create LLM instances for multiple providers (Google, OpenAI, Anthropic)
6
+ - CostTrackingCallback: Track token usage and costs per LLM call
7
+ """
8
+
9
+ from .factory import LLMFactory, detect_provider, MODEL_PROVIDERS
10
+ from .cost_tracker import CostTrackingCallback
11
+ from .exceptions import (
12
+ LLMError,
13
+ LLMProviderError,
14
+ LLMTimeoutError,
15
+ LLMRateLimitError,
16
+ LLMInvalidModelError,
17
+ )
18
+
19
+ __all__ = [
20
+ # Factory
21
+ "LLMFactory",
22
+ "detect_provider",
23
+ "MODEL_PROVIDERS",
24
+ # Cost tracking
25
+ "CostTrackingCallback",
26
+ # Exceptions
27
+ "LLMError",
28
+ "LLMProviderError",
29
+ "LLMTimeoutError",
30
+ "LLMRateLimitError",
31
+ "LLMInvalidModelError",
32
+ ]
src/llm/cost_tracker.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cost Tracking Callback for LLM usage monitoring.
3
+
4
+ Tracks token usage and calculates costs per LLM call.
5
+ """
6
+
7
+ import logging
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ from langchain_core.callbacks import BaseCallbackHandler
12
+ from langchain_core.outputs import LLMResult
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class CostTrackingCallback(BaseCallbackHandler):
18
+ """
19
+ LangChain callback handler for tracking LLM costs.
20
+
21
+ Tracks:
22
+ - Input/output token counts
23
+ - Cost per call and cumulative
24
+ - Model usage statistics
25
+
26
+ Usage:
27
+ callback = CostTrackingCallback()
28
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", callbacks=[callback])
29
+ response = llm.invoke("Hello!")
30
+ print(callback.get_summary())
31
+ """
32
+
33
+ # Pricing per 1M tokens (USD) - Update as needed
34
+ PRICING: dict[str, dict[str, float]] = {
35
+ # Google Gemini
36
+ "gemini-3-pro-preview": {"input": 1.50, "output": 6.00, "cache": 0.40},
37
+ "gemini-2.5-flash": {"input": 0.15, "output": 0.60, "cache": 0.02},
38
+ "gemini-2.5-pro": {"input": 1.25, "output": 5.00, "cache": 0.32},
39
+ "gemini-2.0-flash": {"input": 0.10, "output": 0.40, "cache": 0.01},
40
+ "gemini-1.5-flash": {"input": 0.075, "output": 0.30, "cache": 0.02},
41
+ "gemini-1.5-pro": {"input": 1.25, "output": 5.00, "cache": 0.32},
42
+ # OpenAI
43
+ "gpt-4o": {"input": 2.50, "output": 10.00, "cache": 1.25},
44
+ "gpt-4o-mini": {"input": 0.15, "output": 0.60, "cache": 0.08},
45
+ "gpt-4-turbo": {"input": 10.00, "output": 30.00, "cache": 5.00},
46
+ "gpt-4": {"input": 30.00, "output": 60.00, "cache": 15.00},
47
+ "gpt-3.5-turbo": {"input": 0.50, "output": 1.50, "cache": 0.25},
48
+ # Anthropic Claude
49
+ "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00, "cache": 0.30},
50
+ "claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00, "cache": 0.30},
51
+ "claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00, "cache": 0.08},
52
+ "claude-3-opus-20240229": {"input": 15.00, "output": 75.00, "cache": 1.50},
53
+ }
54
+
55
+ # Default pricing for unknown models
56
+ DEFAULT_PRICING = {"input": 1.00, "output": 3.00, "cache": 0.10}
57
+
58
+ def __init__(self, log_calls: bool = True):
59
+ """
60
+ Initialize the cost tracker.
61
+
62
+ Args:
63
+ log_calls: Whether to log each LLM call
64
+ """
65
+ super().__init__()
66
+ self.log_calls = log_calls
67
+ self.total_cost: float = 0.0
68
+ self.total_tokens: dict[str, int] = {"input": 0, "output": 0, "cache": 0}
69
+ self.calls: list[dict[str, Any]] = []
70
+ self.start_time: datetime = datetime.utcnow()
71
+
72
+ def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs) -> None:
73
+ """Called when LLM starts processing."""
74
+ pass # Could track start time per call if needed
75
+
76
+ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
77
+ """
78
+ Called when LLM finishes processing.
79
+
80
+ Calculates and records the cost of the call.
81
+ """
82
+ logger.debug(f"[COST DEBUG] on_llm_end called. llm_output: {response.llm_output}")
83
+
84
+ # Try to extract usage from multiple sources (different providers put it in different places)
85
+ input_tokens = 0
86
+ output_tokens = 0
87
+ cache_tokens = 0
88
+ model = "unknown"
89
+
90
+ # Source 1: llm_output (OpenAI, Anthropic style)
91
+ if response.llm_output:
92
+ model = self._extract_model_name(response.llm_output)
93
+ usage = response.llm_output.get("token_usage", {})
94
+ input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
95
+ output_tokens = usage.get("completion_tokens", 0) or usage.get("output_tokens", 0)
96
+ cache_tokens = usage.get("cache_read_input_tokens", 0) or usage.get("cached_tokens", 0)
97
+
98
+ # Source 2: generations metadata (Google Gemini style)
99
+ if input_tokens == 0 and output_tokens == 0 and response.generations:
100
+ for gen_list in response.generations:
101
+ for gen in gen_list:
102
+ # Check generation_info
103
+ gen_info = getattr(gen, "generation_info", {}) or {}
104
+ usage_meta = gen_info.get("usage_metadata", {})
105
+ if usage_meta:
106
+ input_tokens = usage_meta.get("input_tokens", 0)
107
+ output_tokens = usage_meta.get("output_tokens", 0)
108
+ cache_details = usage_meta.get("input_token_details", {})
109
+ cache_tokens = cache_details.get("cache_read", 0)
110
+ if model == "unknown":
111
+ model = gen_info.get("model_name", "unknown")
112
+ break
113
+
114
+ # Check message attribute (for ChatGeneration)
115
+ msg = getattr(gen, "message", None)
116
+ if msg:
117
+ msg_usage = getattr(msg, "usage_metadata", None)
118
+ if msg_usage:
119
+ input_tokens = msg_usage.get("input_tokens", 0)
120
+ output_tokens = msg_usage.get("output_tokens", 0)
121
+ cache_details = msg_usage.get("input_token_details", {})
122
+ cache_tokens = cache_details.get("cache_read", 0)
123
+ resp_meta = getattr(msg, "response_metadata", {}) or {}
124
+ if model == "unknown":
125
+ model = resp_meta.get("model_name", "unknown")
126
+ break
127
+ if input_tokens > 0 or output_tokens > 0:
128
+ break
129
+
130
+ # Skip if no usage data found
131
+ if input_tokens == 0 and output_tokens == 0:
132
+ logger.debug("[COST DEBUG] No token usage found in response, skipping cost tracking")
133
+ return
134
+
135
+ logger.debug(f"[COST DEBUG] Extracted: model={model}, input={input_tokens}, output={output_tokens}, cache={cache_tokens}")
136
+
137
+ # Calculate cost
138
+ pricing = self.PRICING.get(model, self.DEFAULT_PRICING)
139
+ input_cost = (input_tokens * pricing["input"]) / 1_000_000
140
+ output_cost = (output_tokens * pricing["output"]) / 1_000_000
141
+ cache_cost = (cache_tokens * pricing["cache"]) / 1_000_000
142
+ total_call_cost = input_cost + output_cost + cache_cost
143
+
144
+ # Update totals
145
+ self.total_cost += total_call_cost
146
+ self.total_tokens["input"] += input_tokens
147
+ self.total_tokens["output"] += output_tokens
148
+ self.total_tokens["cache"] += cache_tokens
149
+
150
+ # Record call details
151
+ call_info = {
152
+ "timestamp": datetime.utcnow().isoformat(),
153
+ "model": model,
154
+ "tokens": {
155
+ "input": input_tokens,
156
+ "output": output_tokens,
157
+ "cache": cache_tokens,
158
+ },
159
+ "cost": {
160
+ "input": input_cost,
161
+ "output": output_cost,
162
+ "cache": cache_cost,
163
+ "total": total_call_cost,
164
+ },
165
+ }
166
+ self.calls.append(call_info)
167
+
168
+ # Log if enabled
169
+ if self.log_calls:
170
+ logger.info(
171
+ f"[COST] {model} | "
172
+ f"Tokens: {input_tokens:,} in / {output_tokens:,} out"
173
+ + (f" / {cache_tokens:,} cache" if cache_tokens else "")
174
+ + f" | Cost: ${total_call_cost:.6f} | Total: ${self.total_cost:.6f}"
175
+ )
176
+
177
+ def on_llm_error(self, error: Exception, **kwargs) -> None:
178
+ """Called when LLM encounters an error."""
179
+ logger.error(f"[COST] LLM Error: {error}")
180
+
181
+ def _extract_model_name(self, llm_output: dict[str, Any]) -> str:
182
+ """Extract model name from LLM output."""
183
+ # Try common keys
184
+ for key in ["model_name", "model", "model_id"]:
185
+ if key in llm_output:
186
+ return llm_output[key]
187
+
188
+ # Check nested structure
189
+ if "model_info" in llm_output:
190
+ return llm_output["model_info"].get("model", "unknown")
191
+
192
+ return "unknown"
193
+
194
+ def get_summary(self) -> dict[str, Any]:
195
+ """
196
+ Get a summary of all tracked costs.
197
+
198
+ Returns:
199
+ Dictionary with cost summary
200
+ """
201
+ duration = (datetime.utcnow() - self.start_time).total_seconds()
202
+
203
+ return {
204
+ "total_cost": round(self.total_cost, 6),
205
+ "total_tokens": self.total_tokens.copy(),
206
+ "calls_count": len(self.calls),
207
+ "duration_seconds": round(duration, 2),
208
+ "avg_cost_per_call": round(self.total_cost / len(self.calls), 6) if self.calls else 0,
209
+ "models_used": list(set(call["model"] for call in self.calls)),
210
+ "start_time": self.start_time.isoformat(),
211
+ }
212
+
213
+ def get_detailed_report(self) -> dict[str, Any]:
214
+ """
215
+ Get a detailed report including all calls.
216
+
217
+ Returns:
218
+ Dictionary with full cost details
219
+ """
220
+ summary = self.get_summary()
221
+ summary["calls"] = self.calls
222
+ return summary
223
+
224
+ def get_cost_by_model(self) -> dict[str, dict[str, float]]:
225
+ """
226
+ Get costs aggregated by model.
227
+
228
+ Returns:
229
+ Dictionary mapping model names to their costs
230
+ """
231
+ by_model: dict[str, dict[str, float]] = {}
232
+
233
+ for call in self.calls:
234
+ model = call["model"]
235
+ if model not in by_model:
236
+ by_model[model] = {"cost": 0.0, "input_tokens": 0, "output_tokens": 0, "calls": 0}
237
+
238
+ by_model[model]["cost"] += call["cost"]["total"]
239
+ by_model[model]["input_tokens"] += call["tokens"]["input"]
240
+ by_model[model]["output_tokens"] += call["tokens"]["output"]
241
+ by_model[model]["calls"] += 1
242
+
243
+ return by_model
244
+
245
+ def reset(self) -> None:
246
+ """Reset all tracked data."""
247
+ self.total_cost = 0.0
248
+ self.total_tokens = {"input": 0, "output": 0, "cache": 0}
249
+ self.calls = []
250
+ self.start_time = datetime.utcnow()
251
+
252
+ def get_snapshot(self) -> dict[str, Any]:
253
+ """
254
+ Get a snapshot of current totals for delta calculation.
255
+
256
+ Returns:
257
+ Dictionary with current cost and token totals
258
+ """
259
+ return {
260
+ "total_cost": self.total_cost,
261
+ "total_tokens": self.total_tokens.copy(),
262
+ "calls_count": len(self.calls),
263
+ }
264
+
265
+ def calculate_delta(self, previous_snapshot: dict[str, Any]) -> dict[str, Any]:
266
+ """
267
+ Calculate the delta between current state and a previous snapshot.
268
+
269
+ Args:
270
+ previous_snapshot: Snapshot from get_snapshot()
271
+
272
+ Returns:
273
+ Dictionary with cost and token deltas for this period
274
+ """
275
+ prev_cost = previous_snapshot.get("total_cost", 0.0)
276
+ prev_tokens = previous_snapshot.get("total_tokens", {"input": 0, "output": 0, "cache": 0})
277
+ prev_calls = previous_snapshot.get("calls_count", 0)
278
+
279
+ return {
280
+ "cost": round(self.total_cost - prev_cost, 6),
281
+ "tokens": {
282
+ "input": self.total_tokens["input"] - prev_tokens.get("input", 0),
283
+ "output": self.total_tokens["output"] - prev_tokens.get("output", 0),
284
+ "cache": self.total_tokens["cache"] - prev_tokens.get("cache", 0),
285
+ },
286
+ "calls": len(self.calls) - prev_calls,
287
+ }
288
+
289
+ def __str__(self) -> str:
290
+ """String representation of current costs."""
291
+ return (
292
+ f"CostTracker: ${self.total_cost:.6f} total | "
293
+ f"{self.total_tokens['input']:,} in / {self.total_tokens['output']:,} out | "
294
+ f"{len(self.calls)} calls"
295
+ )
src/llm/exceptions.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exceptions for LLM module.
3
+ """
4
+
5
+
6
+ class LLMError(Exception):
7
+ """Base exception for LLM-related errors."""
8
+
9
+ def __init__(self, message: str, provider: str | None = None, model: str | None = None):
10
+ self.provider = provider
11
+ self.model = model
12
+ super().__init__(message)
13
+
14
+
15
+ class LLMProviderError(LLMError):
16
+ """Raised when there's an error with the LLM provider."""
17
+
18
+ pass
19
+
20
+
21
+ class LLMTimeoutError(LLMError):
22
+ """Raised when an LLM request times out."""
23
+
24
+ pass
25
+
26
+
27
+ class LLMRateLimitError(LLMError):
28
+ """Raised when rate limit is exceeded."""
29
+
30
+ def __init__(
31
+ self,
32
+ message: str,
33
+ provider: str | None = None,
34
+ model: str | None = None,
35
+ retry_after: int | None = None,
36
+ ):
37
+ self.retry_after = retry_after
38
+ super().__init__(message, provider, model)
39
+
40
+
41
+ class LLMInvalidModelError(LLMError):
42
+ """Raised when an invalid model is specified."""
43
+
44
+ def __init__(self, model: str, available_models: list[str] | None = None):
45
+ self.available_models = available_models or []
46
+ message = f"Invalid model: {model}"
47
+ if available_models:
48
+ message += f". Available models: {', '.join(available_models)}"
49
+ super().__init__(message, model=model)
src/llm/factory.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Factory - Multi-provider LLM abstraction.
3
+
4
+ Supports:
5
+ - Google (Gemini)
6
+ - OpenAI (GPT)
7
+ - Anthropic (Claude)
8
+ """
9
+
10
+ import os
11
+ from typing import Literal
12
+
13
+ from langchain_core.language_models import BaseChatModel
14
+
15
+ from .exceptions import LLMInvalidModelError, LLMProviderError
16
+
17
+ Provider = Literal["google", "openai", "anthropic"]
18
+
19
+ MODEL_PROVIDERS: dict[Provider, list[str]] = {
20
+ "google": [
21
+ "gemini-3-pro-preview",
22
+ "gemini-2.5-flash",
23
+ "gemini-2.5-pro",
24
+ "gemini-2.0-flash",
25
+ "gemini-1.5-flash",
26
+ "gemini-1.5-pro",
27
+ ],
28
+ "openai": [
29
+ "gpt-4o",
30
+ "gpt-4o-mini",
31
+ "gpt-4-turbo",
32
+ "gpt-4",
33
+ "gpt-3.5-turbo",
34
+ ],
35
+ "anthropic": [
36
+ "claude-sonnet-4-20250514",
37
+ "claude-3-5-sonnet-20241022",
38
+ "claude-3-5-haiku-20241022",
39
+ "claude-3-opus-20240229",
40
+ ],
41
+ }
42
+
43
+ # Flatten for quick lookup
44
+ ALL_MODELS: set[str] = {model for models in MODEL_PROVIDERS.values() for model in models}
45
+
46
+
47
+ def detect_provider(model: str) -> Provider:
48
+ """
49
+ Detect the provider based on model name.
50
+
51
+ Args:
52
+ model: The model name (e.g., 'gemini-2.5-flash', 'gpt-4o')
53
+
54
+ Returns:
55
+ The provider name ('google', 'openai', 'anthropic')
56
+
57
+ Raises:
58
+ LLMInvalidModelError: If the model is not recognized
59
+ """
60
+ model_lower = model.lower()
61
+
62
+ # Check by prefix
63
+ if model_lower.startswith("gemini"):
64
+ return "google"
65
+ if model_lower.startswith("gpt"):
66
+ return "openai"
67
+ if model_lower.startswith("claude"):
68
+ return "anthropic"
69
+
70
+ # Check in known models
71
+ for provider, models in MODEL_PROVIDERS.items():
72
+ if model in models:
73
+ return provider
74
+
75
+ raise LLMInvalidModelError(model, list(ALL_MODELS))
76
+
77
+
78
+ class LLMFactory:
79
+ """Factory for creating LLM instances across multiple providers."""
80
+
81
+ # Cache for LLM instances (singleton per model+config)
82
+ _instances: dict[str, BaseChatModel] = {}
83
+
84
+ @classmethod
85
+ def create(
86
+ cls,
87
+ model: str,
88
+ temperature: float = 0.7,
89
+ max_retries: int = 3,
90
+ timeout: int = 60,
91
+ api_key: str | None = None,
92
+ use_cache: bool = True,
93
+ **kwargs,
94
+ ) -> BaseChatModel:
95
+ """
96
+ Create an LLM instance for the specified model.
97
+
98
+ Args:
99
+ model: Model name (e.g., 'gemini-2.5-flash', 'gpt-4o', 'claude-sonnet-4-20250514')
100
+ temperature: Sampling temperature (0.0 to 1.0)
101
+ max_retries: Maximum number of retries on failure
102
+ timeout: Request timeout in seconds
103
+ api_key: Optional API key (defaults to environment variable)
104
+ use_cache: Whether to use cached instances
105
+ **kwargs: Additional provider-specific arguments
106
+
107
+ Returns:
108
+ BaseChatModel instance
109
+
110
+ Raises:
111
+ LLMInvalidModelError: If model is not recognized
112
+ LLMProviderError: If provider initialization fails
113
+ """
114
+ # Check cache
115
+ cache_key = f"{model}:{temperature}:{timeout}"
116
+ if use_cache and cache_key in cls._instances:
117
+ return cls._instances[cache_key]
118
+
119
+ provider = detect_provider(model)
120
+
121
+ try:
122
+ llm = cls._create_for_provider(
123
+ provider=provider,
124
+ model=model,
125
+ temperature=temperature,
126
+ max_retries=max_retries,
127
+ timeout=timeout,
128
+ api_key=api_key,
129
+ **kwargs,
130
+ )
131
+
132
+ if use_cache:
133
+ cls._instances[cache_key] = llm
134
+
135
+ return llm
136
+
137
+ except ImportError as e:
138
+ raise LLMProviderError(
139
+ f"Provider '{provider}' dependencies not installed: {e}",
140
+ provider=provider,
141
+ model=model,
142
+ )
143
+ except Exception as e:
144
+ raise LLMProviderError(
145
+ f"Failed to create LLM for '{model}': {e}",
146
+ provider=provider,
147
+ model=model,
148
+ )
149
+
150
+ @classmethod
151
+ def _create_for_provider(
152
+ cls,
153
+ provider: Provider,
154
+ model: str,
155
+ temperature: float,
156
+ max_retries: int,
157
+ timeout: int,
158
+ api_key: str | None,
159
+ **kwargs,
160
+ ) -> BaseChatModel:
161
+ """Create LLM instance for a specific provider."""
162
+ match provider:
163
+ case "google":
164
+ return cls._create_google(
165
+ model, temperature, max_retries, timeout, api_key, **kwargs
166
+ )
167
+ case "openai":
168
+ return cls._create_openai(
169
+ model, temperature, max_retries, timeout, api_key, **kwargs
170
+ )
171
+ case "anthropic":
172
+ return cls._create_anthropic(
173
+ model, temperature, max_retries, timeout, api_key, **kwargs
174
+ )
175
+
176
+ @staticmethod
177
+ def _create_google(
178
+ model: str,
179
+ temperature: float,
180
+ max_retries: int,
181
+ timeout: int,
182
+ api_key: str | None,
183
+ callbacks: list | None = None,
184
+ **kwargs,
185
+ ) -> BaseChatModel:
186
+ """Create Google Gemini LLM instance."""
187
+ from langchain_google_genai import ChatGoogleGenerativeAI
188
+
189
+ return ChatGoogleGenerativeAI(
190
+ model=model,
191
+ temperature=temperature,
192
+ max_retries=max_retries,
193
+ timeout=timeout,
194
+ google_api_key=api_key or os.getenv("GEMINI_API_KEY"),
195
+ callbacks=callbacks,
196
+ **kwargs,
197
+ )
198
+
199
+ @staticmethod
200
+ def _create_openai(
201
+ model: str,
202
+ temperature: float,
203
+ max_retries: int,
204
+ timeout: int,
205
+ api_key: str | None,
206
+ callbacks: list | None = None,
207
+ **kwargs,
208
+ ) -> BaseChatModel:
209
+ """Create OpenAI LLM instance."""
210
+ from langchain_openai import ChatOpenAI
211
+
212
+ return ChatOpenAI(
213
+ model=model,
214
+ temperature=temperature,
215
+ max_retries=max_retries,
216
+ timeout=timeout,
217
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
218
+ callbacks=callbacks,
219
+ **kwargs,
220
+ )
221
+
222
+ @staticmethod
223
+ def _create_anthropic(
224
+ model: str,
225
+ temperature: float,
226
+ max_retries: int,
227
+ timeout: int,
228
+ api_key: str | None,
229
+ callbacks: list | None = None,
230
+ **kwargs,
231
+ ) -> BaseChatModel:
232
+ """Create Anthropic Claude LLM instance."""
233
+ from langchain_anthropic import ChatAnthropic
234
+
235
+ return ChatAnthropic(
236
+ model=model,
237
+ temperature=temperature,
238
+ max_retries=max_retries,
239
+ timeout=timeout,
240
+ api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
241
+ callbacks=callbacks,
242
+ **kwargs,
243
+ )
244
+
245
+ @classmethod
246
+ def list_models(cls, provider: Provider | None = None) -> list[str]:
247
+ """
248
+ List available models.
249
+
250
+ Args:
251
+ provider: Optional provider to filter by
252
+
253
+ Returns:
254
+ List of model names
255
+ """
256
+ if provider:
257
+ return MODEL_PROVIDERS.get(provider, [])
258
+ return list(ALL_MODELS)
259
+
260
+ @classmethod
261
+ def list_providers(cls) -> list[Provider]:
262
+ """List available providers."""
263
+ return list(MODEL_PROVIDERS.keys())
264
+
265
+ @classmethod
266
+ def clear_cache(cls) -> None:
267
+ """Clear the LLM instance cache."""
268
+ cls._instances.clear()
269
+
270
+ @classmethod
271
+ def get_default_model(cls, provider: Provider | None = None) -> str:
272
+ """
273
+ Get the default model for a provider.
274
+
275
+ Args:
276
+ provider: Provider name (defaults to 'google')
277
+
278
+ Returns:
279
+ Default model name
280
+ """
281
+ provider = provider or "google"
282
+ models = MODEL_PROVIDERS.get(provider, [])
283
+ if not models:
284
+ raise LLMProviderError(f"No models available for provider: {provider}")
285
+ return models[0]
src/service/chat_manager.py CHANGED
@@ -168,6 +168,62 @@ class ChatManager:
168
  display_name=display_name,
169
  )
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  # Singleton-style accessor for the FastAPI routes
173
  chat_manager_instance = ChatManager()
 
168
  display_name=display_name,
169
  )
170
 
171
+ # ---- Cost tracking ----------------------------------------------------
172
+ def update_conversation_costs(
173
+ self,
174
+ cost_delta: Dict,
175
+ conversation_id: Optional[str] = None,
176
+ user_id: Optional[str] = None,
177
+ ) -> Dict:
178
+ """
179
+ Update the conversation with cost data from the latest request.
180
+
181
+ Args:
182
+ cost_delta: Cost delta dict with 'cost', 'tokens', 'calls' keys
183
+ conversation_id: Conversation ID
184
+ user_id: User ID
185
+
186
+ Returns:
187
+ Updated conversation data
188
+ """
189
+ conversation_id, user_id = self._resolve_ids(conversation_id, user_id)
190
+ try:
191
+ result = self._store.update_conversation_costs(user_id, conversation_id, cost_delta)
192
+ logger.info(
193
+ "Updated costs for user=%s conversation=%s: +$%.6f (%d calls)",
194
+ user_id,
195
+ conversation_id,
196
+ cost_delta.get("cost", 0),
197
+ cost_delta.get("calls", 0),
198
+ )
199
+ return result
200
+ except Exception as exc:
201
+ logger.warning(
202
+ "Failed to update costs for user=%s conversation=%s: %s",
203
+ user_id,
204
+ conversation_id,
205
+ exc,
206
+ )
207
+ return {}
208
+
209
+ def get_conversation_costs(
210
+ self,
211
+ conversation_id: Optional[str] = None,
212
+ user_id: Optional[str] = None,
213
+ ) -> Dict:
214
+ """
215
+ Get the accumulated costs for a conversation.
216
+
217
+ Args:
218
+ conversation_id: Conversation ID
219
+ user_id: User ID
220
+
221
+ Returns:
222
+ Cost data dict
223
+ """
224
+ conversation_id, user_id = self._resolve_ids(conversation_id, user_id)
225
+ return self._store.get_conversation_costs(user_id, conversation_id)
226
+
227
 
228
  # Singleton-style accessor for the FastAPI routes
229
  chat_manager_instance = ChatManager()
src/service/panorama_store.py CHANGED
@@ -388,3 +388,99 @@ class PanoramaStore:
388
  )
389
  conversation = self.ensure_conversation(user_id, conversation_id)
390
  return user, conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  )
389
  conversation = self.ensure_conversation(user_id, conversation_id)
390
  return user, conversation
391
+
392
+ # ---- cost tracking -----------------------------------------------------
393
+ def update_conversation_costs(
394
+ self,
395
+ user_id: str,
396
+ conversation_id: str,
397
+ cost_delta: Dict[str, Any],
398
+ ) -> Dict[str, Any]:
399
+ """
400
+ Update the conversation with accumulated cost data.
401
+
402
+ Args:
403
+ user_id: User ID
404
+ conversation_id: Conversation ID
405
+ cost_delta: Cost delta from this request (cost, tokens, calls)
406
+
407
+ Returns:
408
+ Updated conversation data
409
+ """
410
+ conv_key = _conversation_key(user_id, conversation_id)
411
+ try:
412
+ conversation = self._client.get("conversations", conv_key)
413
+ except PanoramaGatewayError as exc:
414
+ if exc.status_code == 404:
415
+ self._logger.warning("Conversation %s not found for cost update", conv_key)
416
+ return {}
417
+ raise
418
+
419
+ # Get existing cost data from contextState
420
+ context_state = conversation.get("contextState", {}) or {}
421
+ existing_costs = context_state.get("costs", {
422
+ "total_cost": 0.0,
423
+ "total_tokens": {"input": 0, "output": 0, "cache": 0},
424
+ "total_calls": 0,
425
+ })
426
+
427
+ # Accumulate costs
428
+ delta_tokens = cost_delta.get("tokens", {})
429
+ existing_tokens = existing_costs.get("total_tokens", {"input": 0, "output": 0, "cache": 0})
430
+
431
+ updated_costs = {
432
+ "total_cost": round(existing_costs.get("total_cost", 0.0) + cost_delta.get("cost", 0.0), 6),
433
+ "total_tokens": {
434
+ "input": existing_tokens.get("input", 0) + delta_tokens.get("input", 0),
435
+ "output": existing_tokens.get("output", 0) + delta_tokens.get("output", 0),
436
+ "cache": existing_tokens.get("cache", 0) + delta_tokens.get("cache", 0),
437
+ },
438
+ "total_calls": existing_costs.get("total_calls", 0) + cost_delta.get("calls", 0),
439
+ "last_updated": _utc_now_iso(),
440
+ }
441
+
442
+ # Update contextState with new costs
443
+ context_state["costs"] = updated_costs
444
+ try:
445
+ return self._client.update(
446
+ "conversations",
447
+ conv_key,
448
+ {"contextState": context_state, "updatedAt": _utc_now_iso()},
449
+ )
450
+ except PanoramaGatewayError as exc:
451
+ self._logger.error(
452
+ "Failed to update costs for conversation %s: status=%s",
453
+ conv_key,
454
+ exc.status_code,
455
+ )
456
+ raise
457
+
458
+ def get_conversation_costs(
459
+ self,
460
+ user_id: str,
461
+ conversation_id: str,
462
+ ) -> Dict[str, Any]:
463
+ """
464
+ Get the accumulated costs for a conversation.
465
+
466
+ Args:
467
+ user_id: User ID
468
+ conversation_id: Conversation ID
469
+
470
+ Returns:
471
+ Cost data or empty dict if not found
472
+ """
473
+ conv_key = _conversation_key(user_id, conversation_id)
474
+ try:
475
+ conversation = self._client.get("conversations", conv_key)
476
+ except PanoramaGatewayError as exc:
477
+ if exc.status_code == 404:
478
+ return {}
479
+ raise
480
+
481
+ context_state = conversation.get("contextState", {}) or {}
482
+ return context_state.get("costs", {
483
+ "total_cost": 0.0,
484
+ "total_tokens": {"input": 0, "output": 0, "cache": 0},
485
+ "total_calls": 0,
486
+ })