File size: 19,567 Bytes
a100cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
from abc import ABC, abstractmethod
from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv
import os
import logging
from tenacity import retry, stop_after_attempt, wait_fixed
import httpx 
from sentence_transformers import SentenceTransformer

# Optional torch import for CUDA detection
try:
    import torch
    _TORCH_AVAILABLE = True
except Exception:
    torch = None
    _TORCH_AVAILABLE = False

from .utils.logger_utils import setup_logger

LOGGER_NAME = "MODEL_SERVICE_LOGGER"
# GENERATION ENV VARIABLES (defaults)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", 'http://0.0.0.0:8000/v1')
OPENAI_TOKEN = os.getenv("OPENAI_TOKEN", 'no-need')
MODEL_NAME = os.getenv('MODEL_NAME', "meta-llama/Llama-3.2-3B-Instruct")
# EMBED ENV VARIABLES (defaults)
OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", 'http://0.0.0.0:8001/v1')
OPENAI_EMBED_TOKEN = os.getenv("OPENAI_EMBED_TOKEN", 'no-need')
EMBED_MODEL_NAME = os.getenv('EMBED_MODEL_NAME', "Alibaba-NLP/gte-Qwen2-1.5B-instruct")

# Additional ENV defaults requested
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 2048))
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2))
TOP_P = float(os.getenv("TOP_P", 0.95))
FREQUENCY_PENALTY = float(os.getenv("FREQUENCY_PENALTY", 0))
PRESENCE_PENALTY = float(os.getenv("PRESENCE_PENALTY", 0))
EMBEDDING_MODEL_URL = os.getenv("EMBEDDING_MODEL_URL", "")
EMBEDDING_MODEL_API_KEY = os.getenv("EMBEDDING_MODEL_API_KEY", "no_need")
EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv("EMBEDDING_NUMBER_DIMENSIONS", 1024))

STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5))
WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2))
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", 240))

# Note: module-level clients remain for backward compatibility but instances will create their own if timeout is overridden.
long_timeout_client = httpx.Client(timeout=REQUEST_TIMEOUT)
long_timeout_async_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT)


class ModelServiceInterface(ABC):
    """
    Abstract base class defining the interface for model services.
    All model services should implement these methods.
    """

    # accept model_kwargs so variables can be overridden at runtime
    def __init__(self, model_name: str = None, model_kwargs: dict = None):
        setup_logger(LOGGER_NAME)
        self.logger = logging.getLogger(LOGGER_NAME)

        model_kwargs = model_kwargs or {}

        # allow overriding via model_kwargs; fall back to module-level defaults
        self.openai_base_url = model_kwargs.get("OPENAI_BASE_URL", OPENAI_BASE_URL)
        self.openai_token = model_kwargs.get("OPENAI_TOKEN", OPENAI_TOKEN)
        # model_name param takes precedence, then model_kwargs then default env
        self.model_name = model_name or model_kwargs.get("MODEL_NAME", MODEL_NAME)

        # embed defaults (may be overridden by subclasses or model_kwargs)
        self.openai_embed_base_url = model_kwargs.get("OPENAI_EMBED_BASE_URL", OPENAI_EMBED_BASE_URL)
        self.openai_embed_token = model_kwargs.get("OPENAI_EMBED_TOKEN", OPENAI_EMBED_TOKEN)
        self.embed_model_name = model_kwargs.get("EMBED_MODEL_NAME", EMBED_MODEL_NAME)

        # other configurable parameters
        self.max_tokens = int(model_kwargs.get("MAX_TOKENS", MAX_TOKENS))
        self.temperature = float(model_kwargs.get("TEMPERATURE", TEMPERATURE))
        self.top_p = float(model_kwargs.get("TOP_P", TOP_P))
        self.frequency_penalty = float(model_kwargs.get("FREQUENCY_PENALTY", FREQUENCY_PENALTY))
        self.presence_penalty = float(model_kwargs.get("PRESENCE_PENALTY", PRESENCE_PENALTY))
        self.embedding_model_url = model_kwargs.get("EMBEDDING_MODEL_URL", EMBEDDING_MODEL_URL)
        self.embedding_model_api_key = model_kwargs.get("EMBEDDING_MODEL_API_KEY", EMBEDDING_MODEL_API_KEY)
        self.embedding_number_dimensions = int(model_kwargs.get("EMBEDDING_NUMBER_DIMENSIONS", EMBEDDING_NUMBER_DIMENSIONS))

        self.stop_after_attempt = int(model_kwargs.get("STOP_AFTER_ATTEMPT", STOP_AFTER_ATTEMPT))
        self.wait_between_retries = int(model_kwargs.get("WAIT_BETWEEN_RETRIES", WAIT_BETWEEN_RETRIES))
        request_timeout = int(model_kwargs.get("REQUEST_TIMEOUT", REQUEST_TIMEOUT))

        # create per-instance httpx clients in case REQUEST_TIMEOUT was overridden
        self.long_timeout_client = httpx.Client(timeout=request_timeout)
        self.long_timeout_async_client = httpx.AsyncClient(timeout=request_timeout)

        # Initialize query client (shared by all implementations)
        self.client = OpenAI(
            base_url=self.openai_base_url,
            api_key=self.openai_token,
            http_client=self.long_timeout_client,
        )
        self.async_client = AsyncOpenAI(
            base_url=self.openai_base_url,
            api_key=self.openai_token,
            http_client=self.long_timeout_async_client,
        )

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    def query(self, prompt: str, model_name: str) -> str:
        """Query the model with a prompt."""
        if model_name is None:
            model_name = self.model_name
        completion = self.client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        return completion.choices[0].message.content

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    def query_with_instructions(self, prompt: str, instructions: str, model_name: str) -> str:
        """Query the model with additional system instructions."""
        if model_name is None:
            model_name = self.model_name
        completion = self.client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": instructions},
                {"role": "user", "content": prompt}
            ]
        )
        return completion.choices[0].message.content

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    async def query_async(self, prompt: str, model_name: str ) -> str:
        """Async version of query."""
        if model_name is None:
            model_name = self.model_name
        completion = await self.async_client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        return completion.choices[0].message.content

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    async def query_with_instructions_async(self, prompt: str, instructions: str, model_name: str) -> str:
        """Async version of query with instructions."""
        if model_name is None:
            model_name = self.model_name
        completion = await self.async_client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": instructions},
                {"role": "user", "content": prompt}
            ]
        )
        return completion.choices[0].message.content

    @abstractmethod
    def embed(self, text_to_embed: str) -> list:
        """Embed text using the configured embedding model."""
        pass

    @abstractmethod
    async def embed_async(self, text_to_embed: str) -> list:
        """Async version of embed."""
        pass

    @abstractmethod
    def embed_chunk_code(self, code_to_embed: str) -> list:
        """Embed code chunk for storage/indexing."""
        pass

    @abstractmethod
    def embed_query(self, query_to_embed: str) -> list:
        """Embed query for retrieval."""
        pass

    @abstractmethod
    def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
        """Embed multiple texts in a batch for better performance."""
        pass

    @abstractmethod
    def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
        """Embed multiple code chunks in a batch for storage/indexing."""
        pass


class OpenAIModelService(ModelServiceInterface):
    """
    Model service that uses OpenAI client for both queries and embeddings.
    """

    def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None):
        # forward model_kwargs to base so it can set instance-wide config
        super().__init__(model_name=model_name, model_kwargs=model_kwargs)

        # allow override of embed model name via param or model_kwargs
        model_kwargs = model_kwargs or {}
        self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)

        # embed client should use the instance-level embed base/token
        self.embed_client = OpenAI(
            base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
            api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
            http_client=self.long_timeout_client,
        )
        self.async_embed_client = AsyncOpenAI(
            base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
            api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
            http_client=self.long_timeout_async_client,
        )

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    def embed(self, text_to_embed: str) -> list:
        """Embed text using OpenAI embeddings API."""
        response = self.embed_client.embeddings.create(
            input=text_to_embed,
            model=self.embed_model_name,
        )
        return response.data[0].embedding

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    async def embed_async(self, text_to_embed: str) -> list:
        """Async version of embed using OpenAI embeddings API."""
        response = await self.async_embed_client.embeddings.create(
            input=text_to_embed,
            model=self.embed_model_name,
        )
        return response.data[0].embedding

    def embed_chunk_code(self, code_to_embed: str) -> list:
        """Embed code chunk using OpenAI embeddings API (same as embed)."""
        return self.embed(code_to_embed)

    def embed_query(self, query_to_embed: str) -> list:
        """Embed query using OpenAI embeddings API (same as embed)."""
        return self.embed(query_to_embed)

    @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
    def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
        """Embed multiple texts in a batch using OpenAI embeddings API."""
        if not texts_to_embed:
            return []
        response = self.embed_client.embeddings.create(
            input=texts_to_embed,
            model=self.embed_model_name,
        )
        return [item.embedding for item in response.data]

    def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
        """Embed multiple code chunks in a batch using OpenAI embeddings API."""
        return self.embed_batch(codes_to_embed)


class SentenceTransformersModelService(ModelServiceInterface):
    """
    Model service that uses OpenAI client for queries and SentenceTransformers for embeddings.
    Optimized for high-throughput batch embedding with GPU support.
    """

    def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None, skip_embedder: bool = False):
        super().__init__(model_name=model_name, model_kwargs=model_kwargs)
        model_kwargs = model_kwargs or {}
        # embed_model_name may be overridden by model_kwargs
        self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)
        self.skip_embedder = skip_embedder
        self.embedding_model = None

        if skip_embedder:
            self.logger.info('Skipping embedder initialization (keyword-only mode)')
            self.device = "cpu"
            self.encode_batch_size = 32
            return

        # Debug GPU detection
        self.logger.info(f'PyTorch available: {_TORCH_AVAILABLE}')
        if _TORCH_AVAILABLE:
            self.logger.info(f'CUDA available: {torch.cuda.is_available()}')
            self.logger.info(f'CUDA device count: {torch.cuda.device_count()}')
            if torch.cuda.is_available():
                self.logger.info(f'CUDA device name: {torch.cuda.get_device_name(0)}')

        # Select device: prefer CUDA if available
        self.device = "cuda" if (_TORCH_AVAILABLE and torch.cuda.is_available()) else "cpu"
        self.logger.info(f'Initializing SentenceTransformer on device: {self.device}')

        # Set batch size based on device and available memory
        # Larger batch sizes significantly improve GPU throughput
        self.encode_batch_size = int(model_kwargs.get("ENCODE_BATCH_SIZE", 64 if self.device == "cuda" else 32))
        
        # Show CUDA memory info if available
        if self.device == "cuda" and _TORCH_AVAILABLE:
            try:
                gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
                self.logger.info(f'GPU memory available: {gpu_memory:.2f} GB')
                # Adjust batch size based on available GPU memory
                if gpu_memory > 16:
                    self.encode_batch_size = max(self.encode_batch_size, 128)
                elif gpu_memory > 8:
                    self.encode_batch_size = max(self.encode_batch_size, 64)
            except Exception as e:
                self.logger.warning(f'Could not get GPU memory info: {e}')

        self.logger.info(f'Using encode batch size: {self.encode_batch_size}')

        # Initialize embedding model on the chosen device with performance optimizations
        self.embedding_model = SentenceTransformer(
            self.embed_model_name,
            trust_remote_code=True,
            device=self.device
        )
        
        # Enable half precision for faster inference on CUDA
        if self.device == "cuda" and _TORCH_AVAILABLE:
            try:
                # Check if model supports half precision
                self.embedding_model.half()
                self.logger.info('Enabled half precision (FP16) for faster GPU inference')
            except Exception as e:
                self.logger.warning(f'Could not enable half precision: {e}')

    def _check_embedder(self):
        """Check if embedder is available, raise error if not."""
        if self.skip_embedder or self.embedding_model is None:
            raise RuntimeError(
                "Embedding model not initialized. This model service was created with skip_embedder=True "
                "(keyword-only mode). To use embeddings, set index_type to 'hybrid' or 'embedding-only'."
            )

    def embed(self, text_to_embed: str) -> list:
        """Embed text using SentenceTransformers."""
        self._check_embedder()
        embeddings = self.embedding_model.encode(
            [text_to_embed],
            convert_to_numpy=True,
            show_progress_bar=False
        )
        return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])

    async def embed_async(self, text_to_embed: str) -> list:
        """
        Async version of embed using SentenceTransformers.
        Note: SentenceTransformers doesn't have native async support,
        so this runs synchronously but maintains the async interface.
        """
        return self.embed(text_to_embed)

    def embed_chunk_code(self, code_to_embed: str) -> list:
        """Embed code chunk using SentenceTransformers (no special prompt)."""
        self._check_embedder()
        self.logger.debug(f'Embedding code using {self.embed_model_name}')
        embeddings = self.embedding_model.encode(
            [code_to_embed],
            convert_to_numpy=True,
            show_progress_bar=False
        )
        return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])

    def embed_query(self, query_to_embed: str) -> list:
        """Embed query using SentenceTransformers with retrieval prompt."""
        self._check_embedder()
        self.logger.debug(f'Embedding query using {self.embed_model_name}')
        embeddings = self.embedding_model.encode(
            [query_to_embed],
            prompt='Given this prompt, retrieve relevant content\n Query:',
            convert_to_numpy=True,
            show_progress_bar=False
        )
        return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])

    def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
        """Embed multiple texts in a batch using SentenceTransformers with optimized settings."""
        if not texts_to_embed:
            return []
        self._check_embedder()
        self.logger.info(f'Batch embedding {len(texts_to_embed)} texts using {self.embed_model_name}')
        embeddings = self.embedding_model.encode(
            texts_to_embed,
            batch_size=self.encode_batch_size,
            convert_to_numpy=True,
            show_progress_bar=len(texts_to_embed) > 100,  # Only show progress for large batches
            normalize_embeddings=True  # Normalize for better similarity computation
        )
        return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]

    def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
        """Embed multiple code chunks in a batch using SentenceTransformers with optimized settings."""
        if not codes_to_embed:
            return []
        self._check_embedder()
        self.logger.info(f'Batch embedding {len(codes_to_embed)} code chunks using {self.embed_model_name}')
        embeddings = self.embedding_model.encode(
            codes_to_embed,
            batch_size=self.encode_batch_size,
            convert_to_numpy=True,
            show_progress_bar=len(codes_to_embed) > 100,  # Only show progress for large batches
            normalize_embeddings=True  # Normalize for better similarity computation
        )
        return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]


def create_model_service(skip_embedder: bool = False, **kwargs) -> ModelServiceInterface:
    """
    Factory function to create the appropriate ModelService based on embedder_type.

    Args:
        skip_embedder (bool): If True, skip loading the embedding model (for keyword-only search).
        **kwargs: Additional arguments including 'embedder_type' ('openai' or 'sentence-transformers')
                and optional 'model_kwargs' dict which can override any env var defaults.
    Returns:
        ModelServiceInterface: An instance of the appropriate ModelService
    """
    model_kwargs = kwargs.pop('model_kwargs', None)
    embedder_type = kwargs.pop('embedder_type', 'openai')

    if embedder_type == 'openai':
        return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)
    elif embedder_type == 'sentence-transformers':
        return SentenceTransformersModelService(model_kwargs=model_kwargs, skip_embedder=skip_embedder, **kwargs)
    else:
        logging.getLogger(LOGGER_NAME).warning(
            f'Unknown embedder type: {embedder_type}, defaulting to OpenAI'
        )
        return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)