File size: 6,118 Bytes
bdca525
 
 
 
b4ecb60
5160420
bdca525
b4ecb60
bdca525
 
 
 
 
 
 
 
 
 
 
b4ecb60
5160420
5f43529
 
 
b4ecb60
0304bfe
5f43529
3a14fb3
5f43529
 
 
 
b4ecb60
bdca525
3a14fb3
a251128
b4ecb60
5160420
b4ecb60
 
5160420
0304bfe
b4ecb60
bdca525
5160420
b4ecb60
bdca525
b4ecb60
bdca525
 
d24a753
5160420
e9e9e0c
d99243b
 
 
 
 
e9e9e0c
 
bdca525
8cc5c82
5f43529
5160420
bdca525
 
 
 
5f43529
 
 
 
 
 
bdca525
 
 
 
 
d24a753
bdca525
 
 
 
 
d24a753
bdca525
d24a753
 
 
 
8cc5c82
 
 
 
 
 
d24a753
8cc5c82
 
 
5f43529
0304bfe
e9e9e0c
 
8cc5c82
 
0304bfe
 
8cc5c82
 
28295c6
a251128
e96d38d
8cc5c82
5f43529
e9e9e0c
 
5f43529
 
b4ecb60
5160420
d24a753
5f43529
 
d24a753
b4ecb60
5f43529
5160420
b4ecb60
bdca525
 
 
5f43529
 
 
5160420
d24a753
 
b4ecb60
bdca525
b4ecb60
5160420
d24a753
 
 
 
e9e9e0c
 
 
 
 
d24a753
bdca525
b709bb5
d24a753
3328c8a
d24a753
 
 
3328c8a
b709bb5
3328c8a
 
b4ecb60
bdca525
 
 
 
 
 
8cc5c82
5160420
b4ecb60
bdca525
 
 
 
 
 
d24a753
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
from typing import Dict
from llama_cpp import Llama
from huggingface_hub import hf_hub_download


class LlamaCppGemmaModel:
    """
    A class for the Gemma model using llama.cpp. This class replicates the API of the original
    HuggingFaceGemmaModel but uses llama.cpp for inference. It handles model selection, loading,
    downloading (if necessary), and text generation.

    Available Models (ensure the repo_id and filename match the GGUF file on Hugging Face):
    - gemma-2b: 2B parameters, base model
    - gemma-2b-it: 2B parameters, instruction-tuned
    - gemma-7b: 7B parameters, base model
    - gemma-7b-it: 7B parameters, instruction-tuned

    All models will be stored in the "models/" directory.
    """

    # Class variable to cache loaded models
    _model_cache = {}

    AVAILABLE_MODELS: Dict[str, Dict] = {
        "gemma-3b": {
            "model_path": "models/gemma-3-1b-it-Q5_K_M.gguf",
            "repo_id": "bartowski/google_gemma-3-1b-it-GGUF",
            "filename": "google_gemma-3-1b-it-Q5_K_M.gguf",  # Better quantization
            "description": "3B parameters, instruction-tuned (Q5_K_M)",
            "type": "instruct",
        },
        "gemma-2b": {
            "model_path": "models/gemma-2b-it.gguf",
            "repo_id": "MaziyarPanahi/gemma-2b-it-GGUF",
            "filename": "gemma-2b-it.Q4_K_M.gguf",
            "description": "2B parameters, instruction-tuned",
            "type": "instruct",
        },
    }

    def __init__(self, name: str = "gemma-3b"):
        """
        Initialize the model instance.

        Args:
            name (str): The model name (should match one of the AVAILABLE_MODELS keys).
        """
        self.name = name
        self.model = None  # Instance of Llama from llama.cpp
        self.messages = []

        # Model response generation attributes
        self.max_tokens = 512
        self.temperature = 0.7
        self.top_p = 0.95
        self.top_k = 40
        self.repeat_penalty = 1.1

    def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0, system_prompt=""):
        """
        Load the model. If the model file does not exist, it will be downloaded.
        Uses caching to avoid reloading models unnecessarily.

        Args:
            n_ctx (int): Context window size.
            n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
        """
        # Check if model is already loaded in cache
        cache_key = f"{self.name}_{n_ctx}_{n_gpu_layers}"
        if cache_key in LlamaCppGemmaModel._model_cache:
            self.model = LlamaCppGemmaModel._model_cache[cache_key]
            return self

        model_info = self.AVAILABLE_MODELS.get(self.name)
        if not model_info:
            raise ValueError(f"Model {self.name} is not available.")

        model_path = model_info["model_path"]

        # If the model file doesn't exist, download it.
        if not os.path.exists(model_path):
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            repo_id = model_info.get("repo_id")
            filename = model_info.get("filename")

            if repo_id is None or filename is None:
                raise ValueError(
                    "Repository ID or filename is missing for model download."
                )

            downloaded_path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=os.path.dirname(model_path),
                local_dir_use_symlinks=False,
            )

            if downloaded_path != model_path:
                os.rename(downloaded_path, model_path)

        _threads = min(2, os.cpu_count() or 1)

        _sys_prompt = {"role": "system", "content": system_prompt}

        self.model = Llama(
            model_path=model_path,
            n_threads=_threads,
            n_threads_batch=_threads,
            n_ctx=n_ctx,
            n_gpu_layers=n_gpu_layers,
            n_batch=8,
            verbose=False,
            chat_format="gemma",
        )

        self.messages.append(_sys_prompt)

        # Cache the model for future use
        LlamaCppGemmaModel._model_cache[cache_key] = self.model
        return self

    def generate_response(
        self,
        prompt: str,
    ):
        """
        Generate a response using the llama.cpp model with optimized parameters.

        Args:
            prompt (str): Input prompt text.
            max_tokens (int): Maximum number of tokens to generate.
            temperature (float): Sampling temperature (higher = more creative).
            top_p (float): Nucleus sampling threshold.
            top_k (int): Limit vocabulary choices to top K tokens.
            repeat_penalty (float): Penalize repeated words.

        Yields:
            str: Generated response text as a stream.
        """
        if self.model is None:
            self.load_model()

        self.messages.append({"role": "user", "content": prompt})

        response_stream = self.model.create_chat_completion(
            messages=self.messages,
            max_tokens=self.max_tokens,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            repeat_penalty=self.repeat_penalty,
            stream=True,
        )
        self.messages.append({"role": "assistant", "content": ""})

        outputs = ""
        for chunk in response_stream:
            delta = chunk["choices"][0]["delta"]
            if "content" in delta:
                outputs += delta["content"]
                self.messages[-1]["content"] += delta["content"]
                yield outputs

    def get_model_info(self) -> Dict:
        """
        Return information about the model.

        Returns:
            Dict: A dictionary containing the model name and load status.
        """
        return {"name": self.name, "loaded": self.model is not None}

    def get_model_name(self) -> str:
        """
        Return the name of the model.

        Returns:
            str: Model name.
        """
        return self.name