aadya1762 commited on
Commit
b4ecb60
·
0 Parent(s):

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Interface all the functions from gemmademo.
2
+ # Implement login functionality in the side bar.
3
+ # Implement a task selector in the side bar.
4
+ # Interface all the functions from gemmademo.
5
+ # Add a button to clear the chat history.
6
+
7
+ import streamlit as st
8
+ from gemmademo import HuggingFaceGemmaModel, StreamlitChat, PromptManager, huggingface_login
9
+ import os
10
+ import sys
11
+
12
+ def main():
13
+ # Page configuration
14
+ st.set_page_config(page_title="Gemma Chat Demo", layout="wide")
15
+
16
+ # Initialize session state variables
17
+ if "authenticated" not in st.session_state:
18
+ st.session_state.authenticated = False
19
+ if "selected_model" not in st.session_state:
20
+ st.session_state.selected_model = "gemma-2b-it"
21
+ if "selected_task" not in st.session_state:
22
+ st.session_state.selected_task = "Question Answering"
23
+
24
+ # Sidebar for login and configuration
25
+ with st.sidebar:
26
+ st.title("Gemma Chat Configuration")
27
+
28
+ # Login section
29
+ st.subheader("Login")
30
+ if not st.session_state.authenticated:
31
+ hf_token = st.text_input("Hugging Face Token", type="password")
32
+ if st.button("Login"):
33
+ try:
34
+ huggingface_login(hf_token)
35
+ st.session_state.authenticated = True
36
+ st.success("Successfully logged in!")
37
+ except Exception as e:
38
+ st.error(f"Login failed: {str(e)}")
39
+ else:
40
+ st.success("Logged in to Hugging Face")
41
+ if st.button("Logout"):
42
+ st.session_state.authenticated = False
43
+ st.experimental_rerun()
44
+
45
+ # Model selection
46
+ st.subheader("Model Selection")
47
+ model_options = list(HuggingFaceGemmaModel.AVAILABLE_MODELS.keys())
48
+ selected_model = st.selectbox(
49
+ "Select Gemma Model",
50
+ model_options,
51
+ index=model_options.index(st.session_state.selected_model)
52
+ )
53
+ if selected_model != st.session_state.selected_model:
54
+ st.session_state.selected_model = selected_model
55
+ st.experimental_rerun()
56
+
57
+ # Task selection
58
+ st.subheader("Task Selection")
59
+ task_options = ["Question Answering", "Text Generation", "Code Completion"]
60
+ selected_task = st.selectbox(
61
+ "Select Task",
62
+ task_options,
63
+ index=task_options.index(st.session_state.selected_task)
64
+ )
65
+ if selected_task != st.session_state.selected_task:
66
+ st.session_state.selected_task = selected_task
67
+ st.experimental_rerun()
68
+
69
+ # Clear chat history button
70
+ if st.button("Clear Chat History"):
71
+ if "chat_instance" in st.session_state:
72
+ st.session_state.chat_instance.clear_history()
73
+ st.experimental_rerun()
74
+
75
+ # Main content area
76
+ if st.session_state.authenticated:
77
+ # Initialize model with the selected configuration
78
+ model_name = HuggingFaceGemmaModel.AVAILABLE_MODELS[st.session_state.selected_model]["name"]
79
+ model = HuggingFaceGemmaModel(name=model_name)
80
+
81
+ # Load model (will use cached version if available)
82
+ with st.spinner(f"Loading {model_name}..."):
83
+ model.load_model(device_map="auto")
84
+
85
+ # Initialize prompt manager with selected task
86
+ prompt_manager = PromptManager(task=st.session_state.selected_task)
87
+
88
+ # Initialize chat interface
89
+ chat = StreamlitChat(model=model, prompt_manager=prompt_manager)
90
+ st.session_state.chat_instance = chat
91
+
92
+ # Run the chat interface
93
+ chat.run()
94
+ else:
95
+ st.info("Please login with your Hugging Face token in the sidebar to start chatting.")
96
+
97
+ if __name__ == "__main__":
98
+ if len(sys.argv) == 1:
99
+ os.system(f"streamlit run {__file__}")
100
+ else:
101
+ main()
gemmademo/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ._chat import StreamlitChat
2
+ from ._model import HuggingFaceGemmaModel
3
+ from ._prompts import PromptManager
4
+ from ._utils import huggingface_login
5
+
6
+ __all__ = ["StreamlitChat", "HuggingFaceGemmaModel", "PromptManager", "huggingface_login"]
gemmademo/_chat.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ._model import HuggingFaceGemmaModel
3
+ from ._prompts import PromptManager
4
+
5
+ class StreamlitChat:
6
+ """
7
+ A class that handles the chat interface for the Gemma model.
8
+
9
+ Features:
10
+ ✅ A Streamlit-based chatbot UI.
11
+ ✅ Maintains chat history across reruns.
12
+ ✅ Uses Gemma (Hugging Face) model for generating responses.
13
+ ✅ Formats user inputs before sending them to the model.
14
+ """
15
+ def __init__(self, model: HuggingFaceGemmaModel, prompt_manager: PromptManager):
16
+ self.model = model
17
+ self.prompt_manager = prompt_manager
18
+
19
+ def run(self):
20
+ self._chat()
21
+
22
+ def _chat(self):
23
+ st.title("Using model : " + self.model.get_model_name())
24
+ self._build_states()
25
+
26
+ # Display chat messages from history on app rerun
27
+ for message in st.session_state.messages:
28
+ with st.chat_message(message["role"]):
29
+ st.markdown(message["content"])
30
+
31
+ # React to user input
32
+ if prompt := st.chat_input("What is up?"):
33
+ with st.chat_message("User"):
34
+ st.markdown(prompt)
35
+ st.session_state.messages.append({"role": "User", "content": prompt})
36
+
37
+ prompt = self.prompt_manager.get_prompt(user_input=st.session_state.messages[-1]["content"])
38
+ response = self.model.generate_response(prompt)
39
+ with st.chat_message("Gemma"):
40
+ st.markdown(response)
41
+ st.session_state.messages.append({"role": "Gemma", "content": response})
42
+
43
+
44
+ def _build_states(self):
45
+ # Initialize chat history
46
+ if "messages" not in st.session_state:
47
+ st.session_state.messages = []
48
+
49
+ def clear_history(self):
50
+ st.session_state.messages = []
gemmademo/_model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ import torch
3
+ from typing import Dict, Optional
4
+ import streamlit as st
5
+
6
+ torch.classes.__path__ = [] # add this line to manually set it to empty.
7
+
8
+ def load_model(name: str, device_map: str = "cpu"):
9
+ """
10
+ Model loading function that loads the model without caching
11
+ """
12
+ tokenizer = AutoTokenizer.from_pretrained(name)
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ name,
16
+ torch_dtype=torch.bfloat16,
17
+ low_cpu_mem_usage=True,
18
+ device_map=device_map,
19
+ use_safetensors=True,
20
+ use_flash_attention_2=False,
21
+ use_cache=True,
22
+ load_in_8bit=True,
23
+ )
24
+
25
+ model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
26
+
27
+ pipe = pipeline(
28
+ "text-generation",
29
+ model=model,
30
+ tokenizer=tokenizer,
31
+ device_map=device_map,
32
+ torch_dtype=torch.bfloat16,
33
+ do_sample=True,
34
+ temperature=0.7,
35
+ max_new_tokens=512,
36
+ pad_token_id=tokenizer.eos_token_id,
37
+ eos_token_id=tokenizer.eos_token_id,
38
+ return_full_text=False
39
+ )
40
+
41
+ return tokenizer, model, pipe
42
+
43
+ class HuggingFaceGemmaModel:
44
+ """
45
+ A class for the Hugging Face Gemma model. Handles model selection, loading, and inference.
46
+ Uses transformers pipeline for better text generation and formatting.
47
+
48
+ Example
49
+ -------
50
+ Select Gemma 2B, 7B etc.
51
+
52
+ Additional Information:
53
+ ----------------------
54
+ Complete Information: https://huggingface.co/google/gemma-2b
55
+
56
+ Available Models:
57
+ - google/gemma-2b (2B parameters, base)
58
+ - google/gemma-2b-it (2B parameters, instruction-tuned)
59
+ - google/gemma-7b (7B parameters, base)
60
+ - google/gemma-7b-it (7B parameters, instruction-tuned)
61
+ """
62
+
63
+ AVAILABLE_MODELS: Dict[str, Dict] = {
64
+ "gemma-2b": {
65
+ "name": "google/gemma-2b",
66
+ "description": "2B parameters, base model",
67
+ "type": "base"
68
+ },
69
+ "gemma-2b-it": {
70
+ "name": "google/gemma-2b-it",
71
+ "description": "2B parameters, instruction-tuned",
72
+ "type": "instruct"
73
+ },
74
+ "gemma-7b": {
75
+ "name": "google/gemma-7b",
76
+ "description": "7B parameters, base model",
77
+ "type": "base"
78
+ },
79
+ "gemma-7b-it": {
80
+ "name": "google/gemma-7b-it",
81
+ "description": "7B parameters, instruction-tuned",
82
+ "type": "instruct"
83
+ }
84
+ }
85
+
86
+ def __init__(self, name: str = "google/gemma-2b"):
87
+ self.name = name
88
+ self.model = None
89
+ self.tokenizer = None
90
+ self.pipeline = None
91
+
92
+ def load_model(self, device_map: str = "cpu"):
93
+ """
94
+ Load the model using session state
95
+
96
+ Args:
97
+ device_map: Device mapping strategy (should be "cpu" for CPU-only inference)
98
+ """
99
+ # Create a unique key for this model in session state
100
+ model_key = f"gemma_model_{self.name}"
101
+ tokenizer_key = f"gemma_tokenizer_{self.name}"
102
+ pipeline_key = f"gemma_pipeline_{self.name}"
103
+
104
+ # Check if model is already loaded in session state
105
+ if (model_key not in st.session_state or
106
+ tokenizer_key not in st.session_state or
107
+ pipeline_key not in st.session_state):
108
+
109
+ # Show loading indicator
110
+ with st.spinner(f"Loading {self.name}..."):
111
+ tokenizer, model, pipe = load_model(self.name, device_map)
112
+
113
+ # Store in session state
114
+ st.session_state[tokenizer_key] = tokenizer
115
+ st.session_state[model_key] = model
116
+ st.session_state[pipeline_key] = pipe
117
+
118
+ # Get model from session state
119
+ self.tokenizer = st.session_state[tokenizer_key]
120
+ self.model = st.session_state[model_key]
121
+ self.pipeline = st.session_state[pipeline_key]
122
+
123
+ return self
124
+
125
+ def generate_response(
126
+ self,
127
+ prompt: str,
128
+ max_length: int = 512,
129
+ temperature: float = 0.7,
130
+ num_return_sequences: int = 1,
131
+ **kwargs
132
+ ) -> str:
133
+ """
134
+ Generate a response using the text generation pipeline
135
+
136
+ Args:
137
+ prompt: Input text
138
+ max_length: Maximum number of new tokens to generate
139
+ temperature: Sampling temperature (higher = more creative)
140
+ num_return_sequences: Number of responses to generate
141
+ **kwargs: Additional generation parameters for the pipeline
142
+
143
+ Returns:
144
+ str: Generated response
145
+ """
146
+ if not self.pipeline:
147
+ self.load_model()
148
+
149
+ # Update generation config with any provided kwargs
150
+ generation_config = {
151
+ "max_new_tokens": max_length,
152
+ "temperature": temperature,
153
+ "num_return_sequences": num_return_sequences,
154
+ "do_sample": True,
155
+ **kwargs
156
+ }
157
+
158
+ # Generate response using the pipeline
159
+ outputs = self.pipeline(
160
+ prompt,
161
+ **generation_config
162
+ )
163
+
164
+ # Extract the generated text
165
+ if num_return_sequences == 1:
166
+ response = outputs[0]["generated_text"]
167
+ else:
168
+ # Join multiple sequences if requested
169
+ response = "\n---\n".join(output["generated_text"] for output in outputs)
170
+
171
+ return response.strip()
172
+
173
+ def get_model_info(self) -> Dict:
174
+ """Return information about the model"""
175
+ return {
176
+ "name": self.name,
177
+ "loaded": self.model is not None,
178
+ "pipeline_ready": self.pipeline is not None
179
+ }
180
+
181
+ def get_model_name(self) -> str:
182
+ """Return the name of the model"""
183
+ return self.name
184
+
gemmademo/_prompts.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class PromptManager:
2
+ def __init__(self, task):
3
+ self.task = task
4
+
5
+ def get_prompt(self, user_input):
6
+ if self.task == "Question Answering":
7
+ return self.get_question_answering_prompt(user_input)
8
+ elif self.task == "Text Generation":
9
+ return self.get_text_generation_prompt(user_input)
10
+ elif self.task == "Code Completion":
11
+ return self.get_code_completion_prompt(user_input)
12
+ else:
13
+ raise ValueError(f"Task {self.task} not supported")
14
+
15
+ def get_question_answering_prompt(self, user_input):
16
+ """
17
+ Format user input for question answering task
18
+ """
19
+ prompt = f"""You are a helpful AI assistant. Answer the following question accurately and concisely.
20
+ Question: {user_input}
21
+
22
+ Answer:"""
23
+ return prompt
24
+
25
+ def get_text_generation_prompt(self, user_input):
26
+ """
27
+ Format user input for text generation task
28
+ """
29
+ prompt = f"""Continue the following text in a coherent and engaging way:
30
+ {user_input}
31
+
32
+ Continuation:"""
33
+ return prompt
34
+
35
+ def get_code_completion_prompt(self, user_input):
36
+ """
37
+ Format user input for code completion task
38
+ """
39
+ prompt = f"""Complete the following code snippet with proper syntax and best practices:
40
+ {user_input}
41
+
42
+ Completed code:"""
43
+ return prompt
gemmademo/_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def huggingface_login(token: str):
2
+ """
3
+ Login to Hugging Face using the token
4
+ """
5
+ from huggingface_hub import login
6
+ login(token=token)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.30.0
2
+ transformers>=4.36.0
3
+ torch>=2.1.0
4
+ huggingface-hub>=0.19.0
5
+ accelerate>=0.25.0
6
+ bitsandbytes>=0.41.0
7
+ safetensors>=0.4.0
8
+ sentencepiece>=0.1.99
9
+ protobuf>=4.25.0