Pulastya B commited on
Commit
be4eb33
Β·
1 Parent(s): 7d775b3

Add Mistral AI integration - 1B tokens/month (33x Groq limits)

Browse files

NEW PROVIDER: Mistral AI
- Token limits: 1 billion tokens/month (~33M tokens/day)
- 500K tokens/minute (42x better than Groq's 12K TPM)
- 500 requests/minute (17x better than Groq's 30 RPM)
- OpenAI-compatible API (minimal code changes)

CHANGES:
1. requirements.txt: Added mistralai>=1.0.0

2. .env:
- LLM_PROVIDER=mistral (new default)
- MISTRAL_API_KEY env var
- MISTRAL_MODEL=mistral-large-latest

3. orchestrator.py:
- Added Mistral client initialization
- Set TPM=500K, RPM=500 limits
- Added mistral provider to API call logic
- Reuses Groq code path (OpenAI format)
- Token budget management applies to both

4. api/app.py:
- Updated startup to default to mistral
- Compact prompts enabled for mistral/groq

BENEFITS:
- 33x more daily tokens than Groq (no TPD exhaustion)
- 42x higher TPM limit (no minute-level throttling)
- Production-ready for heavy usage
- Backward compatible (groq/gemini still work)

ARCHITECTURE IMPACT: Minimal (~80 lines)
- Mistral uses same API format as Groq
- All existing logic preserved

Files changed (3) hide show
  1. requirements.txt +1 -0
  2. src/api/app.py +3 -3
  3. src/orchestrator.py +63 -11
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  # Core Dependencies
2
  groq>=0.13.0 # Updated for httpx compatibility
 
3
  python-dotenv==1.0.0
4
 
5
  # Data Processing
 
1
  # Core Dependencies
2
  groq>=0.13.0 # Updated for httpx compatibility
3
+ mistralai>=1.0.0 # Mistral AI - 1B tokens/month
4
  python-dotenv==1.0.0
5
 
6
  # Data Processing
src/api/app.py CHANGED
@@ -61,9 +61,9 @@ async def startup_event():
61
  global agent
62
  try:
63
  logger.info("Initializing DataScienceCopilot...")
64
- provider = os.getenv("LLM_PROVIDER", "groq")
65
- # Auto-enable compact prompts for Groq (small context window)
66
- use_compact = provider.lower() == "groq"
67
 
68
  agent = DataScienceCopilot(
69
  reasoning_effort="medium",
 
61
  global agent
62
  try:
63
  logger.info("Initializing DataScienceCopilot...")
64
+ provider = os.getenv("LLM_PROVIDER", "mistral")
65
+ # Auto-enable compact prompts for Mistral/Groq (smaller context windows)
66
+ use_compact = provider.lower() in ["mistral", "groq"]
67
 
68
  agent = DataScienceCopilot(
69
  reasoning_effort="medium",
src/orchestrator.py CHANGED
@@ -135,6 +135,7 @@ class DataScienceCopilot:
135
 
136
  def __init__(self, groq_api_key: Optional[str] = None,
137
  google_api_key: Optional[str] = None,
 
138
  cache_db_path: Optional[str] = None,
139
  reasoning_effort: str = "medium",
140
  provider: Optional[str] = None,
@@ -147,6 +148,7 @@ class DataScienceCopilot:
147
  Args:
148
  groq_api_key: Groq API key (or set GROQ_API_KEY env var)
149
  google_api_key: Google API key (or set GOOGLE_API_KEY env var)
 
150
  cache_db_path: Path to cache database
151
  reasoning_effort: Reasoning effort for Groq ('low', 'medium', 'high')
152
  provider: LLM provider - 'groq' or 'gemini' (or set LLM_PROVIDER env var)
@@ -158,12 +160,26 @@ class DataScienceCopilot:
158
  load_dotenv()
159
 
160
  # Determine provider
161
- self.provider = provider or os.getenv("LLM_PROVIDER", "groq").lower()
162
 
163
- # Set compact prompts: Auto-enable for Groq, manual for others
164
- self.use_compact_prompts = use_compact_prompts or (self.provider == "groq")
165
 
166
- if self.provider == "groq":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Initialize Groq client
168
  api_key = groq_api_key or os.getenv("GROQ_API_KEY")
169
  if not api_key:
@@ -173,6 +189,7 @@ class DataScienceCopilot:
173
  self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
174
  self.reasoning_effort = reasoning_effort
175
  self.gemini_model = None
 
176
  print(f"πŸ€– Initialized with Groq provider - Model: {self.model}")
177
 
178
  elif self.provider == "gemini":
@@ -198,9 +215,11 @@ class DataScienceCopilot:
198
  safety_settings=safety_settings
199
  )
200
  self.groq_client = None
 
201
  print(f"πŸ€– Initialized with Gemini provider - Model: {self.model}")
202
 
203
  else:
 
204
  raise ValueError(f"Unsupported provider: {self.provider}. Choose 'groq' or 'gemini'")
205
 
206
  # Initialize cache
@@ -253,7 +272,11 @@ class DataScienceCopilot:
253
  self.api_calls_made = 0
254
 
255
  # Provider-specific limits
256
- if self.provider == "groq":
 
 
 
 
257
  self.tpm_limit = 12000 # Tokens per minute
258
  self.rpm_limit = 30 # Requests per minute
259
  self.min_api_call_interval = 0.5 # Wait between calls
@@ -1726,8 +1749,8 @@ You are a DOER. Complete workflows based on user intent."""
1726
  messages = [messages[0], messages[1]] + messages[-4:]
1727
  print(f"⚠️ Emergency pruning (conversation > 8K tokens)")
1728
 
1729
- # πŸ’° Token budget management (Groq TPM limit)
1730
- if self.provider == "groq":
1731
  # Reset minute counter if needed
1732
  elapsed = time.time() - self.minute_start_time
1733
  if elapsed > 60:
@@ -1764,7 +1787,36 @@ You are a DOER. Complete workflows based on user intent."""
1764
  response_message = None
1765
 
1766
  # Call LLM with function calling (provider-specific)
1767
- if self.provider == "groq":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1768
  try:
1769
  response = self.groq_client.chat.completions.create(
1770
  model=self.model,
@@ -2406,9 +2458,9 @@ You are a DOER. Complete workflows based on user intent."""
2406
  self._update_workflow_state(tool_name, tool_result)
2407
 
2408
  # ⚑ CRITICAL FIX: Add tool result back to messages so LLM sees it in next iteration!
2409
- if self.provider == "groq":
2410
- # For Groq, add tool message with the result
2411
- # **COMPRESS RESULT** for small context models (Groq 12K token limit)
2412
  clean_tool_result = self._make_json_serializable(tool_result)
2413
 
2414
  # Smart compression: Keep only what LLM needs for next decision
 
135
 
136
  def __init__(self, groq_api_key: Optional[str] = None,
137
  google_api_key: Optional[str] = None,
138
+ mistral_api_key: Optional[str] = None,
139
  cache_db_path: Optional[str] = None,
140
  reasoning_effort: str = "medium",
141
  provider: Optional[str] = None,
 
148
  Args:
149
  groq_api_key: Groq API key (or set GROQ_API_KEY env var)
150
  google_api_key: Google API key (or set GOOGLE_API_KEY env var)
151
+ mistral_api_key: Mistral API key (or set MISTRAL_API_KEY env var)
152
  cache_db_path: Path to cache database
153
  reasoning_effort: Reasoning effort for Groq ('low', 'medium', 'high')
154
  provider: LLM provider - 'groq' or 'gemini' (or set LLM_PROVIDER env var)
 
160
  load_dotenv()
161
 
162
  # Determine provider
163
+ self.provider = provider or os.getenv("LLM_PROVIDER", "mistral").lower()
164
 
165
+ # Set compact prompts: Auto-enable for Groq/Mistral, manual for others
166
+ self.use_compact_prompts = use_compact_prompts or (self.provider in ["groq", "mistral"])
167
 
168
+ if self.provider == "mistral":
169
+ # Initialize Mistral client (OpenAI-compatible)
170
+ api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
171
+ if not api_key:
172
+ raise ValueError("Mistral API key must be provided or set in MISTRAL_API_KEY env var")
173
+
174
+ from mistralai import Mistral
175
+ self.mistral_client = Mistral(api_key=api_key)
176
+ self.model = os.getenv("MISTRAL_MODEL", "mistral-large-latest")
177
+ self.reasoning_effort = reasoning_effort
178
+ self.gemini_model = None
179
+ self.groq_client = None
180
+ print(f"πŸ€– Initialized with Mistral provider - Model: {self.model}")
181
+
182
+ elif self.provider == "groq":
183
  # Initialize Groq client
184
  api_key = groq_api_key or os.getenv("GROQ_API_KEY")
185
  if not api_key:
 
189
  self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
190
  self.reasoning_effort = reasoning_effort
191
  self.gemini_model = None
192
+ self.mistral_client = None
193
  print(f"πŸ€– Initialized with Groq provider - Model: {self.model}")
194
 
195
  elif self.provider == "gemini":
 
215
  safety_settings=safety_settings
216
  )
217
  self.groq_client = None
218
+ self.mistral_client = None
219
  print(f"πŸ€– Initialized with Gemini provider - Model: {self.model}")
220
 
221
  else:
222
+ raise ValueError(f"Invalid provider: {self.provider}. Must be 'mistral', 'groq', or 'gemini'")
223
  raise ValueError(f"Unsupported provider: {self.provider}. Choose 'groq' or 'gemini'")
224
 
225
  # Initialize cache
 
272
  self.api_calls_made = 0
273
 
274
  # Provider-specific limits
275
+ if self.provider == "mistral":
276
+ self.tpm_limit = 500000 # 500K tokens/minute (very generous)
277
+ self.rpm_limit = 500 # 500 requests/minute
278
+ self.min_api_call_interval = 0.1 # Minimal delay
279
+ elif self.provider == "groq":
280
  self.tpm_limit = 12000 # Tokens per minute
281
  self.rpm_limit = 30 # Requests per minute
282
  self.min_api_call_interval = 0.5 # Wait between calls
 
1749
  messages = [messages[0], messages[1]] + messages[-4:]
1750
  print(f"⚠️ Emergency pruning (conversation > 8K tokens)")
1751
 
1752
+ # πŸ’° Token budget management (TPM limit)
1753
+ if self.provider in ["mistral", "groq"]:
1754
  # Reset minute counter if needed
1755
  elapsed = time.time() - self.minute_start_time
1756
  if elapsed > 60:
 
1787
  response_message = None
1788
 
1789
  # Call LLM with function calling (provider-specific)
1790
+ if self.provider == "mistral":
1791
+ try:
1792
+ response = self.mistral_client.chat.complete(
1793
+ model=self.model,
1794
+ messages=messages,
1795
+ tools=tools_to_use,
1796
+ tool_choice="auto",
1797
+ temperature=0.1,
1798
+ max_tokens=4096
1799
+ )
1800
+
1801
+ self.api_calls_made += 1
1802
+ self.last_api_call_time = time.time()
1803
+
1804
+ # Track tokens used (for TPM budget management)
1805
+ if hasattr(response, 'usage') and response.usage:
1806
+ tokens_used = response.usage.total_tokens
1807
+ self.tokens_this_minute += tokens_used
1808
+ print(f"πŸ“Š Tokens: {tokens_used} this call | {self.tokens_this_minute}/{self.tpm_limit} this minute")
1809
+
1810
+ response_message = response.choices[0].message
1811
+ tool_calls = response_message.tool_calls
1812
+ final_content = response_message.content
1813
+
1814
+ except Exception as mistral_error:
1815
+ error_str = str(mistral_error)
1816
+ print(f"❌ MISTRAL ERROR: {error_str[:300]}")
1817
+ raise
1818
+
1819
+ elif self.provider == "groq":
1820
  try:
1821
  response = self.groq_client.chat.completions.create(
1822
  model=self.model,
 
2458
  self._update_workflow_state(tool_name, tool_result)
2459
 
2460
  # ⚑ CRITICAL FIX: Add tool result back to messages so LLM sees it in next iteration!
2461
+ if self.provider in ["mistral", "groq"]:
2462
+ # For Mistral/Groq, add tool message with the result
2463
+ # **COMPRESS RESULT** for small context models
2464
  clean_tool_result = self._make_json_serializable(tool_result)
2465
 
2466
  # Smart compression: Keep only what LLM needs for next decision