Beibars003 commited on
Commit
29654e2
·
verified ·
1 Parent(s): 58d0dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -71
app.py CHANGED
@@ -1,82 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import torch
4
- from transformers import AutoTokenizer, Gemma3ForCausalLM
5
-
6
- model_path = "SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest"
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
- model = Gemma3ForCausalLM.from_pretrained(
9
- model_path,
10
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
11
- device_map="auto"
 
 
 
 
 
 
 
 
 
 
12
  )
13
- model.eval()
14
-
15
- """
16
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
17
- """
18
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
19
-
20
-
21
- def respond(message, history, system_message, max_tokens, temperature, top_p):
22
- messages = [{"role": "system", "content": system_message}]
23
-
24
- # Rebuild full chat history
25
- for user_msg, assistant_msg in history:
26
- if user_msg:
27
- messages.append({"role": "user", "content": user_msg})
28
- if assistant_msg:
29
- messages.append({"role": "assistant", "content": assistant_msg})
30
- messages.append({"role": "user", "content": message})
31
-
32
- # Convert chat to single prompt
33
- prompt = ""
34
- for msg in messages:
35
- role = msg["role"]
36
- content = msg["content"]
37
- if role == "system":
38
- prompt += f"[SYSTEM] {content}\n"
39
- elif role == "user":
40
- prompt += f"[USER] {content}\n"
41
- elif role == "assistant":
42
- prompt += f"[ASSISTANT] {content}\n"
43
- prompt += "[ASSISTANT]"
44
-
45
- # Tokenize
46
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
- input_len = inputs["input_ids"].shape[-1]
48
-
49
- # Generate tokens (with streaming behavior)
50
- generated_text = ""
51
- with torch.no_grad():
52
- output_ids = model.generate(
53
- **inputs,
54
- max_new_tokens=max_tokens,
55
- do_sample=True,
56
- temperature=temperature,
57
- top_p=top_p,
58
- repetition_penalty=1.2,
59
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
- output = output_ids[0][input_len:]
62
- for i in range(output.shape[0]):
63
- token = output[i].unsqueeze(0)
64
- text_piece = tokenizer.decode(token, skip_special_tokens=True)
65
- generated_text += text_piece
66
- yield generated_text
67
-
68
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  demo = gr.ChatInterface(
70
  respond,
71
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
 
 
 
72
  additional_inputs=[
73
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
74
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
75
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
76
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ],
78
  theme="Ocean",
 
 
 
 
 
 
 
 
79
  )
80
 
 
 
81
  if __name__ == "__main__":
82
- demo.launch()
 
 
 
 
 
 
1
+ # Importing required libraries
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import os
6
+ import json
7
+ import subprocess
8
+ import sys
9
+ from typing import List, Tuple
10
+ from llama_cpp import Llama
11
+ from llama_cpp_agent import LlamaCppAgent
12
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
13
+ from llama_cpp_agent.chat_history import BasicChatHistory
14
+ from llama_cpp_agent.chat_history.messages import Roles
15
+ from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
16
+ from huggingface_hub import hf_hub_download
17
  import gradio as gr
18
+ from logger import logging
19
+ from exception import CustomExceptionHandling
20
+
21
+
22
+ # Load the Environment Variables from .env file
23
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
24
+
25
+ # Download gguf model files
26
+ if not os.path.exists("./models"):
27
+ os.makedirs("./models")
28
+
29
+ hf_hub_download(
30
+ repo_id="bartowski/google_gemma-3-1b-it-GGUF",
31
+ filename="google_gemma-3-1b-it-Q4_K_M.gguf",
32
+ local_dir="./models",
33
+ )
34
+ hf_hub_download(
35
+ repo_id="bartowski/google_gemma-3-1b-it-GGUF",
36
+ filename="google_gemma-3-1b-it-Q5_K_M.gguf",
37
+ local_dir="./models",
38
  )
39
+
40
+
41
+ # Define the prompt markers for Gemma 3
42
+ gemma_3_prompt_markers = {
43
+ Roles.system: PromptMarkers("", "\n"), # System prompt should be included within user message
44
+ Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
45
+ Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"),
46
+ Roles.tool: PromptMarkers("", ""), # If you need tool support
47
+ }
48
+
49
+ # Create the formatter
50
+ gemma_3_formatter = MessagesFormatter(
51
+ pre_prompt="", # No pre-prompt
52
+ prompt_markers=gemma_3_prompt_markers,
53
+ include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message
54
+ default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
55
+ strip_prompt=False, # Don't strip whitespace from the prompt
56
+ bos_token="<bos>", # Beginning of sequence token for Gemma 3
57
+ eos_token="<eos>", # End of sequence token for Gemma 3
58
+ )
59
+
60
+
61
+ # Set the title and description
62
+ title = "Gemma Llama.cpp"
63
+ description = """Google released **[Gemma 3](https://blog.google/technology/developers/gemma-3/)**, a family of multimodal models that offers advanced capabilities like large context and multilingual support.
64
+ This interactive chat interface allows you to experiment with the [`gemma-3-1b-it`](https://huggingface.co/google/gemma-3-1b-it) text model using various prompts and generation parameters.
65
+ Users can select different model variants (GGUF format), system prompts, and observe generated responses in real-time.
66
+ Key generation parameters, such as ⁣`temperature`, `max_tokens`, `top_k` and others are exposed below for tuning model behavior.
67
+ For a detailed technical walkthrough, please refer to the accompanying **[blog post](https://sitammeur.medium.com/build-your-own-gemma-3-chatbot-with-gradio-and-llama-cpp-46457b22a28e)**."""
68
+
69
+
70
+ llm = None
71
+ llm_model = None
72
+
73
+ def respond(
74
+ message: str,
75
+ history: List[Tuple[str, str]],
76
+ model: str = "google_gemma-3-1b-it-Q4_K_M.gguf", # Set default model
77
+ system_message: str = "You are a helpful assistant.",
78
+ max_tokens: int = 1024,
79
+ temperature: float = 0.7,
80
+ top_p: float = 0.95,
81
+ top_k: int = 40,
82
+ repeat_penalty: float = 1.1,
83
+ ):
84
+ """
85
+ Respond to a message using the Gemma3 model via Llama.cpp.
86
+ Args:
87
+ - message (str): The message to respond to.
88
+ - history (List[Tuple[str, str]]): The chat history.
89
+ - model (str): The model to use.
90
+ - system_message (str): The system message to use.
91
+ - max_tokens (int): The maximum number of tokens to generate.
92
+ - temperature (float): The temperature of the model.
93
+ - top_p (float): The top-p of the model.
94
+ - top_k (int): The top-k of the model.
95
+ - repeat_penalty (float): The repetition penalty of the model.
96
+ Returns:
97
+ str: The response to the message.
98
+ """
99
+ try:
100
+ # Load the global variables
101
+ global llm
102
+ global llm_model
103
+
104
+ # Ensure model is not None
105
+ if model is None:
106
+ model = "google_gemma-3-1b-it-Q4_K_M.gguf"
107
+
108
+ # Load the model
109
+ if llm is None or llm_model != model:
110
+ # Check if model file exists
111
+ model_path = f"models/{model}"
112
+ if not os.path.exists(model_path):
113
+ yield f"Error: Model file not found at {model_path}. Please check your model path."
114
+ return
115
+
116
+ llm = Llama(
117
+ model_path=f"models/{model}",
118
+ flash_attn=False,
119
+ n_gpu_layers=0,
120
+ n_batch=8,
121
+ n_ctx=2048,
122
+ n_threads=8,
123
+ n_threads_batch=8,
124
+ )
125
+ llm_model = model
126
+ provider = LlamaCppPythonProvider(llm)
127
+
128
+ # Create the agent
129
+ agent = LlamaCppAgent(
130
+ provider,
131
+ system_prompt=f"{system_message}",
132
+ custom_messages_formatter=gemma_3_formatter,
133
+ debug_output=True,
134
  )
135
+
136
+ # Set the settings like temperature, top-k, top-p, max tokens, etc.
137
+ settings = provider.get_provider_default_settings()
138
+ settings.temperature = temperature
139
+ settings.top_k = top_k
140
+ settings.top_p = top_p
141
+ settings.max_tokens = max_tokens
142
+ settings.repeat_penalty = repeat_penalty
143
+ settings.stream = True
144
+
145
+ messages = BasicChatHistory()
146
+
147
+ # Add the chat history
148
+ for msn in history:
149
+ user = {"role": Roles.user, "content": msn[0]}
150
+ assistant = {"role": Roles.assistant, "content": msn[1]}
151
+ messages.add_message(user)
152
+ messages.add_message(assistant)
153
+
154
+ # Get the response stream
155
+ stream = agent.get_chat_response(
156
+ message,
157
+ llm_sampling_settings=settings,
158
+ chat_history=messages,
159
+ returns_streaming_generator=True,
160
+ print_output=False,
161
+ )
162
+
163
+ # Log the success
164
+ logging.info("Response stream generated successfully")
165
+
166
+ # Generate the response
167
+ outputs = ""
168
+ for output in stream:
169
+ outputs += output
170
+ yield outputs
171
+
172
+ # Handle exceptions that may occur during the process
173
+ except Exception as e:
174
+ # Custom exception handling
175
+ raise CustomExceptionHandling(e, sys) from e
176
+
177
+
178
+ # Create a chat interface
179
  demo = gr.ChatInterface(
180
  respond,
181
+ examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]],
182
+ additional_inputs_accordion=gr.Accordion(
183
+ label="⚙️ Parameters", open=False, render=False
184
+ ),
185
  additional_inputs=[
186
+ gr.Dropdown(
187
+ choices=[
188
+ "google_gemma-3-1b-it-Q4_K_M.gguf",
189
+ "google_gemma-3-1b-it-Q5_K_M.gguf",
190
+ ],
191
+ value="google_gemma-3-1b-it-Q4_K_M.gguf",
192
+ label="Model",
193
+ info="Select the AI model to use for chat",
194
+ ),
195
+ gr.Textbox(
196
+ value="You are a helpful assistant.",
197
+ label="System Prompt",
198
+ info="Define the AI assistant's personality and behavior",
199
+ lines=2,
200
+ ),
201
+ gr.Slider(
202
+ minimum=512,
203
+ maximum=2048,
204
+ value=1024,
205
+ step=1,
206
+ label="Max Tokens",
207
+ info="Maximum length of response (higher = longer replies)",
208
+ ),
209
+ gr.Slider(
210
+ minimum=0.1,
211
+ maximum=2.0,
212
+ value=0.7,
213
+ step=0.1,
214
+ label="Temperature",
215
+ info="Creativity level (higher = more creative, lower = more focused)",
216
+ ),
217
+ gr.Slider(
218
+ minimum=0.1,
219
+ maximum=1.0,
220
+ value=0.95,
221
+ step=0.05,
222
+ label="Top-p",
223
+ info="Nucleus sampling threshold",
224
+ ),
225
+ gr.Slider(
226
+ minimum=1,
227
+ maximum=100,
228
+ value=40,
229
+ step=1,
230
+ label="Top-k",
231
+ info="Limit vocabulary choices to top K tokens",
232
+ ),
233
+ gr.Slider(
234
+ minimum=1.0,
235
+ maximum=2.0,
236
+ value=1.1,
237
+ step=0.1,
238
+ label="Repetition Penalty",
239
+ info="Penalize repeated words (higher = less repetition)",
240
+ ),
241
  ],
242
  theme="Ocean",
243
+ submit_btn="Send",
244
+ stop_btn="Stop",
245
+ title=title,
246
+ description=description,
247
+ chatbot=gr.Chatbot(scale=1, show_copy_button=True, resizable=True),
248
+ flagging_mode="never",
249
+ editable=True,
250
+ cache_examples=False,
251
  )
252
 
253
+
254
+ # Launch the chat interface
255
  if __name__ == "__main__":
256
+ demo.launch(
257
+ share=False,
258
+ server_name="0.0.0.0",
259
+ server_port=7860,
260
+ show_api=False,
261
+ )