Spaces:
Running
Add Mistral AI integration - 1B tokens/month (33x Groq limits)
Browse filesNEW PROVIDER: Mistral AI
- Token limits: 1 billion tokens/month (~33M tokens/day)
- 500K tokens/minute (42x better than Groq's 12K TPM)
- 500 requests/minute (17x better than Groq's 30 RPM)
- OpenAI-compatible API (minimal code changes)
CHANGES:
1. requirements.txt: Added mistralai>=1.0.0
2. .env:
- LLM_PROVIDER=mistral (new default)
- MISTRAL_API_KEY env var
- MISTRAL_MODEL=mistral-large-latest
3. orchestrator.py:
- Added Mistral client initialization
- Set TPM=500K, RPM=500 limits
- Added mistral provider to API call logic
- Reuses Groq code path (OpenAI format)
- Token budget management applies to both
4. api/app.py:
- Updated startup to default to mistral
- Compact prompts enabled for mistral/groq
BENEFITS:
- 33x more daily tokens than Groq (no TPD exhaustion)
- 42x higher TPM limit (no minute-level throttling)
- Production-ready for heavy usage
- Backward compatible (groq/gemini still work)
ARCHITECTURE IMPACT: Minimal (~80 lines)
- Mistral uses same API format as Groq
- All existing logic preserved
- requirements.txt +1 -0
- src/api/app.py +3 -3
- src/orchestrator.py +63 -11
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# Core Dependencies
|
| 2 |
groq>=0.13.0 # Updated for httpx compatibility
|
|
|
|
| 3 |
python-dotenv==1.0.0
|
| 4 |
|
| 5 |
# Data Processing
|
|
|
|
| 1 |
# Core Dependencies
|
| 2 |
groq>=0.13.0 # Updated for httpx compatibility
|
| 3 |
+
mistralai>=1.0.0 # Mistral AI - 1B tokens/month
|
| 4 |
python-dotenv==1.0.0
|
| 5 |
|
| 6 |
# Data Processing
|
|
@@ -61,9 +61,9 @@ async def startup_event():
|
|
| 61 |
global agent
|
| 62 |
try:
|
| 63 |
logger.info("Initializing DataScienceCopilot...")
|
| 64 |
-
provider = os.getenv("LLM_PROVIDER", "
|
| 65 |
-
# Auto-enable compact prompts for Groq (
|
| 66 |
-
use_compact = provider.lower()
|
| 67 |
|
| 68 |
agent = DataScienceCopilot(
|
| 69 |
reasoning_effort="medium",
|
|
|
|
| 61 |
global agent
|
| 62 |
try:
|
| 63 |
logger.info("Initializing DataScienceCopilot...")
|
| 64 |
+
provider = os.getenv("LLM_PROVIDER", "mistral")
|
| 65 |
+
# Auto-enable compact prompts for Mistral/Groq (smaller context windows)
|
| 66 |
+
use_compact = provider.lower() in ["mistral", "groq"]
|
| 67 |
|
| 68 |
agent = DataScienceCopilot(
|
| 69 |
reasoning_effort="medium",
|
|
@@ -135,6 +135,7 @@ class DataScienceCopilot:
|
|
| 135 |
|
| 136 |
def __init__(self, groq_api_key: Optional[str] = None,
|
| 137 |
google_api_key: Optional[str] = None,
|
|
|
|
| 138 |
cache_db_path: Optional[str] = None,
|
| 139 |
reasoning_effort: str = "medium",
|
| 140 |
provider: Optional[str] = None,
|
|
@@ -147,6 +148,7 @@ class DataScienceCopilot:
|
|
| 147 |
Args:
|
| 148 |
groq_api_key: Groq API key (or set GROQ_API_KEY env var)
|
| 149 |
google_api_key: Google API key (or set GOOGLE_API_KEY env var)
|
|
|
|
| 150 |
cache_db_path: Path to cache database
|
| 151 |
reasoning_effort: Reasoning effort for Groq ('low', 'medium', 'high')
|
| 152 |
provider: LLM provider - 'groq' or 'gemini' (or set LLM_PROVIDER env var)
|
|
@@ -158,12 +160,26 @@ class DataScienceCopilot:
|
|
| 158 |
load_dotenv()
|
| 159 |
|
| 160 |
# Determine provider
|
| 161 |
-
self.provider = provider or os.getenv("LLM_PROVIDER", "
|
| 162 |
|
| 163 |
-
# Set compact prompts: Auto-enable for Groq, manual for others
|
| 164 |
-
self.use_compact_prompts = use_compact_prompts or (self.provider
|
| 165 |
|
| 166 |
-
if self.provider == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# Initialize Groq client
|
| 168 |
api_key = groq_api_key or os.getenv("GROQ_API_KEY")
|
| 169 |
if not api_key:
|
|
@@ -173,6 +189,7 @@ class DataScienceCopilot:
|
|
| 173 |
self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 174 |
self.reasoning_effort = reasoning_effort
|
| 175 |
self.gemini_model = None
|
|
|
|
| 176 |
print(f"π€ Initialized with Groq provider - Model: {self.model}")
|
| 177 |
|
| 178 |
elif self.provider == "gemini":
|
|
@@ -198,9 +215,11 @@ class DataScienceCopilot:
|
|
| 198 |
safety_settings=safety_settings
|
| 199 |
)
|
| 200 |
self.groq_client = None
|
|
|
|
| 201 |
print(f"π€ Initialized with Gemini provider - Model: {self.model}")
|
| 202 |
|
| 203 |
else:
|
|
|
|
| 204 |
raise ValueError(f"Unsupported provider: {self.provider}. Choose 'groq' or 'gemini'")
|
| 205 |
|
| 206 |
# Initialize cache
|
|
@@ -253,7 +272,11 @@ class DataScienceCopilot:
|
|
| 253 |
self.api_calls_made = 0
|
| 254 |
|
| 255 |
# Provider-specific limits
|
| 256 |
-
if self.provider == "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
self.tpm_limit = 12000 # Tokens per minute
|
| 258 |
self.rpm_limit = 30 # Requests per minute
|
| 259 |
self.min_api_call_interval = 0.5 # Wait between calls
|
|
@@ -1726,8 +1749,8 @@ You are a DOER. Complete workflows based on user intent."""
|
|
| 1726 |
messages = [messages[0], messages[1]] + messages[-4:]
|
| 1727 |
print(f"β οΈ Emergency pruning (conversation > 8K tokens)")
|
| 1728 |
|
| 1729 |
-
# π° Token budget management (
|
| 1730 |
-
if self.provider
|
| 1731 |
# Reset minute counter if needed
|
| 1732 |
elapsed = time.time() - self.minute_start_time
|
| 1733 |
if elapsed > 60:
|
|
@@ -1764,7 +1787,36 @@ You are a DOER. Complete workflows based on user intent."""
|
|
| 1764 |
response_message = None
|
| 1765 |
|
| 1766 |
# Call LLM with function calling (provider-specific)
|
| 1767 |
-
if self.provider == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1768 |
try:
|
| 1769 |
response = self.groq_client.chat.completions.create(
|
| 1770 |
model=self.model,
|
|
@@ -2406,9 +2458,9 @@ You are a DOER. Complete workflows based on user intent."""
|
|
| 2406 |
self._update_workflow_state(tool_name, tool_result)
|
| 2407 |
|
| 2408 |
# β‘ CRITICAL FIX: Add tool result back to messages so LLM sees it in next iteration!
|
| 2409 |
-
if self.provider
|
| 2410 |
-
# For Groq, add tool message with the result
|
| 2411 |
-
# **COMPRESS RESULT** for small context models
|
| 2412 |
clean_tool_result = self._make_json_serializable(tool_result)
|
| 2413 |
|
| 2414 |
# Smart compression: Keep only what LLM needs for next decision
|
|
|
|
| 135 |
|
| 136 |
def __init__(self, groq_api_key: Optional[str] = None,
|
| 137 |
google_api_key: Optional[str] = None,
|
| 138 |
+
mistral_api_key: Optional[str] = None,
|
| 139 |
cache_db_path: Optional[str] = None,
|
| 140 |
reasoning_effort: str = "medium",
|
| 141 |
provider: Optional[str] = None,
|
|
|
|
| 148 |
Args:
|
| 149 |
groq_api_key: Groq API key (or set GROQ_API_KEY env var)
|
| 150 |
google_api_key: Google API key (or set GOOGLE_API_KEY env var)
|
| 151 |
+
mistral_api_key: Mistral API key (or set MISTRAL_API_KEY env var)
|
| 152 |
cache_db_path: Path to cache database
|
| 153 |
reasoning_effort: Reasoning effort for Groq ('low', 'medium', 'high')
|
| 154 |
provider: LLM provider - 'groq' or 'gemini' (or set LLM_PROVIDER env var)
|
|
|
|
| 160 |
load_dotenv()
|
| 161 |
|
| 162 |
# Determine provider
|
| 163 |
+
self.provider = provider or os.getenv("LLM_PROVIDER", "mistral").lower()
|
| 164 |
|
| 165 |
+
# Set compact prompts: Auto-enable for Groq/Mistral, manual for others
|
| 166 |
+
self.use_compact_prompts = use_compact_prompts or (self.provider in ["groq", "mistral"])
|
| 167 |
|
| 168 |
+
if self.provider == "mistral":
|
| 169 |
+
# Initialize Mistral client (OpenAI-compatible)
|
| 170 |
+
api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
|
| 171 |
+
if not api_key:
|
| 172 |
+
raise ValueError("Mistral API key must be provided or set in MISTRAL_API_KEY env var")
|
| 173 |
+
|
| 174 |
+
from mistralai import Mistral
|
| 175 |
+
self.mistral_client = Mistral(api_key=api_key)
|
| 176 |
+
self.model = os.getenv("MISTRAL_MODEL", "mistral-large-latest")
|
| 177 |
+
self.reasoning_effort = reasoning_effort
|
| 178 |
+
self.gemini_model = None
|
| 179 |
+
self.groq_client = None
|
| 180 |
+
print(f"π€ Initialized with Mistral provider - Model: {self.model}")
|
| 181 |
+
|
| 182 |
+
elif self.provider == "groq":
|
| 183 |
# Initialize Groq client
|
| 184 |
api_key = groq_api_key or os.getenv("GROQ_API_KEY")
|
| 185 |
if not api_key:
|
|
|
|
| 189 |
self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 190 |
self.reasoning_effort = reasoning_effort
|
| 191 |
self.gemini_model = None
|
| 192 |
+
self.mistral_client = None
|
| 193 |
print(f"π€ Initialized with Groq provider - Model: {self.model}")
|
| 194 |
|
| 195 |
elif self.provider == "gemini":
|
|
|
|
| 215 |
safety_settings=safety_settings
|
| 216 |
)
|
| 217 |
self.groq_client = None
|
| 218 |
+
self.mistral_client = None
|
| 219 |
print(f"π€ Initialized with Gemini provider - Model: {self.model}")
|
| 220 |
|
| 221 |
else:
|
| 222 |
+
raise ValueError(f"Invalid provider: {self.provider}. Must be 'mistral', 'groq', or 'gemini'")
|
| 223 |
raise ValueError(f"Unsupported provider: {self.provider}. Choose 'groq' or 'gemini'")
|
| 224 |
|
| 225 |
# Initialize cache
|
|
|
|
| 272 |
self.api_calls_made = 0
|
| 273 |
|
| 274 |
# Provider-specific limits
|
| 275 |
+
if self.provider == "mistral":
|
| 276 |
+
self.tpm_limit = 500000 # 500K tokens/minute (very generous)
|
| 277 |
+
self.rpm_limit = 500 # 500 requests/minute
|
| 278 |
+
self.min_api_call_interval = 0.1 # Minimal delay
|
| 279 |
+
elif self.provider == "groq":
|
| 280 |
self.tpm_limit = 12000 # Tokens per minute
|
| 281 |
self.rpm_limit = 30 # Requests per minute
|
| 282 |
self.min_api_call_interval = 0.5 # Wait between calls
|
|
|
|
| 1749 |
messages = [messages[0], messages[1]] + messages[-4:]
|
| 1750 |
print(f"β οΈ Emergency pruning (conversation > 8K tokens)")
|
| 1751 |
|
| 1752 |
+
# π° Token budget management (TPM limit)
|
| 1753 |
+
if self.provider in ["mistral", "groq"]:
|
| 1754 |
# Reset minute counter if needed
|
| 1755 |
elapsed = time.time() - self.minute_start_time
|
| 1756 |
if elapsed > 60:
|
|
|
|
| 1787 |
response_message = None
|
| 1788 |
|
| 1789 |
# Call LLM with function calling (provider-specific)
|
| 1790 |
+
if self.provider == "mistral":
|
| 1791 |
+
try:
|
| 1792 |
+
response = self.mistral_client.chat.complete(
|
| 1793 |
+
model=self.model,
|
| 1794 |
+
messages=messages,
|
| 1795 |
+
tools=tools_to_use,
|
| 1796 |
+
tool_choice="auto",
|
| 1797 |
+
temperature=0.1,
|
| 1798 |
+
max_tokens=4096
|
| 1799 |
+
)
|
| 1800 |
+
|
| 1801 |
+
self.api_calls_made += 1
|
| 1802 |
+
self.last_api_call_time = time.time()
|
| 1803 |
+
|
| 1804 |
+
# Track tokens used (for TPM budget management)
|
| 1805 |
+
if hasattr(response, 'usage') and response.usage:
|
| 1806 |
+
tokens_used = response.usage.total_tokens
|
| 1807 |
+
self.tokens_this_minute += tokens_used
|
| 1808 |
+
print(f"π Tokens: {tokens_used} this call | {self.tokens_this_minute}/{self.tpm_limit} this minute")
|
| 1809 |
+
|
| 1810 |
+
response_message = response.choices[0].message
|
| 1811 |
+
tool_calls = response_message.tool_calls
|
| 1812 |
+
final_content = response_message.content
|
| 1813 |
+
|
| 1814 |
+
except Exception as mistral_error:
|
| 1815 |
+
error_str = str(mistral_error)
|
| 1816 |
+
print(f"β MISTRAL ERROR: {error_str[:300]}")
|
| 1817 |
+
raise
|
| 1818 |
+
|
| 1819 |
+
elif self.provider == "groq":
|
| 1820 |
try:
|
| 1821 |
response = self.groq_client.chat.completions.create(
|
| 1822 |
model=self.model,
|
|
|
|
| 2458 |
self._update_workflow_state(tool_name, tool_result)
|
| 2459 |
|
| 2460 |
# β‘ CRITICAL FIX: Add tool result back to messages so LLM sees it in next iteration!
|
| 2461 |
+
if self.provider in ["mistral", "groq"]:
|
| 2462 |
+
# For Mistral/Groq, add tool message with the result
|
| 2463 |
+
# **COMPRESS RESULT** for small context models
|
| 2464 |
clean_tool_result = self._make_json_serializable(tool_result)
|
| 2465 |
|
| 2466 |
# Smart compression: Keep only what LLM needs for next decision
|