|
|
from abc import ABC, abstractmethod |
|
|
from openai import OpenAI, AsyncOpenAI |
|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
import logging |
|
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
import httpx |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
_TORCH_AVAILABLE = True |
|
|
except Exception: |
|
|
torch = None |
|
|
_TORCH_AVAILABLE = False |
|
|
|
|
|
from .utils.logger_utils import setup_logger |
|
|
|
|
|
LOGGER_NAME = "MODEL_SERVICE_LOGGER" |
|
|
|
|
|
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", 'http://0.0.0.0:8000/v1') |
|
|
OPENAI_TOKEN = os.getenv("OPENAI_TOKEN", 'no-need') |
|
|
MODEL_NAME = os.getenv('MODEL_NAME', "meta-llama/Llama-3.2-3B-Instruct") |
|
|
|
|
|
OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", 'http://0.0.0.0:8001/v1') |
|
|
OPENAI_EMBED_TOKEN = os.getenv("OPENAI_EMBED_TOKEN", 'no-need') |
|
|
EMBED_MODEL_NAME = os.getenv('EMBED_MODEL_NAME', "Alibaba-NLP/gte-Qwen2-1.5B-instruct") |
|
|
|
|
|
|
|
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 2048)) |
|
|
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2)) |
|
|
TOP_P = float(os.getenv("TOP_P", 0.95)) |
|
|
FREQUENCY_PENALTY = float(os.getenv("FREQUENCY_PENALTY", 0)) |
|
|
PRESENCE_PENALTY = float(os.getenv("PRESENCE_PENALTY", 0)) |
|
|
EMBEDDING_MODEL_URL = os.getenv("EMBEDDING_MODEL_URL", "") |
|
|
EMBEDDING_MODEL_API_KEY = os.getenv("EMBEDDING_MODEL_API_KEY", "no_need") |
|
|
EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv("EMBEDDING_NUMBER_DIMENSIONS", 1024)) |
|
|
|
|
|
STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5)) |
|
|
WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2)) |
|
|
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", 240)) |
|
|
|
|
|
|
|
|
long_timeout_client = httpx.Client(timeout=REQUEST_TIMEOUT) |
|
|
long_timeout_async_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT) |
|
|
|
|
|
|
|
|
class ModelServiceInterface(ABC): |
|
|
""" |
|
|
Abstract base class defining the interface for model services. |
|
|
All model services should implement these methods. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = None, model_kwargs: dict = None): |
|
|
setup_logger(LOGGER_NAME) |
|
|
self.logger = logging.getLogger(LOGGER_NAME) |
|
|
|
|
|
model_kwargs = model_kwargs or {} |
|
|
|
|
|
|
|
|
self.openai_base_url = model_kwargs.get("OPENAI_BASE_URL", OPENAI_BASE_URL) |
|
|
self.openai_token = model_kwargs.get("OPENAI_TOKEN", OPENAI_TOKEN) |
|
|
|
|
|
self.model_name = model_name or model_kwargs.get("MODEL_NAME", MODEL_NAME) |
|
|
|
|
|
|
|
|
self.openai_embed_base_url = model_kwargs.get("OPENAI_EMBED_BASE_URL", OPENAI_EMBED_BASE_URL) |
|
|
self.openai_embed_token = model_kwargs.get("OPENAI_EMBED_TOKEN", OPENAI_EMBED_TOKEN) |
|
|
self.embed_model_name = model_kwargs.get("EMBED_MODEL_NAME", EMBED_MODEL_NAME) |
|
|
|
|
|
|
|
|
self.max_tokens = int(model_kwargs.get("MAX_TOKENS", MAX_TOKENS)) |
|
|
self.temperature = float(model_kwargs.get("TEMPERATURE", TEMPERATURE)) |
|
|
self.top_p = float(model_kwargs.get("TOP_P", TOP_P)) |
|
|
self.frequency_penalty = float(model_kwargs.get("FREQUENCY_PENALTY", FREQUENCY_PENALTY)) |
|
|
self.presence_penalty = float(model_kwargs.get("PRESENCE_PENALTY", PRESENCE_PENALTY)) |
|
|
self.embedding_model_url = model_kwargs.get("EMBEDDING_MODEL_URL", EMBEDDING_MODEL_URL) |
|
|
self.embedding_model_api_key = model_kwargs.get("EMBEDDING_MODEL_API_KEY", EMBEDDING_MODEL_API_KEY) |
|
|
self.embedding_number_dimensions = int(model_kwargs.get("EMBEDDING_NUMBER_DIMENSIONS", EMBEDDING_NUMBER_DIMENSIONS)) |
|
|
|
|
|
self.stop_after_attempt = int(model_kwargs.get("STOP_AFTER_ATTEMPT", STOP_AFTER_ATTEMPT)) |
|
|
self.wait_between_retries = int(model_kwargs.get("WAIT_BETWEEN_RETRIES", WAIT_BETWEEN_RETRIES)) |
|
|
request_timeout = int(model_kwargs.get("REQUEST_TIMEOUT", REQUEST_TIMEOUT)) |
|
|
|
|
|
|
|
|
self.long_timeout_client = httpx.Client(timeout=request_timeout) |
|
|
self.long_timeout_async_client = httpx.AsyncClient(timeout=request_timeout) |
|
|
|
|
|
|
|
|
self.client = OpenAI( |
|
|
base_url=self.openai_base_url, |
|
|
api_key=self.openai_token, |
|
|
http_client=self.long_timeout_client, |
|
|
) |
|
|
self.async_client = AsyncOpenAI( |
|
|
base_url=self.openai_base_url, |
|
|
api_key=self.openai_token, |
|
|
http_client=self.long_timeout_async_client, |
|
|
) |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
def query(self, prompt: str, model_name: str) -> str: |
|
|
"""Query the model with a prompt.""" |
|
|
if model_name is None: |
|
|
model_name = self.model_name |
|
|
completion = self.client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
def query_with_instructions(self, prompt: str, instructions: str, model_name: str) -> str: |
|
|
"""Query the model with additional system instructions.""" |
|
|
if model_name is None: |
|
|
model_name = self.model_name |
|
|
completion = self.client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[ |
|
|
{"role": "system", "content": instructions}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
async def query_async(self, prompt: str, model_name: str ) -> str: |
|
|
"""Async version of query.""" |
|
|
if model_name is None: |
|
|
model_name = self.model_name |
|
|
completion = await self.async_client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
async def query_with_instructions_async(self, prompt: str, instructions: str, model_name: str) -> str: |
|
|
"""Async version of query with instructions.""" |
|
|
if model_name is None: |
|
|
model_name = self.model_name |
|
|
completion = await self.async_client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[ |
|
|
{"role": "system", "content": instructions}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
|
|
|
@abstractmethod |
|
|
def embed(self, text_to_embed: str) -> list: |
|
|
"""Embed text using the configured embedding model.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def embed_async(self, text_to_embed: str) -> list: |
|
|
"""Async version of embed.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def embed_chunk_code(self, code_to_embed: str) -> list: |
|
|
"""Embed code chunk for storage/indexing.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def embed_query(self, query_to_embed: str) -> list: |
|
|
"""Embed query for retrieval.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def embed_batch(self, texts_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple texts in a batch for better performance.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple code chunks in a batch for storage/indexing.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class OpenAIModelService(ModelServiceInterface): |
|
|
""" |
|
|
Model service that uses OpenAI client for both queries and embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None): |
|
|
|
|
|
super().__init__(model_name=model_name, model_kwargs=model_kwargs) |
|
|
|
|
|
|
|
|
model_kwargs = model_kwargs or {} |
|
|
self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name) |
|
|
|
|
|
|
|
|
self.embed_client = OpenAI( |
|
|
base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url), |
|
|
api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token), |
|
|
http_client=self.long_timeout_client, |
|
|
) |
|
|
self.async_embed_client = AsyncOpenAI( |
|
|
base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url), |
|
|
api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token), |
|
|
http_client=self.long_timeout_async_client, |
|
|
) |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
def embed(self, text_to_embed: str) -> list: |
|
|
"""Embed text using OpenAI embeddings API.""" |
|
|
response = self.embed_client.embeddings.create( |
|
|
input=text_to_embed, |
|
|
model=self.embed_model_name, |
|
|
) |
|
|
return response.data[0].embedding |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
async def embed_async(self, text_to_embed: str) -> list: |
|
|
"""Async version of embed using OpenAI embeddings API.""" |
|
|
response = await self.async_embed_client.embeddings.create( |
|
|
input=text_to_embed, |
|
|
model=self.embed_model_name, |
|
|
) |
|
|
return response.data[0].embedding |
|
|
|
|
|
def embed_chunk_code(self, code_to_embed: str) -> list: |
|
|
"""Embed code chunk using OpenAI embeddings API (same as embed).""" |
|
|
return self.embed(code_to_embed) |
|
|
|
|
|
def embed_query(self, query_to_embed: str) -> list: |
|
|
"""Embed query using OpenAI embeddings API (same as embed).""" |
|
|
return self.embed(query_to_embed) |
|
|
|
|
|
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES)) |
|
|
def embed_batch(self, texts_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple texts in a batch using OpenAI embeddings API.""" |
|
|
if not texts_to_embed: |
|
|
return [] |
|
|
response = self.embed_client.embeddings.create( |
|
|
input=texts_to_embed, |
|
|
model=self.embed_model_name, |
|
|
) |
|
|
return [item.embedding for item in response.data] |
|
|
|
|
|
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple code chunks in a batch using OpenAI embeddings API.""" |
|
|
return self.embed_batch(codes_to_embed) |
|
|
|
|
|
|
|
|
class SentenceTransformersModelService(ModelServiceInterface): |
|
|
""" |
|
|
Model service that uses OpenAI client for queries and SentenceTransformers for embeddings. |
|
|
Optimized for high-throughput batch embedding with GPU support. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None, skip_embedder: bool = False): |
|
|
super().__init__(model_name=model_name, model_kwargs=model_kwargs) |
|
|
model_kwargs = model_kwargs or {} |
|
|
|
|
|
self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name) |
|
|
self.skip_embedder = skip_embedder |
|
|
self.embedding_model = None |
|
|
|
|
|
if skip_embedder: |
|
|
self.logger.info('Skipping embedder initialization (keyword-only mode)') |
|
|
self.device = "cpu" |
|
|
self.encode_batch_size = 32 |
|
|
return |
|
|
|
|
|
|
|
|
self.logger.info(f'PyTorch available: {_TORCH_AVAILABLE}') |
|
|
if _TORCH_AVAILABLE: |
|
|
self.logger.info(f'CUDA available: {torch.cuda.is_available()}') |
|
|
self.logger.info(f'CUDA device count: {torch.cuda.device_count()}') |
|
|
if torch.cuda.is_available(): |
|
|
self.logger.info(f'CUDA device name: {torch.cuda.get_device_name(0)}') |
|
|
|
|
|
|
|
|
self.device = "cuda" if (_TORCH_AVAILABLE and torch.cuda.is_available()) else "cpu" |
|
|
self.logger.info(f'Initializing SentenceTransformer on device: {self.device}') |
|
|
|
|
|
|
|
|
|
|
|
self.encode_batch_size = int(model_kwargs.get("ENCODE_BATCH_SIZE", 64 if self.device == "cuda" else 32)) |
|
|
|
|
|
|
|
|
if self.device == "cuda" and _TORCH_AVAILABLE: |
|
|
try: |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
|
self.logger.info(f'GPU memory available: {gpu_memory:.2f} GB') |
|
|
|
|
|
if gpu_memory > 16: |
|
|
self.encode_batch_size = max(self.encode_batch_size, 128) |
|
|
elif gpu_memory > 8: |
|
|
self.encode_batch_size = max(self.encode_batch_size, 64) |
|
|
except Exception as e: |
|
|
self.logger.warning(f'Could not get GPU memory info: {e}') |
|
|
|
|
|
self.logger.info(f'Using encode batch size: {self.encode_batch_size}') |
|
|
|
|
|
|
|
|
self.embedding_model = SentenceTransformer( |
|
|
self.embed_model_name, |
|
|
trust_remote_code=True, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
if self.device == "cuda" and _TORCH_AVAILABLE: |
|
|
try: |
|
|
|
|
|
self.embedding_model.half() |
|
|
self.logger.info('Enabled half precision (FP16) for faster GPU inference') |
|
|
except Exception as e: |
|
|
self.logger.warning(f'Could not enable half precision: {e}') |
|
|
|
|
|
def _check_embedder(self): |
|
|
"""Check if embedder is available, raise error if not.""" |
|
|
if self.skip_embedder or self.embedding_model is None: |
|
|
raise RuntimeError( |
|
|
"Embedding model not initialized. This model service was created with skip_embedder=True " |
|
|
"(keyword-only mode). To use embeddings, set index_type to 'hybrid' or 'embedding-only'." |
|
|
) |
|
|
|
|
|
def embed(self, text_to_embed: str) -> list: |
|
|
"""Embed text using SentenceTransformers.""" |
|
|
self._check_embedder() |
|
|
embeddings = self.embedding_model.encode( |
|
|
[text_to_embed], |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) |
|
|
|
|
|
async def embed_async(self, text_to_embed: str) -> list: |
|
|
""" |
|
|
Async version of embed using SentenceTransformers. |
|
|
Note: SentenceTransformers doesn't have native async support, |
|
|
so this runs synchronously but maintains the async interface. |
|
|
""" |
|
|
return self.embed(text_to_embed) |
|
|
|
|
|
def embed_chunk_code(self, code_to_embed: str) -> list: |
|
|
"""Embed code chunk using SentenceTransformers (no special prompt).""" |
|
|
self._check_embedder() |
|
|
self.logger.debug(f'Embedding code using {self.embed_model_name}') |
|
|
embeddings = self.embedding_model.encode( |
|
|
[code_to_embed], |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) |
|
|
|
|
|
def embed_query(self, query_to_embed: str) -> list: |
|
|
"""Embed query using SentenceTransformers with retrieval prompt.""" |
|
|
self._check_embedder() |
|
|
self.logger.debug(f'Embedding query using {self.embed_model_name}') |
|
|
embeddings = self.embedding_model.encode( |
|
|
[query_to_embed], |
|
|
prompt='Given this prompt, retrieve relevant content\n Query:', |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0]) |
|
|
|
|
|
def embed_batch(self, texts_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple texts in a batch using SentenceTransformers with optimized settings.""" |
|
|
if not texts_to_embed: |
|
|
return [] |
|
|
self._check_embedder() |
|
|
self.logger.info(f'Batch embedding {len(texts_to_embed)} texts using {self.embed_model_name}') |
|
|
embeddings = self.embedding_model.encode( |
|
|
texts_to_embed, |
|
|
batch_size=self.encode_batch_size, |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=len(texts_to_embed) > 100, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings] |
|
|
|
|
|
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]: |
|
|
"""Embed multiple code chunks in a batch using SentenceTransformers with optimized settings.""" |
|
|
if not codes_to_embed: |
|
|
return [] |
|
|
self._check_embedder() |
|
|
self.logger.info(f'Batch embedding {len(codes_to_embed)} code chunks using {self.embed_model_name}') |
|
|
embeddings = self.embedding_model.encode( |
|
|
codes_to_embed, |
|
|
batch_size=self.encode_batch_size, |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=len(codes_to_embed) > 100, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings] |
|
|
|
|
|
|
|
|
def create_model_service(skip_embedder: bool = False, **kwargs) -> ModelServiceInterface: |
|
|
""" |
|
|
Factory function to create the appropriate ModelService based on embedder_type. |
|
|
|
|
|
Args: |
|
|
skip_embedder (bool): If True, skip loading the embedding model (for keyword-only search). |
|
|
**kwargs: Additional arguments including 'embedder_type' ('openai' or 'sentence-transformers') |
|
|
and optional 'model_kwargs' dict which can override any env var defaults. |
|
|
Returns: |
|
|
ModelServiceInterface: An instance of the appropriate ModelService |
|
|
""" |
|
|
model_kwargs = kwargs.pop('model_kwargs', None) |
|
|
embedder_type = kwargs.pop('embedder_type', 'openai') |
|
|
|
|
|
if embedder_type == 'openai': |
|
|
return OpenAIModelService(model_kwargs=model_kwargs, **kwargs) |
|
|
elif embedder_type == 'sentence-transformers': |
|
|
return SentenceTransformersModelService(model_kwargs=model_kwargs, skip_embedder=skip_embedder, **kwargs) |
|
|
else: |
|
|
logging.getLogger(LOGGER_NAME).warning( |
|
|
f'Unknown embedder type: {embedder_type}, defaulting to OpenAI' |
|
|
) |
|
|
return OpenAIModelService(model_kwargs=model_kwargs, **kwargs) |
|
|
|