SII-InnoMegrez commited on
Commit
bc9db78
Β·
verified Β·
1 Parent(s): 6b8c84c

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +6 -6
  2. app.py +133 -263
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Megrez 3B Omni
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.3.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Megrez-3B-Omni Chat Demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Megrez2-3x7B-A3B-Preview
3
+ emoji: πŸ‘€
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Megrez2 Chat Demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,297 +1,167 @@
1
- # -*- encoding: utf-8 -*-
2
- # File: app.py
3
- # Description: None
4
-
5
-
6
- from copy import deepcopy
7
- from typing import Dict, List
8
- from PIL import Image
9
- import io
10
- import subprocess
11
  import requests
12
  import json
13
- import base64
14
- import gradio as gr
15
- import librosa
16
- import os
17
-
18
- IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")
19
- VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v")
20
- AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a", ".wma")
21
-
22
- DEFAULT_SAMPLING_PARAMS = {
23
- "top_p": 0.8,
24
- "top_k": 100,
25
- "temperature": 0.7,
26
- "do_sample": True,
27
- "num_beams": 1,
28
- "repetition_penalty": 1.2,
29
- }
30
- MAX_NEW_TOKENS = 1024
31
-
32
 
33
 
34
- def load_image_to_base64(image_path):
35
- """Load image and convert to base64 string"""
36
- with Image.open(image_path) as img:
37
- if img.mode != 'RGB':
38
- img = img.convert('RGB')
39
- img_byte_arr = io.BytesIO()
40
- img.save(img_byte_arr, format='PNG')
41
- img_byte_arr = img_byte_arr.getvalue()
42
- return base64.b64encode(img_byte_arr).decode('utf-8')
43
 
44
- def wav_to_bytes_with_ffmpeg(wav_file_path):
45
- process = subprocess.Popen(
46
- ['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'],
47
- stdout=subprocess.PIPE,
48
- stderr=subprocess.PIPE
49
- )
50
- out, _ = process.communicate()
51
- return base64.b64encode(out).decode('utf-8')
52
-
53
- def parse_sse_response(response):
54
- for line in response.iter_lines():
55
- print(line)
56
- if line:
57
- line = line.decode('utf-8')
58
- if line.startswith('data: '):
59
- data = line[6:].strip() # Remove 'data: ' prefix
60
- if data == '[DONE]':
61
  break
 
62
  try:
63
- json_data = json.loads(data)
64
- print(f"{json_data['text']}")
65
- yield json_data['text']
66
- except json.JSONDecodeError:
67
- print(f"Failed to parse JSON: {data}")
68
- raise gr.Error(f"Failed to parse JSON: {data}")
69
-
70
- def history2messages(history: List[Dict]) -> List[Dict]:
71
- """
72
- Transform gradio history to chat messages.
73
- """
74
- messages = []
75
- cur_message = dict()
76
- for item in history:
77
- if item["role"] == "assistant":
78
- if len(cur_message) > 0:
79
- messages.append(deepcopy(cur_message))
80
- cur_message = dict()
81
- messages.append(deepcopy(item))
82
- continue
83
-
84
- if "role" not in cur_message:
85
- cur_message["role"] = "user"
86
- if "content" not in cur_message:
87
- cur_message["content"] = dict()
88
-
89
- if "metadata" not in item or item["metadata"] is None:
90
- item["metadata"] = {"title": ""}
91
- if item["metadata"]["title"] == "":
92
- cur_message["content"]["text"] = item["content"]
93
- elif item["metadata"]["title"] == "image":
94
- cur_message["content"]["image"] = load_image_to_base64(item["content"][0])
95
- elif item["metadata"]["title"] == "audio":
96
- cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0])
97
- if len(cur_message) > 0:
98
- messages.append(cur_message)
99
- return messages
100
-
101
- def check_messages(history, message, audio):
102
- if not isinstance(message, dict):
103
- raise gr.Error("ζΆˆζ―ζ ΌεΌι”™θ――")
104
-
105
- has_text = message.get("text", "") and message["text"].strip()
106
- has_files = len(message.get("files", [])) > 0
107
- has_audio = audio is not None
108
-
109
- if not (has_text or has_files or has_audio):
110
- raise gr.Error("θ―·θΎ“ε…₯ζ–‡ε­—ζˆ–δΈŠδΌ ιŸ³ι’‘/ε›Ύη‰‡εŽε†ε‘ι€γ€‚")
111
-
112
- audios = []
113
- images = []
114
-
115
- for file_msg in message["files"]:
116
- if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS):
117
- duration = librosa.get_duration(filename=file_msg)
118
- if duration > 30:
119
- raise gr.Error("ιŸ³ι’‘ζ—Άι•ΏδΈθƒ½θΆ…θΏ‡30秒。")
120
- if duration == 0:
121
- raise gr.Error("ιŸ³ι’‘ζ—Άι•ΏδΈθƒ½δΈΊ0秒。")
122
- audios.append(file_msg)
123
- elif file_msg.endswith(IMAGE_EXTENSIONS):
124
- images.append(file_msg)
125
- else:
126
- filename = file_msg.split("/")[-1]
127
- raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.")
128
-
129
- if len(audios) > 1:
130
- raise gr.Error("Please upload only one audio file.")
131
-
132
- if len(images) > 1:
133
- raise gr.Error("Please upload only one image file.")
134
-
135
- if audio is not None:
136
- if len(audios) > 0:
137
- raise gr.Error("Please upload only one audio file or record audio.")
138
- audios.append(audio)
139
-
140
- # Append the message to the history
141
- for image in images:
142
- history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}})
143
-
144
- for audio in audios:
145
- history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}})
146
-
147
- if message["text"]:
148
- history.append({"role": "user", "content": message["text"]})
149
-
150
- return history, gr.MultimodalTextbox(value=None, interactive=False), None
151
-
152
- def bot(
153
- history: list,
154
- top_p: float,
155
- top_k: int,
156
- temperature: float,
157
- repetition_penalty: float,
158
- max_new_tokens: int = MAX_NEW_TOKENS,
159
- regenerate: bool = False,
160
- ):
161
-
162
- if history and regenerate:
163
- history = history[:-1]
164
-
165
- if not history:
166
- return history
167
 
168
- msgs = history2messages(history)
169
- print(msgs)
170
-
171
- API_URL = os.getenv("API_URL", "http://8.141.126.196:28000/v1/chat")
172
-
173
  payload = {
174
- "messages": msgs,
175
- "sampling_params": {
176
- "top_p": top_p,
177
- "top_k": top_k,
178
- "temperature": temperature,
179
- "repetition_penalty": repetition_penalty,
180
- "max_new_tokens": max_new_tokens,
181
- "num_beams": 3,
182
- }
183
  }
184
 
185
- response = requests.get(
186
- API_URL,
187
- json=payload,
188
- headers={'Accept': 'text/event-stream'},
189
- stream=True
190
- )
191
- response_text = ""
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- for text in parse_sse_response(response):
194
- response_text += text
195
- yield history + [{"role": "assistant", "content": response_text}]
196
-
197
- return response_text
198
-
199
- def change_state(state):
200
- return gr.update(visible=not state), not state
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def reset_user_input():
203
  return gr.update(value="")
204
 
 
 
 
 
 
 
 
205
  if __name__ == "__main__":
206
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
207
  gr.Markdown(
208
  f"""
209
- # πŸͺ Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a>
210
  """
211
  )
212
- chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh')
213
-
214
- sampling_params_group_hidden_state = gr.State(False)
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- with gr.Row(equal_height=True):
218
- chat_input = gr.MultimodalTextbox(
219
- file_count="multiple",
220
- placeholder="Enter your prompt or upload image/audio here, then press ENTER...",
221
- show_label=False,
222
- scale=8,
223
- file_types=["image", "audio"],
 
 
 
224
  interactive=True,
225
- # stop_btn=True,
226
  )
227
- with gr.Row(equal_height=True):
228
- audio_input = gr.Audio(
229
- sources=["microphone", "upload"],
230
- type="filepath",
231
- scale=1,
232
- max_length=30
 
 
 
 
 
 
 
 
 
233
  )
234
- with gr.Row(equal_height=True):
235
- with gr.Column(scale=1, min_width=150):
236
- with gr.Row(equal_height=True):
237
- regenerate_btn = gr.Button("Regenerate", variant="primary")
238
- clear_btn = gr.ClearButton(
239
- [chat_input, audio_input, chatbot],
240
- )
241
-
242
- with gr.Row():
243
- sampling_params_toggle_btn = gr.Button("Sampling Parameters")
244
-
245
- with gr.Group(visible=False) as sampling_params_group:
246
- with gr.Row():
247
- temperature = gr.Slider(
248
- minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature"
249
- )
250
- repetition_penalty = gr.Slider(
251
- minimum=0,
252
- maximum=2,
253
- value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
254
- label="Repetition Penalty",
255
- )
256
-
257
- with gr.Row():
258
- top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p")
259
- top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k")
260
-
261
- with gr.Row():
262
- max_new_tokens = gr.Slider(
263
- minimum=1,
264
- maximum=MAX_NEW_TOKENS,
265
- value=MAX_NEW_TOKENS,
266
- label="Max New Tokens",
267
- interactive=True,
268
- )
269
 
270
- sampling_params_toggle_btn.click(
271
- change_state,
272
- sampling_params_group_hidden_state,
273
- [sampling_params_group, sampling_params_group_hidden_state],
274
  )
275
-
276
- chat_msg = chat_input.submit(
277
- check_messages,
278
- [chatbot, chat_input, audio_input],
279
- [chatbot, chat_input, audio_input],
280
  )
 
 
281
 
282
- bot_msg = chat_msg.then(
283
- bot,
284
- inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens],
285
- outputs=chatbot,
286
- api_name="bot_response",
287
  )
288
-
289
- bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
290
-
291
- regenerate_btn.click(
292
- bot,
293
- inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)],
294
- outputs=chatbot,
295
  )
296
 
297
- demo.launch(server_name="0.0.0.0")
 
1
+ from argparse import ArgumentParser
2
+ import gradio as gr
 
 
 
 
 
 
 
 
3
  import requests
4
  import json
5
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def get_streaming_response(response: requests.Response):
10
+ for chunk in response.iter_lines():
11
+ if chunk:
12
+ data = chunk.decode("utf-8")
13
+ if data.startswith('data: '):
14
+ json_str = data[6:]
15
+
16
+ if json_str == '[DONE]':
 
 
 
 
 
 
 
 
 
17
  break
18
+
19
  try:
20
+ chunk = json.loads(json_str)
21
+ delta = chunk.get('choices', [{}])[0].get('delta', {})
22
+ new_text = delta.get('content', '')
23
+
24
+ if new_text:
25
+ yield new_text
26
+ except (json.JSONDecodeError, IndexError):
27
+ print(f"Skipping malformed SSE line: {json_str}")
28
+ continue
29
+
30
+ def _chat_stream(model, tokenizer, query, history, temperature, top_p, max_output_tokens):
31
+ conversation = []
32
+ for query_h, response_h in history:
33
+ conversation.append({"role": "user", "content": query_h})
34
+ conversation.append({"role": "assistant", "content": response_h})
35
+ conversation.append({"role": "user", "content": query})
36
+
37
+ headers = {
38
+ "Content-Type": "application/json"
39
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
41
  payload = {
42
+ "model": "megrez-moe-waic",
43
+ "messages": conversation,
44
+ "max_tokens": max_output_tokens,
45
+ "temperature": max(temperature, 0),
46
+ "top_p": top_p,
47
+ "stream": True
 
 
 
48
  }
49
 
50
+ try:
51
+ API_URL = "http://8.152.0.142:8080/v1/chat/completions"
52
+ response = requests.post(API_URL, headers=headers, data=json.dumps(payload), timeout=60, stream=True)
53
+ response.raise_for_status()
54
+ for chunk in get_streaming_response(response):
55
+ yield chunk
56
+ time.sleep(0.01)
57
+
58
+ except requests.exceptions.RequestException as e:
59
+ print(f"API request failed: {e}")
60
+ yield f"Error: Could not connect to the API. Details: {e}"
61
+ except (KeyError, IndexError) as e:
62
+ print(f"Failed to parse API response: {response.text}")
63
+ yield f"Error: Invalid response format from the API. Details: {e}"
64
+
65
+ def predict(_query, _chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
66
+ print(f"User: {_query}")
67
+ _chatbot.append((_query, ""))
68
 
69
+ full_response = ""
70
+ stream = _chat_stream(None, None, _query, history=_task_history, temperature=_temperature, top_p=_top_p, max_output_tokens=_max_output_tokens)
71
+
72
+ for new_text in stream:
73
+ full_response += new_text
74
+ _chatbot[-1] = (_query, full_response)
75
+ yield _chatbot
76
+
77
+ print(f"History: {_task_history}")
78
+ _task_history.append((_query, full_response))
79
+ print(f"Megrez (from API): {full_response}")
80
+
81
+ def regenerate(_chatbot, _task_history, _temperature, _top_p, _max_output_tokens):
82
+ if not _task_history:
83
+ yield _chatbot
84
+ return
85
+ item = _task_history.pop(-1)
86
+ _chatbot.pop(-1)
87
+ yield from predict(item[0], _chatbot, _task_history, _temperature, _top_p, _max_output_tokens)
88
 
89
  def reset_user_input():
90
  return gr.update(value="")
91
 
92
+ def reset_state(_chatbot, _task_history):
93
+ _task_history.clear()
94
+ _chatbot.clear()
95
+ return _chatbot
96
+
97
+
98
+
99
  if __name__ == "__main__":
100
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
  gr.Markdown(
102
  f"""
103
+ # 🎱 Chat with Megrez2 <a href="https://github.com/infinigence/Infini-Megrez">
104
  """
105
  )
 
 
 
106
 
107
+ chatbot = gr.Chatbot(label="Megrez2", elem_classes="control-height", height='48vh', show_copy_button=True,
108
+ latex_delimiters=[
109
+ {"left": "$$", "right": "$$", "display": True},
110
+ {"left": "$", "right": "$", "display": False},
111
+ {"left": "\\(", "right": "\\)", "display": False},
112
+ {"left": "\\[", "right": "\\]", "display": True},
113
+ ])
114
+ with gr.Row():
115
+ with gr.Column(scale=20):
116
+ query = gr.Textbox(show_label=False, container=False, placeholder="Enter your prompt here and press ENTER")
117
+ with gr.Column(scale=1, min_width=100):
118
+ submit_btn = gr.Button("πŸš€ Send", variant="primary")
119
+ task_history = gr.State([])
120
 
121
+ with gr.Row():
122
+ empty_btn = gr.Button("πŸ—‘οΈ Clear History")
123
+ regen_btn = gr.Button("πŸ”„ Regenerate")
124
+
125
+ with gr.Accordion("Parameters", open=False) as parameter_row:
126
+ temperature = gr.Slider(
127
+ minimum=0.0,
128
+ maximum=1.2,
129
+ value=0.7,
130
+ step=0.1,
131
  interactive=True,
132
+ label="Temperature",
133
  )
134
+ top_p = gr.Slider(
135
+ minimum=0.0,
136
+ maximum=1.0,
137
+ value=0.9,
138
+ step=0.1,
139
+ interactive=True,
140
+ label="Top P",
141
+ )
142
+ max_output_tokens = gr.Slider(
143
+ minimum=16,
144
+ maximum=32768,
145
+ value=4096,
146
+ step=1024,
147
+ interactive=True,
148
+ label="Max output tokens",
149
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ submit_btn.click(
152
+ predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
 
 
153
  )
154
+ query.submit(
155
+ predict, [query, chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
 
 
 
156
  )
157
+ submit_btn.click(reset_user_input, [], [query])
158
+ query.submit(reset_user_input, [], [query])
159
 
160
+ empty_btn.click(
161
+ reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True
 
 
 
162
  )
163
+ regen_btn.click(
164
+ regenerate, [chatbot, task_history, temperature, top_p, max_output_tokens], [chatbot], show_progress=True
 
 
 
 
 
165
  )
166
 
167
+ demo.launch(ssr_mode=False, share=True)