| import threading |
| import time |
| from typing import Any, Dict, Generator, List, Optional |
|
|
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
|
|
| from src.core.config import settings |
|
|
|
|
| class ModelEngine: |
| def __init__(self): |
| self.llm = None |
| self.lock = threading.Lock() |
| self._load_model() |
|
|
| def _load_model(self): |
| try: |
| print(f"Downloading/Loading model: {settings.REPO_ID}...") |
| model_path = hf_hub_download( |
| repo_id=settings.REPO_ID, filename=settings.FILENAME |
| ) |
| self.llm = Llama( |
| model_path=model_path, |
| n_ctx=settings.CONTEXT_SIZE, |
| n_threads=settings.N_THREADS, |
| n_gpu_layers=settings.N_GPU_LAYERS, |
| verbose=True, |
| ) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"CRITICAL ERROR loading model: {e}") |
|
|
| def generate_stream( |
| self, |
| messages: List[Dict[str, str]], |
| abort_event: Optional[threading.Event] = None, |
| **kwargs, |
| ) -> Generator: |
| if not self.llm: |
| raise RuntimeError("Model not loaded") |
|
|
| max_tokens = kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS) |
| temperature = kwargs.get("temperature", settings.DEFAULT_TEMP) |
| stop = kwargs.get("stop", []) |
|
|
| acquired = False |
| while not acquired: |
| if abort_event and abort_event.is_set(): |
| print("Request aborted while waiting in queue.") |
| return |
|
|
| acquired = self.lock.acquire(timeout=0.5) |
|
|
| try: |
| if abort_event and abort_event.is_set(): |
| return |
|
|
| stream = self.llm.create_chat_completion( |
| messages=messages, |
| max_tokens=int(max_tokens), |
| temperature=float(temperature), |
| stop=stop, |
| stream=True, |
| top_p=kwargs.get("top_p", 0.95), |
| ) |
|
|
| for chunk in stream: |
| if abort_event and abort_event.is_set(): |
| print("Request aborted during generation.") |
| break |
|
|
| yield chunk |
|
|
| finally: |
| self.lock.release() |
|
|
|
|
| engine = ModelEngine() |
|
|