Spaces:
Paused
Paused
| import aiohttp | |
| import asyncio | |
| import json | |
| import re | |
| from typing import Dict, Any, List, Union, Optional, AsyncGenerator | |
| import time | |
| # Global cache for project IDs: {api_key: project_id} | |
| PROJECT_ID_CACHE: Dict[str, str] = {} | |
| class DirectVertexClient: | |
| """ | |
| A client that connects to Vertex AI using direct URLs instead of the SDK. | |
| Mimics the interface of genai.Client for seamless integration. | |
| """ | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self.project_id: Optional[str] = None | |
| self.base_url = "https://aiplatform.googleapis.com/v1" | |
| self.session: Optional[aiohttp.ClientSession] = None | |
| # Mimic the model_name attribute that might be accessed | |
| self.model_name = "direct_vertex_client" | |
| # Create nested structure to mimic genai.Client interface | |
| self.aio = self._AioNamespace(self) | |
| class _AioNamespace: | |
| def __init__(self, parent): | |
| self.parent = parent | |
| self.models = self._ModelsNamespace(parent) | |
| class _ModelsNamespace: | |
| def __init__(self, parent): | |
| self.parent = parent | |
| async def generate_content(self, model: str, contents: Any, config: Dict[str, Any]) -> Any: | |
| """Non-streaming content generation""" | |
| return await self.parent._generate_content(model, contents, config, stream=False) | |
| async def generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]): | |
| """Streaming content generation - returns an async generator""" | |
| # This needs to be an async method that returns the generator | |
| # to match the SDK's interface where you await the method call | |
| return self.parent._generate_content_stream(model, contents, config) | |
| async def _ensure_session(self): | |
| """Ensure aiohttp session is created""" | |
| if self.session is None: | |
| self.session = aiohttp.ClientSession() | |
| async def close(self): | |
| """Clean up resources""" | |
| if self.session: | |
| await self.session.close() | |
| self.session = None | |
| async def discover_project_id(self) -> None: | |
| """Discover project ID by triggering an intentional error""" | |
| # Check cache first | |
| if self.api_key in PROJECT_ID_CACHE: | |
| self.project_id = PROJECT_ID_CACHE[self.api_key] | |
| print(f"INFO: Using cached project ID: {self.project_id}") | |
| return | |
| await self._ensure_session() | |
| # Use a non-existent model to trigger error | |
| error_url = f"{self.base_url}/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={self.api_key}" | |
| try: | |
| # Send minimal request to trigger error | |
| payload = { | |
| "contents": [{"role": "user", "parts": [{"text": "test"}]}] | |
| } | |
| async with self.session.post(error_url, json=payload) as response: | |
| response_text = await response.text() | |
| try: | |
| # Try to parse as JSON first | |
| error_data = json.loads(response_text) | |
| # Handle array response format | |
| if isinstance(error_data, list) and len(error_data) > 0: | |
| error_data = error_data[0] | |
| if "error" in error_data: | |
| error_message = error_data["error"].get("message", "") | |
| # Extract project ID from error message | |
| # Pattern: "projects/39982734461/locations/..." | |
| match = re.search(r'projects/(\d+)/locations/', error_message) | |
| if match: | |
| self.project_id = match.group(1) | |
| PROJECT_ID_CACHE[self.api_key] = self.project_id | |
| print(f"INFO: Discovered project ID: {self.project_id}") | |
| return | |
| except json.JSONDecodeError: | |
| # If not JSON, try to find project ID in raw text | |
| match = re.search(r'projects/(\d+)/locations/', response_text) | |
| if match: | |
| self.project_id = match.group(1) | |
| PROJECT_ID_CACHE[self.api_key] = self.project_id | |
| print(f"INFO: Discovered project ID from raw response: {self.project_id}") | |
| return | |
| raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}") | |
| except Exception as e: | |
| print(f"ERROR: Failed to discover project ID: {e}") | |
| raise | |
| def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]: | |
| """Convert SDK Content objects to REST API format""" | |
| if isinstance(contents, list): | |
| return [self._convert_content_item(item) for item in contents] | |
| else: | |
| return [self._convert_content_item(contents)] | |
| def _convert_content_item(self, content: Any) -> Dict[str, Any]: | |
| """Convert a single content item to REST API format""" | |
| if isinstance(content, dict): | |
| return content | |
| # Handle SDK Content objects | |
| result = {} | |
| if hasattr(content, 'role'): | |
| result['role'] = content.role | |
| if hasattr(content, 'parts'): | |
| result['parts'] = [] | |
| for part in content.parts: | |
| if isinstance(part, dict): | |
| result['parts'].append(part) | |
| elif hasattr(part, 'text'): | |
| result['parts'].append({'text': part.text}) | |
| elif hasattr(part, 'inline_data'): | |
| result['parts'].append({ | |
| 'inline_data': { | |
| 'mime_type': part.inline_data.mime_type, | |
| 'data': part.inline_data.data | |
| } | |
| }) | |
| return result | |
| def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]: | |
| """Convert SDK SafetySetting objects to REST API format""" | |
| if not safety_settings: | |
| return [] | |
| result = [] | |
| for setting in safety_settings: | |
| if isinstance(setting, dict): | |
| result.append(setting) | |
| elif hasattr(setting, 'category') and hasattr(setting, 'threshold'): | |
| # Convert SDK SafetySetting to dict | |
| result.append({ | |
| 'category': setting.category, | |
| 'threshold': setting.threshold | |
| }) | |
| return result | |
| def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]: | |
| """Convert SDK Tool objects to REST API format""" | |
| if not tools: | |
| return [] | |
| result = [] | |
| for tool in tools: | |
| if isinstance(tool, dict): | |
| result.append(tool) | |
| else: | |
| # Convert SDK Tool object to dict | |
| result.append(self._convert_tool_item(tool)) | |
| return result | |
| def _convert_tool_item(self, tool: Any) -> Dict[str, Any]: | |
| """Convert a single tool item to REST API format""" | |
| if isinstance(tool, dict): | |
| return tool | |
| tool_dict = {} | |
| # Convert all non-private attributes | |
| if hasattr(tool, '__dict__'): | |
| for attr_name, attr_value in tool.__dict__.items(): | |
| if not attr_name.startswith('_'): | |
| # Convert attribute names from snake_case to camelCase for REST API | |
| rest_api_name = self._to_camel_case(attr_name) | |
| # Special handling for known types | |
| if attr_name == 'google_search' and attr_value is not None: | |
| tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST | |
| elif attr_name == 'function_declarations' and attr_value is not None: | |
| tool_dict[rest_api_name] = attr_value | |
| elif attr_value is not None: | |
| # Recursively convert any other SDK objects | |
| tool_dict[rest_api_name] = self._convert_sdk_object(attr_value) | |
| return tool_dict | |
| def _to_camel_case(self, snake_str: str) -> str: | |
| """Convert snake_case to camelCase""" | |
| components = snake_str.split('_') | |
| return components[0] + ''.join(x.title() for x in components[1:]) | |
| def _convert_sdk_object(self, obj: Any) -> Any: | |
| """Generic SDK object converter""" | |
| if isinstance(obj, (str, int, float, bool, type(None))): | |
| return obj | |
| elif isinstance(obj, dict): | |
| return {k: self._convert_sdk_object(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [self._convert_sdk_object(item) for item in obj] | |
| elif hasattr(obj, '__dict__'): | |
| # Convert SDK object to dict | |
| result = {} | |
| for key, value in obj.__dict__.items(): | |
| if not key.startswith('_'): | |
| result[self._to_camel_case(key)] = self._convert_sdk_object(value) | |
| return result | |
| else: | |
| return obj | |
| async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any: | |
| """Internal method for content generation""" | |
| if not self.project_id: | |
| raise ValueError("Project ID not discovered. Call discover_project_id() first.") | |
| await self._ensure_session() | |
| # Build URL | |
| endpoint = "streamGenerateContent" if stream else "generateContent" | |
| url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}" | |
| # Convert contents to REST API format | |
| payload = { | |
| "contents": self._convert_contents(contents) | |
| } | |
| # Extract specific config sections | |
| if "system_instruction" in config: | |
| # System instruction should be a content object | |
| if isinstance(config["system_instruction"], dict): | |
| payload["systemInstruction"] = config["system_instruction"] | |
| else: | |
| payload["systemInstruction"] = self._convert_content_item(config["system_instruction"]) | |
| if "safety_settings" in config: | |
| payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"]) | |
| if "tools" in config: | |
| payload["tools"] = self._convert_tools(config["tools"]) | |
| # All other config goes under generationConfig | |
| generation_config = {} | |
| for key, value in config.items(): | |
| if key not in ["system_instruction", "safety_settings", "tools"]: | |
| generation_config[key] = value | |
| if generation_config: | |
| payload["generationConfig"] = generation_config | |
| try: | |
| async with self.session.post(url, json=payload) as response: | |
| if response.status != 200: | |
| error_data = await response.json() | |
| error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") | |
| raise Exception(f"Vertex AI API error: {error_msg}") | |
| # Get the JSON response | |
| response_data = await response.json() | |
| # Convert dict to object with attributes for compatibility | |
| return self._dict_to_obj(response_data) | |
| except Exception as e: | |
| print(f"ERROR: Direct Vertex API call failed: {e}") | |
| raise | |
| def _dict_to_obj(self, data): | |
| """Convert a dict to an object with attributes""" | |
| if isinstance(data, dict): | |
| # Create a simple object that allows attribute access | |
| class AttrDict: | |
| def __init__(self, d): | |
| for key, value in d.items(): | |
| setattr(self, key, self._convert_value(value)) | |
| def _convert_value(self, value): | |
| if isinstance(value, dict): | |
| return AttrDict(value) | |
| elif isinstance(value, list): | |
| return [self._convert_value(item) for item in value] | |
| else: | |
| return value | |
| return AttrDict(data) | |
| elif isinstance(data, list): | |
| return [self._dict_to_obj(item) for item in data] | |
| else: | |
| return data | |
| async def _generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]) -> AsyncGenerator: | |
| """Internal method for streaming content generation""" | |
| if not self.project_id: | |
| raise ValueError("Project ID not discovered. Call discover_project_id() first.") | |
| await self._ensure_session() | |
| # Build URL for streaming | |
| url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}" | |
| # Convert contents to REST API format | |
| payload = { | |
| "contents": self._convert_contents(contents) | |
| } | |
| # Extract specific config sections | |
| if "system_instruction" in config: | |
| # System instruction should be a content object | |
| if isinstance(config["system_instruction"], dict): | |
| payload["systemInstruction"] = config["system_instruction"] | |
| else: | |
| payload["systemInstruction"] = self._convert_content_item(config["system_instruction"]) | |
| if "safety_settings" in config: | |
| payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"]) | |
| if "tools" in config: | |
| payload["tools"] = self._convert_tools(config["tools"]) | |
| # All other config goes under generationConfig | |
| generation_config = {} | |
| for key, value in config.items(): | |
| if key not in ["system_instruction", "safety_settings", "tools"]: | |
| generation_config[key] = value | |
| if generation_config: | |
| payload["generationConfig"] = generation_config | |
| try: | |
| async with self.session.post(url, json=payload) as response: | |
| if response.status != 200: | |
| error_data = await response.json() | |
| # Handle array response format | |
| if isinstance(error_data, list) and len(error_data) > 0: | |
| error_data = error_data[0] | |
| error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") if isinstance(error_data, dict) else str(error_data) | |
| raise Exception(f"Vertex AI API error: {error_msg}") | |
| # The Vertex AI streaming endpoint returns JSON array elements | |
| # We need to parse these as they arrive | |
| buffer = "" | |
| async for chunk in response.content.iter_any(): | |
| decoded_chunk = chunk.decode('utf-8') | |
| buffer += decoded_chunk | |
| # Try to extract complete JSON objects from the buffer | |
| while True: | |
| # Skip whitespace and array brackets | |
| buffer = buffer.lstrip() | |
| if buffer.startswith('['): | |
| buffer = buffer[1:].lstrip() | |
| continue | |
| if buffer.startswith(']'): | |
| # End of array | |
| return | |
| # Skip comma and whitespace between objects | |
| if buffer.startswith(','): | |
| buffer = buffer[1:].lstrip() | |
| continue | |
| # Look for a complete JSON object | |
| if buffer.startswith('{'): | |
| # Find the matching closing brace | |
| brace_count = 0 | |
| in_string = False | |
| escape_next = False | |
| for i, char in enumerate(buffer): | |
| if escape_next: | |
| escape_next = False | |
| continue | |
| if char == '\\' and in_string: | |
| escape_next = True | |
| continue | |
| if char == '"' and not in_string: | |
| in_string = True | |
| elif char == '"' and in_string: | |
| in_string = False | |
| elif char == '{' and not in_string: | |
| brace_count += 1 | |
| elif char == '}' and not in_string: | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| # Found complete object | |
| obj_str = buffer[:i+1] | |
| buffer = buffer[i+1:] | |
| try: | |
| chunk_data = json.loads(obj_str) | |
| converted_obj = self._dict_to_obj(chunk_data) | |
| yield converted_obj | |
| except json.JSONDecodeError as e: | |
| print(f"ERROR: DirectVertexClient - Failed to parse JSON: {e}") | |
| break | |
| else: | |
| # No complete object found, need more data | |
| break | |
| else: | |
| # No more objects to process in current buffer | |
| break | |
| except Exception as e: | |
| print(f"ERROR: Direct Vertex streaming API call failed: {e}") | |
| raise |