Spaces:
Running
Running
Kunal Pai commited on
Commit ·
1b3d55c
1
Parent(s): 4a2efcf
Refactor GroqAgent initialization and enhance error handling; update AgentCostManager with Groq model costs
Browse files
src/manager/agent_manager.py
CHANGED
|
@@ -147,43 +147,72 @@ class GroqAgent(Agent):
|
|
| 147 |
def __init__(
|
| 148 |
self,
|
| 149 |
agent_name: str,
|
| 150 |
-
base_model: str
|
| 151 |
-
system_prompt: str
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
):
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
| 156 |
api_key = os.getenv("GROQ_API_KEY")
|
|
|
|
|
|
|
| 157 |
self.client = Groq(api_key=api_key)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def create_model(self) -> None:
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
pass
|
| 165 |
|
| 166 |
def ask_agent(self, prompt: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
messages = [
|
| 168 |
-
{"role": "system", "content": self.
|
| 169 |
{"role": "user", "content": prompt},
|
| 170 |
]
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
def delete_agent(self) -> None:
|
| 180 |
-
|
| 181 |
pass
|
| 182 |
|
| 183 |
-
def get_type(self):
|
|
|
|
| 184 |
return self.type
|
| 185 |
|
| 186 |
-
|
| 187 |
@singleton
|
| 188 |
class AgentManager():
|
| 189 |
budget_manager: BudgetManager = BudgetManager()
|
|
|
|
| 147 |
def __init__(
|
| 148 |
self,
|
| 149 |
agent_name: str,
|
| 150 |
+
base_model: str,
|
| 151 |
+
system_prompt: str,
|
| 152 |
+
create_resource_cost: int,
|
| 153 |
+
invoke_resource_cost: int,
|
| 154 |
+
create_expense_cost: int = 0,
|
| 155 |
+
invoke_expense_cost: int = 0,
|
| 156 |
):
|
| 157 |
+
# Call the parent class constructor first
|
| 158 |
+
super().__init__(agent_name, base_model, system_prompt,
|
| 159 |
+
create_resource_cost, invoke_resource_cost,
|
| 160 |
+
create_expense_cost, invoke_expense_cost)
|
| 161 |
+
|
| 162 |
+
# Groq-specific API client setup
|
| 163 |
api_key = os.getenv("GROQ_API_KEY")
|
| 164 |
+
if not api_key:
|
| 165 |
+
raise ValueError("GROQ_API_KEY environment variable not set. Please set it in your .env file or environment.")
|
| 166 |
self.client = Groq(api_key=api_key)
|
| 167 |
+
|
| 168 |
+
if self.base_model and "groq-" in self.base_model:
|
| 169 |
+
self.groq_api_model_name = self.base_model.split("groq-", 1)[1]
|
| 170 |
+
else:
|
| 171 |
+
# Fallback or error if the naming convention isn't followed.
|
| 172 |
+
# This ensures that if a non-prefixed model name is somehow passed,
|
| 173 |
+
# it might still work, or you can raise an error.
|
| 174 |
+
self.groq_api_model_name = self.base_model
|
| 175 |
+
print(f"Warning: GroqAgent base_model '{self.base_model}' does not follow 'groq-' prefix convention.")
|
| 176 |
|
| 177 |
def create_model(self) -> None:
|
| 178 |
+
"""
|
| 179 |
+
Create and Initialize agent.
|
| 180 |
+
For Groq, models are pre-existing on their cloud.
|
| 181 |
+
This method is called by Agent's __init__.
|
| 182 |
+
"""
|
| 183 |
pass
|
| 184 |
|
| 185 |
def ask_agent(self, prompt: str) -> str:
|
| 186 |
+
"""Ask agent a question"""
|
| 187 |
+
if not self.client:
|
| 188 |
+
raise ConnectionError("Groq client not initialized. Check API key and constructor.")
|
| 189 |
+
if not self.groq_api_model_name:
|
| 190 |
+
raise ValueError("Groq API model name not set. Check base_model configuration.")
|
| 191 |
+
|
| 192 |
messages = [
|
| 193 |
+
{"role": "system", "content": self.system_prompt},
|
| 194 |
{"role": "user", "content": prompt},
|
| 195 |
]
|
| 196 |
+
try:
|
| 197 |
+
response = self.client.chat.completions.create(
|
| 198 |
+
messages=messages,
|
| 199 |
+
model=self.groq_api_model_name, # Use the derived model name for Groq API
|
| 200 |
+
)
|
| 201 |
+
result = response.choices[0].message.content
|
| 202 |
+
return result
|
| 203 |
+
except Exception as e:
|
| 204 |
+
# Handle API errors or other exceptions during the call
|
| 205 |
+
print(f"Error calling Groq API: {e}")
|
| 206 |
+
raise # Re-raise the exception or handle it as appropriate
|
| 207 |
|
| 208 |
def delete_agent(self) -> None:
|
| 209 |
+
"""Delete agent"""
|
| 210 |
pass
|
| 211 |
|
| 212 |
+
def get_type(self) -> str: # Ensure return type hint matches Agent ABC
|
| 213 |
+
"""Get agent type"""
|
| 214 |
return self.type
|
| 215 |
|
|
|
|
| 216 |
@singleton
|
| 217 |
class AgentManager():
|
| 218 |
budget_manager: BudgetManager = BudgetManager()
|
src/tools/default_tools/agent_cost_manager.py
CHANGED
|
@@ -60,6 +60,11 @@ class AgentCostManager():
|
|
| 60 |
"create_expense_cost": 0,
|
| 61 |
"invoke_expense_cost": 0.0375,
|
| 62 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
|
| 65 |
def get_costs(self):
|
|
|
|
| 60 |
"create_expense_cost": 0,
|
| 61 |
"invoke_expense_cost": 0.0375,
|
| 62 |
},
|
| 63 |
+
"groq-qwen-qwq-32b": {
|
| 64 |
+
"description": "Avg Accuracy: 60.0%, 70.0% on multi-task understanding, 80.0% on Math",
|
| 65 |
+
"create_expense_cost": 0,
|
| 66 |
+
"invoke_expense_cost": 0.05,
|
| 67 |
+
},
|
| 68 |
}
|
| 69 |
|
| 70 |
def get_costs(self):
|