| import threading |
| from typing import Any, Dict, Generator, List |
|
|
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
|
|
| from src.core.config import settings |
|
|
|
|
| class ModelEngine: |
| def __init__(self): |
| self.llm = None |
| self.lock = threading.Lock() |
| self._load_model() |
|
|
| def _load_model(self): |
| try: |
| print(f"Downloading/Loading model: {settings.REPO_ID}...") |
| model_path = hf_hub_download( |
| repo_id=settings.REPO_ID, filename=settings.FILENAME |
| ) |
| self.llm = Llama( |
| model_path=model_path, |
| n_ctx=settings.CONTEXT_SIZE, |
| n_threads=settings.N_THREADS, |
| n_gpu_layers=settings.N_GPU_LAYERS, |
| verbose=True, |
| ) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"CRITICAL ERROR loading model: {e}") |
|
|
| def generate_stream( |
| self, messages: List[Dict[str, str]], max_tokens: int, temperature: float |
| ) -> Generator: |
| if not self.llm: |
| raise RuntimeError("Model not loaded") |
|
|
| with self.lock: |
| stream = self.llm.create_chat_completion( |
| messages=messages, |
| max_tokens=int(max_tokens), |
| temperature=float(temperature), |
| stream=True, |
| ) |
| for chunk in stream: |
| yield chunk |
|
|
|
|
| engine = ModelEngine() |
|
|