Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import uuid | |
| import time | |
| import base64 | |
| import sys | |
| import inspect | |
| import secrets | |
| import asyncio | |
| import contextvars | |
| import functools | |
| import re | |
| import threading | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Any, Dict, Union, Tuple, Iterator | |
| from pathlib import Path | |
| from contextlib import contextmanager | |
| from concurrent.futures import ThreadPoolExecutor | |
| from enum import Enum | |
| import requests | |
| from flask import ( | |
| Flask, | |
| request, | |
| Response, | |
| jsonify, | |
| stream_with_context, | |
| render_template, | |
| redirect, | |
| session, | |
| ) | |
| from curl_cffi import requests as curl_requests | |
| from werkzeug.middleware.proxy_fix import ProxyFix | |
| from patchright.async_api import async_playwright, Browser, BrowserContext | |
| class PlaywrightStatsigManager: | |
| """ | |
| x-statsig-id capture using Playwright (adapted from Grok3API driver.py) | |
| This approach captures authentic x-statsig-id headers by: | |
| 1. Patching window.fetch to intercept grok.com's own API calls | |
| 2. Triggering a real request on grok.com to generate authentic headers | |
| 3. Capturing and storing the real x-statsig-id for reuse | |
| """ | |
| def __init__(self, proxy_url: Optional[str] = None): | |
| self._cached_statsig_id: Optional[str] = None | |
| self._cache_timestamp: Optional[int] = None | |
| self._cache_duration = 300 | |
| self._context: Optional[BrowserContext] = None | |
| self._playwright = None | |
| self._lock = threading.Lock() | |
| self._base_url = "https://grok.com/" | |
| self._proxy_url = proxy_url | |
| def _run_async(self, coro): | |
| """Run async function in thread-safe manner""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(asyncio.run, coro) | |
| return future.result() | |
| else: | |
| return loop.run_until_complete(coro) | |
| except RuntimeError: | |
| return asyncio.run(coro) | |
| async def _ensure_browser(self): | |
| """Ensure browser is available and ready""" | |
| if not self._context: | |
| self._playwright = await async_playwright().start() | |
| context_options = { | |
| "viewport": {"width": 1920, "height": 1080}, | |
| "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36", | |
| } | |
| if self._proxy_url: | |
| context_options["proxy"] = {"server": self._proxy_url} | |
| self._context = await self._playwright.chromium.launch_persistent_context( | |
| user_data_dir="./data/chrome", | |
| headless=True, | |
| no_viewport=True, | |
| channel="chrome", | |
| args=[ | |
| "--no-first-run", | |
| "--force-color-profile=srgb", | |
| "--metrics-recording-only", | |
| "--password-store=basic", | |
| "--no-default-browser-check", | |
| "--no-sandbox", | |
| "--disable-dev-shm-usage", | |
| "--disable-gpu", | |
| "--disable-web-security", | |
| "--disable-features=VizDisplayCompositor", | |
| "--user-agent=Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36", | |
| ], | |
| **context_options, | |
| ) | |
| async def check_real_ip(self) -> str: | |
| """Check the real IP address using Playwright browser""" | |
| try: | |
| await self._ensure_browser() | |
| page = await self._context.new_page() # type: ignore | |
| try: | |
| print("Checking real IP address via ipify API") | |
| await page.goto("https://api.ipify.org?format=json", timeout=30000) | |
| content = await page.content() | |
| ip_info = await page.evaluate( | |
| """ | |
| () => { | |
| try { | |
| const bodyText = document.body.textContent || document.body.innerText; | |
| return JSON.parse(bodyText); | |
| } catch (e) { | |
| return null; | |
| } | |
| } | |
| """ | |
| ) | |
| if ip_info and ip_info.get("ip"): | |
| ip_address = ip_info["ip"] | |
| print(f"Playwright real IP address: {ip_address}") | |
| return ip_address | |
| else: | |
| print("Failed to parse IP from ipify response") | |
| return "unknown" | |
| except Exception as e: | |
| print(f"Error checking IP address: {e}") | |
| return "error" | |
| finally: | |
| await page.close() | |
| except Exception as e: | |
| print(f"Failed to check real IP address: {e}") | |
| return "failed" | |
| async def _cleanup(self): | |
| """Clean up browser resources""" | |
| if self._context: | |
| await self._context.close() | |
| self._context = None | |
| if self._playwright: | |
| await self._playwright.stop() | |
| self._playwright = None | |
| def cleanup(self): | |
| """Synchronous cleanup wrapper""" | |
| if self._context: | |
| self._run_async(self._cleanup()) | |
| async def _patch_fetch_for_statsig(self, page): | |
| """Patch window.fetch to intercept x-statsig-id headers (adapted from driver.py)""" | |
| result = await page.evaluate( | |
| """ | |
| (() => { | |
| if (window.__fetchPatched) { | |
| return "fetch already patched"; | |
| } | |
| window.__fetchPatched = false; | |
| const originalFetch = window.fetch; | |
| window.__xStatsigId = null; | |
| window.fetch = async function(...args) { | |
| console.log("Intercepted fetch call with args:", args); | |
| const response = await originalFetch.apply(this, args); | |
| try { | |
| const req = args[0]; | |
| const opts = args[1] || {}; | |
| const url = typeof req === 'string' ? req : req.url; | |
| const headers = opts.headers || {}; | |
| const targetUrl = "https://grok.com/rest/app-chat/conversations/new"; | |
| if (url === targetUrl) { | |
| let id = null; | |
| if (headers["x-statsig-id"]) { | |
| id = headers["x-statsig-id"]; | |
| } else if (typeof opts.headers?.get === "function") { | |
| id = opts.headers.get("x-statsig-id"); | |
| } | |
| if (id) { | |
| window.__xStatsigId = id; | |
| console.log("Captured x-statsig-id:", id); | |
| } else { | |
| console.warn("x-statsig-id not found in headers"); | |
| } | |
| } else { | |
| console.log("Skipped fetch, URL doesn't match target:", url); | |
| } | |
| } catch (e) { | |
| console.warn("Error capturing x-statsig-id:", e); | |
| } | |
| return response; | |
| }; | |
| window.__fetchPatched = true; | |
| return "fetch successfully patched"; | |
| })() | |
| """ | |
| ) | |
| print(f"Fetch patching result: {result}") | |
| async def _initiate_answer(self, page): | |
| """Trigger a real request to grok.com to capture x-statsig-id""" | |
| try: | |
| await page.wait_for_selector("div.relative.z-10 textarea", timeout=2000) | |
| import random | |
| import string | |
| random_char = random.choice(string.ascii_lowercase) | |
| await page.fill("div.relative.z-10 textarea", random_char) | |
| await page.press("div.relative.z-10 textarea", "Enter") | |
| print(f"Triggered request with character: {random_char}") | |
| except Exception as e: | |
| print(f"Error triggering answer: {e}") | |
| title = await page.title() | |
| url = page.url | |
| print(f"Page title: {title}, URL: {url}") | |
| raise | |
| async def _capture_statsig_id_async( | |
| self, restart_session: bool = False | |
| ) -> Optional[str]: | |
| """Capture x-statsig-id from real grok.com interaction""" | |
| try: | |
| await self._ensure_browser() | |
| page = await self._context.new_page() # type: ignore | |
| try: | |
| print("Navigating to grok.com") | |
| await page.goto( | |
| self._base_url, wait_until="domcontentloaded", timeout=30000 | |
| ) | |
| await self._patch_fetch_for_statsig(page) | |
| captcha_visible = await page.evaluate( | |
| """ | |
| (() => { | |
| const elements = document.querySelectorAll("p"); | |
| for (const el of elements) { | |
| if (el.textContent.includes("Making sure you're human")) { | |
| const style = window.getComputedStyle(el); | |
| if (style.visibility !== 'hidden' && style.display !== 'none') { | |
| return true; | |
| } | |
| } | |
| } | |
| return false; | |
| })() | |
| """ | |
| ) | |
| if captcha_visible: | |
| print("Captcha detected, cannot capture x-statsig-id") | |
| return None | |
| await self._initiate_answer(page) | |
| try: | |
| await page.locator("div.message-bubble p[dir='auto']").or_( | |
| page.locator("div.w-full.max-w-\\[48rem\\]") | |
| ).or_( | |
| page.locator("p", has_text="Making sure you're human") | |
| ).wait_for( | |
| timeout=20000 | |
| ) | |
| except: | |
| print("No response elements found within timeout") | |
| error_elements = await page.query_selector_all( | |
| "div.w-full.max-w-\\[48rem\\]" | |
| ) | |
| if error_elements: | |
| print("Authentication error detected") | |
| return None | |
| captcha_elements = await page.query_selector_all( | |
| "p:has-text('Making sure you\\'re human')" | |
| ) | |
| if captcha_elements: | |
| print("Captcha appeared during request") | |
| return None | |
| statsig_id = await page.evaluate("window.__xStatsigId") | |
| if statsig_id: | |
| print(f"Successfully captured x-statsig-id: {statsig_id[:30]}...") | |
| return statsig_id | |
| else: | |
| print("No x-statsig-id was captured") | |
| return None | |
| finally: | |
| await page.close() | |
| except Exception as e: | |
| print(f"Error capturing x-statsig-id: {e}") | |
| return None | |
| def capture_statsig_id(self, restart_session: bool = False) -> Optional[str]: | |
| """Capture x-statsig-id (sync wrapper)""" | |
| with self._lock: | |
| return self._run_async(self._capture_statsig_id_async(restart_session)) | |
| def check_real_ip_sync(self) -> str: | |
| """Check real IP address (sync wrapper)""" | |
| with self._lock: | |
| return self._run_async(self.check_real_ip()) | |
| def generate_xai_request_id(self) -> str: | |
| """Generate x-xai-request-id (simple UUID)""" | |
| return str(uuid.uuid4()) | |
| def get_dynamic_headers( | |
| self, method: str = "POST", pathname: str = "/rest/app-chat/conversations/new" | |
| ) -> Dict[str, str]: | |
| """Get dynamic headers including captured x-statsig-id and x-xai-request-id""" | |
| headers = {} | |
| current_time = int(time.time()) | |
| if ( | |
| self._cached_statsig_id | |
| and self._cache_timestamp | |
| and (current_time - self._cache_timestamp) < self._cache_duration | |
| ): | |
| print("Using cached x-statsig-id") | |
| headers["x-statsig-id"] = self._cached_statsig_id | |
| else: | |
| print("Capturing fresh x-statsig-id") | |
| statsig_id = self.capture_statsig_id() | |
| if statsig_id: | |
| self._cached_statsig_id = statsig_id | |
| self._cache_timestamp = current_time | |
| headers["x-statsig-id"] = statsig_id | |
| else: | |
| print("Failed to capture x-statsig-id, using fallback") | |
| headers["x-statsig-id"] = ( | |
| "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" | |
| ) | |
| headers["x-xai-request-id"] = self.generate_xai_request_id() | |
| print(f"Generated dynamic headers: {list(headers.keys())}") | |
| return headers | |
| _global_statsig_manager: Optional[PlaywrightStatsigManager] = None | |
| def initialize_statsig_manager(proxy_url: Optional[str] = None) -> None: | |
| """Initialize the global StatsigManager instance with configuration""" | |
| global _global_statsig_manager | |
| if _global_statsig_manager is None: | |
| _global_statsig_manager = PlaywrightStatsigManager(proxy_url=proxy_url) | |
| def get_statsig_manager() -> PlaywrightStatsigManager: | |
| """Get or create the global StatsigManager instance""" | |
| global _global_statsig_manager | |
| if _global_statsig_manager is None: | |
| _global_statsig_manager = PlaywrightStatsigManager() | |
| return _global_statsig_manager | |
| class ModelType(Enum): | |
| """Supported Grok model types.""" | |
| GROK_3 = "grok-3" | |
| GROK_3_SEARCH = "grok-3-search" | |
| GROK_3_IMAGEGEN = "grok-3-imageGen" | |
| GROK_3_DEEPSEARCH = "grok-3-deepsearch" | |
| GROK_3_DEEPERSEARCH = "grok-3-deepersearch" | |
| GROK_3_REASONING = "grok-3-reasoning" | |
| GROK_4 = "grok-4" | |
| GROK_4_REASONING = "grok-4-reasoning" | |
| GROK_4_IMAGEGEN = "grok-4-imageGen" | |
| GROK_4_DEEPSEARCH = "grok-4-deepsearch" | |
| class TokenType(Enum): | |
| """Token privilege levels.""" | |
| NORMAL = "normal" | |
| SUPER = "super" | |
| class ResponseState(Enum): | |
| """Response processing states.""" | |
| IDLE = "idle" | |
| THINKING = "thinking" | |
| GENERATING_IMAGE = "generating_image" | |
| COMPLETE = "complete" | |
| MESSAGE_LENGTH_LIMIT = 40000 | |
| MAX_FILE_ATTACHMENTS = 4 | |
| DEFAULT_REQUEST_TIMEOUT = 120000 | |
| MAX_RETRY_ATTEMPTS = 3 | |
| BASE_RETRY_DELAY = 1.0 | |
| BASE_HEADERS = { | |
| "Accept": "*/*", | |
| "Accept-Language": "zh-CN,zh;q=0.9", | |
| "Accept-Encoding": "gzip, deflate, br, zstd", | |
| "Content-Type": "text/plain;charset=UTF-8", | |
| "Connection": "keep-alive", | |
| "Origin": "https://grok.com", | |
| "Priority": "u=1, i", | |
| "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36", | |
| "Sec-Ch-Ua": '"Not(A:Brand";v="99", "Google Chrome";v="133", "Chromium";v="133"', | |
| "Sec-Ch-Ua-Mobile": "?0", | |
| "Sec-Ch-Ua-Platform": '"macOS"', | |
| "Sec-Fetch-Dest": "empty", | |
| "Sec-Fetch-Mode": "cors", | |
| "Sec-Fetch-Site": "same-origin", | |
| "Baggage": "sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", | |
| } | |
| def get_dynamic_headers( | |
| method: str = "POST", pathname: str = "/rest/app-chat/conversations/new" | |
| ) -> Dict[str, str]: | |
| """ | |
| Get headers with dynamic x-statsig-id and x-xai-request-id | |
| Args: | |
| method: HTTP method for the request | |
| pathname: Request pathname for statsig generation | |
| Returns: | |
| Dictionary with all headers including dynamic ones | |
| """ | |
| try: | |
| headers = BASE_HEADERS.copy() | |
| statsig_manager = get_statsig_manager() | |
| dynamic_headers = statsig_manager.get_dynamic_headers(method, pathname) | |
| headers.update(dynamic_headers) | |
| print(f"Generated dynamic headers for {method} {pathname}") | |
| return headers | |
| except Exception as e: | |
| print(f"Error generating dynamic headers: {e}") | |
| headers = BASE_HEADERS.copy() | |
| headers["x-xai-request-id"] = str(uuid.uuid4()) | |
| headers["x-statsig-id"] = ( | |
| "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" | |
| ) | |
| return headers | |
| class GrokApiException(Exception): | |
| """Base exception for Grok API errors.""" | |
| def __init__(self, message: str, error_code: str = "UNKNOWN_ERROR"): | |
| super().__init__(message) | |
| self.error_code = error_code | |
| class TokenException(GrokApiException): | |
| """Token-related exceptions.""" | |
| pass | |
| class ValidationException(GrokApiException): | |
| """Input validation exceptions.""" | |
| pass | |
| class RateLimitException(GrokApiException): | |
| """Rate limiting exceptions.""" | |
| pass | |
| class TokenCredential: | |
| """Represents a token credential with validation.""" | |
| sso_token: str | |
| token_type: TokenType | |
| def __post_init__(self): | |
| """Validate token format.""" | |
| if not self.sso_token or not self.sso_token.strip(): | |
| raise ValidationException("SSO token cannot be empty") | |
| if "sso=" not in self.sso_token: | |
| raise ValidationException("Invalid SSO token format") | |
| def from_raw_token( | |
| cls, raw_token: str, token_type: TokenType = TokenType.NORMAL | |
| ) -> "TokenCredential": | |
| """Create TokenCredential from raw SSO value.""" | |
| if not raw_token or not raw_token.strip(): | |
| raise ValidationException("Raw token cannot be empty") | |
| formatted_token = f"sso-rw={raw_token.strip()};sso={raw_token.strip()}" | |
| return cls(formatted_token, token_type) | |
| def extract_sso_value(self) -> str: | |
| """Extract the SSO value from the token.""" | |
| try: | |
| return self.sso_token.split("sso=")[1].split(";")[0] | |
| except (IndexError, AttributeError) as e: | |
| raise TokenException(f"Failed to parse SSO token: {self.sso_token}") from e | |
| class GeneratedImage: | |
| """Represents a generated image with metadata.""" | |
| url: str | |
| base_url: str = "https://assets.grok.com" | |
| cookies: List[Dict[str, Any]] = field(default_factory=list) | |
| def __post_init__(self): | |
| """Validate image data.""" | |
| if not self.url: | |
| raise ValidationException("Image URL cannot be empty") | |
| class ProcessingState: | |
| """Immutable state for response processing.""" | |
| is_thinking: bool = False | |
| is_generating_image: bool = False | |
| image_generation_phase: int = 0 | |
| def with_thinking(self, thinking: bool) -> "ProcessingState": | |
| """Return new state with updated thinking status.""" | |
| return ProcessingState( | |
| thinking, self.is_generating_image, self.image_generation_phase | |
| ) | |
| def with_image_generation( | |
| self, generating: bool, phase: int = 0 | |
| ) -> "ProcessingState": | |
| """Return new state with updated image generation status.""" | |
| return ProcessingState(self.is_thinking, generating, phase) | |
| class ModelResponse: | |
| """Enhanced model response with proper validation and transformation.""" | |
| response_id: str | |
| message: str | |
| sender: str | |
| create_time: str | |
| parent_response_id: str | |
| manual: bool | |
| partial: bool | |
| shared: bool | |
| query: str | |
| query_type: str | |
| web_search_results: List[Any] = field(default_factory=list) | |
| xpost_ids: List[Any] = field(default_factory=list) | |
| xposts: List[Any] = field(default_factory=list) | |
| generated_images: List[GeneratedImage] = field(default_factory=list) | |
| image_attachments: List[Any] = field(default_factory=list) | |
| file_attachments: List[Any] = field(default_factory=list) | |
| card_attachments_json: List[Any] = field(default_factory=list) | |
| file_uris: List[Any] = field(default_factory=list) | |
| file_attachments_metadata: List[Any] = field(default_factory=list) | |
| is_control: bool = False | |
| steps: List[Any] = field(default_factory=list) | |
| media_types: List[Any] = field(default_factory=list) | |
| def from_api_response( | |
| cls, data: Dict[str, Any], enable_artifact_files: bool = False | |
| ) -> "ModelResponse": | |
| """Create ModelResponse from API response data with validation.""" | |
| try: | |
| response_id = str(data.get("responseId", "")) | |
| sender = str(data.get("sender", "")) | |
| create_time = str(data.get("createTime", "")) | |
| parent_response_id = str(data.get("parentResponseId", "")) | |
| query = str(data.get("query", "")) | |
| query_type = str(data.get("queryType", "")) | |
| manual = bool(data.get("manual", False)) | |
| partial = bool(data.get("partial", False)) | |
| shared = bool(data.get("shared", False)) | |
| is_control = bool(data.get("isControl", False)) | |
| raw_message = data.get("message", "") | |
| processed_message = cls._transform_xai_artifacts(str(raw_message)) | |
| generated_images = [] | |
| for image_url in data.get("generatedImageUrls", []): | |
| if image_url: | |
| generated_images.append(GeneratedImage(url=str(image_url))) | |
| return cls( | |
| response_id=response_id, | |
| message=processed_message, | |
| sender=sender, | |
| create_time=create_time, | |
| parent_response_id=parent_response_id, | |
| manual=manual, | |
| partial=partial, | |
| shared=shared, | |
| query=query, | |
| query_type=query_type, | |
| web_search_results=data.get("webSearchResults", []), | |
| xpost_ids=data.get("xpostIds", []), | |
| xposts=data.get("xposts", []), | |
| generated_images=generated_images, | |
| image_attachments=data.get("imageAttachments", []), | |
| file_attachments=data.get("fileAttachments", []), | |
| card_attachments_json=data.get("cardAttachmentsJson", []), | |
| file_uris=data.get("fileUris", []), | |
| file_attachments_metadata=data.get("fileAttachmentsMetadata", []), | |
| is_control=is_control, | |
| steps=data.get("steps", []), | |
| media_types=data.get("mediaTypes", []), | |
| ) | |
| except Exception as e: | |
| print(f"Failed to create ModelResponse: {e}") | |
| return cls( | |
| response_id="", | |
| message="Error processing response", | |
| sender="system", | |
| create_time=str(int(time.time())), | |
| parent_response_id="", | |
| manual=False, | |
| partial=False, | |
| shared=False, | |
| query="", | |
| query_type="", | |
| ) | |
| def _transform_xai_artifacts(text: str) -> str: | |
| """ | |
| Transform xaiArtifact blocks to proper markdown code blocks. | |
| Comprehensive version that handles all xaiArtifact formats including: | |
| 1. <xaiArtifact contentType="text/..."> blocks → ```<lang>\ncode\n``` | |
| 2. ```x-<lang>src format → ```<lang> | |
| 3. ```x-<lang> format → ```<lang> | |
| 4. Any xaiArtifact with artifact_id, title, etc. | |
| 5. Self-closing xaiArtifact tags | |
| """ | |
| if not text: | |
| return text | |
| def replace_artifact_with_content(match): | |
| full_match = match.group(0) | |
| content = match.group(1).strip() if match.group(1) else "" | |
| content_type_match = re.search(r'contentType="([^"]+)"', full_match) | |
| if content_type_match: | |
| content_type = content_type_match.group(1).strip() | |
| if "/" in content_type: | |
| lang = content_type.split("/")[-1] | |
| else: | |
| lang = content_type | |
| if content: | |
| return f"```{lang}\n{content}\n```" | |
| else: | |
| return "" | |
| else: | |
| return content | |
| text = re.sub( | |
| r"<xaiArtifact[^>]*?>(.*?)</xaiArtifact>", | |
| replace_artifact_with_content, | |
| text, | |
| flags=re.DOTALL, | |
| ) | |
| text = re.sub(r"<xaiArtifact[^>]*?/>", "", text) | |
| text = re.sub(r"<xaiArtifact[^>]*>", "", text) | |
| text = re.sub(r"</xaiArtifact>", "", text) | |
| text = re.sub( | |
| r"```x-([a-zA-Z0-9_+-]+)src\b", lambda m: f"```{m.group(1)}", text | |
| ) | |
| text = re.sub( | |
| r"```x-([a-zA-Z0-9_+-]+)\b(?![a-zA-Z0-9_-]*src)", | |
| lambda m: f"```{m.group(1)}", | |
| text, | |
| ) | |
| return text | |
| class GrokResponse: | |
| """Complete Grok API response wrapper.""" | |
| model_response: ModelResponse | |
| is_thinking: bool = False | |
| is_soft_stop: bool = False | |
| response_id: str = "" | |
| conversation_id: Optional[str] = None | |
| title: Optional[str] = None | |
| conversation_create_time: Optional[str] = None | |
| conversation_modify_time: Optional[str] = None | |
| temporary: Optional[bool] = None | |
| error: Optional[str] = None | |
| error_code: Optional[Union[int, str]] = None | |
| def from_api_response( | |
| cls, data: Dict[str, Any], enable_artifact_files: bool = False | |
| ) -> "GrokResponse": | |
| """Create GrokResponse from API response data.""" | |
| try: | |
| error = data.get("error") | |
| error_code = data.get("error_code") | |
| result = data.get("result", {}) | |
| response_data = result.get("response", {}) | |
| model_response = ModelResponse.from_api_response( | |
| response_data.get("modelResponse", {}), enable_artifact_files | |
| ) | |
| is_thinking = bool(response_data.get("isThinking", False)) | |
| is_soft_stop = bool(response_data.get("isSoftStop", False)) | |
| response_id = str(response_data.get("responseId", "")) | |
| conversation_id = response_data.get("conversationId") | |
| new_title = result.get("newTitle") or result.get("title") | |
| title = new_title if new_title else None | |
| conversation_create_time = response_data.get("createTime") | |
| conversation_modify_time = response_data.get("modifyTime") | |
| temporary = response_data.get("temporary") | |
| return cls( | |
| model_response=model_response, | |
| is_thinking=is_thinking, | |
| is_soft_stop=is_soft_stop, | |
| response_id=response_id, | |
| conversation_id=conversation_id, | |
| title=title, | |
| conversation_create_time=conversation_create_time, | |
| conversation_modify_time=conversation_modify_time, | |
| temporary=temporary, | |
| error=error, | |
| error_code=error_code, | |
| ) | |
| except Exception as e: | |
| error_msg = str(e) | |
| return cls( | |
| model_response=ModelResponse.from_api_response({}), | |
| error=error_msg, | |
| error_code="RESPONSE_PARSING_ERROR", | |
| ) | |
| class ConfigurationManager: | |
| """Centralized configuration management with validation.""" | |
| def __init__(self): | |
| """Initialize configuration with environment variables.""" | |
| self.data_dir = Path("/data") | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| self._config = self._load_configuration() | |
| self._validate_configuration() | |
| def _load_configuration(self) -> Dict[str, Any]: | |
| """Load configuration from environment variables.""" | |
| return { | |
| "MODELS": { | |
| model.value: model.value.split("-")[0] + "-" + model.value.split("-")[1] | |
| for model in ModelType | |
| }, | |
| "API": { | |
| "IS_TEMP_CONVERSATION": self._get_bool_env( | |
| "IS_TEMP_CONVERSATION", True | |
| ), | |
| "IS_CUSTOM_SSO": self._get_bool_env("IS_CUSTOM_SSO", False), | |
| "BASE_URL": "https://grok.com", | |
| "API_KEY": os.environ.get("API_KEY", "sk-123456"), | |
| "PICGO_KEY": os.environ.get("PICGO_KEY"), | |
| "TUMY_KEY": os.environ.get("TUMY_KEY"), | |
| "RETRY_TIME": 1000, | |
| "PROXY": os.environ.get("PROXY"), | |
| }, | |
| "ADMIN": { | |
| "MANAGER_SWITCH": os.environ.get("MANAGER_SWITCH"), | |
| "PASSWORD": os.environ.get("ADMINPASSWORD"), | |
| }, | |
| "SERVER": { | |
| "CF_CLEARANCE": os.environ.get("CF_CLEARANCE"), | |
| "PORT": int(os.environ.get("PORT", 5200)), | |
| }, | |
| "RETRY": { | |
| "RETRYSWITCH": False, | |
| "MAX_ATTEMPTS": MAX_RETRY_ATTEMPTS, | |
| }, | |
| "TOKEN_STATUS_FILE": str(self.data_dir / "token_status.json"), | |
| "SHOW_THINKING": self._get_bool_env("SHOW_THINKING", False), | |
| "SHOW_SEARCH_RESULTS": self._get_bool_env("SHOW_SEARCH_RESULTS", True), | |
| "IS_SUPER_GROK": self._get_bool_env("IS_SUPER_GROK", False), | |
| "FILTERED_TAGS": self._get_list_env("FILTERED_TAGS", ["xaiArtifact"]), | |
| } | |
| def _get_bool_env(self, key: str, default: bool = False) -> bool: | |
| """Get boolean environment variable.""" | |
| return os.environ.get(key, str(default)).lower() == "true" | |
| def _get_list_env(self, key: str, default: List[str]) -> List[str]: | |
| """Get comma-separated list from environment variable.""" | |
| value = os.environ.get(key) | |
| if not value: | |
| return default | |
| return [tag.strip() for tag in value.split(",") if tag.strip()] | |
| def _validate_configuration(self) -> None: | |
| """Validate configuration settings.""" | |
| issues = [] | |
| if not os.environ.get("API_KEY"): | |
| issues.append("Missing required environment variable: API_KEY") | |
| if not self._config["API"]["IS_CUSTOM_SSO"]: | |
| sso_env = os.environ.get("SSO", "") | |
| sso_super_env = os.environ.get("SSO_SUPER", "") | |
| if not sso_env and not sso_super_env: | |
| issues.append( | |
| "No SSO tokens configured. Set SSO or SSO_SUPER environment variables." | |
| ) | |
| proxy = self._config["API"]["PROXY"] | |
| if proxy and not any( | |
| proxy.startswith(p) for p in ["http://", "https://", "socks5://"] | |
| ): | |
| issues.append(f"Invalid proxy format: {proxy}") | |
| if issues: | |
| for issue in issues: | |
| print(f"Configuration issue: {issue}") | |
| else: | |
| print("Configuration validation passed") | |
| def get(self, key_path: str, default: Any = None) -> Any: | |
| """Get configuration value using dot notation.""" | |
| keys = key_path.split(".") | |
| value = self._config | |
| for key in keys: | |
| if isinstance(value, dict) and key in value: | |
| value = value[key] | |
| else: | |
| return default | |
| return value | |
| def set(self, key_path: str, value: Any) -> None: | |
| """Set configuration value using dot notation.""" | |
| keys = key_path.split(".") | |
| config = self._config | |
| for key in keys[:-1]: | |
| if key not in config: | |
| config[key] = {} | |
| config = config[key] | |
| config[keys[-1]] = value | |
| def models(self) -> Dict[str, str]: | |
| """Get supported models mapping.""" | |
| return self._config["MODELS"] | |
| def data_directory(self) -> Path: | |
| """Get data directory path.""" | |
| return self.data_dir | |
| class UtilityFunctions: | |
| """Collection of utility functions for common operations.""" | |
| def get_proxy_configuration(proxy_url: Optional[str]) -> Dict[str, Any]: | |
| """Get proxy configuration for requests.""" | |
| if not proxy_url: | |
| return {} | |
| print(f"Using proxy: {proxy_url}") | |
| if proxy_url.startswith("socks5://"): | |
| proxy_config: Dict[str, Any] = {"proxy": proxy_url} | |
| if "@" in proxy_url: | |
| auth_part = proxy_url.split("@")[0].split("://")[1] | |
| if ":" in auth_part: | |
| username, password = auth_part.split(":", 1) | |
| proxy_config["proxy_auth"] = (username, password) | |
| return proxy_config | |
| else: | |
| return {"proxies": {"https": proxy_url, "http": proxy_url}} | |
| def organize_search_results(search_results: Dict[str, Any]) -> str: | |
| """Format search results for display.""" | |
| if not search_results or "results" not in search_results: | |
| return "" | |
| results = search_results["results"] | |
| formatted_results = [] | |
| for index, result in enumerate(results): | |
| title = result.get("title", "未知标题") | |
| url = result.get("url", "#") | |
| preview = result.get("preview", "无预览内容") | |
| formatted_result = ( | |
| f"\r\n<details><summary>资料[{index}]: {title}</summary>\r\n" | |
| f"{preview}\r\n\n[Link]({url})\r\n</details>" | |
| ) | |
| formatted_results.append(formatted_result) | |
| return "\n\n".join(formatted_results) | |
| def parse_error_response(response_text: str) -> Dict[str, Any]: | |
| """Parse error response with structured handling.""" | |
| if not response_text or not response_text.strip(): | |
| return { | |
| "error_code": "EMPTY_RESPONSE", | |
| "error": "Empty or invalid response received", | |
| "details": [], | |
| } | |
| try: | |
| try: | |
| response = json.loads(response_text) | |
| if isinstance(response, dict): | |
| if "error" in response: | |
| error = response["error"] | |
| if isinstance(error, dict): | |
| return { | |
| "error_code": error.get("code"), | |
| "error": error.get("message") or response_text, | |
| "details": ( | |
| error.get("details", []) | |
| if isinstance(error.get("details"), list) | |
| else [] | |
| ), | |
| } | |
| else: | |
| return { | |
| "error_code": "Unknown", | |
| "error": str(error), | |
| "details": [], | |
| } | |
| elif "message" in response: | |
| return { | |
| "error_code": response.get("code"), | |
| "error": response.get("message") or response_text, | |
| "details": ( | |
| response.get("details", []) | |
| if isinstance(response.get("details"), list) | |
| else [] | |
| ), | |
| } | |
| else: | |
| return { | |
| "error_code": "Unknown", | |
| "error": response_text, | |
| "details": [], | |
| } | |
| except json.JSONDecodeError: | |
| pass | |
| if " - " in response_text: | |
| json_str = response_text.split(" - ", 1)[1] | |
| response = json.loads(json_str) | |
| if isinstance(response, dict): | |
| if "error" in response: | |
| error = response["error"] | |
| if isinstance(error, dict): | |
| return { | |
| "error_code": error.get("code"), | |
| "error": error.get("message") or response_text, | |
| "details": ( | |
| error.get("details", []) | |
| if isinstance(error.get("details"), list) | |
| else [] | |
| ), | |
| } | |
| else: | |
| return { | |
| "error_code": "Unknown", | |
| "error": str(error), | |
| "details": [], | |
| } | |
| elif "message" in response: | |
| return { | |
| "error_code": response.get("code"), | |
| "error": response.get("message") or response_text, | |
| "details": ( | |
| response.get("details", []) | |
| if isinstance(response.get("details"), list) | |
| else [] | |
| ), | |
| } | |
| except (json.JSONDecodeError, KeyError, AttributeError) as e: | |
| print(f"Error parsing error response: {e}") | |
| return { | |
| "error_code": "Unknown", | |
| "error": response_text or "Unknown error occurred", | |
| "details": [], | |
| } | |
| def create_retry_decorator( | |
| max_attempts: int = MAX_RETRY_ATTEMPTS, base_delay: float = BASE_RETRY_DELAY | |
| ): | |
| """Create retry decorator with exponential backoff.""" | |
| def retry_decorator(func): | |
| def wrapper(*args, **kwargs): | |
| attempts = 0 | |
| last_error = None | |
| while attempts < max_attempts: | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| attempts += 1 | |
| last_error = e | |
| if attempts >= max_attempts: | |
| print( | |
| f"All retries failed ({max_attempts} attempts): {e}", | |
| "RetryMechanism", | |
| ) | |
| raise e | |
| delay = min(base_delay * (2 ** (attempts - 1)), 60) | |
| print( | |
| f"Retry {attempts}/{max_attempts}, delay {delay}s: {e}", | |
| "RetryMechanism", | |
| ) | |
| time.sleep(delay) | |
| raise last_error or Exception("Retry mechanism failed unexpectedly") | |
| return wrapper | |
| return retry_decorator | |
| async def run_in_thread_pool(func, *args, **kwargs): | |
| """Run synchronous function in thread pool for async compatibility.""" | |
| try: | |
| loop = asyncio.get_running_loop() | |
| ctx = contextvars.copy_context() | |
| func_call = functools.partial(ctx.run, func, *args, **kwargs) | |
| return await loop.run_in_executor(None, func_call) | |
| except RuntimeError: | |
| with ThreadPoolExecutor() as executor: | |
| future = executor.submit(func, *args, **kwargs) | |
| return future.result() | |
| def create_structured_error_response( | |
| error_data: Union[str, Dict[str, Any]], status_code: int = 500 | |
| ) -> Tuple[Dict[str, Any], int]: | |
| """Create structured error response.""" | |
| if isinstance(error_data, str): | |
| error_data = UtilityFunctions.parse_error_response(error_data) | |
| error_message = error_data.get("error", "Unknown error") | |
| error_code = error_data.get("error_code") | |
| error_details = error_data.get("details", []) | |
| if not error_message or error_message.strip() == "": | |
| error_message = "An error occurred while processing the request" | |
| return { | |
| "error": { | |
| "message": str(error_message), | |
| "type": "server_error", | |
| "code": str(error_code), | |
| "details": list(error_details) if error_details else [], | |
| } | |
| }, status_code | |
| class TokenEntry: | |
| """Represents a single token entry with usage tracking.""" | |
| credential: TokenCredential | |
| max_request_count: int | |
| request_count: int | |
| added_time: int | |
| start_call_time: Optional[int] = None | |
| def is_available(self) -> bool: | |
| """Check if token is available for use.""" | |
| return self.request_count < self.max_request_count | |
| def can_be_reset(self, expiration_time_ms: int, current_time_ms: int) -> bool: | |
| """Check if token can be reset based on expiration time.""" | |
| if not self.start_call_time: | |
| return False | |
| return current_time_ms - self.start_call_time >= expiration_time_ms | |
| def use_token(self) -> None: | |
| """Mark token as used.""" | |
| if not self.start_call_time: | |
| self.start_call_time = int(time.time() * 1000) | |
| self.request_count += 1 | |
| def reset_usage(self) -> None: | |
| """Reset token usage counters.""" | |
| self.request_count = 0 | |
| self.start_call_time = None | |
| class ModelLimits: | |
| """Configuration for model request limits.""" | |
| request_frequency: int | |
| expiration_time_ms: int | |
| class ThreadSafeTokenManager: | |
| """Thread-safe token management with proper synchronization.""" | |
| def __init__(self, config: ConfigurationManager): | |
| """Initialize token manager with configuration.""" | |
| self.config = config | |
| self._lock = threading.RLock() | |
| self._token_storage: Dict[str, List[TokenEntry]] = {} | |
| self._token_status: Dict[str, Dict[str, Dict[str, Any]]] = {} | |
| self._expired_tokens: List[Tuple[str, str, int, TokenType]] = [] | |
| self._super_limits = { | |
| "grok-3": ModelLimits(100, 3 * 60 * 60 * 1000), | |
| "grok-3-deepsearch": ModelLimits(30, 24 * 60 * 60 * 1000), | |
| "grok-3-deepersearch": ModelLimits(10, 3 * 60 * 60 * 1000), | |
| "grok-3-reasoning": ModelLimits(30, 3 * 60 * 60 * 1000), | |
| "grok-4": ModelLimits(20, 3 * 60 * 60 * 1000), | |
| } | |
| self._normal_limits = { | |
| "grok-4": ModelLimits(5, int(11.5 * 60 * 60 * 1000)), | |
| "grok-3": ModelLimits(20, 3 * 60 * 60 * 1000), | |
| "grok-3-deepsearch": ModelLimits(10, 24 * 60 * 60 * 1000), | |
| "grok-3-deepersearch": ModelLimits(3, 24 * 60 * 60 * 1000), | |
| "grok-3-reasoning": ModelLimits(8, 24 * 60 * 60 * 1000), | |
| } | |
| self._reset_timer_started = False | |
| self._load_token_status() | |
| def _normalize_model_name(self, model: str) -> str: | |
| """Normalize model name for consistent lookup.""" | |
| if model.startswith("grok-") and not any( | |
| keyword in model for keyword in ["deepsearch", "deepersearch", "reasoning"] | |
| ): | |
| parts = model.split("-") | |
| return f"{parts[0]}-{parts[1]}" if len(parts) >= 2 else model | |
| return model | |
| def _get_model_limits(self, token_type: TokenType) -> Dict[str, ModelLimits]: | |
| """Get model limits based on token type.""" | |
| return ( | |
| self._super_limits if token_type == TokenType.SUPER else self._normal_limits | |
| ) | |
| def _save_token_status(self) -> None: | |
| """Save token status to persistent storage.""" | |
| try: | |
| status_file = Path(self.config.get("TOKEN_STATUS_FILE")) | |
| with open(status_file, "w", encoding="utf-8") as f: | |
| json.dump(self._token_status, f, indent=2, ensure_ascii=False) | |
| print("Token status saved to file") | |
| except Exception as e: | |
| print(f"Failed to save token status: {e}") | |
| def _load_token_status(self) -> None: | |
| """Load token status from persistent storage and reconstruct token storage.""" | |
| try: | |
| status_file = Path(self.config.get("TOKEN_STATUS_FILE")) | |
| if status_file.exists(): | |
| with open(status_file, "r", encoding="utf-8") as f: | |
| self._token_status = json.load(f) | |
| print("Token status loaded from file") | |
| self._reconstruct_token_storage() | |
| except Exception as e: | |
| print(f"Failed to load token status: {e}") | |
| self._token_status = {} | |
| def _reconstruct_token_storage(self) -> None: | |
| """Reconstruct _token_storage from _token_status.""" | |
| try: | |
| reconstructed_count = 0 | |
| for sso_value, models_data in self._token_status.items(): | |
| for model, model_data in models_data.items(): | |
| is_super = model_data.get("isSuper", False) | |
| token_type = TokenType.SUPER if is_super else TokenType.NORMAL | |
| credential = TokenCredential.from_raw_token(sso_value, token_type) | |
| token_entry = TokenEntry( | |
| credential=credential, | |
| max_request_count=model_data.get("max_request_count", 20), | |
| request_count=model_data.get("request_count", 0), | |
| added_time=model_data.get( | |
| "added_time", int(time.time() * 1000) | |
| ), | |
| start_call_time=model_data.get("start_call_time"), | |
| ) | |
| if model_data.get("is_expired", False): | |
| print(f"Skipping expired token for {model}") | |
| continue | |
| if model not in self._token_storage: | |
| self._token_storage[model] = [] | |
| existing = next( | |
| ( | |
| entry | |
| for entry in self._token_storage[model] | |
| if entry.credential.sso_token == credential.sso_token | |
| ), | |
| None, | |
| ) | |
| if not existing: | |
| self._token_storage[model].append(token_entry) | |
| reconstructed_count += 1 | |
| if reconstructed_count > 0: | |
| print(f"Reconstructed {reconstructed_count} token entries") | |
| except Exception as e: | |
| print(f"Failed to reconstruct token storage: {e}") | |
| def record_token_failure( | |
| self, model: str, token_string: str, failure_reason: str, status_code: int | |
| ) -> None: | |
| """Record a token failure and potentially mark as expired.""" | |
| with self._lock: | |
| try: | |
| normalized_model = self._normalize_model_name(model) | |
| credential = TokenCredential(token_string, TokenType.NORMAL) | |
| sso_value = credential.extract_sso_value() | |
| if ( | |
| sso_value in self._token_status | |
| and normalized_model in self._token_status[sso_value] | |
| ): | |
| status = self._token_status[sso_value][normalized_model] | |
| status["failed_request_count"] = ( | |
| status.get("failed_request_count", 0) + 1 | |
| ) | |
| status["last_failure_time"] = int(time.time() * 1000) | |
| status["last_failure_reason"] = f"{status_code}: {failure_reason}" | |
| failure_threshold = 3 | |
| if status[ | |
| "failed_request_count" | |
| ] >= failure_threshold and status_code in [401, 403]: | |
| status["is_expired"] = True | |
| status["isValid"] = False | |
| print( | |
| f"Token marked as expired after {status['failed_request_count']} failures: {failure_reason}", | |
| "TokenManager", | |
| ) | |
| self._save_token_status() | |
| print( | |
| f"Recorded token failure for {model}: {failure_reason} (total failures: {status['failed_request_count']})", | |
| "TokenManager", | |
| ) | |
| except Exception as e: | |
| print(f"Failed to record token failure: {e}") | |
| def _is_token_expired(self, token_entry: TokenEntry, model: str) -> bool: | |
| """Check if a token is marked as expired.""" | |
| try: | |
| sso_value = token_entry.credential.extract_sso_value() | |
| if ( | |
| sso_value in self._token_status | |
| and model in self._token_status[sso_value] | |
| ): | |
| status = self._token_status[sso_value][model] | |
| return status.get("is_expired", False) | |
| return False | |
| except Exception as e: | |
| print(f"Failed to check token expiration: {e}") | |
| return False | |
| def add_token( | |
| self, credential: TokenCredential, is_initialization: bool = False | |
| ) -> bool: | |
| """Add token to the management system.""" | |
| with self._lock: | |
| try: | |
| model_limits = self._get_model_limits(credential.token_type) | |
| sso_value = credential.extract_sso_value() | |
| for model, limits in model_limits.items(): | |
| if model not in self._token_storage: | |
| self._token_storage[model] = [] | |
| existing_entry = next( | |
| ( | |
| entry | |
| for entry in self._token_storage[model] | |
| if entry.credential.sso_token == credential.sso_token | |
| ), | |
| None, | |
| ) | |
| if not existing_entry: | |
| token_entry = TokenEntry( | |
| credential=credential, | |
| max_request_count=limits.request_frequency, | |
| request_count=0, | |
| added_time=int(time.time() * 1000), | |
| ) | |
| self._token_storage[model].append(token_entry) | |
| if sso_value not in self._token_status: | |
| self._token_status[sso_value] = {} | |
| if model not in self._token_status[sso_value]: | |
| self._token_status[sso_value][model] = { | |
| "isValid": True, | |
| "invalidatedTime": None, | |
| "totalRequestCount": 0, | |
| "isSuper": credential.token_type == TokenType.SUPER, | |
| "max_request_count": limits.request_frequency, | |
| "request_count": 0, | |
| "added_time": token_entry.added_time, | |
| "start_call_time": None, | |
| "failed_request_count": 0, | |
| "is_expired": False, | |
| "last_failure_time": None, | |
| "last_failure_reason": None, | |
| } | |
| if not is_initialization: | |
| self._save_token_status() | |
| print( | |
| f"Token added successfully for type: {credential.token_type.value}", | |
| "TokenManager", | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"Failed to add token: {e}") | |
| return False | |
| def get_token_for_model(self, model: str) -> Optional[str]: | |
| """Get available token for specified model.""" | |
| with self._lock: | |
| normalized_model = self._normalize_model_name(model) | |
| if normalized_model not in self._token_storage: | |
| return None | |
| tokens = self._token_storage[normalized_model] | |
| if not tokens: | |
| return None | |
| for token_entry in tokens: | |
| if self._is_token_expired(token_entry, normalized_model): | |
| continue | |
| if token_entry.is_available(): | |
| token_entry.use_token() | |
| try: | |
| sso_value = token_entry.credential.extract_sso_value() | |
| if ( | |
| sso_value in self._token_status | |
| and normalized_model in self._token_status[sso_value] | |
| ): | |
| status = self._token_status[sso_value][normalized_model] | |
| status["totalRequestCount"] += 1 | |
| status["request_count"] = token_entry.request_count | |
| status["start_call_time"] = token_entry.start_call_time | |
| if ( | |
| token_entry.request_count | |
| >= token_entry.max_request_count | |
| ): | |
| status["isValid"] = False | |
| status["invalidatedTime"] = int(time.time() * 1000) | |
| except Exception as e: | |
| print(f"Failed to update token status: {e}") | |
| if not self._reset_timer_started: | |
| self._start_reset_timer() | |
| self._save_token_status() | |
| return token_entry.credential.sso_token | |
| return None | |
| def remove_token_from_model(self, model: str, token_string: str) -> bool: | |
| """Remove specific token from model.""" | |
| with self._lock: | |
| normalized_model = self._normalize_model_name(model) | |
| if normalized_model not in self._token_storage: | |
| return False | |
| tokens = self._token_storage[normalized_model] | |
| for i, token_entry in enumerate(tokens): | |
| if token_entry.credential.sso_token == token_string: | |
| removed_entry = tokens.pop(i) | |
| self._expired_tokens.append( | |
| ( | |
| token_string, | |
| normalized_model, | |
| int(time.time() * 1000), | |
| removed_entry.credential.token_type, | |
| ) | |
| ) | |
| print(f"Token removed from model {model}") | |
| return True | |
| return False | |
| def get_token_count_for_model(self, model: str) -> int: | |
| """Get available token count for model.""" | |
| with self._lock: | |
| normalized_model = self._normalize_model_name(model) | |
| if normalized_model not in self._token_storage: | |
| return 0 | |
| return len(self._token_storage[normalized_model]) | |
| def get_remaining_capacity(self) -> Dict[str, int]: | |
| """Get remaining request capacity for each model.""" | |
| with self._lock: | |
| capacity_map = {} | |
| for model, tokens in self._token_storage.items(): | |
| total_capacity = sum(entry.max_request_count for entry in tokens) | |
| used_requests = sum(entry.request_count for entry in tokens) | |
| capacity_map[model] = max(0, total_capacity - used_requests) | |
| return capacity_map | |
| def reduce_token_request_count(self, model: str, count: int) -> bool: | |
| """Reduce token request count (for error recovery).""" | |
| with self._lock: | |
| normalized_model = self._normalize_model_name(model) | |
| if normalized_model not in self._token_storage: | |
| return False | |
| tokens = self._token_storage[normalized_model] | |
| if not tokens: | |
| return False | |
| token_entry = tokens[0] | |
| original_count = token_entry.request_count | |
| token_entry.request_count = max(0, token_entry.request_count - count) | |
| reduction = original_count - token_entry.request_count | |
| try: | |
| sso_value = token_entry.credential.extract_sso_value() | |
| if ( | |
| sso_value in self._token_status | |
| and normalized_model in self._token_status[sso_value] | |
| ): | |
| status = self._token_status[sso_value][normalized_model] | |
| status["totalRequestCount"] = max( | |
| 0, status["totalRequestCount"] - reduction | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Failed to update token status during reduction: {e}", | |
| "TokenManager", | |
| ) | |
| return True | |
| def _start_reset_timer(self) -> None: | |
| """Start the token reset timer.""" | |
| def reset_expired_tokens(): | |
| while True: | |
| try: | |
| current_time = int(time.time() * 1000) | |
| with self._lock: | |
| tokens_to_remove = [] | |
| for token_info in self._expired_tokens: | |
| token, model, expired_time, token_type = token_info | |
| model_limits = self._get_model_limits(token_type) | |
| if model in model_limits: | |
| expiration_time = model_limits[model].expiration_time_ms | |
| if current_time - expired_time >= expiration_time: | |
| try: | |
| credential = TokenCredential(token, token_type) | |
| self._reactivate_token( | |
| model, credential, model_limits[model] | |
| ) | |
| tokens_to_remove.append(token_info) | |
| except Exception as e: | |
| print( | |
| f"Failed to reactivate token: {e}", | |
| "TokenManager", | |
| ) | |
| for token_info in tokens_to_remove: | |
| self._expired_tokens.remove(token_info) | |
| for model, tokens in self._token_storage.items(): | |
| for token_entry in tokens: | |
| token_type = token_entry.credential.token_type | |
| model_limits = self._get_model_limits(token_type) | |
| if model in model_limits: | |
| if token_entry.can_be_reset( | |
| model_limits[model].expiration_time_ms, | |
| current_time, | |
| ): | |
| token_entry.reset_usage() | |
| try: | |
| sso_value = ( | |
| token_entry.credential.extract_sso_value() | |
| ) | |
| if ( | |
| sso_value in self._token_status | |
| and model | |
| in self._token_status[sso_value] | |
| ): | |
| status = self._token_status[sso_value][ | |
| model | |
| ] | |
| status["isValid"] = True | |
| status["invalidatedTime"] = None | |
| status["totalRequestCount"] = 0 | |
| status["request_count"] = ( | |
| token_entry.request_count | |
| ) | |
| status["start_call_time"] = ( | |
| token_entry.start_call_time | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Failed to update status during reset: {e}", | |
| "TokenManager", | |
| ) | |
| self._save_token_status() | |
| except Exception as e: | |
| print(f"Error in token reset timer: {e}") | |
| time.sleep(3600) | |
| timer_thread = threading.Thread(target=reset_expired_tokens, daemon=True) | |
| timer_thread.start() | |
| self._reset_timer_started = True | |
| def _reactivate_token( | |
| self, model: str, credential: TokenCredential, limits: ModelLimits | |
| ) -> None: | |
| """Reactivate an expired token.""" | |
| existing = next( | |
| ( | |
| entry | |
| for entry in self._token_storage.get(model, []) | |
| if entry.credential.sso_token == credential.sso_token | |
| ), | |
| None, | |
| ) | |
| if not existing: | |
| if model not in self._token_storage: | |
| self._token_storage[model] = [] | |
| token_entry = TokenEntry( | |
| credential=credential, | |
| max_request_count=limits.request_frequency, | |
| request_count=0, | |
| added_time=int(time.time() * 1000), | |
| ) | |
| self._token_storage[model].append(token_entry) | |
| try: | |
| sso_value = credential.extract_sso_value() | |
| if sso_value in self._token_status: | |
| if model not in self._token_status[sso_value]: | |
| self._token_status[sso_value][model] = {} | |
| status = self._token_status[sso_value][model] | |
| status["isValid"] = True | |
| status["invalidatedTime"] = None | |
| status["totalRequestCount"] = 0 | |
| status["isSuper"] = credential.token_type == TokenType.SUPER | |
| status["max_request_count"] = token_entry.max_request_count | |
| status["request_count"] = token_entry.request_count | |
| status["added_time"] = token_entry.added_time | |
| status["start_call_time"] = token_entry.start_call_time | |
| except Exception as e: | |
| print(f"Failed to update reactivated token status: {e}") | |
| def delete_token(self, token_string: str) -> bool: | |
| """Delete token completely from the system.""" | |
| with self._lock: | |
| try: | |
| removed = False | |
| credential = TokenCredential(token_string, TokenType.NORMAL) | |
| sso_value = credential.extract_sso_value() | |
| for model in self._token_storage: | |
| tokens = self._token_storage[model] | |
| original_length = len(tokens) | |
| self._token_storage[model] = [ | |
| entry | |
| for entry in tokens | |
| if entry.credential.sso_token != token_string | |
| ] | |
| if len(self._token_storage[model]) < original_length: | |
| removed = True | |
| if sso_value in self._token_status: | |
| del self._token_status[sso_value] | |
| self._expired_tokens = [ | |
| token_info | |
| for token_info in self._expired_tokens | |
| if token_info[0] != token_string | |
| ] | |
| if removed: | |
| self._save_token_status() | |
| print(f"Token deleted successfully") | |
| return removed | |
| except Exception as e: | |
| print(f"Failed to delete token: {e}") | |
| return False | |
| def get_all_tokens(self) -> List[str]: | |
| """Get all token strings in the system.""" | |
| with self._lock: | |
| all_tokens = set() | |
| for tokens in self._token_storage.values(): | |
| for entry in tokens: | |
| all_tokens.add(entry.credential.sso_token) | |
| return list(all_tokens) | |
| def get_token_status_map(self) -> Dict[str, Dict[str, Dict[str, Any]]]: | |
| """Get complete token status mapping.""" | |
| with self._lock: | |
| return dict(self._token_status) | |
| def get_token_health_summary(self) -> Dict[str, Any]: | |
| """Get summary of token health across all models.""" | |
| with self._lock: | |
| summary = { | |
| "total_tokens": 0, | |
| "healthy_tokens": 0, | |
| "expired_tokens": 0, | |
| "rate_limited_tokens": 0, | |
| "tokens_with_failures": 0, | |
| "total_failures": 0, | |
| "by_model": {}, | |
| } | |
| unique_tokens = set() | |
| token_health_status = {} | |
| for sso_value, models_data in self._token_status.items(): | |
| unique_tokens.add(sso_value) | |
| is_expired = False | |
| is_rate_limited = False | |
| has_failures = False | |
| total_token_failures = 0 | |
| for model, model_data in models_data.items(): | |
| if model not in summary["by_model"]: | |
| summary["by_model"][model] = { | |
| "total": 0, | |
| "healthy": 0, | |
| "expired": 0, | |
| "rate_limited": 0, | |
| "with_failures": 0, | |
| } | |
| summary["by_model"][model]["total"] += 1 | |
| if model_data.get("is_expired", False): | |
| summary["by_model"][model]["expired"] += 1 | |
| is_expired = True | |
| elif not model_data.get("isValid", True): | |
| summary["by_model"][model]["rate_limited"] += 1 | |
| is_rate_limited = True | |
| else: | |
| summary["by_model"][model]["healthy"] += 1 | |
| failure_count = model_data.get("failed_request_count", 0) | |
| if failure_count > 0: | |
| summary["by_model"][model]["with_failures"] += 1 | |
| has_failures = True | |
| total_token_failures += failure_count | |
| token_health_status[sso_value] = { | |
| "is_expired": is_expired, | |
| "is_rate_limited": is_rate_limited, | |
| "has_failures": has_failures, | |
| "total_failures": total_token_failures, | |
| } | |
| summary["total_tokens"] = len(unique_tokens) | |
| for sso_value in unique_tokens: | |
| status = token_health_status[sso_value] | |
| if status["is_expired"]: | |
| summary["expired_tokens"] += 1 | |
| elif status["is_rate_limited"]: | |
| summary["rate_limited_tokens"] += 1 | |
| else: | |
| summary["healthy_tokens"] += 1 | |
| if status["has_failures"]: | |
| summary["tokens_with_failures"] += 1 | |
| summary["total_failures"] += status["total_failures"] | |
| return summary | |
| class ImageTypeInfo: | |
| """Information about an image type.""" | |
| mime_type: str | |
| file_name: str | |
| extension: str | |
| class ImageProcessor: | |
| """Handles image processing and type detection.""" | |
| IMAGE_SIGNATURES = { | |
| b"\xff\xd8\xff": ("jpg", "image/jpeg"), | |
| b"\x89PNG\r\n\x1a\n": ("png", "image/png"), | |
| b"GIF89a": ("gif", "image/gif"), | |
| b"GIF87a": ("gif", "image/gif"), | |
| } | |
| def is_base64_image(cls, s: str) -> bool: | |
| """Check if string is a valid base64 image by examining binary signatures.""" | |
| try: | |
| decoded = base64.b64decode(s, validate=True) | |
| return any(decoded.startswith(sig) for sig in cls.IMAGE_SIGNATURES) | |
| except Exception: | |
| return False | |
| def get_extension_and_mime_from_header(cls, data: bytes) -> tuple: | |
| """Detect image format from binary header.""" | |
| for sig, (ext, mime) in cls.IMAGE_SIGNATURES.items(): | |
| if data.startswith(sig): | |
| return ext, mime | |
| return "jpg", "image/jpeg" | |
| def get_image_type_info(cls, base64_string: str) -> ImageTypeInfo: | |
| """Enhanced image type detection with binary signature support.""" | |
| mime_type = "image/jpeg" | |
| extension = "jpg" | |
| if "data:image" in base64_string: | |
| matches = re.search( | |
| r"data:([a-zA-Z0-9]+\/[a-zA-Z0-9-.+]+);base64,", base64_string | |
| ) | |
| if matches: | |
| mime_type = matches.group(1) | |
| extension = mime_type.split("/")[1] | |
| else: | |
| try: | |
| image_data = base64.b64decode(base64_string, validate=True) | |
| extension, mime_type = cls.get_extension_and_mime_from_header( | |
| image_data | |
| ) | |
| except Exception: | |
| pass | |
| file_name = f"image.{extension}" | |
| return ImageTypeInfo(mime_type, file_name, extension) | |
| class FileUploadManager: | |
| """Handles file and image upload operations.""" | |
| def __init__( | |
| self, config: ConfigurationManager, token_manager: ThreadSafeTokenManager | |
| ): | |
| """Initialize file upload manager.""" | |
| self.config = config | |
| self.token_manager = token_manager | |
| def upload_text_file(self, content: str, model: str) -> str: | |
| """Upload text content as a file attachment.""" | |
| try: | |
| content_base64 = base64.b64encode(content.encode("utf-8")).decode("utf-8") | |
| upload_data = { | |
| "fileName": "message.txt", | |
| "fileMimeType": "text/plain", | |
| "content": content_base64, | |
| } | |
| print("Uploading text file") | |
| auth_token = self.token_manager.get_token_for_model(model) | |
| if not auth_token: | |
| raise TokenException(f"No available tokens for model: {model}") | |
| cf_clearance = self.config.get("SERVER.CF_CLEARANCE", "") | |
| cookie = f"{auth_token};{cf_clearance}" if cf_clearance else auth_token | |
| proxy_config = UtilityFunctions.get_proxy_configuration( | |
| self.config.get("API.PROXY") | |
| ) | |
| response = curl_requests.post( | |
| "https://grok.com/rest/app-chat/upload-file", | |
| headers={ | |
| **get_dynamic_headers("POST", "/rest/app-chat/upload-file"), | |
| "Cookie": cookie, | |
| }, | |
| json=upload_data, | |
| impersonate="chrome133a", | |
| timeout=60, | |
| **proxy_config, | |
| ) | |
| if response.status_code != 200: | |
| raise GrokApiException( | |
| f"File upload failed with status: {response.status_code}", | |
| "UPLOAD_FAILED", | |
| ) | |
| result = response.json() | |
| file_metadata_id = result.get("fileMetadataId", "") | |
| if not file_metadata_id: | |
| raise GrokApiException( | |
| "No file metadata ID in response", "INVALID_RESPONSE" | |
| ) | |
| print(f"Text file uploaded successfully: {file_metadata_id}", "FileUpload") | |
| return file_metadata_id | |
| except Exception as error: | |
| print(f"Text file upload failed: {error}") | |
| raise GrokApiException( | |
| f"Text file upload failed: {error}", "UPLOAD_ERROR" | |
| ) from error | |
| def upload_image(self, image_data: str, model: str) -> str: | |
| """Upload image with enhanced format support.""" | |
| try: | |
| if "data:image" in image_data: | |
| image_buffer = image_data.split(",")[1] | |
| else: | |
| image_buffer = image_data | |
| image_info = ImageProcessor.get_image_type_info(image_data) | |
| upload_data = { | |
| "fileName": image_info.file_name, | |
| "fileMimeType": image_info.mime_type, | |
| "content": image_buffer, | |
| } | |
| print("Uploading image file") | |
| auth_token = self.token_manager.get_token_for_model(model) | |
| if not auth_token: | |
| raise TokenException(f"No available tokens for model: {model}") | |
| cf_clearance = self.config.get("SERVER.CF_CLEARANCE", "") | |
| cookie = f"{auth_token};{cf_clearance}" if cf_clearance else auth_token | |
| proxy_config = UtilityFunctions.get_proxy_configuration( | |
| self.config.get("API.PROXY") | |
| ) | |
| response = curl_requests.post( | |
| "https://grok.com/rest/app-chat/upload-file", | |
| headers={ | |
| **get_dynamic_headers("POST", "/rest/app-chat/upload-file"), | |
| "Cookie": cookie, | |
| }, | |
| json=upload_data, | |
| impersonate="chrome133a", | |
| timeout=60, | |
| **proxy_config, | |
| ) | |
| if response.status_code != 200: | |
| print( | |
| f"Image upload failed with status: {response.status_code}", | |
| "ImageUpload", | |
| ) | |
| return "" | |
| result = response.json() | |
| file_metadata_id = result.get("fileMetadataId", "") | |
| if file_metadata_id: | |
| print(f"Image uploaded successfully: {file_metadata_id}", "ImageUpload") | |
| return file_metadata_id | |
| except Exception as error: | |
| print(f"Image upload failed: {error}") | |
| return "" | |
| class ProcessedMessage: | |
| """Result of message processing.""" | |
| content: str | |
| file_attachments: List[str] | |
| requires_file_upload: bool | |
| upload_content: str = "" | |
| class MessageContentProcessor: | |
| """Processes message content and handles complex formats.""" | |
| def __init__(self, file_upload_manager: FileUploadManager): | |
| """Initialize message processor.""" | |
| self.file_upload_manager = file_upload_manager | |
| def remove_think_tags_and_images(self, text: str) -> str: | |
| """Remove think tags and base64 images from text.""" | |
| text = re.sub(r"<think>[\s\S]*?<\/think>", "", text).strip() | |
| text = re.sub(r"!\[image\]\(data:.*?base64,.*?\)", "[图片]", text) | |
| return text | |
| def process_content_item(self, content_item: Any) -> str: | |
| """Process individual content item (text or image).""" | |
| if isinstance(content_item, list): | |
| text_parts = [] | |
| for item in content_item: | |
| if isinstance(item, dict): | |
| if item.get("type") == "image_url": | |
| text_parts.append("[图片]") | |
| elif item.get("type") == "text": | |
| text_parts.append( | |
| self.remove_think_tags_and_images(item.get("text", "")) | |
| ) | |
| return "\n".join(filter(None, text_parts)) | |
| elif isinstance(content_item, dict): | |
| if content_item.get("type") == "image_url": | |
| return "[图片]" | |
| elif content_item.get("type") == "text": | |
| return self.remove_think_tags_and_images(content_item.get("text", "")) | |
| elif isinstance(content_item, str): | |
| return self.remove_think_tags_and_images(content_item) | |
| return "" | |
| def extract_image_attachments(self, content_item: Any, model: str) -> List[str]: | |
| """Extract and upload image attachments from content.""" | |
| attachments = [] | |
| if isinstance(content_item, list): | |
| for item in content_item: | |
| if isinstance(item, dict) and item.get("type") == "image_url": | |
| image_url = item.get("image_url", {}).get("url", "") | |
| if image_url: | |
| file_id = self.file_upload_manager.upload_image( | |
| image_url, model | |
| ) | |
| if file_id: | |
| attachments.append(file_id) | |
| elif isinstance(content_item, dict) and content_item.get("type") == "image_url": | |
| image_url = content_item.get("image_url", {}).get("url", "") | |
| if image_url: | |
| file_id = self.file_upload_manager.upload_image(image_url, model) | |
| if file_id: | |
| attachments.append(file_id) | |
| return attachments | |
| def process_messages( | |
| self, messages: List[Dict[str, Any]], model: str | |
| ) -> ProcessedMessage: | |
| """Process list of messages into a single formatted string.""" | |
| formatted_messages = "" | |
| all_file_attachments = [] | |
| message_length = 0 | |
| requires_file_upload = False | |
| last_role = None | |
| last_content = "" | |
| for message in messages: | |
| role = "assistant" if message.get("role") == "assistant" else "user" | |
| is_last_message = message == messages[-1] | |
| if is_last_message and "content" in message: | |
| image_attachments = self.extract_image_attachments( | |
| message["content"], model | |
| ) | |
| all_file_attachments.extend(image_attachments) | |
| text_content = self.process_content_item(message.get("content", "")) | |
| if text_content or (is_last_message and all_file_attachments): | |
| if role == last_role and text_content: | |
| last_content += "\n" + text_content | |
| role_header = f"{role.upper()}: " | |
| last_index = formatted_messages.rindex(role_header) | |
| formatted_messages = ( | |
| formatted_messages[:last_index] | |
| + f"{role_header}{last_content}\n" | |
| ) | |
| else: | |
| content_to_add = text_content or "[图片]" | |
| formatted_messages += f"{role.upper()}: {content_to_add}\n" | |
| last_content = text_content | |
| last_role = role | |
| message_length += len(formatted_messages) | |
| if message_length >= MESSAGE_LENGTH_LIMIT: | |
| requires_file_upload = True | |
| if requires_file_upload: | |
| last_message = messages[-1] if messages else {} | |
| last_role = ( | |
| "assistant" if last_message.get("role") == "assistant" else "user" | |
| ) | |
| last_text = self.process_content_item(last_message.get("content", "")) | |
| final_content = f"{last_role.upper()}: {last_text or '[图片]'}" | |
| try: | |
| file_id = self.file_upload_manager.upload_text_file( | |
| formatted_messages, model | |
| ) | |
| if file_id: | |
| all_file_attachments.insert(0, file_id) | |
| formatted_messages = "基于txt文件内容进行回复:" | |
| except Exception as e: | |
| print(f"Failed to upload conversation file: {e}", "MessageProcessor") | |
| formatted_messages = final_content | |
| if not formatted_messages.strip(): | |
| if requires_file_upload: | |
| formatted_messages = "基于txt文件内容进行回复:" | |
| else: | |
| raise ValidationException("Message content is empty after processing") | |
| return ProcessedMessage( | |
| content=formatted_messages.strip(), | |
| file_attachments=all_file_attachments[:MAX_FILE_ATTACHMENTS], | |
| requires_file_upload=requires_file_upload, | |
| ) | |
| class ChatRequestConfig: | |
| """Configuration for chat request.""" | |
| model_name: str | |
| message: str | |
| file_attachments: List[str] | |
| enable_search: bool | |
| enable_image_generation: bool | |
| enable_reasoning: bool | |
| deepsearch_preset: str | |
| temporary_conversation: bool | |
| class GrokApiClient: | |
| """Clean, focused Grok API client with separated concerns.""" | |
| def __init__( | |
| self, config: ConfigurationManager, token_manager: ThreadSafeTokenManager | |
| ): | |
| """Initialize Grok API client.""" | |
| self.config = config | |
| self.token_manager = token_manager | |
| self.file_upload_manager = FileUploadManager(config, token_manager) | |
| self.message_processor = MessageContentProcessor(self.file_upload_manager) | |
| def validate_model_and_request( | |
| self, model: str, request_data: Dict[str, Any] | |
| ) -> str: | |
| """Validate model and request parameters.""" | |
| if model not in self.config.models: | |
| raise ValidationException(f"Unsupported model: {model}") | |
| if ( | |
| model in ["grok-4-imageGen", "grok-3-imageGen"] | |
| and request_data.get("stream", False) | |
| and not self.config.get("API.PICGO_KEY") | |
| and not self.config.get("API.TUMY_KEY") | |
| ): | |
| raise ValidationException( | |
| "Image generation models require PICGO or TUMY API key for streaming" | |
| ) | |
| return self.config.models[model] | |
| def determine_search_and_generation_settings(self, model: str) -> tuple: | |
| """Determine search and generation settings based on model.""" | |
| enable_search = model in [ | |
| "grok-4-deepsearch", | |
| "grok-3-search", | |
| "grok-3-deepsearch", | |
| "grok-3-deepersearch", | |
| ] | |
| enable_image_generation = model in ["grok-4-imageGen", "grok-3-imageGen"] | |
| enable_reasoning = model in [ | |
| "grok-3-reasoning", | |
| "grok-4-reasoning", | |
| "grok-3-deepsearch", | |
| "grok-3-deepersearch", | |
| "grok-4-deepsearch", | |
| ] | |
| deepsearch_preset = "" | |
| if model == "grok-3-deepsearch": | |
| deepsearch_preset = "default" | |
| elif model == "grok-3-deepersearch": | |
| deepsearch_preset = "deeper" | |
| return ( | |
| enable_search, | |
| enable_image_generation, | |
| enable_reasoning, | |
| deepsearch_preset, | |
| ) | |
| def validate_message_requirements( | |
| self, model: str, messages: List[Dict[str, Any]] | |
| ) -> None: | |
| """Validate message requirements for specific models.""" | |
| if model in ["grok-4-imageGen", "grok-3-imageGen", "grok-3-deepsearch"]: | |
| if not messages: | |
| raise ValidationException("Messages cannot be empty") | |
| last_message = messages[-1] | |
| if last_message.get("role") != "user": | |
| raise ValidationException( | |
| f"Model {model} requires the last message to be from user" | |
| ) | |
| def prepare_chat_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Prepare chat request with clean separation of concerns.""" | |
| try: | |
| model = str(request_data.get("model")) | |
| messages = request_data.get("messages", []) | |
| normalized_model = self.validate_model_and_request(model, request_data) | |
| self.validate_message_requirements(model, messages) | |
| ( | |
| enable_search, | |
| enable_image_generation, | |
| enable_reasoning, | |
| deepsearch_preset, | |
| ) = self.determine_search_and_generation_settings(model) | |
| if model in ["grok-4-imageGen", "grok-3-imageGen", "grok-3-deepsearch"]: | |
| messages = [messages[-1]] | |
| processed_message = self.message_processor.process_messages(messages, model) | |
| request_config = ChatRequestConfig( | |
| model_name=normalized_model, | |
| message=processed_message.content, | |
| file_attachments=processed_message.file_attachments, | |
| enable_search=enable_search, | |
| enable_image_generation=enable_image_generation, | |
| enable_reasoning=enable_reasoning, | |
| deepsearch_preset=deepsearch_preset, | |
| temporary_conversation=self.config.get( | |
| "API.IS_TEMP_CONVERSATION", False | |
| ), | |
| ) | |
| return self.build_request_payload(request_config) | |
| except Exception as e: | |
| print(f"Failed to prepare chat request: {e}") | |
| raise | |
| def build_request_payload(self, config: ChatRequestConfig) -> Dict[str, Any]: | |
| """Build the final request payload.""" | |
| return { | |
| "temporary": config.temporary_conversation, | |
| "modelName": config.model_name, | |
| "message": config.message, | |
| "fileAttachments": config.file_attachments, | |
| "imageAttachments": [], | |
| "disableSearch": not config.enable_search, | |
| "enableImageGeneration": config.enable_image_generation, | |
| "returnImageBytes": False, | |
| "returnRawGrokInXaiRequest": False, | |
| "enableImageStreaming": False, | |
| "imageGenerationCount": 1, | |
| "forceConcise": False, | |
| "toolOverrides": { | |
| "imageGen": config.enable_image_generation, | |
| "webSearch": config.enable_search, | |
| "xSearch": config.enable_search, | |
| "xMediaSearch": config.enable_search, | |
| "trendsSearch": config.enable_search, | |
| "xPostAnalyze": config.enable_search, | |
| }, | |
| "enableSideBySide": True, | |
| "sendFinalMetadata": True, | |
| "customPersonality": "", | |
| "deepsearchPreset": config.deepsearch_preset, | |
| "isReasoning": config.enable_reasoning, | |
| "disableTextFollowUps": True, | |
| } | |
| def make_request( | |
| self, payload: Dict[str, Any], model: str, stream: bool = False | |
| ) -> Tuple[requests.Response, str]: | |
| """Make the actual HTTP request to Grok API and return response with used token.""" | |
| auth_token = self.token_manager.get_token_for_model(model) | |
| if not auth_token: | |
| token_count = self.token_manager.get_token_count_for_model(model) | |
| if token_count == 0: | |
| raise TokenException( | |
| f"No tokens available for model: {model}. Please add tokens or check configuration." | |
| ) | |
| else: | |
| raise TokenException( | |
| f"All tokens for model {model} are currently rate limited. Please try again later." | |
| ) | |
| cf_clearance = self.config.get("SERVER.CF_CLEARANCE", "") | |
| cookie = f"{auth_token};{cf_clearance}" if cf_clearance else auth_token | |
| proxy_config = UtilityFunctions.get_proxy_configuration( | |
| self.config.get("API.PROXY") | |
| ) | |
| print(f"Making request to Grok API for model: {model}") | |
| try: | |
| response = curl_requests.post( | |
| f"{self.config.get('API.BASE_URL')}/rest/app-chat/conversations/new", | |
| headers={ | |
| **get_dynamic_headers("POST", "/rest/app-chat/conversations/new"), | |
| "Cookie": cookie, | |
| }, | |
| data=json.dumps(payload), | |
| impersonate="chrome133a", | |
| stream=stream, | |
| timeout=120, | |
| **proxy_config, | |
| ) | |
| print(f"Response status: {response.status_code}") | |
| return response, auth_token # type: ignore | |
| except Exception as e: | |
| print(f"HTTP request failed: {e}") | |
| raise GrokApiException(f"HTTP request failed: {e}", "REQUEST_FAILED") from e | |
| class ProcessingResult: | |
| """Result of processing a model response.""" | |
| token: Optional[str] = None | |
| image_url: Optional[str] = None | |
| new_state: Optional[ProcessingState] = None | |
| should_skip: bool = False | |
| class ModelResponseProcessor: | |
| """Stateless response processor for different model types.""" | |
| def __init__(self, config: ConfigurationManager): | |
| """Initialize response processor.""" | |
| self.config = config | |
| def process_response( | |
| self, response_data: Dict[str, Any], model: str, current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process model response based on model type and current state.""" | |
| try: | |
| if response_data.get("doImgGen") or response_data.get( | |
| "imageAttachmentInfo" | |
| ): | |
| new_state = current_state.with_image_generation(True) | |
| return ProcessingResult(new_state=new_state) | |
| if current_state.is_generating_image: | |
| cached_response = response_data.get("cachedImageGenerationResponse") | |
| if cached_response and not current_state.image_generation_phase: | |
| image_url = cached_response.get("imageUrl") | |
| if image_url: | |
| new_state = current_state.with_image_generation(True, 1) | |
| return ProcessingResult( | |
| image_url=image_url, new_state=new_state | |
| ) | |
| if model == "grok-3": | |
| return self._process_grok_3_response(response_data, current_state) | |
| elif model == "grok-3-search": | |
| return self._process_grok_3_search_response( | |
| response_data, current_state | |
| ) | |
| elif model in [ | |
| "grok-3-deepsearch", | |
| "grok-3-deepersearch", | |
| "grok-4-deepsearch", | |
| ]: | |
| return self._process_deepsearch_response(response_data, current_state) | |
| elif model == "grok-3-reasoning": | |
| return self._process_reasoning_response(response_data, current_state) | |
| elif model == "grok-4": | |
| return self._process_grok_4_response(response_data, current_state) | |
| elif model == "grok-4-reasoning": | |
| return self._process_grok_4_reasoning_response( | |
| response_data, current_state | |
| ) | |
| else: | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token) | |
| except Exception as e: | |
| print(f"Error processing {model} response: {e}", "ResponseProcessor") | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token) | |
| def _process_grok_3_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process Grok-3 model response.""" | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| def _process_grok_3_search_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process Grok-3 search model response.""" | |
| if response_data.get("webSearchResults") and self.config.get( | |
| "SHOW_SEARCH_RESULTS", True | |
| ): | |
| search_results = UtilityFunctions.organize_search_results( | |
| response_data["webSearchResults"] | |
| ) | |
| token = f"\r\n<think>{search_results}</think>\r\n" | |
| processed_token = self._transform_artifacts(token) | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| else: | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| def _process_deepsearch_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process deep search model responses with thinking state management.""" | |
| show_thinking = self.config.get("SHOW_THINKING", False) | |
| has_step_id = bool(response_data.get("messageStepId")) | |
| message_tag = response_data.get("messageTag", "") | |
| if has_step_id and not show_thinking: | |
| return ProcessingResult(should_skip=True, new_state=current_state) | |
| if has_step_id and not current_state.is_thinking: | |
| token = "<think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(True) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| if not has_step_id and current_state.is_thinking and message_tag == "final": | |
| token = "</think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(False) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| if ( | |
| has_step_id and current_state.is_thinking and message_tag == "assistant" | |
| ) or message_tag == "final": | |
| token = response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| if current_state.is_thinking and isinstance(response_data.get("token"), dict): | |
| token_dict = response_data.get("token", {}) | |
| if token_dict.get("action") == "webSearch": | |
| action_input = token_dict.get("action_input", {}) | |
| query = action_input.get("query", "") | |
| processed_token = self._transform_artifacts(query) if query else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| if current_state.is_thinking and response_data.get("webSearchResults"): | |
| search_results = UtilityFunctions.organize_search_results( | |
| response_data["webSearchResults"] | |
| ) | |
| processed_token = self._transform_artifacts(search_results) | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| return ProcessingResult(new_state=current_state) | |
| def _process_reasoning_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process reasoning model responses.""" | |
| show_thinking = self.config.get("SHOW_THINKING", False) | |
| is_thinking = response_data.get("isThinking", False) | |
| if is_thinking and not show_thinking: | |
| return ProcessingResult(should_skip=True, new_state=current_state) | |
| if is_thinking and not current_state.is_thinking: | |
| token = "<think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(True) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| if not is_thinking and current_state.is_thinking: | |
| token = "</think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(False) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| def _process_grok_4_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process Grok-4 model response.""" | |
| if response_data.get("isThinking"): | |
| return ProcessingResult(should_skip=True, new_state=current_state) | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| def _process_grok_4_reasoning_response( | |
| self, response_data: Dict[str, Any], current_state: ProcessingState | |
| ) -> ProcessingResult: | |
| """Process Grok-4 reasoning model response.""" | |
| show_thinking = self.config.get("SHOW_THINKING", False) | |
| is_thinking = response_data.get("isThinking", False) | |
| message_tag = response_data.get("messageTag", "") | |
| if is_thinking and not show_thinking: | |
| return ProcessingResult(should_skip=True, new_state=current_state) | |
| if is_thinking and not current_state.is_thinking and message_tag == "assistant": | |
| token = "<think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(True) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| if not is_thinking and current_state.is_thinking and message_tag == "final": | |
| token = "</think>" + response_data.get("token", "") | |
| processed_token = self._transform_artifacts(token) | |
| new_state = current_state.with_thinking(False) | |
| return ProcessingResult(token=processed_token, new_state=new_state) | |
| token = response_data.get("token") | |
| processed_token = self._transform_artifacts(token) if token else None | |
| return ProcessingResult(token=processed_token, new_state=current_state) | |
| def _transform_artifacts(self, text: Any) -> str: | |
| """Artifact transformation now handled at streaming level - return unchanged.""" | |
| if not text: | |
| return "" | |
| return str(text) if not isinstance(text, str) else text | |
| class ResponseImageHandler: | |
| """Handles image responses and uploads to image hosting services.""" | |
| def __init__(self, config: ConfigurationManager): | |
| """Initialize image handler.""" | |
| self.config = config | |
| def handle_image_response(self, image_url: str) -> str: | |
| """Process image response and return appropriate format.""" | |
| max_retries = 2 | |
| retry_count = 0 | |
| image_data = None | |
| while retry_count < max_retries: | |
| try: | |
| proxy_config = UtilityFunctions.get_proxy_configuration( | |
| self.config.get("API.PROXY") | |
| ) | |
| response = curl_requests.get( | |
| f"https://assets.grok.com/{image_url}", | |
| headers=get_dynamic_headers("GET", f"/assets/{image_url}"), | |
| impersonate="chrome133a", | |
| timeout=60, | |
| **proxy_config, | |
| ) | |
| if response.status_code == 200: | |
| image_data = response.content | |
| break | |
| retry_count += 1 | |
| if retry_count == max_retries: | |
| raise GrokApiException( | |
| f"Failed to retrieve image after {max_retries} attempts: {response.status_code}", | |
| "IMAGE_RETRIEVAL_FAILED", | |
| ) | |
| time.sleep(self.config.get("API.RETRY_TIME", 1000) / 1000 * retry_count) | |
| except Exception as error: | |
| retry_count += 1 | |
| if retry_count == max_retries: | |
| print(f"Image retrieval failed: {error}") | |
| raise | |
| time.sleep(self.config.get("API.RETRY_TIME", 1000) / 1000 * retry_count) | |
| if not image_data: | |
| raise GrokApiException("No image data retrieved", "NO_IMAGE_DATA") | |
| picgo_key = self.config.get("API.PICGO_KEY") | |
| tumy_key = self.config.get("API.TUMY_KEY") | |
| if not picgo_key and not tumy_key: | |
| base64_image = base64.b64encode(image_data).decode("utf-8") | |
| content_type = "image/jpeg" | |
| return f"" | |
| print("Uploading to image hosting service") | |
| if picgo_key: | |
| return self._upload_to_picgo(image_data, picgo_key) | |
| elif tumy_key: | |
| return self._upload_to_tumy(image_data, tumy_key) | |
| base64_image = base64.b64encode(image_data).decode("utf-8") | |
| return f"" | |
| def _upload_to_picgo(self, image_data: bytes, api_key: str) -> str: | |
| """Upload image to PicGo service.""" | |
| try: | |
| files = {"source": ("image.jpg", image_data, "image/jpeg")} | |
| headers = {"X-API-Key": api_key} | |
| response = requests.post( | |
| "https://www.picgo.net/api/1/upload", | |
| files=files, | |
| headers=headers, | |
| timeout=30, | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| image_url = result.get("image", {}).get("url", "") | |
| if image_url: | |
| print("Image uploaded to PicGo successfully", "ImageHandler") | |
| return f"" | |
| print(f"PicGo upload failed: {response.status_code}", "ImageHandler") | |
| return "Image generation failed - PicGo upload error" | |
| except Exception as e: | |
| print(f"PicGo upload exception: {e}") | |
| return "Image generation failed - PicGo service error" | |
| def _upload_to_tumy(self, image_data: bytes, api_key: str) -> str: | |
| """Upload image to TuMy service.""" | |
| try: | |
| files = {"file": ("image.jpg", image_data, "image/jpeg")} | |
| headers = { | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {api_key}", | |
| } | |
| response = requests.post( | |
| "https://tu.my/api/v1/upload", files=files, headers=headers, timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| image_url = result.get("data", {}).get("links", {}).get("url", "") | |
| if image_url: | |
| print("Image uploaded to TuMy successfully", "ImageHandler") | |
| return f"" | |
| print(f"TuMy upload failed: {response.status_code}", "ImageHandler") | |
| return "Image generation failed - TuMy upload error" | |
| except Exception as e: | |
| print(f"TuMy upload exception: {e}") | |
| return "Image generation failed - TuMy service error" | |
| class StreamingTagFilter: | |
| """Configurable streaming filter for removing specified XML/HTML tags.""" | |
| def __init__(self, filtered_tags: List[str] = []): | |
| """ | |
| Initialize filter with configurable tag list. | |
| Args: | |
| filtered_tags: List of tag names to filter (case-insensitive) | |
| e.g., ["xaiArtifact", "someOtherTag"] | |
| """ | |
| self.buffer = "" | |
| self.filtered_tags = [tag.lower() for tag in (filtered_tags or ["xaiArtifact"])] | |
| def _starts_with_special_tag(self, text: str) -> bool: | |
| """Check if text starts with any of our special tag patterns.""" | |
| if not text.startswith("<"): | |
| return False | |
| text_lower = text.lower() | |
| if text_lower.startswith("</"): | |
| tag_part = text_lower[2:] | |
| for tag in self.filtered_tags: | |
| if tag.startswith(tag_part) or text_lower.startswith(f"</{tag}"): | |
| return True | |
| else: | |
| tag_part = text_lower[1:] | |
| for tag in self.filtered_tags: | |
| if tag.startswith(tag_part) or text_lower.startswith(f"<{tag}"): | |
| return True | |
| return False | |
| def _extract_tag_name(self, tag_text: str) -> str: | |
| """Extract the tag name from a complete tag.""" | |
| if not tag_text.startswith("<"): | |
| return "" | |
| if tag_text.startswith("</"): | |
| content = tag_text[2:] | |
| else: | |
| content = tag_text[1:] | |
| content = content.rstrip("/>") | |
| return content.split()[0].lower() if content.split() else content.lower() | |
| def filter_chunk(self, chunk: str) -> str: | |
| """Filter chunk by removing complete special tags.""" | |
| if not chunk: | |
| return "" | |
| self.buffer += chunk | |
| result = "" | |
| i = 0 | |
| while i < len(self.buffer): | |
| if self.buffer[i] != "<": | |
| result += self.buffer[i] | |
| i += 1 | |
| continue | |
| j = i + 1 | |
| while j < len(self.buffer) and self.buffer[j] != ">": | |
| j += 1 | |
| if j >= len(self.buffer): | |
| partial_tag = self.buffer[i:] | |
| if self._starts_with_special_tag(partial_tag): | |
| break | |
| else: | |
| result += self.buffer[i] | |
| i += 1 | |
| continue | |
| complete_tag = self.buffer[i : j + 1] | |
| tag_name = self._extract_tag_name(complete_tag) | |
| if tag_name in self.filtered_tags: | |
| pass | |
| else: | |
| result += complete_tag | |
| i = j + 1 | |
| if i < len(self.buffer): | |
| self.buffer = self.buffer[i:] | |
| else: | |
| self.buffer = "" | |
| return result | |
| def finalize(self) -> str: | |
| """Get any remaining content from buffer when stream ends.""" | |
| result = self.buffer | |
| self.buffer = "" | |
| return result | |
| class StreamingContext: | |
| """Context for streaming response processing.""" | |
| model: str | |
| processor: ModelResponseProcessor | |
| image_handler: ResponseImageHandler | |
| tag_filter: StreamingTagFilter | |
| state: ProcessingState = field(default_factory=ProcessingState) | |
| class StreamProcessor: | |
| """Handles streaming response processing.""" | |
| def process_non_stream_response( | |
| response: requests.Response, context: StreamingContext | |
| ) -> str: | |
| """Process non-streaming response and return complete content.""" | |
| print("Processing non-streaming response") | |
| full_response = "" | |
| current_state = context.state | |
| try: | |
| for chunk in response.iter_lines(): | |
| if not chunk: | |
| continue | |
| try: | |
| line_data = json.loads(chunk.decode("utf-8").strip()) | |
| if line_data.get("error"): | |
| print( | |
| f"API error: {json.dumps(line_data, indent=2)}", | |
| "StreamProcessor", | |
| ) | |
| return json.dumps({"error": "RateLimitError"}) + "\n\n" | |
| response_data = line_data.get("result", {}).get("response") | |
| if not response_data: | |
| continue | |
| result = context.processor.process_response( | |
| response_data, context.model, current_state | |
| ) | |
| if result.new_state: | |
| current_state = result.new_state | |
| if result.should_skip: | |
| continue | |
| if result.token: | |
| # Apply streaming tag filter to remove configured tags | |
| filtered_token = context.tag_filter.filter_chunk(result.token) | |
| if filtered_token: | |
| full_response += filtered_token | |
| if result.image_url: | |
| image_content = context.image_handler.handle_image_response( | |
| result.image_url | |
| ) | |
| return image_content | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f"Error processing response line: {e}", "StreamProcessor") | |
| continue | |
| # Finalize the tag filter to get any remaining content | |
| final_content = context.tag_filter.finalize() | |
| if final_content: | |
| full_response += final_content | |
| return full_response | |
| except Exception as e: | |
| print(f"Non-stream response processing failed: {e}", "StreamProcessor") | |
| raise | |
| def process_stream_response( | |
| response: requests.Response, context: StreamingContext | |
| ) -> Iterator[str]: | |
| """Process streaming response and yield formatted chunks.""" | |
| print("Processing streaming response") | |
| current_state = context.state | |
| try: | |
| for chunk in response.iter_lines(): | |
| if not chunk: | |
| continue | |
| try: | |
| line_data = json.loads(chunk.decode("utf-8").strip()) | |
| if line_data.get("error"): | |
| print( | |
| f"API error: {json.dumps(line_data, indent=2)}", | |
| "StreamProcessor", | |
| ) | |
| yield json.dumps({"error": "RateLimitError"}) + "\n\n" | |
| return | |
| response_data = line_data.get("result", {}).get("response") | |
| if not response_data: | |
| continue | |
| result = context.processor.process_response( | |
| response_data, context.model, current_state | |
| ) | |
| if result.new_state: | |
| current_state = result.new_state | |
| if result.should_skip: | |
| continue | |
| if result.token: | |
| # Apply streaming tag filter to remove configured tags | |
| filtered_token = context.tag_filter.filter_chunk(result.token) | |
| # Only yield if we have content after filtering | |
| if filtered_token: | |
| formatted_response = ( | |
| MessageProcessor.create_chat_completion_chunk( | |
| filtered_token, context.model | |
| ) | |
| ) | |
| yield f"data: {json.dumps(formatted_response)}\n\n" | |
| if result.image_url: | |
| image_content = context.image_handler.handle_image_response( | |
| result.image_url | |
| ) | |
| formatted_response = ( | |
| MessageProcessor.create_chat_completion_chunk( | |
| image_content, context.model | |
| ) | |
| ) | |
| yield f"data: {json.dumps(formatted_response)}\n\n" | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f"Error processing stream line: {e}", "StreamProcessor") | |
| continue | |
| # Finalize the tag filter to get any remaining content | |
| final_content = context.tag_filter.finalize() | |
| if final_content: | |
| formatted_response = MessageProcessor.create_chat_completion_chunk( | |
| final_content, context.model | |
| ) | |
| yield f"data: {json.dumps(formatted_response)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| print(f"Stream processing failed: {e}") | |
| error_response = MessageProcessor.create_error_response(str(e)) | |
| yield f"data: {json.dumps(error_response)}\n\n" | |
| class MessageProcessor: | |
| """Creates properly formatted chat completion responses.""" | |
| def create_chat_completion(message: str, model: str) -> Dict[str, Any]: | |
| """Create a complete chat completion response.""" | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": message}, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": None, | |
| } | |
| def create_chat_completion_chunk(message: str, model: str) -> Dict[str, Any]: | |
| """Create a streaming chat completion chunk.""" | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {"content": message}}], | |
| } | |
| def create_error_response(error_message: str) -> Dict[str, Any]: | |
| """Create an error response.""" | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "error": {"message": error_message, "type": "server_error"}, | |
| } | |
| class AuthenticationService: | |
| """Handles authentication and authorization.""" | |
| def __init__( | |
| self, config: ConfigurationManager, token_manager: ThreadSafeTokenManager | |
| ): | |
| """Initialize authentication service.""" | |
| self.config = config | |
| self.token_manager = token_manager | |
| def validate_api_key(self, auth_header: Optional[str]) -> str: | |
| """Validate API key from authorization header.""" | |
| if not auth_header: | |
| raise ValidationException("Authorization header missing") | |
| auth_token = auth_header.replace("Bearer ", "").strip() | |
| if not auth_token: | |
| raise ValidationException("API key missing") | |
| return auth_token | |
| def process_authentication(self, auth_header: Optional[str]) -> bool: | |
| """Process authentication and add token if needed.""" | |
| try: | |
| auth_token = self.validate_api_key(auth_header) | |
| if self.config.get("API.IS_CUSTOM_SSO", False): | |
| try: | |
| credential = TokenCredential.from_raw_token( | |
| auth_token, TokenType.NORMAL | |
| ) | |
| success = self.token_manager.add_token(credential) | |
| if not success: | |
| print("Failed to add custom SSO token", "AuthenticationService") | |
| return True | |
| except Exception as e: | |
| print( | |
| f"Failed to process custom SSO token: {e}", | |
| "AuthenticationService", | |
| ) | |
| raise ValidationException(f"Invalid SSO token format: {e}") | |
| else: | |
| expected_key = self.config.get("API.API_KEY", "sk-123456") | |
| if auth_token != expected_key: | |
| print(f"Invalid API key provided", "AuthenticationService") | |
| raise ValidationException("Invalid API key") | |
| return True | |
| except ValidationException: | |
| raise | |
| except Exception as e: | |
| print(f"Authentication processing failed: {e}", "AuthenticationService") | |
| raise ValidationException(f"Authentication failed: {e}") from e | |
| def create_app(config: ConfigurationManager) -> Flask: | |
| """Create and configure Flask application.""" | |
| app = Flask(__name__) | |
| app.config["SECRET_KEY"] = secrets.token_urlsafe(32) | |
| app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) | |
| token_manager = ThreadSafeTokenManager(config) | |
| auth_service = AuthenticationService(config, token_manager) | |
| grok_client = GrokApiClient(config, token_manager) | |
| response_processor = ModelResponseProcessor(config) | |
| image_handler = ResponseImageHandler(config) | |
| def create_error_response( | |
| error_data: Union[str, Dict[str, Any]], status_code: int | |
| ) -> Tuple[Dict[str, Any], int]: | |
| """Create consistent error responses.""" | |
| return UtilityFunctions.create_structured_error_response( | |
| error_data, status_code | |
| ) | |
| def handle_validation_error(e: ValidationException) -> Tuple[Dict[str, Any], int]: | |
| """Handle validation exceptions.""" | |
| return create_error_response( | |
| {"error": str(e), "error_code": "VALIDATION_ERROR"}, 400 | |
| ) | |
| def handle_token_error(e: TokenException) -> Tuple[Dict[str, Any], int]: | |
| """Handle token-related exceptions.""" | |
| return create_error_response( | |
| {"error": str(e), "error_code": "TOKEN_ERROR"}, 429 | |
| ) | |
| def handle_rate_limit_error(e: RateLimitException) -> Tuple[Dict[str, Any], int]: | |
| """Handle rate limiting exceptions.""" | |
| return create_error_response( | |
| {"error": str(e), "error_code": "RATE_LIMIT_ERROR"}, 429 | |
| ) | |
| def handle_grok_api_error(e: GrokApiException) -> Tuple[Dict[str, Any], int]: | |
| """Handle Grok API exceptions.""" | |
| return create_error_response({"error": str(e), "error_code": e.error_code}, 500) | |
| def handle_internal_error(e) -> Tuple[Dict[str, Any], int]: | |
| """Handle internal server errors.""" | |
| print(f"Internal server error: {e}") | |
| return create_error_response("Internal server error", 500) | |
| def chat_completions(): | |
| """Main chat completions endpoint with clean separation of concerns.""" | |
| response_status_code = 500 | |
| try: | |
| if not request.is_json: | |
| raise ValidationException("Request must be JSON") | |
| data = request.get_json() | |
| if not data: | |
| raise ValidationException("Request body is empty") | |
| auth_service.process_authentication(request.headers.get("Authorization")) | |
| model = data.get("model") | |
| messages = data.get("messages", []) | |
| stream = data.get("stream", False) | |
| if not model: | |
| raise ValidationException("Model parameter is required") | |
| if not messages: | |
| raise ValidationException("Messages parameter is required") | |
| print(f"Processing chat completion request for model: {model}") | |
| payload = grok_client.prepare_chat_request(data) | |
| response = None | |
| used_token = None | |
| retry_count = 0 | |
| max_retries = config.get("RETRY.MAX_ATTEMPTS", MAX_RETRY_ATTEMPTS) | |
| while retry_count < max_retries: | |
| try: | |
| response, used_token = grok_client.make_request( | |
| payload, model, stream | |
| ) | |
| response_status_code = response.status_code | |
| if response.status_code == 200: | |
| break | |
| elif response.status_code == 429: | |
| if token_manager.get_token_count_for_model(model) > 1: | |
| print("Rate limited, retrying with different token") | |
| retry_count += 1 | |
| continue | |
| else: | |
| raise RateLimitException( | |
| "Rate limit exceeded and no alternative tokens available" | |
| ) | |
| else: | |
| error_text = ( | |
| response.text | |
| if response.text | |
| else f"HTTP {response.status_code} error" | |
| ) | |
| print( | |
| f"API request failed with status {response.status_code}: {error_text}" | |
| ) | |
| if used_token: | |
| token_manager.record_token_failure( | |
| model, used_token, error_text, response.status_code | |
| ) | |
| error_data = UtilityFunctions.parse_error_response(error_text) | |
| raise GrokApiException( | |
| error_data.get( | |
| "error", | |
| f"API request failed with status {response.status_code}", | |
| ), | |
| error_data.get("error_code", "API_ERROR"), | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| retry_count += 1 | |
| if retry_count >= max_retries: | |
| raise GrokApiException( | |
| f"Request failed after {max_retries} attempts: {e}", | |
| "REQUEST_FAILED", | |
| ) | |
| delay = BASE_RETRY_DELAY * (2 ** (retry_count - 1)) | |
| print(f"Request failed, retrying in {delay}s: {e}") | |
| time.sleep(delay) | |
| if not response or response.status_code != 200: | |
| error_msg = f"Request failed with status: {response_status_code}" | |
| raise GrokApiException(error_msg, "REQUEST_FAILED") | |
| # Get configured filtered tags | |
| filtered_tags = config.get("FILTERED_TAGS", ["xaiArtifact"]) | |
| context = StreamingContext( | |
| model=model, | |
| processor=response_processor, | |
| image_handler=image_handler, | |
| tag_filter=StreamingTagFilter(filtered_tags), | |
| ) | |
| if stream: | |
| def generate(): | |
| try: | |
| yield from StreamProcessor.process_stream_response( | |
| response, context | |
| ) | |
| except Exception as e: | |
| print(f"Stream processing error: {e}") | |
| error_response = MessageProcessor.create_error_response(str(e)) | |
| yield f"data: {json.dumps(error_response)}\n\n" | |
| return Response( | |
| stream_with_context(generate()), | |
| content_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| else: | |
| message_content = StreamProcessor.process_non_stream_response( | |
| response, context | |
| ) | |
| formatted_response = MessageProcessor.create_chat_completion( | |
| message_content, model | |
| ) | |
| return jsonify(formatted_response) | |
| except ( | |
| ValidationException, | |
| TokenException, | |
| RateLimitException, | |
| GrokApiException, | |
| ): | |
| raise | |
| except Exception as e: | |
| print(f"Unexpected error in chat completions: {e}") | |
| raise GrokApiException("Internal server error", "INTERNAL_ERROR") from e | |
| def list_models(): | |
| """List available models.""" | |
| models_data = [] | |
| current_time = int(time.time()) | |
| for model_key in config.models.keys(): | |
| models_data.append( | |
| { | |
| "id": model_key, | |
| "object": "model", | |
| "created": current_time, | |
| "owned_by": "grok", | |
| } | |
| ) | |
| return jsonify({"object": "list", "data": models_data}) | |
| def health_check(): | |
| """Health check endpoint.""" | |
| return jsonify( | |
| {"status": "healthy", "timestamp": int(time.time()), "version": "2.0.0"} | |
| ) | |
| def index(): | |
| """Index page with basic information.""" | |
| return jsonify( | |
| {"message": "Grok API Gateway", "version": "2.0.0", "status": "running"} | |
| ) | |
| def check_admin_auth() -> bool: | |
| """Check admin authentication.""" | |
| if not config.get("ADMIN.MANAGER_SWITCH"): | |
| return False | |
| password = request.form.get("password") or request.args.get("password") | |
| expected_password = config.get("ADMIN.PASSWORD") | |
| return bool(password and expected_password and password == expected_password) | |
| def add_token(): | |
| """Add token endpoint (admin only).""" | |
| if not check_admin_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| token_data = request.form.get("tokens") or ( | |
| request.json and request.json.get("tokens") | |
| ) | |
| if not token_data: | |
| return jsonify({"error": "Token data required"}), 400 | |
| if isinstance(token_data, str): | |
| try: | |
| token_dict = json.loads(token_data) | |
| except json.JSONDecodeError: | |
| token_dict = {"token": token_data, "type": "normal"} | |
| else: | |
| token_dict = token_data | |
| token_string = token_dict.get("token", "") | |
| token_type_str = token_dict.get("type") | |
| token_type = ( | |
| TokenType.SUPER if token_type_str == "super" else TokenType.NORMAL | |
| ) | |
| if not token_string: | |
| return jsonify({"error": "Token string required"}), 400 | |
| credential = TokenCredential(token_string, token_type) | |
| success = token_manager.add_token(credential) | |
| if success: | |
| return jsonify({"message": "Token added successfully"}) | |
| else: | |
| return jsonify({"error": "Failed to add token"}), 500 | |
| except Exception as e: | |
| print(f"Error adding token: {e}") | |
| return jsonify({"error": f"Failed to add token: {e}"}), 500 | |
| def tokens_info(): | |
| """Get token information (admin only).""" | |
| if not check_admin_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| status_map = token_manager.get_token_status_map() | |
| capacity_map = token_manager.get_remaining_capacity() | |
| health_summary = token_manager.get_token_health_summary() | |
| return jsonify( | |
| { | |
| "token_status": status_map, | |
| "remaining_capacity": capacity_map, | |
| "health_summary": health_summary, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error getting token info: {e}") | |
| return jsonify({"error": f"Failed to get token info: {e}"}), 500 | |
| def check_session_auth() -> bool: | |
| """Check session-based authentication for web manager.""" | |
| return session.get("is_logged_in", False) | |
| def manager_login(): | |
| """Manager login page and handler.""" | |
| if not config.get("ADMIN.MANAGER_SWITCH"): | |
| return redirect("/") | |
| if request.method == "POST": | |
| password = request.form.get("password") | |
| if password == config.get("ADMIN.PASSWORD"): | |
| session["is_logged_in"] = True | |
| return redirect("/manager") | |
| return render_template("login.html", error=True) | |
| return render_template("login.html", error=False) | |
| def manager(): | |
| """Main manager dashboard.""" | |
| if not check_session_auth(): | |
| return redirect("/manager/login") | |
| return render_template("manager.html") | |
| def get_manager_tokens(): | |
| """Get tokens via manager API.""" | |
| if not check_session_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| status_map = token_manager.get_token_status_map() | |
| health_summary = token_manager.get_token_health_summary() | |
| return jsonify( | |
| {"token_status": status_map, "health_summary": health_summary} | |
| ) | |
| except Exception as e: | |
| print(f"Error getting manager tokens: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def add_manager_token(): | |
| """Add token via manager API.""" | |
| if not check_session_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "JSON data required"}), 400 | |
| sso = data.get("sso") | |
| if not sso or not sso.strip(): | |
| return ( | |
| jsonify({"error": "SSO token is required and cannot be empty"}), | |
| 400, | |
| ) | |
| credential = TokenCredential.from_raw_token(sso.strip(), TokenType.NORMAL) | |
| success = token_manager.add_token(credential) | |
| if success: | |
| return jsonify({"success": True}) | |
| else: | |
| return jsonify({"error": "Failed to add token"}), 500 | |
| except Exception as e: | |
| print(f"Error adding manager token: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def delete_manager_token(): | |
| """Delete token via manager API.""" | |
| if not check_session_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "JSON data required"}), 400 | |
| sso = data.get("sso") | |
| if not sso: | |
| return jsonify({"error": "SSO token is required"}), 400 | |
| token_string = f"sso-rw={sso};sso={sso}" | |
| success = token_manager.delete_token(token_string) | |
| if success: | |
| return jsonify({"success": True}) | |
| else: | |
| return jsonify({"error": "Token not found or failed to delete"}), 404 | |
| except Exception as e: | |
| print(f"Error deleting manager token: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def set_cf_clearance(): | |
| """Set CF clearance via manager API.""" | |
| if not check_session_auth(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "JSON data required"}), 400 | |
| cf_clearance = data.get("cf_clearance") | |
| if not cf_clearance: | |
| return jsonify({"error": "cf_clearance is required"}), 400 | |
| config.set("SERVER.CF_CLEARANCE", cf_clearance) | |
| return jsonify({"success": True}) | |
| except Exception as e: | |
| print(f"Error setting CF clearance: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def get_tokens(): | |
| """Legacy endpoint to get tokens.""" | |
| auth_token = request.headers.get("Authorization", "").replace("Bearer ", "") | |
| if config.get("API.IS_CUSTOM_SSO", False): | |
| return ( | |
| jsonify( | |
| {"error": "Custom SSO mode cannot get polling SSO token status"} | |
| ), | |
| 403, | |
| ) | |
| elif auth_token != config.get("API.API_KEY", "sk-123456"): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| return jsonify(token_manager.get_token_status_map()) | |
| except Exception as e: | |
| print(f"Error getting tokens: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def add_token_api(): | |
| """API endpoint to add tokens (API key auth).""" | |
| auth_token = request.headers.get("Authorization", "").replace("Bearer ", "") | |
| if config.get("API.IS_CUSTOM_SSO", False): | |
| return jsonify({"error": "Custom SSO mode cannot add SSO tokens"}), 403 | |
| elif auth_token != config.get("API.API_KEY", "sk-123456"): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "JSON data required"}), 400 | |
| sso = data.get("sso") | |
| if not sso or not sso.strip(): | |
| return ( | |
| jsonify({"error": "SSO token is required and cannot be empty"}), | |
| 400, | |
| ) | |
| credential = TokenCredential.from_raw_token(sso.strip(), TokenType.NORMAL) | |
| success = token_manager.add_token(credential) | |
| if success: | |
| return jsonify({"message": "Token added successfully"}) | |
| else: | |
| return jsonify({"error": "Failed to add token"}), 500 | |
| except Exception as e: | |
| print(f"Error adding token via API: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| return app | |
| def initialize_application( | |
| config: ConfigurationManager, token_manager: ThreadSafeTokenManager | |
| ) -> None: | |
| """Initialize the application with environment tokens.""" | |
| tokens_added = 0 | |
| sso_tokens = os.environ.get("SSO", "") | |
| if sso_tokens: | |
| for token in sso_tokens.split(","): | |
| token = token.strip() | |
| if token: | |
| try: | |
| credential = TokenCredential(token, TokenType.NORMAL) | |
| if token_manager.add_token(credential, is_initialization=True): | |
| tokens_added += 1 | |
| except Exception as e: | |
| print(f"Failed to add normal token: {e}", "Initialization") | |
| sso_super_tokens = os.environ.get("SSO_SUPER", "") | |
| if sso_super_tokens: | |
| for token in sso_super_tokens.split(","): | |
| token = token.strip() | |
| if token: | |
| try: | |
| credential = TokenCredential(token, TokenType.SUPER) | |
| if token_manager.add_token(credential, is_initialization=True): | |
| tokens_added += 1 | |
| except Exception as e: | |
| print(f"Failed to add super token: {e}", "Initialization") | |
| if tokens_added > 0: | |
| print(f"Successfully loaded {tokens_added} tokens") | |
| else: | |
| print("No tokens loaded during initialization") | |
| if not config.get("API.IS_CUSTOM_SSO", False): | |
| print( | |
| "Set SSO or SSO_SUPER environment variables, or enable IS_CUSTOM_SSO", | |
| "Initialization", | |
| ) | |
| proxy_url = config.get("API.PROXY") | |
| initialize_statsig_manager(proxy_url=proxy_url) | |
| if proxy_url: | |
| print(f"StatsigManager initialized with proxy: {proxy_url}", "Initialization") | |
| else: | |
| print("StatsigManager initialized without proxy") | |
| try: | |
| statsig_manager = get_statsig_manager() | |
| real_ip = statsig_manager.check_real_ip_sync() | |
| if real_ip and real_ip not in ["error", "failed", "unknown"]: | |
| print(f"Playwright real IP address: {real_ip}") | |
| else: | |
| print(f"Failed to get real IP address: {real_ip}") | |
| except Exception as e: | |
| print(f"Error checking real IP address: {e}") | |
| print("Application initialization completed") | |
| def cleanup_resources(): | |
| """Clean up browser resources before shutdown""" | |
| global _global_statsig_manager | |
| if _global_statsig_manager: | |
| try: | |
| _global_statsig_manager.cleanup() | |
| print("Browser resources cleaned up successfully") | |
| except Exception as e: | |
| print(f"Error cleaning up browser resources: {e}") | |
| def main(): | |
| """Main application entry point.""" | |
| try: | |
| config = ConfigurationManager() | |
| token_manager = ThreadSafeTokenManager(config) | |
| initialize_application(config, token_manager) | |
| app = create_app(config) | |
| port = config.get("SERVER.PORT", 5200) | |
| print(f"Starting Grok API Gateway on port {port}") | |
| import atexit | |
| atexit.register(cleanup_resources) | |
| app.run( | |
| host="0.0.0.0", | |
| port=port, | |
| debug=False, | |
| threaded=False, | |
| processes=1, | |
| ) | |
| except KeyboardInterrupt: | |
| print("Application stopped by user") | |
| cleanup_resources() | |
| except Exception as e: | |
| print(f"Application failed to start: {e}") | |
| cleanup_resources() | |
| sys.exit(1) | |
| finally: | |
| cleanup_resources() | |
| if __name__ == "__main__": | |
| main() | |