aadya1762 commited on
Commit
8cc5c82
·
1 Parent(s): 73503c4

port to gradio

Browse files
Files changed (4) hide show
  1. README.md +1 -2
  2. app.py +21 -95
  3. gemmademo/_chat.py +18 -40
  4. gemmademo/_model.py +26 -47
README.md CHANGED
@@ -3,8 +3,7 @@ title: Gemma Chat Interface
3
  emoji: 🤖
4
  colorFrom: indigo
5
  colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.43.1
8
  python_version: 3.12
9
  app_file: app.py
10
  pinned: false
 
3
  emoji: 🤖
4
  colorFrom: indigo
5
  colorTo: blue
6
+ sdk: gradio
 
7
  python_version: 3.12
8
  app_file: app.py
9
  pinned: false
app.py CHANGED
@@ -1,100 +1,26 @@
1
- # Interface all the functions from gemmademo.
2
- # Implement a task selector in the side bar.
3
- # Add a button to clear the chat history.
4
- import streamlit as st
5
- from gemmademo import (
6
- LlamaCppGemmaModel,
7
- StreamlitChat,
8
- PromptManager,
9
- huggingface_login,
10
- )
11
- import os
12
- import sys
13
- import subprocess
14
-
15
 
16
  def main():
17
- # Page configuration
18
- st.set_page_config(page_title="Gemma Chat Demo", layout="wide")
19
-
20
- # Initialize session state variables
21
- if "selected_model" not in st.session_state:
22
- st.session_state.selected_model = "gemma-2b-it"
23
- if "selected_task" not in st.session_state:
24
- st.session_state.selected_task = "Question Answering"
25
-
26
- # Sidebar for login and configuration
27
- with st.sidebar:
28
- st.title("Gemma Chat Configuration")
29
-
30
- # Login section
31
- huggingface_login(os.getenv("HF_TOKEN"))
32
- # Model selection
33
- st.subheader("Model Selection")
34
- model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
35
- selected_model = st.selectbox(
36
- "Select Gemma Model",
37
- model_options,
38
- index=model_options.index(st.session_state.selected_model),
39
- )
40
- if selected_model != st.session_state.selected_model:
41
- st.session_state.selected_model = selected_model
42
- st.rerun()
43
-
44
- # Task selection
45
- st.subheader("Task Selection")
46
- task_options = ["Question Answering", "Text Generation", "Code Completion"]
47
- selected_task = st.selectbox(
48
- "Select Task",
49
- task_options,
50
- index=task_options.index(st.session_state.selected_task),
51
- )
52
- if selected_task != st.session_state.selected_task:
53
- st.session_state.selected_task = selected_task
54
- st.rerun()
55
-
56
- # Main content area
57
- # Initialize model with the selected configuration
58
- model_name = st.session_state.selected_model
59
- model = LlamaCppGemmaModel(name=model_name)
60
-
61
- # Load model (will use cached version if available)
62
- with st.spinner(f"Loading {model_name}..."):
63
  model.load_model()
64
-
65
- # Initialize prompt manager with selected task
66
- prompt_manager = PromptManager(task=st.session_state.selected_task)
67
-
68
- # Initialize chat interface
69
- chat = StreamlitChat(model=model, prompt_manager=prompt_manager)
70
- st.session_state.chat_instance = chat
71
-
72
- # Run the chat interface
73
- chat.run()
74
-
 
75
 
76
  if __name__ == "__main__":
77
- # Check if the script is being run directly with Python
78
- # If so, launch Streamlit programmatically
79
- if not os.environ.get("STREAMLIT_RUN_APP"):
80
- os.environ["STREAMLIT_RUN_APP"] = "1"
81
- # Get the current script path
82
- script_path = os.path.abspath(__file__)
83
- # Launch streamlit run with port 7860 and headless mode
84
- cmd = [
85
- sys.executable,
86
- "-m",
87
- "streamlit",
88
- "run",
89
- script_path,
90
- "--server.port",
91
- "7860",
92
- "--server.address",
93
- "0.0.0.0",
94
- "--server.headless",
95
- "true",
96
- ]
97
- subprocess.run(cmd)
98
- else:
99
- # Normal Streamlit execution
100
- main()
 
1
+ import gradio as gr
2
+ from gemmademo import LlamaCppGemmaModel, GradioChat, PromptManager
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def main():
5
+ # Model and task selection
6
+ model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
7
+ task_options = ["Question Answering", "Text Generation", "Code Completion"]
8
+
9
+ def update_chat(model_name, task_name):
10
+ model = LlamaCppGemmaModel(name=model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model.load_model()
12
+ prompt_manager = PromptManager(task=task_name)
13
+ chat = GradioChat(model=model, prompt_manager=prompt_manager)
14
+ chat.run()
15
+
16
+ gr.Interface(
17
+ fn=update_chat,
18
+ inputs=[
19
+ gr.Dropdown(choices=model_options, value="gemma-2b-it", label="Select Gemma Model"),
20
+ gr.Dropdown(choices=task_options, value="Question Answering", label="Select Task"),
21
+ ],
22
+ outputs=[],
23
+ ).launch()
24
 
25
  if __name__ == "__main__":
26
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gemmademo/_chat.py CHANGED
@@ -1,17 +1,17 @@
1
- import streamlit as st
2
  from ._model import LlamaCppGemmaModel
3
  from ._prompts import PromptManager
4
 
5
 
6
- class StreamlitChat:
7
  """
8
  A class that handles the chat interface for the Gemma model.
9
 
10
  Features:
11
- - A Streamlit-based chatbot UI.
12
- - Maintains chat history across reruns.
13
- - Uses Gemma (Hugging Face) model for generating responses.
14
- - Formats user inputs before sending them to the model.
15
  """
16
 
17
  def __init__(self, model: LlamaCppGemmaModel, prompt_manager: PromptManager):
@@ -22,37 +22,15 @@ class StreamlitChat:
22
  self._chat()
23
 
24
  def _chat(self):
25
- st.title("Using model : " + self.model.get_model_name())
26
- self._build_states()
27
-
28
- # Display chat messages from history on app rerun
29
- for message in st.session_state.messages:
30
- with st.chat_message(message["role"]):
31
- st.markdown(message["content"])
32
-
33
- # React to user input
34
- if prompt := st.chat_input("What is up?"):
35
- prompt = prompt.replace(
36
- "\n", " \n"
37
- ) # Only double spaced backslash is rendered in streamlit for newlines.
38
- with st.chat_message("User"):
39
- st.markdown(prompt)
40
- st.session_state.messages.append({"role": "User", "content": prompt})
41
-
42
- prompt = self.prompt_manager.get_prompt(
43
- user_input=st.session_state.messages[-1]["content"]
44
- )
45
- response = self.model.generate_response(prompt).replace(
46
- "\n", " \n"
47
- ) # Only double spaced backslash is rendered in streamlit for newlines.
48
- with st.chat_message("Gemma"):
49
- st.markdown(response)
50
- st.session_state.messages.append({"role": "Gemma", "content": response})
51
-
52
- def _build_states(self):
53
- # Initialize chat history
54
- if "messages" not in st.session_state:
55
- st.session_state.messages = []
56
-
57
- def clear_history(self):
58
- st.session_state.messages = []
 
1
+ import gradio as gr
2
  from ._model import LlamaCppGemmaModel
3
  from ._prompts import PromptManager
4
 
5
 
6
+ class GradioChat:
7
  """
8
  A class that handles the chat interface for the Gemma model.
9
 
10
  Features:
11
+ - A Gradio-based chatbot UI.
12
+ - Maintains chat history automatically.
13
+ - Uses Gemma (Hugging Face) model for generating responses.
14
+ - Formats user inputs before sending them to the model.
15
  """
16
 
17
  def __init__(self, model: LlamaCppGemmaModel, prompt_manager: PromptManager):
 
22
  self._chat()
23
 
24
  def _chat(self):
25
+ def chat_fn(history, message):
26
+ prompt = self.prompt_manager.get_prompt(user_input=message)
27
+ response = self.model.generate_response(prompt)
28
+ return response
29
+
30
+ chat_interface = gr.ChatInterface(
31
+ chat_fn,
32
+ chatbot=gr.Chatbot(label="Using model: " + self.model.get_model_name()),
33
+ textbox=gr.Textbox(placeholder="What is up?", container=False),
34
+ )
35
+
36
+ chat_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gemmademo/_model.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  from typing import Dict
3
- import streamlit as st
4
  from llama_cpp import Llama
5
  from huggingface_hub import hf_hub_download
6
 
@@ -51,7 +50,7 @@ class LlamaCppGemmaModel:
51
  },
52
  }
53
 
54
- def __init__(self, name: str = "gemma-2b",):
55
  """
56
  Initialize the model instance.
57
 
@@ -63,60 +62,46 @@ class LlamaCppGemmaModel:
63
 
64
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
65
  """
66
- Load the model and cache it in Streamlit's session state.
67
- If the model file does not exist, it will be downloaded to the models/ directory.
68
 
69
  Args:
70
- n_threads (int): Number of CPU threads to use.
71
  n_ctx (int): Context window size.
72
  n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
73
-
74
- Returns:
75
- self: Loaded model instance.
76
  """
77
  model_info = self.AVAILABLE_MODELS.get(self.name)
78
  if not model_info:
79
  raise ValueError(f"Model {self.name} is not available.")
80
 
81
  model_path = model_info["model_path"]
 
82
  # If the model file doesn't exist, download it.
83
  if not os.path.exists(model_path):
84
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
85
  repo_id = model_info.get("repo_id")
86
  filename = model_info.get("filename")
 
87
  if repo_id is None or filename is None:
88
- raise ValueError(
89
- "Repository ID or filename is missing for model download."
90
- )
91
- with st.spinner(f"Downloading {self.name}..."):
92
- downloaded_path = hf_hub_download(
93
- repo_id=repo_id,
94
- filename=filename,
95
- local_dir=os.path.dirname(model_path),
96
- local_dir_use_symlinks=False,
97
- )
98
- # If the downloaded file is not at the expected location, rename it.
99
- if downloaded_path != model_path:
100
- os.rename(downloaded_path, model_path)
101
-
102
- model_key = f"gemma_model_{self.name}"
103
- if model_key not in st.session_state:
104
- with st.spinner(f"Loading {self.name}..."):
105
- st.session_state[model_key] = Llama(
106
- model_path=model_path,
107
- n_threads=os.cpu_count(),
108
- n_ctx=n_ctx,
109
- n_gpu_layers=n_gpu_layers,
110
- )
111
- self.model = st.session_state[model_key]
112
  return self
113
 
114
- def generate_response(
115
- self,
116
- prompt: str,
117
- max_tokens: int = 512,
118
- temperature: float = 0.7,
119
- ) -> str:
120
  """
121
  Generate a response using the llama.cpp model.
122
 
@@ -124,7 +109,6 @@ class LlamaCppGemmaModel:
124
  prompt (str): Input prompt text.
125
  max_tokens (int): Maximum number of tokens to generate.
126
  temperature (float): Sampling temperature (higher = more creative).
127
- **kwargs: Additional generation parameters.
128
 
129
  Returns:
130
  str: Generated response text.
@@ -132,14 +116,12 @@ class LlamaCppGemmaModel:
132
  if self.model is None:
133
  self.load_model()
134
 
135
- # Call the llama.cpp model with the provided parameters.
136
  response = self.model(
137
  prompt,
138
  max_tokens=max_tokens,
139
  temperature=temperature,
140
  )
141
- generated_text = response["choices"][0]["text"]
142
- return generated_text.strip()
143
 
144
  def get_model_info(self) -> Dict:
145
  """
@@ -148,10 +130,7 @@ class LlamaCppGemmaModel:
148
  Returns:
149
  Dict: A dictionary containing the model name and load status.
150
  """
151
- return {
152
- "name": self.name,
153
- "loaded": self.model is not None,
154
- }
155
 
156
  def get_model_name(self) -> str:
157
  """
@@ -160,4 +139,4 @@ class LlamaCppGemmaModel:
160
  Returns:
161
  str: Model name.
162
  """
163
- return self.name
 
1
  import os
2
  from typing import Dict
 
3
  from llama_cpp import Llama
4
  from huggingface_hub import hf_hub_download
5
 
 
50
  },
51
  }
52
 
53
+ def __init__(self, name: str = "gemma-2b"):
54
  """
55
  Initialize the model instance.
56
 
 
62
 
63
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
64
  """
65
+ Load the model. If the model file does not exist, it will be downloaded.
 
66
 
67
  Args:
 
68
  n_ctx (int): Context window size.
69
  n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
 
 
 
70
  """
71
  model_info = self.AVAILABLE_MODELS.get(self.name)
72
  if not model_info:
73
  raise ValueError(f"Model {self.name} is not available.")
74
 
75
  model_path = model_info["model_path"]
76
+
77
  # If the model file doesn't exist, download it.
78
  if not os.path.exists(model_path):
79
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
80
  repo_id = model_info.get("repo_id")
81
  filename = model_info.get("filename")
82
+
83
  if repo_id is None or filename is None:
84
+ raise ValueError("Repository ID or filename is missing for model download.")
85
+
86
+ downloaded_path = hf_hub_download(
87
+ repo_id=repo_id,
88
+ filename=filename,
89
+ local_dir=os.path.dirname(model_path),
90
+ local_dir_use_symlinks=False,
91
+ )
92
+
93
+ if downloaded_path != model_path:
94
+ os.rename(downloaded_path, model_path)
95
+
96
+ self.model = Llama(
97
+ model_path=model_path,
98
+ n_threads=os.cpu_count(),
99
+ n_ctx=n_ctx,
100
+ n_gpu_layers=n_gpu_layers,
101
+ )
 
 
 
 
 
 
102
  return self
103
 
104
+ def generate_response(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
 
 
 
 
 
105
  """
106
  Generate a response using the llama.cpp model.
107
 
 
109
  prompt (str): Input prompt text.
110
  max_tokens (int): Maximum number of tokens to generate.
111
  temperature (float): Sampling temperature (higher = more creative).
 
112
 
113
  Returns:
114
  str: Generated response text.
 
116
  if self.model is None:
117
  self.load_model()
118
 
 
119
  response = self.model(
120
  prompt,
121
  max_tokens=max_tokens,
122
  temperature=temperature,
123
  )
124
+ return response["choices"][0]["text"].strip()
 
125
 
126
  def get_model_info(self) -> Dict:
127
  """
 
130
  Returns:
131
  Dict: A dictionary containing the model name and load status.
132
  """
133
+ return {"name": self.name, "loaded": self.model is not None}
 
 
 
134
 
135
  def get_model_name(self) -> str:
136
  """
 
139
  Returns:
140
  str: Model name.
141
  """
142
+ return self.name