File size: 2,350 Bytes
f9aca5d
010db11
 
54f6dae
 
 
 
42fa16e
54f6dae
 
 
 
 
f9aca5d
54f6dae
 
 
 
f9aca5d
42fa16e
 
 
54f6dae
 
42fa16e
 
 
010db11
54f6dae
f9aca5d
54f6dae
f9aca5d
54f6dae
010db11
 
 
 
 
 
54f6dae
42fa16e
f9aca5d
a8a31d7
 
 
 
010db11
 
 
 
 
 
 
 
 
 
 
 
f9aca5d
 
 
 
a8a31d7
f9aca5d
a8a31d7
f9aca5d
010db11
f9aca5d
010db11
 
 
 
f9aca5d
54f6dae
010db11
 
 
54f6dae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import threading
import time
from typing import Any, Dict, Generator, List, Optional

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

from src.core.config import settings


class ModelEngine:
    def __init__(self):
        self.llm = None
        self.lock = threading.Lock()
        self._load_model()

    def _load_model(self):
        try:
            print(f"Downloading/Loading model: {settings.REPO_ID}...")
            model_path = hf_hub_download(
                repo_id=settings.REPO_ID, filename=settings.FILENAME
            )
            self.llm = Llama(
                model_path=model_path,
                n_ctx=settings.CONTEXT_SIZE,
                n_threads=settings.N_THREADS,
                n_gpu_layers=settings.N_GPU_LAYERS,
                verbose=True,
            )
            print("Model loaded successfully!")
        except Exception as e:
            print(f"CRITICAL ERROR loading model: {e}")

    def generate_stream(
        self,
        messages: List[Dict[str, str]],
        abort_event: Optional[threading.Event] = None,  # Новый аргумент
        **kwargs,
    ) -> Generator:
        if not self.llm:
            raise RuntimeError("Model not loaded")

        max_tokens = kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
        temperature = kwargs.get("temperature", settings.DEFAULT_TEMP)
        stop = kwargs.get("stop", [])

        acquired = False
        while not acquired:
            if abort_event and abort_event.is_set():
                print("Request aborted while waiting in queue.")
                return

            acquired = self.lock.acquire(timeout=0.5)

        try:
            if abort_event and abort_event.is_set():
                return

            stream = self.llm.create_chat_completion(
                messages=messages,
                max_tokens=int(max_tokens),
                temperature=float(temperature),
                stop=stop,
                stream=True,
                top_p=kwargs.get("top_p", 0.95),
            )

            for chunk in stream:
                if abort_event and abort_event.is_set():
                    print("Request aborted during generation.")
                    break

                yield chunk

        finally:
            self.lock.release()


engine = ModelEngine()