File size: 2,350 Bytes
f9aca5d 010db11 54f6dae 42fa16e 54f6dae f9aca5d 54f6dae f9aca5d 42fa16e 54f6dae 42fa16e 010db11 54f6dae f9aca5d 54f6dae f9aca5d 54f6dae 010db11 54f6dae 42fa16e f9aca5d a8a31d7 010db11 f9aca5d a8a31d7 f9aca5d a8a31d7 f9aca5d 010db11 f9aca5d 010db11 f9aca5d 54f6dae 010db11 54f6dae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | import threading
import time
from typing import Any, Dict, Generator, List, Optional
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from src.core.config import settings
class ModelEngine:
def __init__(self):
self.llm = None
self.lock = threading.Lock()
self._load_model()
def _load_model(self):
try:
print(f"Downloading/Loading model: {settings.REPO_ID}...")
model_path = hf_hub_download(
repo_id=settings.REPO_ID, filename=settings.FILENAME
)
self.llm = Llama(
model_path=model_path,
n_ctx=settings.CONTEXT_SIZE,
n_threads=settings.N_THREADS,
n_gpu_layers=settings.N_GPU_LAYERS,
verbose=True,
)
print("Model loaded successfully!")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
def generate_stream(
self,
messages: List[Dict[str, str]],
abort_event: Optional[threading.Event] = None, # Новый аргумент
**kwargs,
) -> Generator:
if not self.llm:
raise RuntimeError("Model not loaded")
max_tokens = kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
temperature = kwargs.get("temperature", settings.DEFAULT_TEMP)
stop = kwargs.get("stop", [])
acquired = False
while not acquired:
if abort_event and abort_event.is_set():
print("Request aborted while waiting in queue.")
return
acquired = self.lock.acquire(timeout=0.5)
try:
if abort_event and abort_event.is_set():
return
stream = self.llm.create_chat_completion(
messages=messages,
max_tokens=int(max_tokens),
temperature=float(temperature),
stop=stop,
stream=True,
top_p=kwargs.get("top_p", 0.95),
)
for chunk in stream:
if abort_event and abort_event.is_set():
print("Request aborted during generation.")
break
yield chunk
finally:
self.lock.release()
engine = ModelEngine()
|