from abc import ABC, abstractmethod from openai import OpenAI, AsyncOpenAI from dotenv import load_dotenv import os import logging from tenacity import retry, stop_after_attempt, wait_fixed import httpx from sentence_transformers import SentenceTransformer # Optional torch import for CUDA detection try: import torch _TORCH_AVAILABLE = True except Exception: torch = None _TORCH_AVAILABLE = False from .utils.logger_utils import setup_logger LOGGER_NAME = "MODEL_SERVICE_LOGGER" # GENERATION ENV VARIABLES (defaults) OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", 'http://0.0.0.0:8000/v1') OPENAI_TOKEN = os.getenv("OPENAI_TOKEN", 'no-need') MODEL_NAME = os.getenv('MODEL_NAME', "meta-llama/Llama-3.2-3B-Instruct") # EMBED ENV VARIABLES (defaults) OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", 'http://0.0.0.0:8001/v1') OPENAI_EMBED_TOKEN = os.getenv("OPENAI_EMBED_TOKEN", 'no-need') EMBED_MODEL_NAME = os.getenv('EMBED_MODEL_NAME', "Alibaba-NLP/gte-Qwen2-1.5B-instruct") # Additional ENV defaults requested MAX_TOKENS = int(os.getenv("MAX_TOKENS", 2048)) TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2)) TOP_P = float(os.getenv("TOP_P", 0.95)) FREQUENCY_PENALTY = float(os.getenv("FREQUENCY_PENALTY", 0)) PRESENCE_PENALTY = float(os.getenv("PRESENCE_PENALTY", 0)) EMBEDDING_MODEL_URL = os.getenv("EMBEDDING_MODEL_URL", "") EMBEDDING_MODEL_API_KEY = os.getenv("EMBEDDING_MODEL_API_KEY", "no_need") EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv("EMBEDDING_NUMBER_DIMENSIONS", 1024)) STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5)) WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2)) REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", 240)) # Note: module-level clients remain for backward compatibility but instances will create their own if timeout is overridden. long_timeout_client = httpx.Client(timeout=REQUEST_TIMEOUT) long_timeout_async_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT) class ModelServiceInterface(ABC): """ Abstract base class defining the interface for model services. All model services should implement these methods. """ # accept model_kwargs so variables can be overridden at runtime def __init__(self, model_name: str = None, model_kwargs: dict = None): setup_logger(LOGGER_NAME) self.logger = logging.getLogger(LOGGER_NAME) model_kwargs = model_kwargs or {} # allow overriding via model_kwargs; fall back to module-level defaults self.openai_base_url = model_kwargs.get("OPENAI_BASE_URL", OPENAI_BASE_URL) self.openai_token = model_kwargs.get("OPENAI_TOKEN", OPENAI_TOKEN) # model_name param takes precedence, then model_kwargs then default env self.model_name = model_name or model_kwargs.get("MODEL_NAME", MODEL_NAME) # embed defaults (may be overridden by subclasses or model_kwargs) self.openai_embed_base_url = model_kwargs.get("OPENAI_EMBED_BASE_URL", OPENAI_EMBED_BASE_URL) self.openai_embed_token = model_kwargs.get("OPENAI_EMBED_TOKEN", OPENAI_EMBED_TOKEN) self.embed_model_name = model_kwargs.get("EMBED_MODEL_NAME", EMBED_MODEL_NAME) # other configurable parameters self.max_tokens = int(model_kwargs.get("MAX_TOKENS", MAX_TOKENS)) self.temperature = float(model_kwargs.get("TEMPERATURE", TEMPERATURE)) self.top_p = float(model_kwargs.get("TOP_P", TOP_P)) self.frequency_penalty = float(model_kwargs.get("FREQUENCY_PENALTY", FREQUENCY_PENALTY)) self.presence_penalty = float(model_kwargs.get("PRESENCE_PENALTY", PRESENCE_PENALTY)) self.embedding_model_url = model_kwargs.get("EMBEDDING_MODEL_URL", EMBEDDING_MODEL_URL) self.embedding_model_api_key = model_kwargs.get("EMBEDDING_MODEL_API_KEY", EMBEDDING_MODEL_API_KEY) self.embedding_number_dimensions = int(model_kwargs.get("EMBEDDING_NUMBER_DIMENSIONS", EMBEDDING_NUMBER_DIMENSIONS)) self.stop_after_attempt = int(model_kwargs.get("STOP_AFTER_ATTEMPT", STOP_AFTER_ATTEMPT)) self.wait_between_retries = int(model_kwargs.get("WAIT_BETWEEN_RETRIES", WAIT_BETWEEN_RETRIES)) request_timeout = int(model_kwargs.get("REQUEST_TIMEOUT", REQUEST_TIMEOUT)) # create per-instance httpx clients in case REQUEST_TIMEOUT was overridden self.long_timeout_client = httpx.Client(timeout=request_timeout) self.long_timeout_async_client = httpx.AsyncClient(timeout=request_timeout) # Initialize query client (shared by all implementations) self.client = OpenAI( base_url=self.openai_base_url, api_key=self.openai_token, http_client=self.long_timeout_client, ) self.async_client = AsyncOpenAI( base_url=self.openai_base_url, api_key=self.openai_token, http_client=self.long_timeout_async_client, ) @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) def query(self, prompt: str, model_name: str) -> str: """Query the model with a prompt.""" if model_name is None: model_name = self.model_name completion = self.client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": prompt} ] ) return completion.choices[0].message.content @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) def query_with_instructions(self, prompt: str, instructions: str, model_name: str) -> str: """Query the model with additional system instructions.""" if model_name is None: model_name = self.model_name completion = self.client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": instructions}, {"role": "user", "content": prompt} ] ) return completion.choices[0].message.content @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) async def query_async(self, prompt: str, model_name: str ) -> str: """Async version of query.""" if model_name is None: model_name = self.model_name completion = await self.async_client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": prompt} ] ) return completion.choices[0].message.content @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) async def query_with_instructions_async(self, prompt: str, instructions: str, model_name: str) -> str: """Async version of query with instructions.""" if model_name is None: model_name = self.model_name completion = await self.async_client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": instructions}, {"role": "user", "content": prompt} ] ) return completion.choices[0].message.content @abstractmethod def embed(self, text_to_embed: str) -> list: """Embed text using the configured embedding model.""" pass @abstractmethod async def embed_async(self, text_to_embed: str) -> list: """Async version of embed.""" pass @abstractmethod def embed_chunk_code(self, code_to_embed: str) -> list: """Embed code chunk for storage/indexing.""" pass @abstractmethod def embed_query(self, query_to_embed: str) -> list: """Embed query for retrieval.""" pass @abstractmethod def embed_batch(self, texts_to_embed: list[str]) -> list[list]: """Embed multiple texts in a batch for better performance.""" pass @abstractmethod def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: """Embed multiple code chunks in a batch for storage/indexing.""" pass class OpenAIModelService(ModelServiceInterface): """ Model service that uses OpenAI client for both queries and embeddings. """ def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None): # forward model_kwargs to base so it can set instance-wide config super().__init__(model_name=model_name, model_kwargs=model_kwargs) # allow override of embed model name via param or model_kwargs model_kwargs = model_kwargs or {} self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name) # embed client should use the instance-level embed base/token self.embed_client = OpenAI( base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url), api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token), http_client=self.long_timeout_client, ) self.async_embed_client = AsyncOpenAI( base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url), api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token), http_client=self.long_timeout_async_client, ) @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) def embed(self, text_to_embed: str) -> list: """Embed text using OpenAI embeddings API.""" response = self.embed_client.embeddings.create( input=text_to_embed, model=self.embed_model_name, ) return response.data[0].embedding @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) async def embed_async(self, text_to_embed: str) -> list: """Async version of embed using OpenAI embeddings API.""" response = await self.async_embed_client.embeddings.create( input=text_to_embed, model=self.embed_model_name, ) return response.data[0].embedding def embed_chunk_code(self, code_to_embed: str) -> list: """Embed code chunk using OpenAI embeddings API (same as embed).""" return self.embed(code_to_embed) def embed_query(self, query_to_embed: str) -> list: """Embed query using OpenAI embeddings API (same as embed).""" return self.embed(query_to_embed) @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) def embed_batch(self, texts_to_embed: list[str]) -> list[list]: """Embed multiple texts in a batch using OpenAI embeddings API.""" if not texts_to_embed: return [] response = self.embed_client.embeddings.create( input=texts_to_embed, model=self.embed_model_name, ) return [item.embedding for item in response.data] def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: """Embed multiple code chunks in a batch using OpenAI embeddings API.""" return self.embed_batch(codes_to_embed) class SentenceTransformersModelService(ModelServiceInterface): """ Model service that uses OpenAI client for queries and SentenceTransformers for embeddings. Optimized for high-throughput batch embedding with GPU support. """ def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None, skip_embedder: bool = False): super().__init__(model_name=model_name, model_kwargs=model_kwargs) model_kwargs = model_kwargs or {} # embed_model_name may be overridden by model_kwargs self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name) self.skip_embedder = skip_embedder self.embedding_model = None if skip_embedder: self.logger.info('Skipping embedder initialization (keyword-only mode)') self.device = "cpu" self.encode_batch_size = 32 return # Debug GPU detection self.logger.info(f'PyTorch available: {_TORCH_AVAILABLE}') if _TORCH_AVAILABLE: self.logger.info(f'CUDA available: {torch.cuda.is_available()}') self.logger.info(f'CUDA device count: {torch.cuda.device_count()}') if torch.cuda.is_available(): self.logger.info(f'CUDA device name: {torch.cuda.get_device_name(0)}') # Select device: prefer CUDA if available self.device = "cuda" if (_TORCH_AVAILABLE and torch.cuda.is_available()) else "cpu" self.logger.info(f'Initializing SentenceTransformer on device: {self.device}') # Set batch size based on device and available memory # Larger batch sizes significantly improve GPU throughput self.encode_batch_size = int(model_kwargs.get("ENCODE_BATCH_SIZE", 64 if self.device == "cuda" else 32)) # Show CUDA memory info if available if self.device == "cuda" and _TORCH_AVAILABLE: try: gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) self.logger.info(f'GPU memory available: {gpu_memory:.2f} GB') # Adjust batch size based on available GPU memory if gpu_memory > 16: self.encode_batch_size = max(self.encode_batch_size, 128) elif gpu_memory > 8: self.encode_batch_size = max(self.encode_batch_size, 64) except Exception as e: self.logger.warning(f'Could not get GPU memory info: {e}') self.logger.info(f'Using encode batch size: {self.encode_batch_size}') # Initialize embedding model on the chosen device with performance optimizations self.embedding_model = SentenceTransformer( self.embed_model_name, trust_remote_code=True, device=self.device ) # Enable half precision for faster inference on CUDA if self.device == "cuda" and _TORCH_AVAILABLE: try: # Check if model supports half precision self.embedding_model.half() self.logger.info('Enabled half precision (FP16) for faster GPU inference') except Exception as e: self.logger.warning(f'Could not enable half precision: {e}') def _check_embedder(self): """Check if embedder is available, raise error if not.""" if self.skip_embedder or self.embedding_model is None: raise RuntimeError( "Embedding model not initialized. This model service was created with skip_embedder=True " "(keyword-only mode). To use embeddings, set index_type to 'hybrid' or 'embedding-only'." ) def embed(self, text_to_embed: str) -> list: """Embed text using SentenceTransformers.""" self._check_embedder() embeddings = self.embedding_model.encode( [text_to_embed], convert_to_numpy=True, show_progress_bar=False ) return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) async def embed_async(self, text_to_embed: str) -> list: """ Async version of embed using SentenceTransformers. Note: SentenceTransformers doesn't have native async support, so this runs synchronously but maintains the async interface. """ return self.embed(text_to_embed) def embed_chunk_code(self, code_to_embed: str) -> list: """Embed code chunk using SentenceTransformers (no special prompt).""" self._check_embedder() self.logger.debug(f'Embedding code using {self.embed_model_name}') embeddings = self.embedding_model.encode( [code_to_embed], convert_to_numpy=True, show_progress_bar=False ) return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) def embed_query(self, query_to_embed: str) -> list: """Embed query using SentenceTransformers with retrieval prompt.""" self._check_embedder() self.logger.debug(f'Embedding query using {self.embed_model_name}') embeddings = self.embedding_model.encode( [query_to_embed], prompt='Given this prompt, retrieve relevant content\n Query:', convert_to_numpy=True, show_progress_bar=False ) return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) def embed_batch(self, texts_to_embed: list[str]) -> list[list]: """Embed multiple texts in a batch using SentenceTransformers with optimized settings.""" if not texts_to_embed: return [] self._check_embedder() self.logger.info(f'Batch embedding {len(texts_to_embed)} texts using {self.embed_model_name}') embeddings = self.embedding_model.encode( texts_to_embed, batch_size=self.encode_batch_size, convert_to_numpy=True, show_progress_bar=len(texts_to_embed) > 100, # Only show progress for large batches normalize_embeddings=True # Normalize for better similarity computation ) return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings] def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: """Embed multiple code chunks in a batch using SentenceTransformers with optimized settings.""" if not codes_to_embed: return [] self._check_embedder() self.logger.info(f'Batch embedding {len(codes_to_embed)} code chunks using {self.embed_model_name}') embeddings = self.embedding_model.encode( codes_to_embed, batch_size=self.encode_batch_size, convert_to_numpy=True, show_progress_bar=len(codes_to_embed) > 100, # Only show progress for large batches normalize_embeddings=True # Normalize for better similarity computation ) return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings] def create_model_service(skip_embedder: bool = False, **kwargs) -> ModelServiceInterface: """ Factory function to create the appropriate ModelService based on embedder_type. Args: skip_embedder (bool): If True, skip loading the embedding model (for keyword-only search). **kwargs: Additional arguments including 'embedder_type' ('openai' or 'sentence-transformers') and optional 'model_kwargs' dict which can override any env var defaults. Returns: ModelServiceInterface: An instance of the appropriate ModelService """ model_kwargs = kwargs.pop('model_kwargs', None) embedder_type = kwargs.pop('embedder_type', 'openai') if embedder_type == 'openai': return OpenAIModelService(model_kwargs=model_kwargs, **kwargs) elif embedder_type == 'sentence-transformers': return SentenceTransformersModelService(model_kwargs=model_kwargs, skip_embedder=skip_embedder, **kwargs) else: logging.getLogger(LOGGER_NAME).warning( f'Unknown embedder type: {embedder_type}, defaulting to OpenAI' ) return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)