import binascii import json from typing import List, Optional, AsyncGenerator from time import time # Remove aioredis import # from aioredis import Redis # import tiktoken from api.chat import ChatConfig from api.telemetry import TelemetryAPI from auth.jwt_handler import JWTHandler from config.constants import ENCRYPTION_KEY from utils.encrypt import encrypt from utils.http import HTTPClient from config.constants import ( APP_LANGUAGE, APP_NAME, APP_VERSION, DISPLAY_NAME, HADWARE_INFO, INFERENCE_URL, SYSTEM_INFO, ) from utils.compression import decompress_chunks from protos import request_pb2, response_pb2 class ChatAPI: def __init__(self, api_key: str, http_client: HTTPClient = HTTPClient()): self.api_key = api_key self.jwt_token = None self.jwt_token_timestamp = 0 self.http_client = http_client async def renew_jwt_token(self): """Renew JWT token asynchronously if it's expired or missing""" current_time = time() # Check if token is still valid (within 2500 seconds) if self.jwt_token and current_time - self.jwt_token_timestamp < 2500: return jwt_handler = JWTHandler(api_key=self.api_key, http_client=self.http_client) jwt_token = await jwt_handler.get_jwt_token() tele = TelemetryAPI(api_key=self.api_key) await tele.do_telemetry() if jwt_token: self.jwt_token = jwt_token self.jwt_token_timestamp = current_time async def _create_chat_request( self, messages: List[dict], config: ChatConfig, system_prompt: str = "You are a helpful assistant.", ) -> request_pb2.ChatRequestMessage: try: await self.renew_jwt_token() except Exception as e: print(e) ... msg = request_pb2.ChatRequestMessage() # Set client info self._set_client_info(msg) # Set system prompt and model config msg.system_prompt = system_prompt msg.model_id = config.model_id.value msg.idk13.idk13nn = 1 msg.idk_id = 5 self._set_model_config(msg, config) # Set tool config # self._set_tool_config(msg) # Convert messages self._add_messages(msg, messages) return msg def _set_client_info(self, msg: request_pb2.ChatRequestMessage) -> None: """Set client information in the request message""" msg.client_info.api_key = self.api_key msg.client_info.user_jwt = self.jwt_token msg.client_info.locale = APP_LANGUAGE msg.client_info.extension_name = APP_NAME msg.client_info.ide_name = APP_NAME msg.client_info.extension_version = APP_VERSION msg.client_info.os = SYSTEM_INFO msg.client_info.ide_version = DISPLAY_NAME msg.client_info.hardware = HADWARE_INFO def _set_model_config( self, msg: request_pb2.ChatRequestMessage, config: ChatConfig ) -> None: """Set model configuration""" msg.model_config.parallel_stream = 1 msg.model_config.max_tokens = config.max_tokens msg.model_config.temperature = config.temperature msg.model_config.top_k = config.top_k msg.model_config.top_P = config.top_p def _set_special_tokens(self, msg: request_pb2.ChatRequestMessage) -> None: msg.model_config.special_tokens.extend( [ "<|user|>", "<|bot|>", "<|context_request|>", "<|endoftext|>", "<|end_of_turn|>", ] ) def _set_tool_config(self, msg: request_pb2.ChatRequestMessage) -> None: """Set tool configuration""" msg.tool_use.mode = "auto" msg.tool_config.tool_name = "do_not_call" msg.tool_config.description = "Do not call this tool." msg.tool_config.schema = '{"$schema":"https://json-schema.org/draft/2020-12/schema","properties":{},"additionalProperties":false,"type":"object"}' def _add_messages( self, msg: request_pb2.ChatRequestMessage, messages: List[dict] ) -> None: """Add chat messages to the request""" role_map = {"user": 1, "assistant": 2, "system": 3} for chat_msg in messages: role = role_map.get(chat_msg["role"], 1) content = chat_msg["content"] # Override system prompt if role == 3: if isinstance(content, str): msg.system_prompt = content elif isinstance(content, list): for item in content: if item.get("type", "") == "text" and "text" in item and isinstance(item["text"], str): msg.system_prompt = item["text"] break continue if isinstance(content, list): pb_msg = self._create_multipart_message(role, content) else: pb_msg = self._create_text_message(role, content) msg.chat_messages.append(pb_msg) def _create_multipart_message( self, role: int, content: List[dict] ) -> request_pb2.ChatMessage: """Create a message with multiple parts (text and images)""" text_parts = [] image_parts = [] for item in content: if item["type"] == "text": text_parts.append(item["text"]) elif item["type"] == "image_url": image_url = item["image_url"]["url"] if image_url.startswith("data:image/") and "base64," in image_url: prefix, image_data = image_url.split("base64,", 1) mime_type = prefix.split("data:")[1].split(";")[0] image_parts.append( request_pb2.ImagePart( image_data=image_data, image_mime_type=mime_type ) ) return self._create_message(role, " ".join(text_parts), image_parts) def _create_text_message(self, role: int, content: str) -> request_pb2.ChatMessage: """Create a simple text message""" return self._create_message(role, content) def _create_message( self, role: int, content: str, image_parts: List[request_pb2.ImagePart] = None ) -> request_pb2.ChatMessage: """Create a chat message with common attributes""" pb_msg = request_pb2.ChatMessage( role=role, content=content ) if role == 1: pb_msg.idk2 = 1 # pb_msg.cache_control.prompt_caching = 1 if image_parts: pb_msg.image_parts.extend(image_parts) return pb_msg async def _process_chat_response(self, type: int, data: bytes) -> tuple[str, int]: """Process a single chat response chunk and return (message, count)""" if type == 3: # end of message try: response = json.loads(data) return (encrypt(str(response), ENCRYPTION_KEY), 0) if response else ("", 0) except Exception as e: raise e try: search_response = response_pb2.ChatResponse() search_response.ParseFromString(data) return (search_response.message, search_response.count) if search_response.message else ("", 0) except: return ("", 0) async def _handle_stream_response(self, chunk_iterator) -> AsyncGenerator[tuple[str, int], None]: """Handle streaming response chunks""" async for chunk in chunk_iterator: for type, data in decompress_chunks(chunk): result = await self._process_chat_response(type, data) yield result async def _handle_response(self, chunk) -> AsyncGenerator[tuple[str, int], None]: """Handle non-streaming response chunks""" for type, data in decompress_chunks(chunk): result = await self._process_chat_response(type, data) yield result async def send_message( self, messages: List[dict], config: Optional[ChatConfig] = None, system_prompt: str = "You are a helpful assistant.", stream: bool = False, ) -> AsyncGenerator[tuple[str, int], None]: """Send chat messages and yield response chunks""" if config is None: config = ChatConfig() request = await self._create_chat_request(messages, config, system_prompt) headers = { "User-Agent": "connect-go/1.16.2 (go1.23.2 X:nocoverageredesign)", "Connect-Accept-Encoding": "gzip", "Connect-Content-Encoding": "gzip", "Connect-Protocol-Version": "1", "Content-Type": "application/connect+proto", } url = f"{INFERENCE_URL}/exa.api_server_pb.ApiServerService/GetChatMessage" request_data = request.SerializeToString() if stream: stream_iterator = self.http_client.stream_post( url=url, data=request_data, headers=headers, compress=True, ) async for result in self._handle_stream_response(stream_iterator): yield result else: response = await self.http_client.post( url=url, data=request_data, headers=headers, compress=True, ) if response.status_code != 200: raise Exception(f"Chat request failed: {response.status_code}") if response.headers.get("connect-content-encoding") == "gzip": async for result in self._handle_response(response.content): yield result else: search_response = response_pb2.ChatResponse() search_response.ParseFromString(response.content) if search_response.message: yield (search_response.message, search_response.count)