aadya1762 commited on
Commit
5160420
·
1 Parent(s): 06a0392

increase cache limit -> fewer recompilations by pytorch

Browse files
Files changed (5) hide show
  1. app.py +39 -19
  2. gemmademo/__init__.py +6 -1
  3. gemmademo/_chat.py +16 -9
  4. gemmademo/_model.py +48 -42
  5. gemmademo/_utils.py +1 -0
app.py CHANGED
@@ -5,11 +5,17 @@
5
  # Add a button to clear the chat history.
6
 
7
  import streamlit as st
8
- from gemmademo import HuggingFaceGemmaModel, StreamlitChat, PromptManager, huggingface_login
 
 
 
 
 
9
  import os
10
  import sys
11
  import subprocess
12
 
 
13
  def main():
14
  # Page configuration
15
  st.set_page_config(page_title="Gemma Chat Demo", layout="wide")
@@ -25,7 +31,7 @@ def main():
25
  # Sidebar for login and configuration
26
  with st.sidebar:
27
  st.title("Gemma Chat Configuration")
28
-
29
  # Login section
30
  st.subheader("Login")
31
  if not st.session_state.authenticated:
@@ -42,31 +48,31 @@ def main():
42
  if st.button("Logout"):
43
  st.session_state.authenticated = False
44
  st.rerun()
45
-
46
  # Model selection
47
  st.subheader("Model Selection")
48
  model_options = list(HuggingFaceGemmaModel.AVAILABLE_MODELS.keys())
49
  selected_model = st.selectbox(
50
  "Select Gemma Model",
51
  model_options,
52
- index=model_options.index(st.session_state.selected_model)
53
  )
54
  if selected_model != st.session_state.selected_model:
55
  st.session_state.selected_model = selected_model
56
  st.rerun()
57
-
58
  # Task selection
59
  st.subheader("Task Selection")
60
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
61
  selected_task = st.selectbox(
62
  "Select Task",
63
  task_options,
64
- index=task_options.index(st.session_state.selected_task)
65
  )
66
  if selected_task != st.session_state.selected_task:
67
  st.session_state.selected_task = selected_task
68
  st.rerun()
69
-
70
  # Clear chat history button
71
  if st.button("Clear Chat History"):
72
  if "chat_instance" in st.session_state:
@@ -76,37 +82,51 @@ def main():
76
  # Main content area
77
  if st.session_state.authenticated:
78
  # Initialize model with the selected configuration
79
- model_name = HuggingFaceGemmaModel.AVAILABLE_MODELS[st.session_state.selected_model]["name"]
 
 
80
  model = HuggingFaceGemmaModel(name=model_name)
81
-
82
  # Load model (will use cached version if available)
83
  with st.spinner(f"Loading {model_name}..."):
84
  model.load_model(device_map="auto")
85
-
86
  # Initialize prompt manager with selected task
87
  prompt_manager = PromptManager(task=st.session_state.selected_task)
88
-
89
  # Initialize chat interface
90
  chat = StreamlitChat(model=model, prompt_manager=prompt_manager)
91
  st.session_state.chat_instance = chat
92
-
93
  # Run the chat interface
94
  chat.run()
95
  else:
96
- st.info("Please login with your Hugging Face token in the sidebar to start chatting.")
 
 
 
97
 
98
  if __name__ == "__main__":
99
  # Check if the script is being run directly with Python
100
  # If so, launch Streamlit programmatically
101
- if not os.environ.get('STREAMLIT_RUN_APP'):
102
- os.environ['STREAMLIT_RUN_APP'] = '1'
103
  # Get the current script path
104
  script_path = os.path.abspath(__file__)
105
  # Launch streamlit run with port 7860 and headless mode
106
- cmd = [sys.executable, "-m", "streamlit", "run", script_path,
107
- "--server.port", "7860",
108
- "--server.address", "0.0.0.0",
109
- "--server.headless", "true"]
 
 
 
 
 
 
 
 
 
110
  subprocess.run(cmd)
111
  else:
112
  # Normal Streamlit execution
 
5
  # Add a button to clear the chat history.
6
 
7
  import streamlit as st
8
+ from gemmademo import (
9
+ HuggingFaceGemmaModel,
10
+ StreamlitChat,
11
+ PromptManager,
12
+ huggingface_login,
13
+ )
14
  import os
15
  import sys
16
  import subprocess
17
 
18
+
19
  def main():
20
  # Page configuration
21
  st.set_page_config(page_title="Gemma Chat Demo", layout="wide")
 
31
  # Sidebar for login and configuration
32
  with st.sidebar:
33
  st.title("Gemma Chat Configuration")
34
+
35
  # Login section
36
  st.subheader("Login")
37
  if not st.session_state.authenticated:
 
48
  if st.button("Logout"):
49
  st.session_state.authenticated = False
50
  st.rerun()
51
+
52
  # Model selection
53
  st.subheader("Model Selection")
54
  model_options = list(HuggingFaceGemmaModel.AVAILABLE_MODELS.keys())
55
  selected_model = st.selectbox(
56
  "Select Gemma Model",
57
  model_options,
58
+ index=model_options.index(st.session_state.selected_model),
59
  )
60
  if selected_model != st.session_state.selected_model:
61
  st.session_state.selected_model = selected_model
62
  st.rerun()
63
+
64
  # Task selection
65
  st.subheader("Task Selection")
66
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
67
  selected_task = st.selectbox(
68
  "Select Task",
69
  task_options,
70
+ index=task_options.index(st.session_state.selected_task),
71
  )
72
  if selected_task != st.session_state.selected_task:
73
  st.session_state.selected_task = selected_task
74
  st.rerun()
75
+
76
  # Clear chat history button
77
  if st.button("Clear Chat History"):
78
  if "chat_instance" in st.session_state:
 
82
  # Main content area
83
  if st.session_state.authenticated:
84
  # Initialize model with the selected configuration
85
+ model_name = HuggingFaceGemmaModel.AVAILABLE_MODELS[
86
+ st.session_state.selected_model
87
+ ]["name"]
88
  model = HuggingFaceGemmaModel(name=model_name)
89
+
90
  # Load model (will use cached version if available)
91
  with st.spinner(f"Loading {model_name}..."):
92
  model.load_model(device_map="auto")
93
+
94
  # Initialize prompt manager with selected task
95
  prompt_manager = PromptManager(task=st.session_state.selected_task)
96
+
97
  # Initialize chat interface
98
  chat = StreamlitChat(model=model, prompt_manager=prompt_manager)
99
  st.session_state.chat_instance = chat
100
+
101
  # Run the chat interface
102
  chat.run()
103
  else:
104
+ st.info(
105
+ "Please login with your Hugging Face token in the sidebar to start chatting."
106
+ )
107
+
108
 
109
  if __name__ == "__main__":
110
  # Check if the script is being run directly with Python
111
  # If so, launch Streamlit programmatically
112
+ if not os.environ.get("STREAMLIT_RUN_APP"):
113
+ os.environ["STREAMLIT_RUN_APP"] = "1"
114
  # Get the current script path
115
  script_path = os.path.abspath(__file__)
116
  # Launch streamlit run with port 7860 and headless mode
117
+ cmd = [
118
+ sys.executable,
119
+ "-m",
120
+ "streamlit",
121
+ "run",
122
+ script_path,
123
+ "--server.port",
124
+ "7860",
125
+ "--server.address",
126
+ "0.0.0.0",
127
+ "--server.headless",
128
+ "true",
129
+ ]
130
  subprocess.run(cmd)
131
  else:
132
  # Normal Streamlit execution
gemmademo/__init__.py CHANGED
@@ -3,4 +3,9 @@ from ._model import HuggingFaceGemmaModel
3
  from ._prompts import PromptManager
4
  from ._utils import huggingface_login
5
 
6
- __all__ = ["StreamlitChat", "HuggingFaceGemmaModel", "PromptManager", "huggingface_login"]
 
 
 
 
 
 
3
  from ._prompts import PromptManager
4
  from ._utils import huggingface_login
5
 
6
+ __all__ = [
7
+ "StreamlitChat",
8
+ "HuggingFaceGemmaModel",
9
+ "PromptManager",
10
+ "huggingface_login",
11
+ ]
gemmademo/_chat.py CHANGED
@@ -2,23 +2,25 @@ import streamlit as st
2
  from ._model import HuggingFaceGemmaModel
3
  from ._prompts import PromptManager
4
 
 
5
  class StreamlitChat:
6
  """
7
  A class that handles the chat interface for the Gemma model.
8
-
9
  Features:
10
  ✅ A Streamlit-based chatbot UI.
11
  ✅ Maintains chat history across reruns.
12
  ✅ Uses Gemma (Hugging Face) model for generating responses.
13
  ✅ Formats user inputs before sending them to the model.
14
  """
 
15
  def __init__(self, model: HuggingFaceGemmaModel, prompt_manager: PromptManager):
16
  self.model = model
17
  self.prompt_manager = prompt_manager
18
 
19
  def run(self):
20
  self._chat()
21
-
22
  def _chat(self):
23
  st.title("Using model : " + self.model.get_model_name())
24
  self._build_states()
@@ -27,25 +29,30 @@ class StreamlitChat:
27
  for message in st.session_state.messages:
28
  with st.chat_message(message["role"]):
29
  st.markdown(message["content"])
30
-
31
  # React to user input
32
  if prompt := st.chat_input("What is up?"):
33
- prompt = prompt.replace("\n", " \n") # Only double spaced backslash is rendered in streamlit for newlines.
 
 
34
  with st.chat_message("User"):
35
  st.markdown(prompt)
36
  st.session_state.messages.append({"role": "User", "content": prompt})
37
-
38
- prompt = self.prompt_manager.get_prompt(user_input=st.session_state.messages[-1]["content"])
39
- response = self.model.generate_response(prompt).replace("\n", " \n") # Only double spaced backslash is rendered in streamlit for newlines.
 
 
 
 
40
  with st.chat_message("Gemma"):
41
  st.markdown(response)
42
  st.session_state.messages.append({"role": "Gemma", "content": response})
43
 
44
-
45
  def _build_states(self):
46
  # Initialize chat history
47
  if "messages" not in st.session_state:
48
  st.session_state.messages = []
49
-
50
  def clear_history(self):
51
  st.session_state.messages = []
 
2
  from ._model import HuggingFaceGemmaModel
3
  from ._prompts import PromptManager
4
 
5
+
6
  class StreamlitChat:
7
  """
8
  A class that handles the chat interface for the Gemma model.
9
+
10
  Features:
11
  ✅ A Streamlit-based chatbot UI.
12
  ✅ Maintains chat history across reruns.
13
  ✅ Uses Gemma (Hugging Face) model for generating responses.
14
  ✅ Formats user inputs before sending them to the model.
15
  """
16
+
17
  def __init__(self, model: HuggingFaceGemmaModel, prompt_manager: PromptManager):
18
  self.model = model
19
  self.prompt_manager = prompt_manager
20
 
21
  def run(self):
22
  self._chat()
23
+
24
  def _chat(self):
25
  st.title("Using model : " + self.model.get_model_name())
26
  self._build_states()
 
29
  for message in st.session_state.messages:
30
  with st.chat_message(message["role"]):
31
  st.markdown(message["content"])
32
+
33
  # React to user input
34
  if prompt := st.chat_input("What is up?"):
35
+ prompt = prompt.replace(
36
+ "\n", " \n"
37
+ ) # Only double spaced backslash is rendered in streamlit for newlines.
38
  with st.chat_message("User"):
39
  st.markdown(prompt)
40
  st.session_state.messages.append({"role": "User", "content": prompt})
41
+
42
+ prompt = self.prompt_manager.get_prompt(
43
+ user_input=st.session_state.messages[-1]["content"]
44
+ )
45
+ response = self.model.generate_response(prompt).replace(
46
+ "\n", " \n"
47
+ ) # Only double spaced backslash is rendered in streamlit for newlines.
48
  with st.chat_message("Gemma"):
49
  st.markdown(response)
50
  st.session_state.messages.append({"role": "Gemma", "content": response})
51
 
 
52
  def _build_states(self):
53
  # Initialize chat history
54
  if "messages" not in st.session_state:
55
  st.session_state.messages = []
56
+
57
  def clear_history(self):
58
  st.session_state.messages = []
gemmademo/_model.py CHANGED
@@ -3,17 +3,24 @@ import torch
3
  from typing import Dict, Optional
4
  import streamlit as st
5
 
6
- torch.classes.__path__ = [] # add this line to manually set it to empty.
 
 
 
7
 
8
  def load_model(name: str, device_map: str = "cpu"):
9
  """
10
  Model loading function that loads the model without caching
11
  """
12
  import torch._dynamo
 
 
 
 
13
  torch._dynamo.config.suppress_errors = True
14
 
15
  tokenizer = AutoTokenizer.from_pretrained(name)
16
-
17
  model = AutoModelForCausalLM.from_pretrained(
18
  name,
19
  torch_dtype=torch.bfloat16,
@@ -24,7 +31,7 @@ def load_model(name: str, device_map: str = "cpu"):
24
  use_cache=True,
25
  load_in_8bit=True,
26
  )
27
-
28
  pipe = pipeline(
29
  "text-generation",
30
  model=model,
@@ -36,11 +43,12 @@ def load_model(name: str, device_map: str = "cpu"):
36
  max_new_tokens=512,
37
  pad_token_id=tokenizer.eos_token_id,
38
  eos_token_id=tokenizer.eos_token_id,
39
- return_full_text=False
40
  )
41
-
42
  return tokenizer, model, pipe
43
 
 
44
  class HuggingFaceGemmaModel:
45
  """
46
  A class for the Hugging Face Gemma model. Handles model selection, loading, and inference.
@@ -49,7 +57,7 @@ class HuggingFaceGemmaModel:
49
  Example
50
  -------
51
  Select Gemma 2B, 7B etc.
52
-
53
  Additional Information:
54
  ----------------------
55
  Complete Information: https://huggingface.co/google/gemma-2b
@@ -60,40 +68,40 @@ class HuggingFaceGemmaModel:
60
  - google/gemma-7b (7B parameters, base)
61
  - google/gemma-7b-it (7B parameters, instruction-tuned)
62
  """
63
-
64
  AVAILABLE_MODELS: Dict[str, Dict] = {
65
  "gemma-2b": {
66
  "name": "google/gemma-2b",
67
  "description": "2B parameters, base model",
68
- "type": "base"
69
  },
70
  "gemma-2b-it": {
71
  "name": "google/gemma-2b-it",
72
  "description": "2B parameters, instruction-tuned",
73
- "type": "instruct"
74
  },
75
  "gemma-7b": {
76
  "name": "google/gemma-7b",
77
  "description": "7B parameters, base model",
78
- "type": "base"
79
  },
80
  "gemma-7b-it": {
81
  "name": "google/gemma-7b-it",
82
  "description": "7B parameters, instruction-tuned",
83
- "type": "instruct"
84
- }
85
  }
86
-
87
  def __init__(self, name: str = "google/gemma-2b"):
88
  self.name = name
89
  self.model = None
90
  self.tokenizer = None
91
  self.pipeline = None
92
-
93
  def load_model(self, device_map: str = "cpu"):
94
  """
95
  Load the model using session state
96
-
97
  Args:
98
  device_map: Device mapping strategy (should be "cpu" for CPU-only inference)
99
  """
@@ -101,85 +109,83 @@ class HuggingFaceGemmaModel:
101
  model_key = f"gemma_model_{self.name}"
102
  tokenizer_key = f"gemma_tokenizer_{self.name}"
103
  pipeline_key = f"gemma_pipeline_{self.name}"
104
-
105
  # Check if model is already loaded in session state
106
- if (model_key not in st.session_state or
107
- tokenizer_key not in st.session_state or
108
- pipeline_key not in st.session_state):
109
-
 
 
110
  # Show loading indicator
111
  with st.spinner(f"Loading {self.name}..."):
112
  tokenizer, model, pipe = load_model(self.name, device_map)
113
-
114
  # Store in session state
115
  st.session_state[tokenizer_key] = tokenizer
116
  st.session_state[model_key] = model
117
  st.session_state[pipeline_key] = pipe
118
-
119
  # Get model from session state
120
  self.tokenizer = st.session_state[tokenizer_key]
121
  self.model = st.session_state[model_key]
122
  self.pipeline = st.session_state[pipeline_key]
123
-
124
  return self
125
-
126
  def generate_response(
127
- self,
128
- prompt: str,
129
  max_length: int = 512,
130
  temperature: float = 0.7,
131
  num_return_sequences: int = 1,
132
- **kwargs
133
  ) -> str:
134
  """
135
  Generate a response using the text generation pipeline
136
-
137
  Args:
138
  prompt: Input text
139
  max_length: Maximum number of new tokens to generate
140
  temperature: Sampling temperature (higher = more creative)
141
  num_return_sequences: Number of responses to generate
142
  **kwargs: Additional generation parameters for the pipeline
143
-
144
  Returns:
145
  str: Generated response
146
  """
147
  if not self.pipeline:
148
  self.load_model()
149
-
150
  # Update generation config with any provided kwargs
151
  generation_config = {
152
  "max_new_tokens": max_length,
153
  "temperature": temperature,
154
  "num_return_sequences": num_return_sequences,
155
  "do_sample": True,
156
- **kwargs
157
  }
158
-
159
  # Generate response using the pipeline
160
- outputs = self.pipeline(
161
- prompt,
162
- **generation_config
163
- )
164
-
165
  # Extract the generated text
166
  if num_return_sequences == 1:
167
  response = outputs[0]["generated_text"]
168
  else:
169
  # Join multiple sequences if requested
170
  response = "\n---\n".join(output["generated_text"] for output in outputs)
171
-
172
  return response.strip()
173
-
174
  def get_model_info(self) -> Dict:
175
  """Return information about the model"""
176
  return {
177
  "name": self.name,
178
  "loaded": self.model is not None,
179
- "pipeline_ready": self.pipeline is not None
180
  }
181
-
182
  def get_model_name(self) -> str:
183
  """Return the name of the model"""
184
  return self.name
185
-
 
3
  from typing import Dict, Optional
4
  import streamlit as st
5
 
6
+ torch.classes.__path__ = (
7
+ []
8
+ ) # add this line to manually set it to empty. If not done, this throws a warning.
9
+
10
 
11
  def load_model(name: str, device_map: str = "cpu"):
12
  """
13
  Model loading function that loads the model without caching
14
  """
15
  import torch._dynamo
16
+
17
+ torch._dynamo.config.suppress_errors = True # Already in your code
18
+ torch._dynamo.config.cache_size_limit = 64 # Increase cache limit
19
+ torch._dynamo.config.force_inference_mode = True # Reduce recompilations
20
  torch._dynamo.config.suppress_errors = True
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(name)
23
+
24
  model = AutoModelForCausalLM.from_pretrained(
25
  name,
26
  torch_dtype=torch.bfloat16,
 
31
  use_cache=True,
32
  load_in_8bit=True,
33
  )
34
+
35
  pipe = pipeline(
36
  "text-generation",
37
  model=model,
 
43
  max_new_tokens=512,
44
  pad_token_id=tokenizer.eos_token_id,
45
  eos_token_id=tokenizer.eos_token_id,
46
+ return_full_text=False,
47
  )
48
+
49
  return tokenizer, model, pipe
50
 
51
+
52
  class HuggingFaceGemmaModel:
53
  """
54
  A class for the Hugging Face Gemma model. Handles model selection, loading, and inference.
 
57
  Example
58
  -------
59
  Select Gemma 2B, 7B etc.
60
+
61
  Additional Information:
62
  ----------------------
63
  Complete Information: https://huggingface.co/google/gemma-2b
 
68
  - google/gemma-7b (7B parameters, base)
69
  - google/gemma-7b-it (7B parameters, instruction-tuned)
70
  """
71
+
72
  AVAILABLE_MODELS: Dict[str, Dict] = {
73
  "gemma-2b": {
74
  "name": "google/gemma-2b",
75
  "description": "2B parameters, base model",
76
+ "type": "base",
77
  },
78
  "gemma-2b-it": {
79
  "name": "google/gemma-2b-it",
80
  "description": "2B parameters, instruction-tuned",
81
+ "type": "instruct",
82
  },
83
  "gemma-7b": {
84
  "name": "google/gemma-7b",
85
  "description": "7B parameters, base model",
86
+ "type": "base",
87
  },
88
  "gemma-7b-it": {
89
  "name": "google/gemma-7b-it",
90
  "description": "7B parameters, instruction-tuned",
91
+ "type": "instruct",
92
+ },
93
  }
94
+
95
  def __init__(self, name: str = "google/gemma-2b"):
96
  self.name = name
97
  self.model = None
98
  self.tokenizer = None
99
  self.pipeline = None
100
+
101
  def load_model(self, device_map: str = "cpu"):
102
  """
103
  Load the model using session state
104
+
105
  Args:
106
  device_map: Device mapping strategy (should be "cpu" for CPU-only inference)
107
  """
 
109
  model_key = f"gemma_model_{self.name}"
110
  tokenizer_key = f"gemma_tokenizer_{self.name}"
111
  pipeline_key = f"gemma_pipeline_{self.name}"
112
+
113
  # Check if model is already loaded in session state
114
+ if (
115
+ model_key not in st.session_state
116
+ or tokenizer_key not in st.session_state
117
+ or pipeline_key not in st.session_state
118
+ ):
119
+
120
  # Show loading indicator
121
  with st.spinner(f"Loading {self.name}..."):
122
  tokenizer, model, pipe = load_model(self.name, device_map)
123
+
124
  # Store in session state
125
  st.session_state[tokenizer_key] = tokenizer
126
  st.session_state[model_key] = model
127
  st.session_state[pipeline_key] = pipe
128
+
129
  # Get model from session state
130
  self.tokenizer = st.session_state[tokenizer_key]
131
  self.model = st.session_state[model_key]
132
  self.pipeline = st.session_state[pipeline_key]
133
+
134
  return self
135
+
136
  def generate_response(
137
+ self,
138
+ prompt: str,
139
  max_length: int = 512,
140
  temperature: float = 0.7,
141
  num_return_sequences: int = 1,
142
+ **kwargs,
143
  ) -> str:
144
  """
145
  Generate a response using the text generation pipeline
146
+
147
  Args:
148
  prompt: Input text
149
  max_length: Maximum number of new tokens to generate
150
  temperature: Sampling temperature (higher = more creative)
151
  num_return_sequences: Number of responses to generate
152
  **kwargs: Additional generation parameters for the pipeline
153
+
154
  Returns:
155
  str: Generated response
156
  """
157
  if not self.pipeline:
158
  self.load_model()
159
+
160
  # Update generation config with any provided kwargs
161
  generation_config = {
162
  "max_new_tokens": max_length,
163
  "temperature": temperature,
164
  "num_return_sequences": num_return_sequences,
165
  "do_sample": True,
166
+ **kwargs,
167
  }
168
+
169
  # Generate response using the pipeline
170
+ outputs = self.pipeline(prompt, **generation_config)
171
+
 
 
 
172
  # Extract the generated text
173
  if num_return_sequences == 1:
174
  response = outputs[0]["generated_text"]
175
  else:
176
  # Join multiple sequences if requested
177
  response = "\n---\n".join(output["generated_text"] for output in outputs)
178
+
179
  return response.strip()
180
+
181
  def get_model_info(self) -> Dict:
182
  """Return information about the model"""
183
  return {
184
  "name": self.name,
185
  "loaded": self.model is not None,
186
+ "pipeline_ready": self.pipeline is not None,
187
  }
188
+
189
  def get_model_name(self) -> str:
190
  """Return the name of the model"""
191
  return self.name
 
gemmademo/_utils.py CHANGED
@@ -3,4 +3,5 @@ def huggingface_login(token: str):
3
  Login to Hugging Face using the token
4
  """
5
  from huggingface_hub import login
 
6
  login(token=token)
 
3
  Login to Hugging Face using the token
4
  """
5
  from huggingface_hub import login
6
+
7
  login(token=token)