aadya1762 commited on
Commit
bdca525
·
1 Parent(s): 5160420

use llama.cpp

Browse files
Files changed (4) hide show
  1. app.py +4 -4
  2. gemmademo/_chat.py +6 -6
  3. gemmademo/_model.py +113 -131
  4. requirements.txt +1 -3
app.py CHANGED
@@ -6,7 +6,7 @@
6
 
7
  import streamlit as st
8
  from gemmademo import (
9
- HuggingFaceGemmaModel,
10
  StreamlitChat,
11
  PromptManager,
12
  huggingface_login,
@@ -51,7 +51,7 @@ def main():
51
 
52
  # Model selection
53
  st.subheader("Model Selection")
54
- model_options = list(HuggingFaceGemmaModel.AVAILABLE_MODELS.keys())
55
  selected_model = st.selectbox(
56
  "Select Gemma Model",
57
  model_options,
@@ -82,10 +82,10 @@ def main():
82
  # Main content area
83
  if st.session_state.authenticated:
84
  # Initialize model with the selected configuration
85
- model_name = HuggingFaceGemmaModel.AVAILABLE_MODELS[
86
  st.session_state.selected_model
87
  ]["name"]
88
- model = HuggingFaceGemmaModel(name=model_name)
89
 
90
  # Load model (will use cached version if available)
91
  with st.spinner(f"Loading {model_name}..."):
 
6
 
7
  import streamlit as st
8
  from gemmademo import (
9
+ LlamaCppGemmaModel,
10
  StreamlitChat,
11
  PromptManager,
12
  huggingface_login,
 
51
 
52
  # Model selection
53
  st.subheader("Model Selection")
54
+ model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
55
  selected_model = st.selectbox(
56
  "Select Gemma Model",
57
  model_options,
 
82
  # Main content area
83
  if st.session_state.authenticated:
84
  # Initialize model with the selected configuration
85
+ model_name = LlamaCppGemmaModel.AVAILABLE_MODELS[
86
  st.session_state.selected_model
87
  ]["name"]
88
+ model = LlamaCppGemmaModel(name=model_name)
89
 
90
  # Load model (will use cached version if available)
91
  with st.spinner(f"Loading {model_name}..."):
gemmademo/_chat.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from ._model import HuggingFaceGemmaModel
3
  from ._prompts import PromptManager
4
 
5
 
@@ -8,13 +8,13 @@ class StreamlitChat:
8
  A class that handles the chat interface for the Gemma model.
9
 
10
  Features:
11
- A Streamlit-based chatbot UI.
12
- Maintains chat history across reruns.
13
- Uses Gemma (Hugging Face) model for generating responses.
14
- Formats user inputs before sending them to the model.
15
  """
16
 
17
- def __init__(self, model: HuggingFaceGemmaModel, prompt_manager: PromptManager):
18
  self.model = model
19
  self.prompt_manager = prompt_manager
20
 
 
1
  import streamlit as st
2
+ from ._model import LlamaCppGemmaModel
3
  from ._prompts import PromptManager
4
 
5
 
 
8
  A class that handles the chat interface for the Gemma model.
9
 
10
  Features:
11
+ - A Streamlit-based chatbot UI.
12
+ - Maintains chat history across reruns.
13
+ - Uses Gemma (Hugging Face) model for generating responses.
14
+ - Formats user inputs before sending them to the model.
15
  """
16
 
17
+ def __init__(self, model: LlamaCppGemmaModel, prompt_manager: PromptManager):
18
  self.model = model
19
  self.prompt_manager = prompt_manager
20
 
gemmademo/_model.py CHANGED
@@ -1,191 +1,173 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
- import torch
3
- from typing import Dict, Optional
4
  import streamlit as st
 
 
5
 
6
- torch.classes.__path__ = (
7
- []
8
- ) # add this line to manually set it to empty. If not done, this throws a warning.
9
 
10
-
11
- def load_model(name: str, device_map: str = "cpu"):
12
- """
13
- Model loading function that loads the model without caching
14
  """
15
- import torch._dynamo
16
-
17
- torch._dynamo.config.suppress_errors = True # Already in your code
18
- torch._dynamo.config.cache_size_limit = 64 # Increase cache limit
19
- torch._dynamo.config.force_inference_mode = True # Reduce recompilations
20
- torch._dynamo.config.suppress_errors = True
21
-
22
- tokenizer = AutoTokenizer.from_pretrained(name)
23
-
24
- model = AutoModelForCausalLM.from_pretrained(
25
- name,
26
- torch_dtype=torch.bfloat16,
27
- low_cpu_mem_usage=True,
28
- device_map=device_map,
29
- use_safetensors=True,
30
- use_flash_attention_2=False,
31
- use_cache=True,
32
- load_in_8bit=True,
33
- )
34
-
35
- pipe = pipeline(
36
- "text-generation",
37
- model=model,
38
- tokenizer=tokenizer,
39
- device_map=device_map,
40
- torch_dtype=torch.bfloat16,
41
- do_sample=True,
42
- temperature=0.7,
43
- max_new_tokens=512,
44
- pad_token_id=tokenizer.eos_token_id,
45
- eos_token_id=tokenizer.eos_token_id,
46
- return_full_text=False,
47
- )
48
-
49
- return tokenizer, model, pipe
50
-
51
-
52
- class HuggingFaceGemmaModel:
53
- """
54
- A class for the Hugging Face Gemma model. Handles model selection, loading, and inference.
55
- Uses transformers pipeline for better text generation and formatting.
56
-
57
- Example
58
- -------
59
- Select Gemma 2B, 7B etc.
60
-
61
- Additional Information:
62
- ----------------------
63
- Complete Information: https://huggingface.co/google/gemma-2b
64
-
65
- Available Models:
66
- - google/gemma-2b (2B parameters, base)
67
- - google/gemma-2b-it (2B parameters, instruction-tuned)
68
- - google/gemma-7b (7B parameters, base)
69
- - google/gemma-7b-it (7B parameters, instruction-tuned)
70
  """
71
 
72
  AVAILABLE_MODELS: Dict[str, Dict] = {
73
  "gemma-2b": {
74
- "name": "google/gemma-2b",
 
 
75
  "description": "2B parameters, base model",
76
  "type": "base",
77
  },
78
  "gemma-2b-it": {
79
- "name": "google/gemma-2b-it",
 
 
80
  "description": "2B parameters, instruction-tuned",
81
  "type": "instruct",
82
  },
83
  "gemma-7b": {
84
- "name": "google/gemma-7b",
 
 
85
  "description": "7B parameters, base model",
86
  "type": "base",
87
  },
88
  "gemma-7b-it": {
89
- "name": "google/gemma-7b-it",
 
 
90
  "description": "7B parameters, instruction-tuned",
91
  "type": "instruct",
92
  },
 
 
 
 
 
 
 
93
  }
94
 
95
- def __init__(self, name: str = "google/gemma-2b"):
96
- self.name = name
97
- self.model = None
98
- self.tokenizer = None
99
- self.pipeline = None
100
-
101
- def load_model(self, device_map: str = "cpu"):
102
  """
103
- Load the model using session state
104
 
105
  Args:
106
- device_map: Device mapping strategy (should be "cpu" for CPU-only inference)
107
  """
108
- # Create a unique key for this model in session state
109
- model_key = f"gemma_model_{self.name}"
110
- tokenizer_key = f"gemma_tokenizer_{self.name}"
111
- pipeline_key = f"gemma_pipeline_{self.name}"
112
 
113
- # Check if model is already loaded in session state
114
- if (
115
- model_key not in st.session_state
116
- or tokenizer_key not in st.session_state
117
- or pipeline_key not in st.session_state
118
- ):
119
 
120
- # Show loading indicator
121
- with st.spinner(f"Loading {self.name}..."):
122
- tokenizer, model, pipe = load_model(self.name, device_map)
 
123
 
124
- # Store in session state
125
- st.session_state[tokenizer_key] = tokenizer
126
- st.session_state[model_key] = model
127
- st.session_state[pipeline_key] = pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # Get model from session state
130
- self.tokenizer = st.session_state[tokenizer_key]
 
 
 
 
 
 
 
131
  self.model = st.session_state[model_key]
132
- self.pipeline = st.session_state[pipeline_key]
133
-
134
  return self
135
 
136
  def generate_response(
137
  self,
138
  prompt: str,
139
- max_length: int = 512,
140
  temperature: float = 0.7,
141
- num_return_sequences: int = 1,
142
  **kwargs,
143
  ) -> str:
144
  """
145
- Generate a response using the text generation pipeline
146
 
147
  Args:
148
- prompt: Input text
149
- max_length: Maximum number of new tokens to generate
150
- temperature: Sampling temperature (higher = more creative)
151
- num_return_sequences: Number of responses to generate
152
- **kwargs: Additional generation parameters for the pipeline
153
 
154
  Returns:
155
- str: Generated response
156
  """
157
- if not self.pipeline:
158
  self.load_model()
159
 
160
- # Update generation config with any provided kwargs
161
- generation_config = {
162
- "max_new_tokens": max_length,
163
- "temperature": temperature,
164
- "num_return_sequences": num_return_sequences,
165
- "do_sample": True,
166
  **kwargs,
167
- }
168
-
169
- # Generate response using the pipeline
170
- outputs = self.pipeline(prompt, **generation_config)
171
-
172
- # Extract the generated text
173
- if num_return_sequences == 1:
174
- response = outputs[0]["generated_text"]
175
- else:
176
- # Join multiple sequences if requested
177
- response = "\n---\n".join(output["generated_text"] for output in outputs)
178
-
179
- return response.strip()
180
 
181
  def get_model_info(self) -> Dict:
182
- """Return information about the model"""
 
 
 
 
 
183
  return {
184
  "name": self.name,
185
  "loaded": self.model is not None,
186
- "pipeline_ready": self.pipeline is not None,
187
  }
188
 
189
  def get_model_name(self) -> str:
190
- """Return the name of the model"""
 
 
 
 
 
191
  return self.name
 
1
+ import os
2
+ from typing import Dict
 
3
  import streamlit as st
4
+ from llama_cpp import Llama
5
+ from huggingface_hub import hf_hub_download
6
 
 
 
 
7
 
8
+ class LlamaCppGemmaModel:
 
 
 
9
  """
10
+ A class for the Gemma model using llama.cpp. This class replicates the API of the original
11
+ HuggingFaceGemmaModel but uses llama.cpp for inference. It handles model selection, loading,
12
+ downloading (if necessary), and text generation.
13
+
14
+ Available Models (ensure the repo_id and filename match the GGUF file on Hugging Face):
15
+ - gemma-2b: 2B parameters, base model
16
+ - gemma-2b-it: 2B parameters, instruction-tuned
17
+ - gemma-7b: 7B parameters, base model
18
+ - gemma-7b-it: 7B parameters, instruction-tuned
19
+ - gemma-7b-gguf: 7B parameters in GGUF format
20
+
21
+ All models will be stored in the "models/" directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
 
24
  AVAILABLE_MODELS: Dict[str, Dict] = {
25
  "gemma-2b": {
26
+ "model_path": "models/gemma-2b.gguf",
27
+ "repo_id": "google/gemma-2b", # update to the actual repo id
28
+ "filename": "gemma-2b.gguf", # update to the actual filename
29
  "description": "2B parameters, base model",
30
  "type": "base",
31
  },
32
  "gemma-2b-it": {
33
+ "model_path": "models/gemma-2b-it.gguf",
34
+ "repo_id": "google/gemma-2b-it", # update to the actual repo id
35
+ "filename": "gemma-2b-it.gguf", # update to the actual filename
36
  "description": "2B parameters, instruction-tuned",
37
  "type": "instruct",
38
  },
39
  "gemma-7b": {
40
+ "model_path": "models/gemma-7b.gguf",
41
+ "repo_id": "google/gemma-7b", # update to the actual repo id
42
+ "filename": "gemma-7b.gguf", # update to the actual filename
43
  "description": "7B parameters, base model",
44
  "type": "base",
45
  },
46
  "gemma-7b-it": {
47
+ "model_path": "models/gemma-7b-it.gguf",
48
+ "repo_id": "google/gemma-7b-it", # update to the actual repo id
49
+ "filename": "gemma-7b-it.gguf", # update to the actual filename
50
  "description": "7B parameters, instruction-tuned",
51
  "type": "instruct",
52
  },
53
+ "gemma-7b-gguf": {
54
+ "model_path": "models/gemma-7b.gguf",
55
+ "repo_id": "google/gemma-7b-GGUF", # repository for the GGUF model
56
+ "filename": "gemma-7b.gguf", # updated filename for GGUF model
57
+ "description": "7B parameters in GGUF format",
58
+ "type": "base",
59
+ },
60
  }
61
 
62
+ def __init__(self, name: str = "gemma-2b"):
 
 
 
 
 
 
63
  """
64
+ Initialize the model instance.
65
 
66
  Args:
67
+ name (str): The model name (should match one of the AVAILABLE_MODELS keys).
68
  """
69
+ self.name = name
70
+ self.model = None # Instance of Llama from llama.cpp
 
 
71
 
72
+ def load_model(self, n_threads: int = 2, n_ctx: int = 2048, n_gpu_layers: int = 0):
73
+ """
74
+ Load the model and cache it in Streamlit's session state.
75
+ If the model file does not exist, it will be downloaded to the models/ directory.
 
 
76
 
77
+ Args:
78
+ n_threads (int): Number of CPU threads to use.
79
+ n_ctx (int): Context window size.
80
+ n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
81
 
82
+ Returns:
83
+ self: Loaded model instance.
84
+ """
85
+ model_info = self.AVAILABLE_MODELS.get(self.name)
86
+ if not model_info:
87
+ raise ValueError(f"Model {self.name} is not available.")
88
+
89
+ model_path = model_info["model_path"]
90
+ # If the model file doesn't exist, download it.
91
+ if not os.path.exists(model_path):
92
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
93
+ repo_id = model_info.get("repo_id")
94
+ filename = model_info.get("filename")
95
+ if repo_id is None or filename is None:
96
+ raise ValueError(
97
+ "Repository ID or filename is missing for model download."
98
+ )
99
+ with st.spinner(f"Downloading {self.name}..."):
100
+ downloaded_path = hf_hub_download(
101
+ repo_id=repo_id,
102
+ filename=filename,
103
+ local_dir=os.path.dirname(model_path),
104
+ local_dir_use_symlinks=False,
105
+ )
106
+ # If the downloaded file is not at the expected location, rename it.
107
+ if downloaded_path != model_path:
108
+ os.rename(downloaded_path, model_path)
109
 
110
+ model_key = f"gemma_model_{self.name}"
111
+ if model_key not in st.session_state:
112
+ with st.spinner(f"Loading {self.name}..."):
113
+ st.session_state[model_key] = Llama(
114
+ model_path=model_path,
115
+ n_threads=n_threads,
116
+ n_ctx=n_ctx,
117
+ n_gpu_layers=n_gpu_layers,
118
+ )
119
  self.model = st.session_state[model_key]
 
 
120
  return self
121
 
122
  def generate_response(
123
  self,
124
  prompt: str,
125
+ max_tokens: int = 512,
126
  temperature: float = 0.7,
 
127
  **kwargs,
128
  ) -> str:
129
  """
130
+ Generate a response using the llama.cpp model.
131
 
132
  Args:
133
+ prompt (str): Input prompt text.
134
+ max_tokens (int): Maximum number of tokens to generate.
135
+ temperature (float): Sampling temperature (higher = more creative).
136
+ **kwargs: Additional generation parameters.
 
137
 
138
  Returns:
139
+ str: Generated response text.
140
  """
141
+ if self.model is None:
142
  self.load_model()
143
 
144
+ # Call the llama.cpp model with the provided parameters.
145
+ response = self.model(
146
+ prompt,
147
+ max_tokens=max_tokens,
148
+ temperature=temperature,
 
149
  **kwargs,
150
+ )
151
+ generated_text = response["choices"][0]["text"]
152
+ return generated_text.strip()
 
 
 
 
 
 
 
 
 
 
153
 
154
  def get_model_info(self) -> Dict:
155
+ """
156
+ Return information about the model.
157
+
158
+ Returns:
159
+ Dict: A dictionary containing the model name and load status.
160
+ """
161
  return {
162
  "name": self.name,
163
  "loaded": self.model is not None,
 
164
  }
165
 
166
  def get_model_name(self) -> str:
167
+ """
168
+ Return the name of the model.
169
+
170
+ Returns:
171
+ str: Model name.
172
+ """
173
  return self.name
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  streamlit>=1.30.0
2
  transformers>=4.36.0
3
- torch>=2.1.0
4
  huggingface-hub>=0.19.0
5
  accelerate>=0.25.0
6
  safetensors>=0.4.0
7
  sentencepiece>=0.1.99
8
  protobuf>=4.25.0
9
- intel_extension_for_pytorch
10
- https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl
 
1
  streamlit>=1.30.0
2
  transformers>=4.36.0
 
3
  huggingface-hub>=0.19.0
4
  accelerate>=0.25.0
5
  safetensors>=0.4.0
6
  sentencepiece>=0.1.99
7
  protobuf>=4.25.0
8
+ llama-cpp-python