Spaces:
Sleeping
Sleeping
File size: 6,118 Bytes
bdca525 b4ecb60 5160420 bdca525 b4ecb60 bdca525 b4ecb60 5160420 5f43529 b4ecb60 0304bfe 5f43529 3a14fb3 5f43529 b4ecb60 bdca525 3a14fb3 a251128 b4ecb60 5160420 b4ecb60 5160420 0304bfe b4ecb60 bdca525 5160420 b4ecb60 bdca525 b4ecb60 bdca525 d24a753 5160420 e9e9e0c d99243b e9e9e0c bdca525 8cc5c82 5f43529 5160420 bdca525 5f43529 bdca525 d24a753 bdca525 d24a753 bdca525 d24a753 8cc5c82 d24a753 8cc5c82 5f43529 0304bfe e9e9e0c 8cc5c82 0304bfe 8cc5c82 28295c6 a251128 e96d38d 8cc5c82 5f43529 e9e9e0c 5f43529 b4ecb60 5160420 d24a753 5f43529 d24a753 b4ecb60 5f43529 5160420 b4ecb60 bdca525 5f43529 5160420 d24a753 b4ecb60 bdca525 b4ecb60 5160420 d24a753 e9e9e0c d24a753 bdca525 b709bb5 d24a753 3328c8a d24a753 3328c8a b709bb5 3328c8a b4ecb60 bdca525 8cc5c82 5160420 b4ecb60 bdca525 d24a753 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | import os
from typing import Dict
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
class LlamaCppGemmaModel:
"""
A class for the Gemma model using llama.cpp. This class replicates the API of the original
HuggingFaceGemmaModel but uses llama.cpp for inference. It handles model selection, loading,
downloading (if necessary), and text generation.
Available Models (ensure the repo_id and filename match the GGUF file on Hugging Face):
- gemma-2b: 2B parameters, base model
- gemma-2b-it: 2B parameters, instruction-tuned
- gemma-7b: 7B parameters, base model
- gemma-7b-it: 7B parameters, instruction-tuned
All models will be stored in the "models/" directory.
"""
# Class variable to cache loaded models
_model_cache = {}
AVAILABLE_MODELS: Dict[str, Dict] = {
"gemma-3b": {
"model_path": "models/gemma-3-1b-it-Q5_K_M.gguf",
"repo_id": "bartowski/google_gemma-3-1b-it-GGUF",
"filename": "google_gemma-3-1b-it-Q5_K_M.gguf", # Better quantization
"description": "3B parameters, instruction-tuned (Q5_K_M)",
"type": "instruct",
},
"gemma-2b": {
"model_path": "models/gemma-2b-it.gguf",
"repo_id": "MaziyarPanahi/gemma-2b-it-GGUF",
"filename": "gemma-2b-it.Q4_K_M.gguf",
"description": "2B parameters, instruction-tuned",
"type": "instruct",
},
}
def __init__(self, name: str = "gemma-3b"):
"""
Initialize the model instance.
Args:
name (str): The model name (should match one of the AVAILABLE_MODELS keys).
"""
self.name = name
self.model = None # Instance of Llama from llama.cpp
self.messages = []
# Model response generation attributes
self.max_tokens = 512
self.temperature = 0.7
self.top_p = 0.95
self.top_k = 40
self.repeat_penalty = 1.1
def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0, system_prompt=""):
"""
Load the model. If the model file does not exist, it will be downloaded.
Uses caching to avoid reloading models unnecessarily.
Args:
n_ctx (int): Context window size.
n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
"""
# Check if model is already loaded in cache
cache_key = f"{self.name}_{n_ctx}_{n_gpu_layers}"
if cache_key in LlamaCppGemmaModel._model_cache:
self.model = LlamaCppGemmaModel._model_cache[cache_key]
return self
model_info = self.AVAILABLE_MODELS.get(self.name)
if not model_info:
raise ValueError(f"Model {self.name} is not available.")
model_path = model_info["model_path"]
# If the model file doesn't exist, download it.
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
repo_id = model_info.get("repo_id")
filename = model_info.get("filename")
if repo_id is None or filename is None:
raise ValueError(
"Repository ID or filename is missing for model download."
)
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=os.path.dirname(model_path),
local_dir_use_symlinks=False,
)
if downloaded_path != model_path:
os.rename(downloaded_path, model_path)
_threads = min(2, os.cpu_count() or 1)
_sys_prompt = {"role": "system", "content": system_prompt}
self.model = Llama(
model_path=model_path,
n_threads=_threads,
n_threads_batch=_threads,
n_ctx=n_ctx,
n_gpu_layers=n_gpu_layers,
n_batch=8,
verbose=False,
chat_format="gemma",
)
self.messages.append(_sys_prompt)
# Cache the model for future use
LlamaCppGemmaModel._model_cache[cache_key] = self.model
return self
def generate_response(
self,
prompt: str,
):
"""
Generate a response using the llama.cpp model with optimized parameters.
Args:
prompt (str): Input prompt text.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature (higher = more creative).
top_p (float): Nucleus sampling threshold.
top_k (int): Limit vocabulary choices to top K tokens.
repeat_penalty (float): Penalize repeated words.
Yields:
str: Generated response text as a stream.
"""
if self.model is None:
self.load_model()
self.messages.append({"role": "user", "content": prompt})
response_stream = self.model.create_chat_completion(
messages=self.messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
repeat_penalty=self.repeat_penalty,
stream=True,
)
self.messages.append({"role": "assistant", "content": ""})
outputs = ""
for chunk in response_stream:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
outputs += delta["content"]
self.messages[-1]["content"] += delta["content"]
yield outputs
def get_model_info(self) -> Dict:
"""
Return information about the model.
Returns:
Dict: A dictionary containing the model name and load status.
"""
return {"name": self.name, "loaded": self.model is not None}
def get_model_name(self) -> str:
"""
Return the name of the model.
Returns:
str: Model name.
"""
return self.name
|