NMCxyz commited on
Commit
9942354
·
verified ·
1 Parent(s): b4f5cb4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc +0 -0
  2. omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc +0 -0
  3. omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-39.pyc +0 -0
  4. omni_speech/serve/__init__.py +0 -0
  5. omni_speech/serve/controller.py +298 -0
  6. omni_speech/serve/gradio_web_server.py +348 -0
  7. omni_speech/serve/model_worker.py +292 -0
  8. omni_speech/train/__pycache__/omni_trainer.cpython-310.pyc +0 -0
  9. omni_speech/train/__pycache__/omni_trainer.cpython-312.pyc +0 -0
  10. omni_speech/train/__pycache__/run_train.cpython-310.pyc +0 -0
  11. omni_speech/train/__pycache__/run_train.cpython-312.pyc +0 -0
  12. omni_speech/train/__pycache__/run_train.cpython-38.pyc +0 -0
  13. omni_speech/train/__pycache__/train.cpython-312.pyc +0 -0
  14. omni_speech/train/__pycache__/train_mem.cpython-312.pyc +0 -0
  15. omni_speech/train/__pycache__/train_multiturn.cpython-312.pyc +0 -0
  16. omni_speech/train/__pycache__/train_raw.cpython-312.pyc +0 -0
  17. omni_speech/train/__pycache__/train_test.cpython-312.pyc +0 -0
  18. omni_speech/train/__pycache__/trainer.cpython-310.pyc +0 -0
  19. omni_speech/train/__pycache__/trainer.cpython-312.pyc +0 -0
  20. omni_speech/train/export.py +512 -0
  21. omni_speech/train/omni_trainer.py +345 -0
  22. omni_speech/train/train.py +420 -0
  23. omni_speech/train/train_mem.py +4 -0
  24. omni_speech/train/train_minicpmo.py +660 -0
  25. omni_speech/train/train_minicpmo_test.py +729 -0
  26. omni_speech/train/train_multiturn.py +515 -0
  27. omni_speech/train/trainer.py +249 -0
  28. scripts/continue.sh +65 -0
  29. scripts/ds_config_zero2.json +54 -0
  30. scripts/ds_config_zero3.json +59 -0
  31. scripts/export.sh +39 -0
  32. scripts/finetune.sh +42 -0
  33. scripts/finetune_llm_speech_decoder.sh +85 -0
  34. scripts/finetune_lora.sh +43 -0
  35. scripts/finetune_minicpmo.sh +65 -0
  36. scripts/finetune_minicpmo_asr.sh +63 -0
  37. scripts/finetune_speech_decoder.sh +42 -0
  38. scripts/minicpmp_config.json +163 -0
  39. scripts/pretrain_minicpmo_test.sh +89 -0
  40. scripts/pretrained.sh +44 -0
  41. scripts/pretrained_minicpmo.sh +74 -0
  42. scripts/test_llama.sh +41 -0
  43. scripts/test_qwen.sh +41 -0
  44. scripts/wandb/debug-internal.log +7 -0
  45. scripts/wandb/debug.log +25 -0
  46. scripts/wandb/latest-run/files/output.log +559 -0
  47. scripts/wandb/latest-run/files/requirements.txt +341 -0
  48. scripts/wandb/latest-run/files/wandb-metadata.json +171 -0
  49. scripts/wandb/latest-run/logs/debug-core.log +7 -0
  50. scripts/wandb/latest-run/logs/debug-internal.log +7 -0
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc ADDED
Binary file (2.07 kB). View file
 
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc ADDED
Binary file (1.19 kB). View file
 
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-39.pyc ADDED
Binary file (1.23 kB). View file
 
omni_speech/serve/__init__.py ADDED
File without changes
omni_speech/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from omni_speech.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from omni_speech.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,), daemon=True)
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
omni_speech/serve/gradio_web_server.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import torch
7
+ import torchaudio
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import requests
12
+ import soundfile as sf
13
+
14
+ from omni_speech.conversation import default_conversation, conv_templates
15
+ from omni_speech.constants import LOGDIR
16
+ from omni_speech.utils import build_logger, server_error_msg
17
+ from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
18
+
19
+
20
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
21
+
22
+ vocoder = None
23
+
24
+ headers = {"User-Agent": "LLaMA-Omni Client"}
25
+
26
+ no_change_btn = gr.Button()
27
+ enable_btn = gr.Button(interactive=True)
28
+ disable_btn = gr.Button(interactive=False)
29
+
30
+
31
+ def get_conv_log_filename():
32
+ t = datetime.datetime.now()
33
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
34
+ return name
35
+
36
+
37
+ def get_model_list():
38
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
39
+ assert ret.status_code == 200
40
+ ret = requests.post(args.controller_url + "/list_models")
41
+ models = ret.json()["models"]
42
+ logger.info(f"Models: {models}")
43
+ return models
44
+
45
+
46
+ get_window_url_params = """
47
+ function() {
48
+ const params = new URLSearchParams(window.location.search);
49
+ url_params = Object.fromEntries(params);
50
+ console.log(url_params);
51
+ return url_params;
52
+ }
53
+ """
54
+
55
+
56
+ def load_demo(url_params, request: gr.Request):
57
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
58
+
59
+ dropdown_update = gr.Dropdown(visible=True)
60
+ if "model" in url_params:
61
+ model = url_params["model"]
62
+ if model in models:
63
+ dropdown_update = gr.Dropdown(value=model, visible=True)
64
+
65
+ state = default_conversation.copy()
66
+ return state, dropdown_update
67
+
68
+
69
+ def load_demo_refresh_model_list(request: gr.Request):
70
+ logger.info(f"load_demo. ip: {request.client.host}")
71
+ models = get_model_list()
72
+ state = default_conversation.copy()
73
+ dropdown_update = gr.Dropdown(
74
+ choices=models,
75
+ value=models[0] if len(models) > 0 else ""
76
+ )
77
+ return state, dropdown_update
78
+
79
+
80
+ def clear_history(request: gr.Request):
81
+ logger.info(f"clear_history. ip: {request.client.host}")
82
+ state = default_conversation.copy()
83
+ return (state, None, "", "", None)
84
+
85
+
86
+ def add_speech(state, speech, request: gr.Request):
87
+ text = "Please directly answer the questions in the user's speech."
88
+ text = '<speech>\n' + text
89
+ text = (text, speech)
90
+ state = default_conversation.copy()
91
+ state.append_message(state.roles[0], text)
92
+ state.append_message(state.roles[1], None)
93
+ state.skip_next = False
94
+ return (state)
95
+
96
+
97
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, chunk_size, request: gr.Request):
98
+ logger.info(f"http_bot. ip: {request.client.host}")
99
+ start_tstamp = time.time()
100
+ model_name = model_selector
101
+
102
+ if state.skip_next:
103
+ # This generate call is skipped due to invalid inputs
104
+ yield (state, "", "", None)
105
+ return
106
+
107
+ if len(state.messages) == state.offset + 2:
108
+ # First round of conversation
109
+ template_name = "llama_3"
110
+ new_state = conv_templates[template_name].copy()
111
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
112
+ new_state.append_message(new_state.roles[1], None)
113
+ state = new_state
114
+
115
+ # Query worker address
116
+ controller_url = args.controller_url
117
+ ret = requests.post(controller_url + "/get_worker_address",
118
+ json={"model": model_name})
119
+ worker_addr = ret.json()["address"]
120
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
121
+
122
+ # No available worker
123
+ if worker_addr == "":
124
+ state.messages[-1][-1] = server_error_msg
125
+ yield (state, "", "", None)
126
+ return
127
+
128
+ # Construct prompt
129
+ prompt = state.get_prompt()
130
+
131
+ sr, audio = state.messages[0][1][1]
132
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
133
+ audio = torch.tensor(audio.astype(np.float32)).unsqueeze(0)
134
+ audio = resampler(audio).squeeze(0).numpy()
135
+ audio /= 32768.0
136
+ audio = audio.tolist()
137
+ # Make requests
138
+ pload = {
139
+ "model": model_name,
140
+ "prompt": prompt,
141
+ "temperature": float(temperature),
142
+ "top_p": float(top_p),
143
+ "max_new_tokens": min(int(max_new_tokens), 1500),
144
+ "stop": state.sep2,
145
+ "audio": audio,
146
+ }
147
+
148
+ yield (state, "", "", None)
149
+
150
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
151
+
152
+ try:
153
+ # Stream output
154
+ response = requests.post(worker_addr + "/worker_generate_stream",
155
+ headers=headers, json=pload, stream=True, timeout=10)
156
+ num_generated_units = 0
157
+ wav_list = []
158
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
159
+ if chunk:
160
+ data = json.loads(chunk.decode())
161
+ if data["error_code"] == 0:
162
+ output = data["text"][len(prompt):].strip()
163
+ output_unit = list(map(int, data["unit"].strip().split()))
164
+ state.messages[-1][-1] = (output, data["unit"].strip())
165
+
166
+ # vocoder
167
+ new_units = output_unit[num_generated_units:]
168
+ if len(new_units) >= chunk_size:
169
+ num_generated_units = len(output_unit)
170
+ x = {"code": torch.LongTensor(new_units).view(1, -1).cuda()}
171
+ wav = vocoder(x, True)
172
+ wav_list.append(wav.detach().cpu().numpy())
173
+
174
+ if len(wav_list) > 0:
175
+ wav_full = np.concatenate(wav_list)
176
+ return_value = (16000, wav_full)
177
+ else:
178
+ return_value = None
179
+
180
+ yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value)
181
+ else:
182
+ output = data["text"] + f" (error_code: {data['error_code']})"
183
+ state.messages[-1][-1] = output
184
+ yield (state, "", "", None)
185
+ return
186
+ time.sleep(0.03)
187
+ except requests.exceptions.RequestException as e:
188
+ state.messages[-1][-1] = server_error_msg
189
+ yield (state, "", "", None)
190
+ return
191
+
192
+ if num_generated_units < len(output_unit):
193
+ new_units = output_unit[num_generated_units:]
194
+ num_generated_units = len(output_unit)
195
+ x = {
196
+ "code": torch.LongTensor(new_units).view(1, -1).cuda()
197
+ }
198
+ wav = vocoder(x, True)
199
+ wav_list.append(wav.detach().cpu().numpy())
200
+
201
+ if len(wav_list) > 0:
202
+ wav_full = np.concatenate(wav_list)
203
+ return_value = (16000, wav_full)
204
+ else:
205
+ return_value = None
206
+
207
+ yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value)
208
+
209
+ finish_tstamp = time.time()
210
+ logger.info(f"{output}")
211
+ logger.info(f"{output_unit}")
212
+
213
+
214
+ title_markdown = ("""
215
+ # 🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models
216
+ """)
217
+
218
+ block_css = """
219
+
220
+ #buttons button {
221
+ min-width: min(120px,100%);
222
+ }
223
+
224
+ """
225
+
226
+ def build_demo(embed_mode, vocoder, cur_dir=None, concurrency_count=10):
227
+ with gr.Blocks(title="LLaMA-Omni Speech Chatbot", theme=gr.themes.Default(), css=block_css) as demo:
228
+ state = gr.State()
229
+
230
+ if not embed_mode:
231
+ gr.Markdown(title_markdown)
232
+
233
+ with gr.Row(elem_id="model_selector_row"):
234
+ model_selector = gr.Dropdown(
235
+ choices=models,
236
+ value=models[0] if len(models) > 0 else "",
237
+ interactive=True,
238
+ show_label=False,
239
+ container=False)
240
+
241
+ with gr.Row():
242
+ audio_input_box = gr.Audio(sources=["upload", "microphone"], label="Speech Input")
243
+ with gr.Accordion("Parameters", open=True) as parameter_row:
244
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",)
245
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
246
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max Output Tokens",)
247
+ chunk_size = gr.Slider(minimum=10, maximum=500, value=40, step=10, interactive=True, label="Chunk Size",)
248
+
249
+ if cur_dir is None:
250
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
251
+ gr.Examples(examples=[
252
+ [f"{cur_dir}/examples/vicuna_1.wav"],
253
+ [f"{cur_dir}/examples/vicuna_2.wav"],
254
+ [f"{cur_dir}/examples/vicuna_3.wav"],
255
+ [f"{cur_dir}/examples/vicuna_4.wav"],
256
+ [f"{cur_dir}/examples/vicuna_5.wav"],
257
+ [f"{cur_dir}/examples/helpful_base_1.wav"],
258
+ [f"{cur_dir}/examples/helpful_base_2.wav"],
259
+ [f"{cur_dir}/examples/helpful_base_3.wav"],
260
+ [f"{cur_dir}/examples/helpful_base_4.wav"],
261
+ [f"{cur_dir}/examples/helpful_base_5.wav"],
262
+ ], inputs=[audio_input_box])
263
+
264
+ with gr.Row():
265
+ submit_btn = gr.Button(value="Send", variant="primary")
266
+ clear_btn = gr.Button(value="Clear")
267
+
268
+ text_output_box = gr.Textbox(label="Text Output", type="text")
269
+ unit_output_box = gr.Textbox(label="Unit Output", type="text")
270
+ audio_output_box = gr.Audio(label="Speech Output")
271
+
272
+ url_params = gr.JSON(visible=False)
273
+
274
+ submit_btn.click(
275
+ add_speech,
276
+ [state, audio_input_box],
277
+ [state]
278
+ ).then(
279
+ http_bot,
280
+ [state, model_selector, temperature, top_p, max_output_tokens, chunk_size],
281
+ [state, text_output_box, unit_output_box, audio_output_box],
282
+ concurrency_limit=concurrency_count
283
+ )
284
+
285
+ clear_btn.click(
286
+ clear_history,
287
+ None,
288
+ [state, audio_input_box, text_output_box, unit_output_box, audio_output_box],
289
+ queue=False
290
+ )
291
+
292
+ if args.model_list_mode == "once":
293
+ demo.load(
294
+ load_demo,
295
+ [url_params],
296
+ [state, model_selector],
297
+ js=get_window_url_params
298
+ )
299
+ elif args.model_list_mode == "reload":
300
+ demo.load(
301
+ load_demo_refresh_model_list,
302
+ None,
303
+ [state, model_selector],
304
+ queue=False
305
+ )
306
+ else:
307
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
308
+
309
+ return demo
310
+
311
+
312
+ def build_vocoder(args):
313
+ global vocoder
314
+ if args.vocoder is None:
315
+ return None
316
+ with open(args.vocoder_cfg) as f:
317
+ vocoder_cfg = json.load(f)
318
+ vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).cuda()
319
+
320
+
321
+ if __name__ == "__main__":
322
+ parser = argparse.ArgumentParser()
323
+ parser.add_argument("--host", type=str, default="0.0.0.0")
324
+ parser.add_argument("--port", type=int)
325
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
326
+ parser.add_argument("--concurrency-count", type=int, default=16)
327
+ parser.add_argument("--model-list-mode", type=str, default="once",
328
+ choices=["once", "reload"])
329
+ parser.add_argument("--share", action="store_true")
330
+ parser.add_argument("--moderate", action="store_true")
331
+ parser.add_argument("--embed", action="store_true")
332
+ parser.add_argument("--vocoder", type=str)
333
+ parser.add_argument("--vocoder-cfg", type=str)
334
+ args = parser.parse_args()
335
+ logger.info(f"args: {args}")
336
+
337
+ models = get_model_list()
338
+ build_vocoder(args)
339
+
340
+ logger.info(args)
341
+ demo = build_demo(args.embed, vocoder, concurrency_count=args.concurrency_count)
342
+ demo.queue(
343
+ api_open=False
344
+ ).launch(
345
+ server_name=args.host,
346
+ server_port=args.port,
347
+ share=args.share
348
+ )
omni_speech/serve/model_worker.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ import whisper
17
+ import numpy as np
18
+ from functools import partial
19
+
20
+ from transformers import PreTrainedTokenizer
21
+
22
+ from omni_speech.constants import WORKER_HEART_BEAT_INTERVAL
23
+ from omni_speech.utils import (build_logger, server_error_msg,
24
+ pretty_print_semaphore)
25
+ from omni_speech.model.builder import load_pretrained_model
26
+ from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
27
+ from omni_speech.datasets.preprocess import tokenizer_speech_token
28
+ from transformers import TextIteratorStreamer
29
+ from threading import Thread
30
+
31
+
32
+ GB = 1 << 30
33
+
34
+ worker_id = str(uuid.uuid4())[:6]
35
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
+ global_counter = 0
37
+
38
+ model_semaphore = None
39
+
40
+
41
+ def heart_beat_worker(controller):
42
+
43
+ while True:
44
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
45
+ controller.send_heart_beat()
46
+
47
+
48
+ def load_speech(audio, input_type, mel_size, speech_normalize):
49
+ speech = np.array(audio, dtype=np.float32)
50
+ if input_type == "raw":
51
+ speech = torch.from_numpy(speech)
52
+ if speech_normalize:
53
+ speech = torch.nn.functional.layer_norm(speech, speech.shape)
54
+ elif input_type == "mel":
55
+ speech = whisper.pad_or_trim(speech)
56
+ speech = whisper.log_mel_spectrogram(speech, n_mels=mel_size).permute(1, 0)
57
+ return speech
58
+
59
+
60
+ def build_unit_tokenizer(vocab_size):
61
+ import os
62
+ from transformers import BertTokenizer
63
+ with open("unit_vocab.txt", "w") as f:
64
+ for i in range(vocab_size + 1):
65
+ f.write(str(i) + "\n")
66
+ tokenizer = BertTokenizer(vocab_file="unit_vocab.txt")
67
+ os.remove("unit_vocab.txt")
68
+ return tokenizer
69
+
70
+
71
+ class ModelWorker:
72
+ def __init__(self, controller_addr, worker_addr,
73
+ worker_id, no_register,
74
+ model_path, model_base, model_name,
75
+ load_8bit, load_4bit, device, input_type, mel_size, s2s, is_lora, use_flash_attn=False):
76
+ self.controller_addr = controller_addr
77
+ self.worker_addr = worker_addr
78
+ self.worker_id = worker_id
79
+ self.device = device
80
+ self.model_name = model_name
81
+ self.input_type = input_type
82
+ self.mel_size = mel_size
83
+ self.tokenizer, self.model, self.context_len = load_pretrained_model(
84
+ model_path, model_base, is_lora=is_lora, s2s=s2s, load_8bit=load_8bit, load_4bit=load_4bit, device=self.device, use_flash_attn=use_flash_attn)
85
+ self.unit_tokenizer = build_unit_tokenizer(self.model.config.unit_vocab_size)
86
+
87
+ if not no_register:
88
+ self.register_to_controller()
89
+ self.heart_beat_thread = threading.Thread(
90
+ target=heart_beat_worker, args=(self,), daemon=True)
91
+ self.heart_beat_thread.start()
92
+
93
+ def register_to_controller(self):
94
+ logger.info("Register to controller")
95
+
96
+ url = self.controller_addr + "/register_worker"
97
+ data = {
98
+ "worker_name": self.worker_addr,
99
+ "check_heart_beat": True,
100
+ "worker_status": self.get_status()
101
+ }
102
+ r = requests.post(url, json=data)
103
+ assert r.status_code == 200
104
+
105
+ def send_heart_beat(self):
106
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
107
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
108
+ f"global_counter: {global_counter}")
109
+
110
+ url = self.controller_addr + "/receive_heart_beat"
111
+
112
+ while True:
113
+ try:
114
+ ret = requests.post(url, json={
115
+ "worker_name": self.worker_addr,
116
+ "queue_length": self.get_queue_length()}, timeout=5)
117
+ exist = ret.json()["exist"]
118
+ break
119
+ except requests.exceptions.RequestException as e:
120
+ logger.error(f"heart beat error: {e}")
121
+ time.sleep(5)
122
+
123
+ if not exist:
124
+ self.register_to_controller()
125
+
126
+ def get_queue_length(self):
127
+ if model_semaphore is None:
128
+ return 0
129
+ else:
130
+ return args.limit_model_concurrency - model_semaphore._value + (len(
131
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
132
+
133
+ def get_status(self):
134
+ return {
135
+ "model_names": [self.model_name],
136
+ "speed": 1,
137
+ "queue_length": self.get_queue_length(),
138
+ }
139
+
140
+ @torch.inference_mode()
141
+ def generate_stream(self, params):
142
+ tokenizer, model = self.tokenizer, self.model
143
+
144
+ prompt = params["prompt"]
145
+ ori_prompt = prompt
146
+ audio = params.get("audio", None)
147
+ if audio is not None and len(audio) > 0:
148
+ speech = load_speech(audio, self.input_type, self.mel_size, self.model.config.speech_normalize)
149
+ speech_length = torch.LongTensor([speech.shape[0]]).unsqueeze(0).to(self.device)
150
+ speech_tensor = speech.unsqueeze(0).to(self.device, dtype=torch.float16)
151
+ speech_args = {"speech": speech_tensor, "speech_lengths": speech_length}
152
+ else:
153
+ speech = None
154
+ speech_args = {}
155
+
156
+ temperature = float(params.get("temperature", 1.0))
157
+ top_p = float(params.get("top_p", 1.0))
158
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160
+ stop_str = params.get("stop", None)
161
+ do_sample = True if temperature > 0.001 else False
162
+
163
+ input_ids = tokenizer_speech_token(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device)
164
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
165
+ streamer_unit = TextIteratorStreamer(self.unit_tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
166
+
167
+ # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
168
+
169
+ if max_new_tokens < 1:
170
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
171
+ return
172
+
173
+ thread = Thread(target=model.generate, kwargs=dict(
174
+ inputs=input_ids,
175
+ do_sample=do_sample,
176
+ temperature=temperature,
177
+ top_p=top_p,
178
+ max_new_tokens=max_new_tokens,
179
+ streamer=streamer,
180
+ streamer_unit=streamer_unit,
181
+ streaming_unit_gen=True,
182
+ use_cache=True,
183
+ **speech_args
184
+ ))
185
+ thread.start()
186
+
187
+ generated_text = ori_prompt
188
+ for new_text in streamer:
189
+ generated_text += new_text
190
+ generated_unit = " ".join(map(str, streamer_unit.token_cache))
191
+ if generated_text.endswith(stop_str):
192
+ generated_text = generated_text[:-len(stop_str)]
193
+ yield json.dumps({"text": generated_text, "unit": generated_unit, "error_code": 0}).encode() + b"\0"
194
+
195
+ def generate_stream_gate(self, params):
196
+ try:
197
+ for x in self.generate_stream(params):
198
+ yield x
199
+ except ValueError as e:
200
+ print("Caught ValueError:", e)
201
+ ret = {
202
+ "text": server_error_msg,
203
+ "error_code": 1,
204
+ }
205
+ yield json.dumps(ret).encode() + b"\0"
206
+ except torch.cuda.CudaError as e:
207
+ print("Caught torch.cuda.CudaError:", e)
208
+ ret = {
209
+ "text": server_error_msg,
210
+ "error_code": 1,
211
+ }
212
+ yield json.dumps(ret).encode() + b"\0"
213
+ except Exception as e:
214
+ print("Caught Unknown Error", e)
215
+ ret = {
216
+ "text": server_error_msg,
217
+ "error_code": 1,
218
+ }
219
+ yield json.dumps(ret).encode() + b"\0"
220
+
221
+
222
+ app = FastAPI()
223
+
224
+
225
+ def release_model_semaphore(fn=None):
226
+ model_semaphore.release()
227
+ if fn is not None:
228
+ fn()
229
+
230
+
231
+ @app.post("/worker_generate_stream")
232
+ async def generate_stream(request: Request):
233
+ global model_semaphore, global_counter
234
+ global_counter += 1
235
+ params = await request.json()
236
+
237
+ if model_semaphore is None:
238
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
239
+ await model_semaphore.acquire()
240
+ worker.send_heart_beat()
241
+ generator = worker.generate_stream_gate(params)
242
+ background_tasks = BackgroundTasks()
243
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
244
+ return StreamingResponse(generator, background=background_tasks)
245
+
246
+
247
+ @app.post("/worker_get_status")
248
+ async def get_status(request: Request):
249
+ return worker.get_status()
250
+
251
+
252
+ if __name__ == "__main__":
253
+ parser = argparse.ArgumentParser()
254
+ parser.add_argument("--host", type=str, default="localhost")
255
+ parser.add_argument("--port", type=int, default=21002)
256
+ parser.add_argument("--worker-address", type=str,
257
+ default="http://localhost:21002")
258
+ parser.add_argument("--controller-address", type=str,
259
+ default="http://localhost:21001")
260
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
261
+ parser.add_argument("--model-base", type=str, default=None)
262
+ parser.add_argument("--model-name", type=str)
263
+ parser.add_argument("--device", type=str, default="cuda")
264
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
+ parser.add_argument("--stream-interval", type=int, default=1)
266
+ parser.add_argument("--no-register", action="store_true")
267
+ parser.add_argument("--load-8bit", action="store_true")
268
+ parser.add_argument("--load-4bit", action="store_true")
269
+ parser.add_argument("--use-flash-attn", action="store_true")
270
+ parser.add_argument("--input-type", type=str, default="mel")
271
+ parser.add_argument("--mel-size", type=int, default=128)
272
+ parser.add_argument("--s2s", action="store_true", default=False)
273
+ parser.add_argument("--is-lora", action="store_true", default=False)
274
+ args = parser.parse_args()
275
+ logger.info(f"args: {args}")
276
+
277
+ worker = ModelWorker(args.controller_address,
278
+ args.worker_address,
279
+ worker_id,
280
+ args.no_register,
281
+ args.model_path,
282
+ args.model_base,
283
+ args.model_name,
284
+ args.load_8bit,
285
+ args.load_4bit,
286
+ args.device,
287
+ args.input_type,
288
+ args.mel_size,
289
+ args.s2s,
290
+ args.is_lora,
291
+ use_flash_attn=args.use_flash_attn)
292
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
omni_speech/train/__pycache__/omni_trainer.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
omni_speech/train/__pycache__/omni_trainer.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
omni_speech/train/__pycache__/run_train.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
omni_speech/train/__pycache__/run_train.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
omni_speech/train/__pycache__/run_train.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
omni_speech/train/__pycache__/train.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
omni_speech/train/__pycache__/train_mem.cpython-312.pyc ADDED
Binary file (348 Bytes). View file
 
omni_speech/train/__pycache__/train_multiturn.cpython-312.pyc ADDED
Binary file (25.4 kB). View file
 
omni_speech/train/__pycache__/train_raw.cpython-312.pyc ADDED
Binary file (19.9 kB). View file
 
omni_speech/train/__pycache__/train_test.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
omni_speech/train/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (7.29 kB). View file
 
omni_speech/train/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
omni_speech/train/export.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import copy
19
+ from dataclasses import dataclass, field
20
+ import json
21
+ import logging
22
+ import pathlib
23
+ from typing import Dict, Optional, Sequence, List
24
+
25
+ import torch
26
+
27
+ import transformers
28
+ import tokenizers
29
+
30
+ from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
31
+ from torch.utils.data import Dataset
32
+ from omni_speech.train.omni_trainer import OmniTrainer
33
+ from audiomentations import AddBackgroundNoise, PolarityInversion
34
+
35
+ from omni_speech import conversation as conversation_lib
36
+ from omni_speech.model import *
37
+ from omni_speech.utils import *
38
+ from omni_speech.datasets.preprocess import *
39
+ import whisper
40
+ import time
41
+
42
+ @dataclass
43
+ class ModelArguments:
44
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
45
+ version: Optional[str] = field(default="llama_3")
46
+ freeze_backbone: bool = field(default=False)
47
+ tune_speech_projector: bool = field(default=False)
48
+ tune_speech_encoder: bool = field(default=False)
49
+ tune_speech_generator_only: bool = field(default=False)
50
+ speech_encoder_type: Optional[str] = field(default=None)
51
+ speech_encoder: Optional[str] = field(default=None)
52
+ pretrain_speech_projector: Optional[str] = field(default=None)
53
+ speech_projector_type: Optional[str] = field(default='linear')
54
+ speech_generator_type: Optional[str] = field(default='ctc')
55
+ ctc_decoder_config: str = "(2,4096,32,11008)"
56
+ ctc_upsample_factor: int = 25
57
+ ctc_loss_weight: float = 1.0
58
+ unit_vocab_size: int = 1000
59
+ speech_encoder_ds_rate: int = 5
60
+ speech_encoder_hidden_size: int = 1280
61
+
62
+
63
+ @dataclass
64
+ class DataArguments:
65
+ data_path: str = field(default=None,
66
+ metadata={"help": "Path to the training data."})
67
+ dev_path: str = field(default=None,
68
+ metadata={"help": "Path to the dev data."})
69
+ is_multimodal: bool = False
70
+ input_type: str = field(default="mel")
71
+ speech_normalize: bool = False
72
+ mel_size: int = 128
73
+ has_tgt_units: bool = False
74
+ augment_prob: float = field(
75
+ default=0.0,
76
+ metadata={"help": "The probability of applying augmentation transform."}
77
+ )
78
+ augment_path: str = field(default=None,
79
+ metadata={"help": "Path to the augment data."})
80
+
81
+
82
+ @dataclass
83
+ class TrainingArguments(transformers.TrainingArguments):
84
+ cache_dir: Optional[str] = field(default=None)
85
+ optim: str = field(default="adamw_torch")
86
+ freeze_speech_projector: bool = field(default=False)
87
+ model_max_length: int = field(
88
+ default=512,
89
+ metadata={
90
+ "help":
91
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
92
+ },
93
+ )
94
+ double_quant: bool = field(
95
+ default=True,
96
+ metadata={"help": "Compress the quantization statistics through double quantization."}
97
+ )
98
+ quant_type: str = field(
99
+ default="nf4",
100
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
101
+ )
102
+ bits: int = field(
103
+ default=16,
104
+ metadata={"help": "How many bits to use."}
105
+ )
106
+ lora_enable: bool = False
107
+ lora_r: int = 64
108
+ lora_alpha: int = 16
109
+ lora_dropout: float = 0.05
110
+ lora_weight_path: str = ""
111
+ lora_bias: str = "none"
112
+ speech_projector_lr: Optional[float] = None
113
+ group_by_modality_length: bool = field(default=False)
114
+
115
+
116
+ class LazySupervisedDataset(Dataset):
117
+ """Dataset for supervised fine-tuning."""
118
+
119
+ def __init__(self, data_path: str,
120
+ tokenizer: transformers.PreTrainedTokenizer,
121
+ data_args: DataArguments):
122
+ super(LazySupervisedDataset, self).__init__()
123
+ list_data_dict = json.load(open(data_path, "r"))
124
+
125
+ self.tokenizer = tokenizer
126
+ self.list_data_dict = list_data_dict
127
+ self.data_args = data_args
128
+ if self.data_args.augment_prob != 0.0:
129
+ with open(self.data_args.augment_path, "r") as f:
130
+ augment_path_list = f.read().splitlines()
131
+ self.transform = AddBackgroundNoise(
132
+ sounds_path=augment_path_list,
133
+ min_snr_db=5.0,
134
+ max_snr_db=30.0,
135
+ noise_transform=PolarityInversion(),
136
+ p=self.data_args.augment_prob
137
+ )
138
+
139
+ def __len__(self):
140
+ return len(self.list_data_dict)
141
+
142
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
143
+ # TODO: define number of retries somewhere else
144
+ num_base_retries = 3
145
+ num_final_retries = 300
146
+
147
+ # try the current sample first
148
+ for attempt_idx in range(num_base_retries):
149
+ try:
150
+ sample = self._get_item(i)
151
+ return sample
152
+ except Exception as e:
153
+ # sleep 1s in case it is a cloud disk issue
154
+ print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
155
+ time.sleep(1)
156
+
157
+ # try other samples, in case it is file corruption issue
158
+ for attempt_idx in range(num_base_retries):
159
+ try:
160
+ next_index = min(i + 1, len(self.list_data_dict) - 1)
161
+ # sample_idx = random.choice(range(len(self)))
162
+ sample = self._get_item(next_index)
163
+ return sample
164
+ except Exception as e:
165
+ # no need to sleep
166
+ print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
167
+ pass
168
+
169
+ try:
170
+ sample = self._get_item(i)
171
+ return sample
172
+ except Exception as e:
173
+ raise e
174
+
175
+ def process_speech(self, speech_file):
176
+ speech = whisper.load_audio(speech_file)
177
+ if self.data_args.augment_prob != 0.0:
178
+ speech = self.transform(speech, sample_rate=16000)
179
+ if self.data_args.input_type == "raw":
180
+ speech = torch.from_numpy(speech)
181
+ if self.model_config.data_args.speech_normalize:
182
+ speech = torch.nn.functional.layer_norm(speech, speech.shape)
183
+ elif self.data_args.input_type == "mel":
184
+ speech = whisper.pad_or_trim(speech)
185
+ speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
186
+ speech_lengths = torch.LongTensor([speech.shape[0]])
187
+ return speech, speech_lengths
188
+
189
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
190
+ sources = self.list_data_dict[i]
191
+ if isinstance(i, int):
192
+ sources = [sources]
193
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
194
+ for item in sources:
195
+ if 'tools' in item:
196
+ tools_dict = {
197
+ "from": "tools",
198
+ "value": item["tools"]
199
+ }
200
+ item["conversations"].insert(0,tools_dict)
201
+
202
+ if self.data_args.has_tgt_units:
203
+ tgt_units = [e["tgt_units"] for e in sources]
204
+ tgt_units = torch.tensor(tgt_units, dtype=torch.long)
205
+ else:
206
+ tgt_units = None
207
+
208
+ if 'speech' in sources[0]:
209
+ import numpy as np
210
+ speech_file = self.list_data_dict[i]['speech']
211
+ if type(speech_file) is list:
212
+ speech = [self.process_speech(f) for f in speech_file]
213
+ else:
214
+ speech = [self.process_speech(speech_file)]
215
+
216
+ sources = preprocess_multimodal(
217
+ copy.deepcopy([e["conversations"] for e in sources]),
218
+ self.data_args)
219
+ else:
220
+ sources = copy.deepcopy([e["conversations"] for e in sources])
221
+ data_dict = preprocess(
222
+ sources,
223
+ self.tokenizer,
224
+ has_speech=('speech' in self.list_data_dict[i]))
225
+ if isinstance(i, int):
226
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
227
+ labels=data_dict["labels"][0])
228
+
229
+ # speech exist in the data
230
+ if 'speech' in self.list_data_dict[i]:
231
+ data_dict['speech'] = speech
232
+
233
+ if tgt_units is not None:
234
+ data_dict['tgt_units'] = tgt_units[0]
235
+
236
+ data_dict["id"] = self.list_data_dict[i].get("id", i)
237
+
238
+ return data_dict
239
+
240
+
241
+ @dataclass
242
+ class DataCollatorForSupervisedDataset(object):
243
+ """Collate examples for supervised fine-tuning."""
244
+
245
+ tokenizer: transformers.PreTrainedTokenizer
246
+
247
+ def pad_sequence(self, input_ids, batch_first, padding_value):
248
+ if self.tokenizer.padding_side == "left":
249
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
250
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
251
+ if self.tokenizer.padding_side == "left":
252
+ input_ids = torch.flip(input_ids, [1])
253
+ return input_ids
254
+
255
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
256
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
257
+ # input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
258
+ input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
259
+ labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
260
+ if self.tokenizer.pad_token_id is None:
261
+ # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
262
+ self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
263
+ input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
264
+ labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
265
+ batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
266
+ # batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
267
+
268
+ if 'speech' in instances[0]:
269
+ speechs = [instance['speech'] for instance in instances]
270
+
271
+ speech = [sp[0] for sp_list in speechs for sp in sp_list]
272
+ speech_lengths = [sp[1] for sp_list in speechs for sp in sp_list]
273
+
274
+ batch["speech"] = speech
275
+ # print(len(speech)) # sum(len(audio) for audio in each batch)
276
+ # print(speech[0].shape) # seq_len, dim
277
+ batch['speech_lengths'] = speech_lengths
278
+ # print(speech_lengths[0].shape) # seq_len
279
+
280
+ if 'tgt_units' in instances[0]:
281
+ tgt_units = [instance['tgt_units'] for instance in instances]
282
+ tgt_units = self.pad_sequence(tgt_units, batch_first=True, padding_value=self.tokenizer.pad_token_id)
283
+ batch['tgt_units'] = tgt_units
284
+ # print(batch['tgt_units'])
285
+ # print("---------------")
286
+ # print(batch['input_ids'])
287
+
288
+ return batch
289
+
290
+
291
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
292
+ data_args) -> Dict:
293
+ """Make dataset and collator for supervised fine-tuning."""
294
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
295
+ data_path=data_args.data_path,
296
+ data_args=data_args)
297
+ if data_args.dev_path is not None:
298
+ dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
299
+ data_path=data_args.dev_path,
300
+ data_args=data_args)
301
+ else:
302
+ dev_dataset = None
303
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
304
+ return dict(train_dataset=train_dataset,
305
+ eval_dataset=dev_dataset,
306
+ data_collator=data_collator)
307
+
308
+
309
+ def train(attn_implementation="flash_attention_2"):
310
+
311
+ parser = transformers.HfArgumentParser(
312
+ (ModelArguments, DataArguments, TrainingArguments))
313
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
314
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
315
+
316
+ bnb_model_from_pretrained_args = {}
317
+ if training_args.bits in [4, 8]:
318
+ from transformers import BitsAndBytesConfig
319
+ bnb_model_from_pretrained_args.update(dict(
320
+ device_map={"": training_args.device},
321
+ load_in_4bit=training_args.bits == 4,
322
+ load_in_8bit=training_args.bits == 8,
323
+ quantization_config=BitsAndBytesConfig(
324
+ load_in_4bit=training_args.bits == 4,
325
+ load_in_8bit=training_args.bits == 8,
326
+ llm_int8_skip_modules=["speech_projector"],
327
+ llm_int8_threshold=6.0,
328
+ llm_int8_has_fp16_weight=False,
329
+ bnb_4bit_compute_dtype=compute_dtype,
330
+ bnb_4bit_use_double_quant=training_args.double_quant,
331
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
332
+ )
333
+ ))
334
+
335
+ if data_args.has_tgt_units:
336
+ if model_args.version == "llama_3":
337
+ model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
338
+ model_args.model_name_or_path,
339
+ cache_dir=training_args.cache_dir,
340
+ attn_implementation=attn_implementation,
341
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
342
+ **bnb_model_from_pretrained_args
343
+ )
344
+ elif model_args.version == "qwen":
345
+ model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
346
+ model_args.model_name_or_path,
347
+ cache_dir=training_args.cache_dir,
348
+ attn_implementation=attn_implementation,
349
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
350
+ **bnb_model_from_pretrained_args
351
+ )
352
+ else:
353
+ raise ValueError("--currently only support llama or qwen model!")
354
+ else:
355
+ if model_args.version == "llama_3":
356
+ model = OmniSpeechLlamaForCausalLM.from_pretrained(
357
+ model_args.model_name_or_path,
358
+ cache_dir=training_args.cache_dir,
359
+ attn_implementation=attn_implementation,
360
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
361
+ **bnb_model_from_pretrained_args
362
+ )
363
+ elif model_args.version == "qwen":
364
+ model = OmniSpeechQwen2ForCausalLM.from_pretrained(
365
+ model_args.model_name_or_path,
366
+ cache_dir=training_args.cache_dir,
367
+ attn_implementation=attn_implementation,
368
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
369
+ **bnb_model_from_pretrained_args
370
+ )
371
+ else:
372
+ raise ValueError("--currently only support llama or qwen model!")
373
+ model.config.use_cache = False
374
+
375
+ if model_args.freeze_backbone:
376
+ model.model.requires_grad_(False)
377
+
378
+ if training_args.bits in [4, 8]:
379
+ from peft import prepare_model_for_kbit_training
380
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
381
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
382
+
383
+ if training_args.gradient_checkpointing:
384
+ if hasattr(model, "enable_input_require_grads"):
385
+ model.enable_input_require_grads()
386
+ else:
387
+ def make_inputs_require_grad(module, input, output):
388
+ output.requires_grad_(True)
389
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
390
+
391
+ if training_args.lora_enable:
392
+ from peft import LoraConfig, get_peft_model
393
+ lora_config = LoraConfig(
394
+ r=training_args.lora_r,
395
+ lora_alpha=training_args.lora_alpha,
396
+ target_modules=find_all_linear_names(model),
397
+ lora_dropout=training_args.lora_dropout,
398
+ bias=training_args.lora_bias,
399
+ task_type="CAUSAL_LM",
400
+ )
401
+ if training_args.bits == 16:
402
+ if training_args.bf16:
403
+ model.to(torch.bfloat16)
404
+ if training_args.fp16:
405
+ model.to(torch.float16)
406
+ model = get_peft_model(model, lora_config)
407
+
408
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
409
+ model_args.model_name_or_path,
410
+ cache_dir=training_args.cache_dir,
411
+ model_max_length=training_args.model_max_length,
412
+ padding_side="right",
413
+ use_fast=False,
414
+ )
415
+
416
+ model.resize_token_embeddings(len(tokenizer))
417
+ model.config.max_length = training_args.model_max_length
418
+
419
+ if model_args.version in conversation_lib.conv_templates:
420
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
421
+ else:
422
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
423
+
424
+ if model_args.speech_encoder is not None:
425
+ model.get_model().initialize_speech_modules(
426
+ model_args=model_args,
427
+ fsdp=training_args.fsdp
428
+ )
429
+
430
+ data_args.is_multimodal = True
431
+
432
+ model.config.tokenizer_padding_side = tokenizer.padding_side
433
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
434
+
435
+ model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
436
+
437
+ model.config.speech_normalize = data_args.speech_normalize
438
+
439
+ for p in model.get_speech_encoder().parameters():
440
+ p.requires_grad = False
441
+
442
+ if model_args.tune_speech_projector:
443
+ model.requires_grad_(False)
444
+ for p in model.get_speech_projector().parameters():
445
+ p.requires_grad = True
446
+
447
+ model.config.freeze_speech_projector = training_args.freeze_speech_projector
448
+ if training_args.freeze_speech_projector:
449
+ for p in model.get_speech_projector().parameters():
450
+ p.requires_grad = False
451
+
452
+ if training_args.bits in [4, 8]:
453
+ model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
454
+
455
+ model.config.speech_projector_lr = training_args.speech_projector_lr
456
+
457
+ if data_args.has_tgt_units:
458
+ model.initialize_speech_generator(model_args=model_args)
459
+
460
+ if training_args.bits in [4, 8]:
461
+ from peft.tuners.lora import LoraLayer
462
+ for name, module in model.named_modules():
463
+ if isinstance(module, LoraLayer):
464
+ if training_args.bf16:
465
+ module = module.to(torch.bfloat16)
466
+ if 'norm' in name:
467
+ module = module.to(torch.float32)
468
+ if 'lm_head' in name or 'embed_tokens' in name:
469
+ if hasattr(module, 'weight'):
470
+ if training_args.bf16 and module.weight.dtype == torch.float32:
471
+ module = module.to(torch.bfloat16)
472
+
473
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
474
+ data_args=data_args)
475
+
476
+ print("Training Layers:")
477
+ for name, param in model.named_parameters():
478
+ if param.requires_grad:
479
+ print(name, param.grad)
480
+
481
+ trainer = OmniTrainer(model=model,
482
+ tokenizer=tokenizer,
483
+ args=training_args,
484
+ **data_module)
485
+
486
+ # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
487
+ # trainer.train(resume_from_checkpoint=True)
488
+ # else:
489
+ # trainer.train()
490
+ # trainer.save_state()
491
+
492
+ model.config.use_cache = True
493
+
494
+ if training_args.lora_enable:
495
+ state_dict = get_peft_state_maybe_zero_3(
496
+ model.named_parameters(), training_args.lora_bias
497
+ )
498
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
499
+ model.named_parameters()
500
+ )
501
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
502
+ model.config.save_pretrained(training_args.output_dir)
503
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
504
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
505
+ else:
506
+ safe_save_model_for_hf_trainer(trainer=trainer,
507
+ output_dir=training_args.output_dir)
508
+
509
+
510
+ if __name__ == "__main__":
511
+ train()
512
+
omni_speech/train/omni_trainer.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from torch.utils.data import Sampler
6
+
7
+ from transformers import Trainer
8
+ from transformers.trainer import (
9
+ is_sagemaker_mp_enabled,
10
+ get_parameter_names,
11
+ has_length,
12
+ ALL_LAYERNORM_LAYERS,
13
+ logger,
14
+ )
15
+ from typing import List, Optional
16
+ from omni_speech.utils import *
17
+
18
+
19
+ def split_to_even_chunks(indices, lengths, num_chunks):
20
+ """
21
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
22
+ """
23
+
24
+ if len(indices) % num_chunks != 0:
25
+ return [indices[i::num_chunks] for i in range(num_chunks)]
26
+
27
+ num_indices_per_chunk = len(indices) // num_chunks
28
+
29
+ chunks = [[] for _ in range(num_chunks)]
30
+ chunks_lengths = [0 for _ in range(num_chunks)]
31
+ for index in indices:
32
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
33
+ chunks[shortest_chunk].append(index)
34
+ chunks_lengths[shortest_chunk] += lengths[index]
35
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
36
+ chunks_lengths[shortest_chunk] = float("inf")
37
+
38
+ return chunks
39
+
40
+
41
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
42
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
43
+ assert all(l != 0 for l in lengths), "Should not have zero length."
44
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
45
+ # all samples are in the same modality
46
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
47
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
48
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
49
+
50
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
51
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
52
+ megabatch_size = world_size * batch_size
53
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
54
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
55
+
56
+ last_mm = mm_megabatches[-1]
57
+ last_lang = lang_megabatches[-1]
58
+ additional_batch = last_mm + last_lang
59
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
60
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
61
+ megabatches = [megabatches[i] for i in megabatch_indices]
62
+
63
+ if len(additional_batch) > 0:
64
+ megabatches.append(sorted(additional_batch))
65
+
66
+ return [i for megabatch in megabatches for i in megabatch]
67
+
68
+
69
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
70
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
71
+ indices = torch.randperm(len(lengths), generator=generator)
72
+ megabatch_size = world_size * batch_size
73
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
74
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
75
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
76
+
77
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
78
+
79
+
80
+ class LengthGroupedSampler(Sampler):
81
+ r"""
82
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
83
+ keeping a bit of randomness.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ batch_size: int,
89
+ world_size: int,
90
+ lengths: Optional[List[int]] = None,
91
+ generator=None,
92
+ group_by_modality: bool = False,
93
+ ):
94
+ if lengths is None:
95
+ raise ValueError("Lengths must be provided.")
96
+
97
+ self.batch_size = batch_size
98
+ self.world_size = world_size
99
+ self.lengths = lengths
100
+ self.generator = generator
101
+ self.group_by_modality = group_by_modality
102
+
103
+ def __len__(self):
104
+ return len(self.lengths)
105
+
106
+ def __iter__(self):
107
+ if self.group_by_modality:
108
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
109
+ else:
110
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
111
+ return iter(indices)
112
+
113
+
114
+ class OmniTrainer(Trainer):
115
+
116
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
117
+ if self.train_dataset is None or not has_length(self.train_dataset):
118
+ return None
119
+
120
+ if self.args.group_by_modality_length:
121
+ lengths = self.train_dataset.modality_lengths
122
+ return LengthGroupedSampler(
123
+ self.args.train_batch_size,
124
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
125
+ lengths=lengths,
126
+ group_by_modality=True,
127
+ )
128
+ else:
129
+ return super()._get_train_sampler()
130
+
131
+ # def create_optimizer(self):
132
+ # from transformers.utils import (
133
+ # is_sagemaker_mp_enabled,
134
+ # )
135
+ # import torch.nn as nn
136
+ # if is_sagemaker_mp_enabled():
137
+ # import smdistributed.modelparallel.torch as smp
138
+
139
+ # """
140
+ # Setup the optimizer.
141
+
142
+ # We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
143
+ # Trainer's init through `optimizers`, or subclass and override this method in a subclass.
144
+ # """
145
+ # opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
146
+
147
+ # if self.optimizer is None:
148
+ # decay_parameters = self.get_decay_parameter_names(opt_model)
149
+
150
+ # optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
151
+
152
+ # optimizer_grouped_parameters = [
153
+ # # speech projector
154
+ # {
155
+ # "params": [
156
+ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "speech_projector" in n)
157
+ # ],
158
+ # "weight_decay": self.args.weight_decay,
159
+ # "learning_rate": optimizer_kwargs["lr"] * 20,
160
+ # },
161
+ # {
162
+ # "params": [
163
+ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and "speech_projector" in n)
164
+ # ],
165
+ # "weight_decay": 0.0,
166
+ # "learning_rate": optimizer_kwargs["lr"] * 20,
167
+ # },
168
+
169
+ # # non speech project
170
+ # {
171
+ # "params": [
172
+ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "speech_projector" not in n)
173
+ # ],
174
+ # "weight_decay": self.args.weight_decay,
175
+ # },
176
+ # {
177
+ # "params": [
178
+ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and "speech_projector" not in n)
179
+ # ],
180
+ # "weight_decay": 0.0,
181
+ # },
182
+ # ]
183
+
184
+ # # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
185
+ # # e.g. for GaLore optimizer.
186
+ # if "params" in optimizer_kwargs:
187
+ # optimizer_grouped_parameters = optimizer_kwargs.pop("params")
188
+
189
+ # # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
190
+ # # e.g. for LOMO optimizer.
191
+ # if "model" in optimizer_kwargs:
192
+ # optimizer_grouped_parameters = optimizer_kwargs.pop("model")
193
+
194
+ # # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
195
+ # # to avoid arguments conflicts.
196
+ # if "optimizer_dict" in optimizer_kwargs:
197
+ # optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
198
+
199
+ # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
200
+
201
+ # if optimizer_cls.__name__ == "Adam8bit":
202
+ # import bitsandbytes
203
+
204
+ # manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
205
+
206
+ # skipped = 0
207
+ # for module in opt_model.modules():
208
+ # if isinstance(module, nn.Embedding):
209
+ # skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
210
+ # logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
211
+ # manager.register_module_override(module, "weight", {"optim_bits": 32})
212
+ # logger.debug(f"bitsandbytes: will optimize {module} in fp32")
213
+ # logger.info(f"skipped: {skipped / 2 ** 20}M params")
214
+
215
+ # if is_sagemaker_mp_enabled():
216
+ # self.optimizer = smp.DistributedOptimizer(self.optimizer)
217
+
218
+ # return self.optimizer
219
+
220
+ def create_optimizer(self):
221
+ """
222
+ Setup the optimizer.
223
+
224
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
225
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
226
+ """
227
+ if is_sagemaker_mp_enabled():
228
+ return super().create_optimizer()
229
+
230
+ opt_model = self.model
231
+
232
+ if self.optimizer is None:
233
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
234
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
235
+ if self.args.speech_projector_lr is not None:
236
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "speech_projector" in name]
237
+ optimizer_grouped_parameters = [
238
+ {
239
+ "params": [
240
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
241
+ ],
242
+ "weight_decay": self.args.weight_decay,
243
+ },
244
+ {
245
+ "params": [
246
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
247
+ ],
248
+ "weight_decay": 0.0,
249
+ },
250
+ {
251
+ "params": [
252
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
253
+ ],
254
+ "weight_decay": self.args.weight_decay,
255
+ "lr": self.args.speech_projector_lr,
256
+ },
257
+ {
258
+ "params": [
259
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
260
+ ],
261
+ "weight_decay": 0.0,
262
+ "lr": self.args.speech_projector_lr,
263
+ },
264
+ ]
265
+ else:
266
+ optimizer_grouped_parameters = [
267
+ {
268
+ "params": [
269
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
270
+ ],
271
+ "weight_decay": self.args.weight_decay,
272
+ },
273
+ {
274
+ "params": [
275
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
276
+ ],
277
+ "weight_decay": 0.0,
278
+ },
279
+ ]
280
+
281
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
282
+
283
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
284
+ if optimizer_cls.__name__ == "Adam8bit":
285
+ import bitsandbytes
286
+
287
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
288
+
289
+ skipped = 0
290
+ for module in opt_model.modules():
291
+ if isinstance(module, nn.Embedding):
292
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
293
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
294
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
295
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
296
+ logger.info(f"skipped: {skipped/2**20}M params")
297
+
298
+ return self.optimizer
299
+
300
+ def _save_checkpoint(self, model, trial, metrics=None):
301
+ if getattr(self.args, 'tune_speech_projector', False):
302
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
303
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
304
+
305
+ run_dir = self._get_output_dir(trial=trial)
306
+ output_dir = os.path.join(run_dir, checkpoint_folder)
307
+
308
+ # Only save Adapter
309
+ keys_to_match = ['speech_projector']
310
+
311
+ weight_to_save = get_speech_projector_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
312
+
313
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
314
+ self.model.config.save_pretrained(output_dir)
315
+ torch.save(weight_to_save, os.path.join(output_dir, f'speech_projector.bin'))
316
+ else:
317
+ super(OmniTrainer, self)._save_checkpoint(model, trial, metrics)
318
+
319
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
320
+ if getattr(self.args, 'tune_speech_projector', False):
321
+ pass
322
+ else:
323
+ super(OmniTrainer, self)._save(output_dir, state_dict)
324
+
325
+ # def training_step(self, model, inputs):
326
+ # # Move inputs to device
327
+ # inputs = self._prepare_inputs(inputs)
328
+
329
+ # # Forward pass
330
+ # outputs = model(**inputs)
331
+ # loss = outputs.loss
332
+
333
+ # # Backward pass
334
+ # loss.backward()
335
+
336
+ # # Check gradients
337
+ # for name, param in model.module.named_parameters():
338
+ # if param.requires_grad:
339
+ # if param.grad is None:
340
+ # print(f"Gradients for {name} are None.")
341
+ # else:
342
+ # print(f"Gradients for {name}: {param.grad.norm()}") # Check norm of the gradients
343
+
344
+ # # Return loss for optimization
345
+ # return loss.detach()
omni_speech/train/train.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import copy
19
+ from dataclasses import dataclass, field
20
+ import json
21
+ import logging
22
+ import pathlib
23
+ from typing import Dict, Optional, Sequence, List
24
+
25
+ import torch
26
+
27
+ import transformers
28
+ import tokenizers
29
+
30
+ from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
31
+ from torch.utils.data import Dataset
32
+ from omni_speech.train.omni_trainer import OmniTrainer
33
+
34
+ from omni_speech import conversation as conversation_lib
35
+ from omni_speech.model import *
36
+ from omni_speech.utils import *
37
+ from omni_speech.datasets.preprocess import *
38
+ import whisper
39
+
40
+ @dataclass
41
+ class ModelArguments:
42
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
43
+ version: Optional[str] = field(default="llama_3")
44
+ freeze_backbone: bool = field(default=False)
45
+ tune_speech_projector: bool = field(default=False)
46
+ tune_speech_encoder: bool = field(default=False)
47
+ tune_speech_generator_only: bool = field(default=False)
48
+ speech_encoder_type: Optional[str] = field(default=None)
49
+ speech_encoder: Optional[str] = field(default=None)
50
+ pretrain_speech_projector: Optional[str] = field(default=None)
51
+ speech_projector_type: Optional[str] = field(default='linear')
52
+ speech_generator_type: Optional[str] = field(default='ctc')
53
+ ctc_decoder_config: str = "(2,4096,32,11008)"
54
+ ctc_upsample_factor: int = 1
55
+ ctc_loss_weight: float = 1.0
56
+ unit_vocab_size: int = 1000
57
+ speech_encoder_ds_rate: int = 5
58
+ speech_encoder_hidden_size: int = 1280
59
+
60
+
61
+ @dataclass
62
+ class DataArguments:
63
+ data_path: str = field(default=None,
64
+ metadata={"help": "Path to the training data."})
65
+ dev_path: str = field(default=None,
66
+ metadata={"help": "Path to the dev data."})
67
+ is_multimodal: bool = False
68
+ input_type: str = field(default="mel")
69
+ speech_normalize: bool = False
70
+ mel_size: int = 128
71
+ has_tgt_units: bool = False
72
+
73
+
74
+ @dataclass
75
+ class TrainingArguments(transformers.TrainingArguments):
76
+ cache_dir: Optional[str] = field(default=None)
77
+ optim: str = field(default="adamw_torch")
78
+ freeze_speech_projector: bool = field(default=False)
79
+ model_max_length: int = field(
80
+ default=512,
81
+ metadata={
82
+ "help":
83
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
84
+ },
85
+ )
86
+ double_quant: bool = field(
87
+ default=True,
88
+ metadata={"help": "Compress the quantization statistics through double quantization."}
89
+ )
90
+ quant_type: str = field(
91
+ default="nf4",
92
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
93
+ )
94
+ bits: int = field(
95
+ default=16,
96
+ metadata={"help": "How many bits to use."}
97
+ )
98
+ lora_enable: bool = False
99
+ lora_r: int = 64
100
+ lora_alpha: int = 16
101
+ lora_dropout: float = 0.05
102
+ lora_weight_path: str = ""
103
+ lora_bias: str = "none"
104
+ speech_projector_lr: Optional[float] = None
105
+ group_by_modality_length: bool = field(default=False)
106
+
107
+
108
+ class LazySupervisedDataset(Dataset):
109
+ """Dataset for supervised fine-tuning."""
110
+
111
+ def __init__(self, data_path: str,
112
+ tokenizer: transformers.PreTrainedTokenizer,
113
+ data_args: DataArguments):
114
+ super(LazySupervisedDataset, self).__init__()
115
+ list_data_dict = json.load(open(data_path, "r"))
116
+
117
+ self.tokenizer = tokenizer
118
+ self.list_data_dict = list_data_dict
119
+ self.data_args = data_args
120
+
121
+ def __len__(self):
122
+ return len(self.list_data_dict)
123
+
124
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
125
+ sources = self.list_data_dict[i]
126
+ if isinstance(i, int):
127
+ sources = [sources]
128
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
129
+ if 'speech' in sources[0]:
130
+ import numpy as np
131
+ speech_file = self.list_data_dict[i]['speech']
132
+ speech = whisper.load_audio(speech_file)
133
+ # speech = np.random.uniform(low=-1.0, high=1.0, size=speech.shape[0]).astype(speech.dtype)
134
+
135
+ if self.data_args.input_type == "raw":
136
+ speech = torch.from_numpy(speech)
137
+ if self.model_config.data_args.speech_normalize:
138
+ speech = torch.nn.functional.layer_norm(speech, speech.shape)
139
+ elif self.data_args.input_type == "mel":
140
+ speech = whisper.pad_or_trim(speech)
141
+ speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
142
+ speech_lengths = torch.LongTensor([speech.shape[0]])
143
+
144
+ sources = preprocess_multimodal(
145
+ copy.deepcopy([e["conversations"] for e in sources]),
146
+ self.data_args)
147
+ else:
148
+ sources = copy.deepcopy([e["conversations"] for e in sources])
149
+ data_dict = preprocess(
150
+ sources,
151
+ self.tokenizer,
152
+ has_speech=('speech' in self.list_data_dict[i]))
153
+ if isinstance(i, int):
154
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
155
+ labels=data_dict["labels"][0])
156
+
157
+ # speech exist in the data
158
+ if 'speech' in self.list_data_dict[i]:
159
+ data_dict['speech'] = speech
160
+ data_dict['speech_lengths'] = speech_lengths
161
+ return data_dict
162
+
163
+
164
+ @dataclass
165
+ class DataCollatorForSupervisedDataset(object):
166
+ """Collate examples for supervised fine-tuning."""
167
+
168
+ tokenizer: transformers.PreTrainedTokenizer
169
+
170
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
171
+ input_ids, labels = tuple([instance[key] for instance in instances]
172
+ for key in ("input_ids", "labels"))
173
+ input_ids = torch.nn.utils.rnn.pad_sequence(
174
+ input_ids,
175
+ batch_first=True,
176
+ padding_value=self.tokenizer.pad_token_id)
177
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
178
+ batch_first=True,
179
+ padding_value=IGNORE_INDEX)
180
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
181
+ labels = labels[:, :self.tokenizer.model_max_length]
182
+ batch = dict(
183
+ input_ids=input_ids,
184
+ labels=labels,
185
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
186
+ )
187
+
188
+ if 'speech' in instances[0]:
189
+ speech = [instance['speech'] for instance in instances]
190
+ speech_lengths = [instance['speech_lengths'] for instance in instances]
191
+ if all(x is not None and x.shape == speech[0].shape for x in speech):
192
+ batch['speech'] = torch.stack(speech)
193
+ batch['speech_lengths'] = torch.stack(speech_lengths)
194
+ else:
195
+ batch['speech'] = speech
196
+ batch['speech_lengths'] = speech_lengths
197
+
198
+ return batch
199
+
200
+
201
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
202
+ data_args) -> Dict:
203
+ """Make dataset and collator for supervised fine-tuning."""
204
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
205
+ data_path=data_args.data_path,
206
+ data_args=data_args)
207
+ if data_args.dev_path is not None:
208
+ dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
209
+ data_path=data_args.dev_path,
210
+ data_args=data_args)
211
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
212
+ return dict(train_dataset=train_dataset,
213
+ eval_dataset=dev_dataset,
214
+ data_collator=data_collator)
215
+
216
+
217
+ def train(attn_implementation="flash_attention_2"):
218
+
219
+ parser = transformers.HfArgumentParser(
220
+ (ModelArguments, DataArguments, TrainingArguments))
221
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
222
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
223
+
224
+ bnb_model_from_pretrained_args = {}
225
+ if training_args.bits in [4, 8]:
226
+ from transformers import BitsAndBytesConfig
227
+ bnb_model_from_pretrained_args.update(dict(
228
+ device_map={"": training_args.device},
229
+ load_in_4bit=training_args.bits == 4,
230
+ load_in_8bit=training_args.bits == 8,
231
+ quantization_config=BitsAndBytesConfig(
232
+ load_in_4bit=training_args.bits == 4,
233
+ load_in_8bit=training_args.bits == 8,
234
+ llm_int8_skip_modules=["speech_projector"],
235
+ llm_int8_threshold=6.0,
236
+ llm_int8_has_fp16_weight=False,
237
+ bnb_4bit_compute_dtype=compute_dtype,
238
+ bnb_4bit_use_double_quant=training_args.double_quant,
239
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
240
+ )
241
+ ))
242
+
243
+ if data_args.has_tgt_units:
244
+ if model_args.version == "llama_3":
245
+ model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
246
+ model_args.model_name_or_path,
247
+ cache_dir=training_args.cache_dir,
248
+ attn_implementation=attn_implementation,
249
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
250
+ **bnb_model_from_pretrained_args
251
+ )
252
+ elif model_args.version == "qwen":
253
+ model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
254
+ model_args.model_name_or_path,
255
+ cache_dir=training_args.cache_dir,
256
+ attn_implementation=attn_implementation,
257
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
258
+ **bnb_model_from_pretrained_args
259
+ )
260
+ else:
261
+ raise ValueError("--currently only support llama or qwen model!")
262
+ else:
263
+ if model_args.version == "llama_3":
264
+ model = OmniSpeechLlamaForCausalLM.from_pretrained(
265
+ model_args.model_name_or_path,
266
+ cache_dir=training_args.cache_dir,
267
+ attn_implementation=attn_implementation,
268
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
269
+ **bnb_model_from_pretrained_args
270
+ )
271
+ elif model_args.version == "qwen":
272
+ model = OmniSpeechQwen2ForCausalLM.from_pretrained(
273
+ model_args.model_name_or_path,
274
+ cache_dir=training_args.cache_dir,
275
+ attn_implementation=attn_implementation,
276
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
277
+ **bnb_model_from_pretrained_args
278
+ )
279
+ else:
280
+ raise ValueError("--currently only support llama or qwen model!")
281
+ model.config.use_cache = False
282
+
283
+ if model_args.freeze_backbone:
284
+ model.model.requires_grad_(False)
285
+
286
+ if training_args.bits in [4, 8]:
287
+ from peft import prepare_model_for_kbit_training
288
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
289
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
290
+
291
+ if training_args.gradient_checkpointing:
292
+ if hasattr(model, "enable_input_require_grads"):
293
+ model.enable_input_require_grads()
294
+ else:
295
+ def make_inputs_require_grad(module, input, output):
296
+ output.requires_grad_(True)
297
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
298
+
299
+ if training_args.lora_enable:
300
+ from peft import LoraConfig, get_peft_model
301
+ lora_config = LoraConfig(
302
+ r=training_args.lora_r,
303
+ lora_alpha=training_args.lora_alpha,
304
+ target_modules=find_all_linear_names(model),
305
+ lora_dropout=training_args.lora_dropout,
306
+ bias=training_args.lora_bias,
307
+ task_type="CAUSAL_LM",
308
+ )
309
+ if training_args.bits == 16:
310
+ if training_args.bf16:
311
+ model.to(torch.bfloat16)
312
+ if training_args.fp16:
313
+ model.to(torch.float16)
314
+ model = get_peft_model(model, lora_config)
315
+
316
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
317
+ model_args.model_name_or_path,
318
+ cache_dir=training_args.cache_dir,
319
+ model_max_length=training_args.model_max_length,
320
+ padding_side="right",
321
+ use_fast=False,
322
+ )
323
+
324
+ model.resize_token_embeddings(len(tokenizer))
325
+ model.config.max_length = training_args.model_max_length
326
+
327
+ if model_args.version in conversation_lib.conv_templates:
328
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
329
+ else:
330
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
331
+
332
+ if model_args.speech_encoder is not None:
333
+ model.get_model().initialize_speech_modules(
334
+ model_args=model_args,
335
+ fsdp=training_args.fsdp
336
+ )
337
+
338
+ data_args.is_multimodal = True
339
+
340
+ model.config.tokenizer_padding_side = tokenizer.padding_side
341
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
342
+
343
+ model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
344
+
345
+ model.config.speech_normalize = data_args.speech_normalize
346
+
347
+ for p in model.get_speech_encoder().parameters():
348
+ p.requires_grad = False
349
+
350
+ if model_args.tune_speech_projector:
351
+ model.requires_grad_(False)
352
+ for p in model.get_speech_projector().parameters():
353
+ p.requires_grad = True
354
+
355
+ model.config.freeze_speech_projector = training_args.freeze_speech_projector
356
+ if training_args.freeze_speech_projector:
357
+ for p in model.get_speech_projector().parameters():
358
+ p.requires_grad = False
359
+
360
+ if training_args.bits in [4, 8]:
361
+ model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
362
+
363
+ model.config.speech_projector_lr = training_args.speech_projector_lr
364
+
365
+ if data_args.has_tgt_units:
366
+ model.initialize_speech_generator(model_args=model_args)
367
+
368
+ if training_args.bits in [4, 8]:
369
+ from peft.tuners.lora import LoraLayer
370
+ for name, module in model.named_modules():
371
+ if isinstance(module, LoraLayer):
372
+ if training_args.bf16:
373
+ module = module.to(torch.bfloat16)
374
+ if 'norm' in name:
375
+ module = module.to(torch.float32)
376
+ if 'lm_head' in name or 'embed_tokens' in name:
377
+ if hasattr(module, 'weight'):
378
+ if training_args.bf16 and module.weight.dtype == torch.float32:
379
+ module = module.to(torch.bfloat16)
380
+
381
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
382
+ data_args=data_args)
383
+
384
+ print("Training Layers:")
385
+ for name, param in model.named_parameters():
386
+ if param.requires_grad:
387
+ print(name, param.grad)
388
+
389
+ trainer = OmniTrainer(model=model,
390
+ tokenizer=tokenizer,
391
+ args=training_args,
392
+ **data_module)
393
+
394
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
395
+ trainer.train(resume_from_checkpoint=True)
396
+ else:
397
+ trainer.train()
398
+ trainer.save_state()
399
+
400
+ model.config.use_cache = True
401
+
402
+ if training_args.lora_enable:
403
+ state_dict = get_peft_state_maybe_zero_3(
404
+ model.named_parameters(), training_args.lora_bias
405
+ )
406
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
407
+ model.named_parameters()
408
+ )
409
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
410
+ model.config.save_pretrained(training_args.output_dir)
411
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
412
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
413
+ else:
414
+ safe_save_model_for_hf_trainer(trainer=trainer,
415
+ output_dir=training_args.output_dir)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ train()
420
+
omni_speech/train/train_mem.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from omni_speech.train.train_multiturn import train
2
+
3
+ if __name__ == "__main__":
4
+ train(attn_implementation="flash_attention_2")
omni_speech/train/train_minicpmo.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Dict, List, Optional, Union, Literal, Tuple
8
+ from types import MethodType
9
+ from torchvision import transforms
10
+ from copy import deepcopy
11
+
12
+ import torch
13
+ import transformers
14
+ from accelerate.utils import DistributedType
15
+ from deepspeed import zero
16
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
17
+ import pathlib
18
+
19
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor
20
+ from transformers.integrations import deepspeed
21
+
22
+ from omni_speech.datasets.dataset import SupervisedDataset, data_collator
23
+ from omni_speech.model import *
24
+ from trainer import CPMTrainer
25
+ from transformers import Trainer
26
+ import librosa
27
+ from datasets import load_dataset
28
+ import numpy as np
29
+ from PIL import Image
30
+ from functools import partial
31
+ from audiomentations import AddBackgroundNoise, PolarityInversion
32
+
33
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
34
+
35
+ @dataclass
36
+ class ModelArguments:
37
+ model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-o-2_6")
38
+ tokenizer_path: Optional[str] = field(default=None)
39
+ audio_encoder_path: Optional[str] = field(default=None)
40
+ pretrained_llm_path: Optional[str] = field(default=None)
41
+
42
+
43
+ @dataclass
44
+ class DataArguments:
45
+ data_path: str = field(
46
+ default=None, metadata={"help": "Path to the training data."}
47
+ )
48
+ eval_data_path: str = field(
49
+ default=None, metadata={"help": "Path to the evaluation data."}
50
+ )
51
+ max_train_samples: Optional[int] = field(
52
+ default=None,
53
+ metadata={
54
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
55
+ "value if set."
56
+ },
57
+ )
58
+ max_eval_samples: Optional[int] = field(
59
+ default=None,
60
+ metadata={
61
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
62
+ "value if set."
63
+ },
64
+ )
65
+ augment_prob: float = field(
66
+ default=0.0,
67
+ metadata={"help": "The probability of applying augmentation transform."}
68
+ )
69
+ augment_path: str = field(default=None,
70
+ metadata={"help": "Path to the augment data."})
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments(transformers.TrainingArguments):
75
+ cache_dir: Optional[str] = field(default=None)
76
+ optim: str = field(default="adamw_torch")
77
+ model_max_length: int = field(
78
+ default=2048,
79
+ metadata={
80
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
81
+ },
82
+ )
83
+ tune_vision: Optional[bool] = field(default=True)
84
+ tune_speech: Optional[bool] = field(default=True)
85
+ tune_llm: Optional[bool] = field(default=True)
86
+ llm_type: str = field(default="qwen")
87
+ use_lora: Optional[bool] = field(default=False)
88
+ max_slice_nums: Optional[int] = field(default=9)
89
+ config_path: Optional[str] = field(default=None)
90
+ chunk_input: Optional[bool] = field(default=True)
91
+ init_vision: Optional[bool] = field(default=False)
92
+ init_speech: Optional[bool] = field(default=True)
93
+
94
+
95
+ @dataclass
96
+ class LoraArguments:
97
+ lora_r: int = 64
98
+ lora_alpha: int = 64
99
+ lora_dropout: float = 0.05
100
+ lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
101
+ lora_weight_path: str = ""
102
+ lora_bias: str = "none"
103
+ q_lora: bool = False
104
+ lora_modules_to_save: str = ""
105
+ lora_layer_replication: Optional[List[Tuple[int, int]]] = None
106
+ lora_layers_to_transform: Optional[List[int]] = None
107
+ lora_layers_pattern: Optional[str] = None
108
+
109
+ local_rank = None
110
+ def rank0_print(*args):
111
+ if local_rank == 0:
112
+ print(*args)
113
+
114
+ def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
115
+ """Collects the state dict and dump to disk."""
116
+ if trainer.args.should_save and trainer.args.local_rank == 0:
117
+ trainer.save_model(output_dir,)
118
+
119
+ # class CollateFn:
120
+ # def __init__(self, processor, prompt="Please transcribe this audio into text.", system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."):
121
+ # self.prompt = prompt
122
+ # self.system_prompt = system_prompt
123
+ # self.processor = processor
124
+
125
+ # def __call__(self, examples):
126
+ # prompts_lists = []
127
+ # input_images_list = []
128
+ # input_audios_list = []
129
+ # audio_parts_list = []
130
+
131
+ # for msgs in examples:
132
+ # msgs = msgs["conversations"]
133
+ # if isinstance(msgs, str):
134
+ # msgs = json.loads(msgs)
135
+ # copy_msgs = deepcopy(msgs)
136
+
137
+ # assert len(msgs) > 0, "msgs is empty"
138
+
139
+ # system_turn = {'role': 'system', 'content': self.system_prompt}
140
+ # if copy_msgs[0]["role"] != 'system':
141
+ # copy_msgs.insert(0, system_turn)
142
+
143
+ # images = []
144
+ # audios = []
145
+ # audio_parts = []
146
+ # for i, msg in enumerate(copy_msgs):
147
+ # role = msg["role"]
148
+ # content = msg["content"]
149
+ # assert role in ["system", "user", "assistant"]
150
+ # if i == 0:
151
+ # assert role in ["user", "system"], "The role of first msg should be user"
152
+ # content = [content, self.prompt]
153
+ # cur_msgs = []
154
+
155
+ # for c in content:
156
+ # if os.path.exists(c):
157
+ # c, _ = librosa.load(c, sr=16000, mono=True)
158
+
159
+ # if isinstance(c, Image.Image):
160
+ # images.append(c)
161
+ # cur_msgs.append("(<image>./</image>)")
162
+ # elif isinstance(c, np.ndarray): # audio
163
+ # audios.append(c)
164
+ # audio_parts.append(i)
165
+ # cur_msgs.append("(<audio>./</audio>)")
166
+ # elif isinstance(c, str):
167
+ # cur_msgs.append(c)
168
+ # else:
169
+ # msg["content"] = "\n".join(cur_msgs)
170
+
171
+ # prompts_lists.append(
172
+ # self.processor.tokenizer.apply_chat_template(
173
+ # copy_msgs,
174
+ # tokenize=False,
175
+ # add_generation_prompt=False,
176
+ # )
177
+ # )
178
+ # input_images_list.append(images)
179
+ # input_audios_list.append(audios)
180
+ # audio_parts_list.append(audio_parts)
181
+
182
+ # inputs = self.processor(
183
+ # prompts_lists,
184
+ # input_images_list,
185
+ # input_audios_list,
186
+ # audio_parts_list,
187
+ # return_tensors="pt",
188
+ # max_length=32768,
189
+ # return_labels=True,
190
+ # )
191
+
192
+ # return inputs
193
+
194
+ def collate_fn(examples, processor, chunk_input, max_len, prompt=None, system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", transform=None):
195
+
196
+ prompts_lists = []
197
+ input_images_list = []
198
+ input_audios_list = []
199
+ audio_parts_list = []
200
+
201
+ for msgs in examples:
202
+ raw_msgs = deepcopy(msgs)
203
+ msgs = msgs["conversations"]
204
+ if isinstance(msgs, str):
205
+ msgs = json.loads(msgs)
206
+ copy_msgs = deepcopy(msgs)
207
+
208
+ assert len(msgs) > 0, "msgs is empty"
209
+
210
+ system_turn = {'role': 'system', 'content': system_prompt}
211
+ if copy_msgs[0]["role"] != 'system':
212
+ copy_msgs.insert(0, system_turn)
213
+
214
+ fc = None
215
+ if "tools" in raw_msgs:
216
+ # if raw_msgs["tools"] != "":
217
+ # json_objects = raw_msgs["tools"].split("\n\n")
218
+ # try:
219
+ # fc = [json.loads(obj) for obj in json_objects]
220
+ # except:
221
+ # if len(json_objects) > 1:
222
+ # json_objects = json_objects[:-1]
223
+ # fc = [json.loads(obj) for obj in json_objects]
224
+ if raw_msgs["tools"] != "":
225
+ fc = json.loads(raw_msgs["tools"])
226
+
227
+ # print(fc)
228
+ # print("-----------")
229
+
230
+ images = []
231
+ audios = []
232
+ audio_parts = []
233
+ for i, msg in enumerate(copy_msgs):
234
+ role = msg["role"]
235
+ content = msg["content"]
236
+ assert role in ["system", "user", "assistant", "tool"]
237
+ if i == 0:
238
+ assert role in ["user", "system"], "The role of first msg should be user or system"
239
+
240
+ if role == "user":
241
+ if prompt is not None:
242
+ content = [content, prompt]
243
+ else:
244
+ content = [content]
245
+ cur_msgs = []
246
+ for c in content:
247
+ if os.path.exists(c):
248
+ c, _ = librosa.load(c, sr=16000, mono=True)
249
+ if transform is not None:
250
+ c = transform(c, sample_rate=16000)
251
+
252
+ if isinstance(c, Image.Image):
253
+ images.append(c)
254
+ cur_msgs.append("(<image>./</image>)")
255
+ elif isinstance(c, np.ndarray): # audio
256
+ audios.append(c)
257
+ audio_parts.append(i)
258
+ cur_msgs.append("(<audio>./</audio>)")
259
+ elif isinstance(c, str):
260
+ cur_msgs.append(c)
261
+
262
+ msg["content"] = "\n".join(cur_msgs)
263
+
264
+ if "tool_calls" in msg:
265
+ if msg["tool_calls"] is not None:
266
+ assert isinstance(msg["tool_calls"], str), f"Invalid type: {type(msg['tool_calls'])}"
267
+ msg["tool_calls"] = json.loads(msg["tool_calls"])
268
+ if type(msg["tool_calls"]) != list:
269
+ msg["tool_calls"] = [msg["tool_calls"]]
270
+
271
+ # print(copy_msgs)
272
+ # print("--------")
273
+
274
+ qwen_template = processor.tokenizer.apply_chat_template(
275
+ copy_msgs,
276
+ tokenize=False,
277
+ add_generation_prompt=False,
278
+ tools = fc,
279
+ )
280
+
281
+ # print(qwen_template)
282
+ # print("---------------")
283
+
284
+ prompts_lists.append(qwen_template)
285
+ input_images_list.append(images)
286
+ input_audios_list.append(audios)
287
+ audio_parts_list.append(audio_parts)
288
+
289
+ inputs = processor(
290
+ prompts_lists,
291
+ input_images_list,
292
+ input_audios_list,
293
+ audio_parts_list,
294
+ chunk_input=chunk_input,
295
+ return_tensors="pt",
296
+ # max_length=max_len,
297
+ return_labels=True,
298
+ )
299
+
300
+ return inputs
301
+
302
+ def make_supervised_data_module(
303
+ tokenizer: transformers.PreTrainedTokenizer,
304
+ processor: transformers.ProcessorMixin,
305
+ data_args,
306
+ transform,
307
+ data_collator=None,
308
+ llm_type="qwen",
309
+ slice_config=None,
310
+ patch_size=14,
311
+ query_nums=64,
312
+ batch_vision=False,
313
+ max_length=2048,
314
+ ) -> Dict:
315
+ """Make dataset and collator for supervised fine-tuning."""
316
+ dataset_cls = SupervisedDataset
317
+
318
+ rank0_print("Loading data...")
319
+
320
+ train_json = json.load(open(data_args.data_path, "r"))
321
+ train_dataset = dataset_cls(
322
+ train_json,
323
+ transform,
324
+ tokenizer,
325
+ processor,
326
+ slice_config=slice_config,
327
+ llm_type=llm_type,
328
+ patch_size=patch_size,
329
+ query_nums=query_nums,
330
+ batch_vision=batch_vision,
331
+ max_length=max_length,
332
+ )
333
+
334
+ if data_args.eval_data_path:
335
+ eval_json = json.load(open(data_args.eval_data_path, "r"))
336
+ eval_dataset = dataset_cls(
337
+ eval_json,
338
+ transform,
339
+ tokenizer,
340
+ processor,
341
+ slice_config=slice_config,
342
+ llm_type=llm_type,
343
+ patch_size=patch_size,
344
+ query_nums=query_nums,
345
+ batch_vision=batch_vision,
346
+ max_length=max_length,
347
+ )
348
+ else:
349
+ eval_dataset = None
350
+
351
+ return dict(
352
+ train_dataset=train_dataset,
353
+ eval_dataset=eval_dataset,
354
+ data_collator= partial(data_collator, max_length=max_length),
355
+ )
356
+
357
+
358
+ def build_transform():
359
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
360
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
361
+ return transforms.Compose(
362
+ [
363
+ transforms.ToTensor(),
364
+ transforms.Normalize(
365
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
366
+ ),
367
+ ]
368
+ )
369
+
370
+ def get_parameter_number(model):
371
+ trainable_params, all_param = 0, 0
372
+ for param in model.parameters():
373
+ num_params = param.numel()
374
+ # if using DS Zero 3 and the weights are initialized empty
375
+ if num_params == 0 and hasattr(param, "ds_numel"):
376
+ num_params = param.ds_numel
377
+
378
+ all_param += num_params
379
+ if param.requires_grad:
380
+ trainable_params += num_params
381
+
382
+ return {'Total': all_param, 'Trainable': trainable_params}
383
+
384
+
385
+ local_rank = 0
386
+
387
+
388
+ def train(attn_implementation="flash_attention_2"):
389
+ global local_rank
390
+ parser = transformers.HfArgumentParser(
391
+ (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
392
+ )
393
+
394
+ (
395
+ model_args,
396
+ data_args,
397
+ training_args,
398
+ lora_args,
399
+ ) = parser.parse_args_into_dataclasses()
400
+
401
+ if getattr(training_args, "deepspeed", None) :
402
+ training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
403
+
404
+ compute_dtype = (
405
+ torch.float16
406
+ if training_args.fp16
407
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
408
+ )
409
+
410
+ local_rank = training_args.local_rank
411
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
412
+ ddp = world_size != 1
413
+ device_map = None
414
+ if lora_args.q_lora:
415
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
416
+ if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
417
+ logging.warning(
418
+ "FSDP or ZeRO3 are not incompatible with QLoRA."
419
+ )
420
+
421
+ minipcmo_config = {}
422
+ if training_args.config_path is not None:
423
+ minipcmo_config = json.load(open(training_args.config_path, "r"))
424
+
425
+ # if model_args.tokenizer_path is not None:
426
+ # tokenizer = AutoTokenizer.from_pretrained(
427
+ # model_args.tokenizer_path, trust_remote_code=True
428
+ # )
429
+ # else:
430
+ # tokenizer = AutoTokenizer.from_pretrained(
431
+ # model_args.model_name_or_path, trust_remote_code=True
432
+ # )
433
+
434
+ tokenizer = AutoTokenizer.from_pretrained(
435
+ model_args.model_name_or_path, trust_remote_code=True
436
+ )
437
+
438
+ if model_args.pretrained_llm_path is None:
439
+ print("Finetuning model!!!")
440
+ model = MiniCPMO.from_pretrained(
441
+ model_args.model_name_or_path,
442
+ torch_dtype=compute_dtype,
443
+ device_map=device_map,
444
+ attn_implementation=attn_implementation,
445
+ init_vision=training_args.init_vision,
446
+ init_audio=training_args.init_speech,
447
+ init_tts=False,
448
+ processor_path=model_args.tokenizer_path,
449
+ **minipcmo_config
450
+ )
451
+ else:
452
+ print("Load pretrained LLM from scratch!!!")
453
+ # # Create the config object as needed
454
+ # config = MiniCPMOConfig(
455
+ # model_name_or_path=model_args.model_name_or_path,
456
+ # pretrained_llm_path=model_args.pretrained_llm_path,
457
+ # init_vision=training_args.init_vision,
458
+ # init_audio=training_args.init_speech,
459
+ # pretrained_encoder_path=model_args.audio_encoder_path,
460
+ # processor_path=model_args.tokenizer_path,
461
+ # **minipcmo_config
462
+ # )
463
+
464
+ # # Initialize model
465
+ # model = MiniCPMO(config)
466
+
467
+ model = MiniCPMO.from_pretrained(
468
+ model_args.model_name_or_path,
469
+ pretrained_llm_path=model_args.pretrained_llm_path,
470
+ init_vision=training_args.init_vision,
471
+ init_audio=training_args.init_speech,
472
+ pretrained_encoder_path=model_args.audio_encoder_path,
473
+ processor_path=model_args.tokenizer_path,
474
+ **minipcmo_config
475
+ )
476
+
477
+ # tokenizer.audio_start_id = tokenizer.convert_tokens_to_ids("<|box_start|>")
478
+ # tokenizer.audio_end_id = tokenizer.convert_tokens_to_ids("<|box_end|>")
479
+ # tokenizer.audio_start = "<|box_start|>"
480
+ # tokenizer.audio_end = "<|box_end|>"
481
+ # tokenizer.im_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
482
+ # tokenizer.im_end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>")
483
+ # tokenizer.im_start = "<|vision_start|>"
484
+ # tokenizer.im_end = "<|vision_end|>"
485
+ # tokenizer.slice_start_id = tokenizer.convert_tokens_to_ids("<|quad_start|>")
486
+ # tokenizer.slice_end_id = tokenizer.convert_tokens_to_ids("<|quad_end|>")
487
+ # tokenizer.slice_start = "<|quad_start|>"
488
+ # tokenizer.slice_end = "<|quad_end|>"
489
+ # tokenizer.unk_token = "<unk>"
490
+
491
+ # print("Audio Start Token:", tokenizer.audio_start)
492
+ # print("Audio End Token:", tokenizer.audio_end)
493
+ # print(tokenizer.audio_start_id)
494
+ # print(tokenizer.audio_end_id)
495
+ # print("Start Token:", tokenizer.im_start)
496
+ # print("End Token:", tokenizer.im_end)
497
+ # print(tokenizer.im_start_id)
498
+ # print(tokenizer.im_end_id)
499
+ # print("Slice Start Token:", tokenizer.slice_start)
500
+ # print("Slice End Token:", tokenizer.slice_end)
501
+ # print(tokenizer.slice_start_id)
502
+ # print(tokenizer.slice_end_id)
503
+
504
+ model.config.chunk_input = training_args.chunk_input
505
+ # model.llm.resize_token_embeddings(len(tokenizer))
506
+ # model.resize_token_embeddings(len(tokenizer))
507
+
508
+ model.llm.config.use_cache = False
509
+ model.config.max_length = training_args.model_max_length
510
+
511
+ if not training_args.tune_vision and training_args.init_vision:
512
+ model.vpm.requires_grad_(False)
513
+ if not training_args.tune_speech and training_args.init_speech:
514
+ model.apm.requires_grad_(False)
515
+ if not training_args.tune_llm:
516
+ model.llm.requires_grad_(False)
517
+
518
+ if training_args.use_lora:
519
+ if training_args.use_lora and training_args.tune_llm:
520
+ raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
521
+
522
+ rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
523
+ for name, param in model.llm.named_parameters():
524
+ param.requires_grad = False
525
+ modules_to_save = ['embed_tokens','resampler']
526
+ if training_args.tune_vision:
527
+ modules_to_save.append('vpm')
528
+ lora_config = LoraConfig(
529
+ r=lora_args.lora_r,
530
+ lora_alpha=lora_args.lora_alpha,
531
+ target_modules=lora_args.lora_target_modules,
532
+ lora_dropout=lora_args.lora_dropout,
533
+ bias=lora_args.lora_bias,
534
+ layers_to_transform=lora_args.lora_layers_to_transform,
535
+ modules_to_save=modules_to_save,
536
+ )
537
+ if not hasattr(model, 'get_input_embeddings'):
538
+ def get_input_embeddings(self):
539
+ return self.llm.get_input_embeddings()
540
+ model.get_input_embeddings = MethodType(get_input_embeddings, model)
541
+ if lora_args.q_lora:
542
+ model = prepare_model_for_kbit_training(
543
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
544
+ )
545
+ model = get_peft_model(model, lora_config)
546
+ if training_args.gradient_checkpointing:
547
+ model.enable_input_require_grads()
548
+
549
+ rank0_print(get_parameter_number(model))
550
+
551
+ llm_type = training_args.llm_type
552
+
553
+ rank0_print(f'llm_type={llm_type}')
554
+
555
+ # Load data
556
+ if hasattr(model.config, "slice_config"):
557
+ model.config.slice_config.max_slice_nums = training_args.max_slice_nums
558
+ slice_config = model.config.slice_config.to_dict()
559
+ else:
560
+ model.config.max_slice_nums = training_args.max_slice_nums
561
+ slice_config = model.config.to_dict()
562
+
563
+ if hasattr(model.config, "batch_vision_input"):
564
+ batch_vision = model.config.batch_vision_input
565
+ else:
566
+ batch_vision = False
567
+
568
+ transform_func = build_transform()
569
+
570
+ if model_args.tokenizer_path is not None:
571
+ processor = AutoProcessor.from_pretrained(model_args.tokenizer_path, trust_remote_code=True)
572
+ else:
573
+ processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
574
+ processor.tokenizer = tokenizer
575
+
576
+ raw_datasets = load_dataset(
577
+ "json",
578
+ data_files={
579
+ "train": data_args.data_path,
580
+ "validation": data_args.eval_data_path,
581
+ },
582
+ cache_dir=training_args.cache_dir,
583
+ )
584
+
585
+ train_ds = raw_datasets["train"]
586
+ if data_args.max_train_samples is not None:
587
+ train_ds = train_ds.select(range(data_args.max_train_samples))
588
+ eval_ds = raw_datasets["validation"]
589
+ if data_args.max_eval_samples is not None:
590
+ eval_ds = eval_ds.select(range(data_args.max_eval_samples))
591
+
592
+ # data_module = make_supervised_data_module(
593
+ # tokenizer=tokenizer,
594
+ # processor=processor,
595
+ # data_args=data_args,
596
+ # transform=transform_func,
597
+ # data_collator=data_collator,
598
+ # slice_config=slice_config,
599
+ # llm_type=llm_type,
600
+ # patch_size=model.config.patch_size,
601
+ # query_nums=model.config.query_num,
602
+ # batch_vision=batch_vision,
603
+ # max_length=training_args.model_max_length,
604
+ # )
605
+
606
+ init_prompt = None
607
+ if not training_args.tune_llm and training_args.tune_speech: # asr finetuning
608
+ init_prompt = "Please transcribe this audio into text."
609
+
610
+ transform = None
611
+ if data_args.augment_prob != 0.0 and data_args.augment_path is not None:
612
+ with open(data_args.augment_path, "r") as f:
613
+ augment_path_list = f.read().splitlines()
614
+ transform = AddBackgroundNoise(
615
+ sounds_path=augment_path_list,
616
+ min_snr_db=5.0,
617
+ max_snr_db=30.0,
618
+ noise_transform=PolarityInversion(),
619
+ p=data_args.augment_prob
620
+ )
621
+
622
+ custom_collate_fn = partial(collate_fn, processor = processor, chunk_input=training_args.chunk_input, max_len=training_args.model_max_length, prompt=init_prompt, transform=transform)
623
+
624
+ training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
625
+
626
+ print("Training Layers:")
627
+ for name, param in model.named_parameters():
628
+ if param.requires_grad:
629
+ print(name, param.grad)
630
+
631
+ # trainer = CPMTrainer(
632
+ # model=model,
633
+ # tokenizer=tokenizer,
634
+ # args=training_args,
635
+ # **data_module,
636
+ # )
637
+ trainer = Trainer(
638
+ model=model,
639
+ tokenizer=tokenizer,
640
+ args=training_args,
641
+ train_dataset=train_ds,
642
+ eval_dataset=eval_ds,
643
+ data_collator=custom_collate_fn
644
+ )
645
+
646
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
647
+ trainer.train(resume_from_checkpoint=True)
648
+ else:
649
+ trainer.train()
650
+
651
+ trainer.save_state()
652
+
653
+ safe_save_model_for_hf_trainer(
654
+ trainer=trainer,
655
+ output_dir=training_args.output_dir,
656
+ bias=lora_args.lora_bias)
657
+
658
+
659
+ if __name__ == "__main__":
660
+ train()
omni_speech/train/train_minicpmo_test.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Dict, List, Optional, Union, Literal, Tuple
8
+ from types import MethodType
9
+ from torchvision import transforms
10
+ from copy import deepcopy
11
+
12
+ import torch
13
+ import transformers
14
+ from accelerate.utils import DistributedType
15
+ from deepspeed import zero
16
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
17
+ import pathlib
18
+
19
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor
20
+ from transformers.integrations import deepspeed
21
+
22
+ from omni_speech.datasets.dataset import SupervisedDataset, data_collator
23
+ from omni_speech.model import *
24
+ from trainer import CPMTrainer
25
+ from transformers import Trainer
26
+ import librosa
27
+ from datasets import load_dataset
28
+ import numpy as np
29
+ from PIL import Image
30
+ from functools import partial
31
+ from audiomentations import AddBackgroundNoise, PolarityInversion
32
+
33
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
34
+
35
+ @dataclass
36
+ class ModelArguments:
37
+ model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-o-2_6")
38
+ tokenizer_path: Optional[str] = field(default=None)
39
+ audio_encoder_path: Optional[str] = field(default=None)
40
+ pretrained_llm_path: Optional[str] = field(default=None)
41
+
42
+
43
+ @dataclass
44
+ class DataArguments:
45
+ data_path: str = field(
46
+ default=None, metadata={"help": "Path to the training data."}
47
+ )
48
+ eval_data_path: str = field(
49
+ default=None, metadata={"help": "Path to the evaluation data."}
50
+ )
51
+ max_train_samples: Optional[int] = field(
52
+ default=None,
53
+ metadata={
54
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
55
+ "value if set."
56
+ },
57
+ )
58
+ max_eval_samples: Optional[int] = field(
59
+ default=None,
60
+ metadata={
61
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
62
+ "value if set."
63
+ },
64
+ )
65
+ augment_prob: float = field(
66
+ default=0.0,
67
+ metadata={"help": "The probability of applying augmentation transform."}
68
+ )
69
+ augment_path: str = field(default=None,
70
+ metadata={"help": "Path to the augment data."})
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments(transformers.TrainingArguments):
75
+ cache_dir: Optional[str] = field(default=None)
76
+ optim: str = field(default="adamw_torch")
77
+ model_max_length: int = field(
78
+ default=2048,
79
+ metadata={
80
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
81
+ },
82
+ )
83
+ tune_vision: Optional[bool] = field(default=True)
84
+ tune_speech: Optional[bool] = field(default=True)
85
+ tune_llm: Optional[bool] = field(default=True)
86
+ llm_type: str = field(default="qwen")
87
+ use_lora: Optional[bool] = field(default=False)
88
+ max_slice_nums: Optional[int] = field(default=9)
89
+ config_path: Optional[str] = field(default=None)
90
+ chunk_input: Optional[bool] = field(default=True)
91
+ init_vision: Optional[bool] = field(default=False)
92
+ init_speech: Optional[bool] = field(default=True)
93
+
94
+
95
+ @dataclass
96
+ class LoraArguments:
97
+ lora_r: int = 64
98
+ lora_alpha: int = 64
99
+ lora_dropout: float = 0.05
100
+ lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
101
+ lora_weight_path: str = ""
102
+ lora_bias: str = "none"
103
+ q_lora: bool = False
104
+ lora_modules_to_save: str = ""
105
+ lora_layer_replication: Optional[List[Tuple[int, int]]] = None
106
+ lora_layers_to_transform: Optional[List[int]] = None
107
+ lora_layers_pattern: Optional[str] = None
108
+
109
+ local_rank = None
110
+ def rank0_print(*args):
111
+ if local_rank == 0:
112
+ print(*args)
113
+
114
+
115
+ def print_trainable_parameters_by_module(model):
116
+ """
117
+ In ra chi tiết các tham số trainable theo module và số lượng tham số
118
+ """
119
+ print("\n" + "="*50)
120
+ print("TRAINABLE PARAMETERS BY MODULE")
121
+ print("="*50)
122
+
123
+ # Lưu trữ tham số theo module cấp 2
124
+ module_params = {}
125
+ all_params = 0
126
+ trainable_params = 0
127
+
128
+ for name, param in model.named_parameters():
129
+ all_params += param.numel()
130
+
131
+ # Lấy module cấp 2
132
+ parts = name.split('.')
133
+ if len(parts) >= 2:
134
+ module_name = f"{parts[0]}.{parts[1]}"
135
+ else:
136
+ module_name = parts[0]
137
+
138
+ if param.requires_grad:
139
+ trainable_params += param.numel()
140
+
141
+ if module_name not in module_params:
142
+ module_params[module_name] = {
143
+ 'count': 0,
144
+ 'names': []
145
+ }
146
+
147
+ module_params[module_name]['count'] += param.numel()
148
+ module_params[module_name]['names'].append(name)
149
+
150
+ # Sắp xếp và in kết quả
151
+ sorted_modules = sorted(module_params.items(), key=lambda x: x[1]['count'], reverse=True)
152
+
153
+ for module_name, info in sorted_modules:
154
+ param_count = info['count']
155
+ percentage = 100 * param_count / trainable_params
156
+ print(f"{module_name:<30} {param_count:,} params ({percentage:.2f}%)")
157
+
158
+ # In ra 3 tham số đầu tiên của module này
159
+ for i, param_name in enumerate(info['names'][:3]):
160
+ print(f" - {param_name}")
161
+
162
+ if len(info['names']) > 3:
163
+ print(f" ... and {len(info['names']) - 3} more parameters")
164
+
165
+ print("\n" + "-"*50)
166
+ print(f"Total trainable parameters: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
167
+ print("="*50 + "\n")
168
+
169
+
170
+ def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
171
+ """Collects the state dict and dump to disk."""
172
+ if trainer.args.should_save and trainer.args.local_rank == 0:
173
+ trainer.save_model(output_dir,)
174
+
175
+ # class CollateFn:
176
+ # def __init__(self, processor, prompt="Please transcribe this audio into text.", system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."):
177
+ # self.prompt = prompt
178
+ # self.system_prompt = system_prompt
179
+ # self.processor = processor
180
+
181
+ # def __call__(self, examples):
182
+ # prompts_lists = []
183
+ # input_images_list = []
184
+ # input_audios_list = []
185
+ # audio_parts_list = []
186
+
187
+ # for msgs in examples:
188
+ # msgs = msgs["conversations"]
189
+ # if isinstance(msgs, str):
190
+ # msgs = json.loads(msgs)
191
+ # copy_msgs = deepcopy(msgs)
192
+
193
+ # assert len(msgs) > 0, "msgs is empty"
194
+
195
+ # system_turn = {'role': 'system', 'content': self.system_prompt}
196
+ # if copy_msgs[0]["role"] != 'system':
197
+ # copy_msgs.insert(0, system_turn)
198
+
199
+ # images = []
200
+ # audios = []
201
+ # audio_parts = []
202
+ # for i, msg in enumerate(copy_msgs):
203
+ # role = msg["role"]
204
+ # content = msg["content"]
205
+ # assert role in ["system", "user", "assistant"]
206
+ # if i == 0:
207
+ # assert role in ["user", "system"], "The role of first msg should be user"
208
+ # content = [content, self.prompt]
209
+ # cur_msgs = []
210
+
211
+ # for c in content:
212
+ # if os.path.exists(c):
213
+ # c, _ = librosa.load(c, sr=16000, mono=True)
214
+
215
+ # if isinstance(c, Image.Image):
216
+ # images.append(c)
217
+ # cur_msgs.append("(<image>./</image>)")
218
+ # elif isinstance(c, np.ndarray): # audio
219
+ # audios.append(c)
220
+ # audio_parts.append(i)
221
+ # cur_msgs.append("(<audio>./</audio>)")
222
+ # elif isinstance(c, str):
223
+ # cur_msgs.append(c)
224
+ # else:
225
+ # msg["content"] = "\n".join(cur_msgs)
226
+
227
+ # prompts_lists.append(
228
+ # self.processor.tokenizer.apply_chat_template(
229
+ # copy_msgs,
230
+ # tokenize=False,
231
+ # add_generation_prompt=False,
232
+ # )
233
+ # )
234
+ # input_images_list.append(images)
235
+ # input_audios_list.append(audios)
236
+ # audio_parts_list.append(audio_parts)
237
+
238
+ # inputs = self.processor(
239
+ # prompts_lists,
240
+ # input_images_list,
241
+ # input_audios_list,
242
+ # audio_parts_list,
243
+ # return_tensors="pt",
244
+ # max_length=32768,
245
+ # return_labels=True,
246
+ # )
247
+
248
+ # return inputs
249
+
250
+ def collate_fn(examples, processor, chunk_input, max_len, prompt=None, system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", transform=None):
251
+
252
+ prompts_lists = []
253
+ input_images_list = []
254
+ input_audios_list = []
255
+ audio_parts_list = []
256
+
257
+ for msgs in examples:
258
+ raw_msgs = deepcopy(msgs)
259
+ msgs = msgs["conversations"]
260
+ if isinstance(msgs, str):
261
+ msgs = json.loads(msgs)
262
+ copy_msgs = deepcopy(msgs)
263
+
264
+ assert len(msgs) > 0, "msgs is empty"
265
+
266
+ system_turn = {'role': 'system', 'content': system_prompt}
267
+ if copy_msgs[0]["role"] != 'system':
268
+ copy_msgs.insert(0, system_turn)
269
+
270
+ fc = None
271
+ if "tools" in raw_msgs:
272
+ # if raw_msgs["tools"] != "":
273
+ # json_objects = raw_msgs["tools"].split("\n\n")
274
+ # try:
275
+ # fc = [json.loads(obj) for obj in json_objects]
276
+ # except:
277
+ # if len(json_objects) > 1:
278
+ # json_objects = json_objects[:-1]
279
+ # fc = [json.loads(obj) for obj in json_objects]
280
+ if raw_msgs["tools"] != "":
281
+ fc = json.loads(raw_msgs["tools"])
282
+
283
+ # print(fc)
284
+ # print("-----------")
285
+
286
+ images = []
287
+ audios = []
288
+ audio_parts = []
289
+ for i, msg in enumerate(copy_msgs):
290
+ role = msg["role"]
291
+ content = msg["content"]
292
+ assert role in ["system", "user", "assistant", "tool"]
293
+ if i == 0:
294
+ assert role in ["user", "system"], "The role of first msg should be user or system"
295
+
296
+ if role == "user":
297
+ if prompt is not None:
298
+ content = [content, prompt]
299
+ else:
300
+ content = [content]
301
+ cur_msgs = []
302
+ for c in content:
303
+ if os.path.exists(c):
304
+ c, _ = librosa.load(c, sr=16000, mono=True)
305
+ if transform is not None:
306
+ c = transform(c, sample_rate=16000)
307
+
308
+ if isinstance(c, Image.Image):
309
+ images.append(c)
310
+ cur_msgs.append("(<image>./</image>)")
311
+ elif isinstance(c, np.ndarray): # audio
312
+ audios.append(c)
313
+ audio_parts.append(i)
314
+ cur_msgs.append("(<audio>./</audio>)")
315
+ elif isinstance(c, str):
316
+ cur_msgs.append(c)
317
+
318
+ msg["content"] = "\n".join(cur_msgs)
319
+
320
+ if "tool_calls" in msg:
321
+ if msg["tool_calls"] is not None:
322
+ assert isinstance(msg["tool_calls"], str), f"Invalid type: {type(msg['tool_calls'])}"
323
+ msg["tool_calls"] = json.loads(msg["tool_calls"])
324
+ if type(msg["tool_calls"]) != list:
325
+ msg["tool_calls"] = [msg["tool_calls"]]
326
+
327
+ # print(copy_msgs)
328
+ # print("--------")
329
+
330
+ qwen_template = processor.tokenizer.apply_chat_template(
331
+ copy_msgs,
332
+ tokenize=False,
333
+ add_generation_prompt=False,
334
+ tools = fc,
335
+ )
336
+
337
+ # print(qwen_template)
338
+ # print("---------------")
339
+
340
+ prompts_lists.append(qwen_template)
341
+ input_images_list.append(images)
342
+ input_audios_list.append(audios)
343
+ audio_parts_list.append(audio_parts)
344
+
345
+ inputs = processor(
346
+ prompts_lists,
347
+ input_images_list,
348
+ input_audios_list,
349
+ audio_parts_list,
350
+ chunk_input=chunk_input,
351
+ return_tensors="pt",
352
+ # max_length=max_len,
353
+ return_labels=True,
354
+ )
355
+
356
+ return inputs
357
+
358
+ def make_supervised_data_module(
359
+ tokenizer: transformers.PreTrainedTokenizer,
360
+ processor: transformers.ProcessorMixin,
361
+ data_args,
362
+ transform,
363
+ data_collator=None,
364
+ llm_type="qwen",
365
+ slice_config=None,
366
+ patch_size=14,
367
+ query_nums=64,
368
+ batch_vision=False,
369
+ max_length=2048,
370
+ ) -> Dict:
371
+ """Make dataset and collator for supervised fine-tuning."""
372
+ dataset_cls = SupervisedDataset
373
+
374
+ rank0_print("Loading data...")
375
+
376
+ train_json = json.load(open(data_args.data_path, "r"))
377
+ train_dataset = dataset_cls(
378
+ train_json,
379
+ transform,
380
+ tokenizer,
381
+ processor,
382
+ slice_config=slice_config,
383
+ llm_type=llm_type,
384
+ patch_size=patch_size,
385
+ query_nums=query_nums,
386
+ batch_vision=batch_vision,
387
+ max_length=max_length,
388
+ )
389
+
390
+ if data_args.eval_data_path:
391
+ eval_json = json.load(open(data_args.eval_data_path, "r"))
392
+ eval_dataset = dataset_cls(
393
+ eval_json,
394
+ transform,
395
+ tokenizer,
396
+ processor,
397
+ slice_config=slice_config,
398
+ llm_type=llm_type,
399
+ patch_size=patch_size,
400
+ query_nums=query_nums,
401
+ batch_vision=batch_vision,
402
+ max_length=max_length,
403
+ )
404
+ else:
405
+ eval_dataset = None
406
+
407
+ return dict(
408
+ train_dataset=train_dataset,
409
+ eval_dataset=eval_dataset,
410
+ data_collator= partial(data_collator, max_length=max_length),
411
+ )
412
+
413
+
414
+ def build_transform():
415
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
416
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
417
+ return transforms.Compose(
418
+ [
419
+ transforms.ToTensor(),
420
+ transforms.Normalize(
421
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
422
+ ),
423
+ ]
424
+ )
425
+
426
+ def get_parameter_number(model):
427
+ trainable_params, all_param = 0, 0
428
+ for param in model.parameters():
429
+ num_params = param.numel()
430
+ # if using DS Zero 3 and the weights are initialized empty
431
+ if num_params == 0 and hasattr(param, "ds_numel"):
432
+ num_params = param.ds_numel
433
+
434
+ all_param += num_params
435
+ if param.requires_grad:
436
+ trainable_params += num_params
437
+
438
+ return {'Total': all_param, 'Trainable': trainable_params}
439
+
440
+
441
+ local_rank = 0
442
+
443
+
444
+ def train(attn_implementation="flash_attention_2"):
445
+ global local_rank
446
+ parser = transformers.HfArgumentParser(
447
+ (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
448
+ )
449
+
450
+ (
451
+ model_args,
452
+ data_args,
453
+ training_args,
454
+ lora_args,
455
+ ) = parser.parse_args_into_dataclasses()
456
+
457
+ if getattr(training_args, "deepspeed", None) :
458
+ training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
459
+
460
+ compute_dtype = (
461
+ torch.float16
462
+ if training_args.fp16
463
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
464
+ )
465
+
466
+ local_rank = training_args.local_rank
467
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
468
+ ddp = world_size != 1
469
+ device_map = None
470
+ if lora_args.q_lora:
471
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
472
+ if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
473
+ logging.warning(
474
+ "FSDP or ZeRO3 are not incompatible with QLoRA."
475
+ )
476
+
477
+ minipcmo_config = {}
478
+ if training_args.config_path is not None:
479
+ minipcmo_config = json.load(open(training_args.config_path, "r"))
480
+
481
+ # if model_args.tokenizer_path is not None:
482
+ # tokenizer = AutoTokenizer.from_pretrained(
483
+ # model_args.tokenizer_path, trust_remote_code=True
484
+ # )
485
+ # else:
486
+ # tokenizer = AutoTokenizer.from_pretrained(
487
+ # model_args.model_name_or_path, trust_remote_code=True
488
+ # )
489
+
490
+ tokenizer = AutoTokenizer.from_pretrained(
491
+ model_args.model_name_or_path, trust_remote_code=True
492
+ )
493
+
494
+ if model_args.pretrained_llm_path is None:
495
+ print("Finetuning model!!!")
496
+ model = MiniCPMO.from_pretrained(
497
+ model_args.model_name_or_path,
498
+ torch_dtype=compute_dtype,
499
+ device_map=device_map,
500
+ attn_implementation=attn_implementation,
501
+ init_vision=training_args.init_vision,
502
+ init_audio=training_args.init_speech,
503
+ init_tts=False,
504
+ processor_path=model_args.tokenizer_path,
505
+ **minipcmo_config
506
+ )
507
+ else:
508
+ print("Load pretrained LLM from scratch!!!")
509
+ # # Create the config object as needed
510
+ # config = MiniCPMOConfig(
511
+ # model_name_or_path=model_args.model_name_or_path,
512
+ # pretrained_llm_path=model_args.pretrained_llm_path,
513
+ # init_vision=training_args.init_vision,
514
+ # init_audio=training_args.init_speech,
515
+ # pretrained_encoder_path=model_args.audio_encoder_path,
516
+ # processor_path=model_args.tokenizer_path,
517
+ # **minipcmo_config
518
+ # )
519
+
520
+ # # Initialize model
521
+ # model = MiniCPMO(config)
522
+
523
+ model = MiniCPMO.from_pretrained(
524
+ model_args.model_name_or_path,
525
+ pretrained_llm_path=model_args.pretrained_llm_path,
526
+ init_vision=training_args.init_vision,
527
+ init_audio=training_args.init_speech,
528
+ init_tts=False,
529
+ pretrained_encoder_path=model_args.audio_encoder_path,
530
+ processor_path=model_args.tokenizer_path,
531
+ **minipcmo_config
532
+ )
533
+
534
+ # tokenizer.audio_start_id = tokenizer.convert_tokens_to_ids("<|box_start|>")
535
+ # tokenizer.audio_end_id = tokenizer.convert_tokens_to_ids("<|box_end|>")
536
+ # tokenizer.audio_start = "<|box_start|>"
537
+ # tokenizer.audio_end = "<|box_end|>"
538
+ # tokenizer.im_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
539
+ # tokenizer.im_end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>")
540
+ # tokenizer.im_start = "<|vision_start|>"
541
+ # tokenizer.im_end = "<|vision_end|>"
542
+ # tokenizer.slice_start_id = tokenizer.convert_tokens_to_ids("<|quad_start|>")
543
+ # tokenizer.slice_end_id = tokenizer.convert_tokens_to_ids("<|quad_end|>")
544
+ # tokenizer.slice_start = "<|quad_start|>"
545
+ # tokenizer.slice_end = "<|quad_end|>"
546
+ # tokenizer.unk_token = "<unk>"
547
+
548
+ # print("Audio Start Token:", tokenizer.audio_start)
549
+ # print("Audio End Token:", tokenizer.audio_end)
550
+ # print(tokenizer.audio_start_id)
551
+ # print(tokenizer.audio_end_id)
552
+ # print("Start Token:", tokenizer.im_start)
553
+ # print("End Token:", tokenizer.im_end)
554
+ # print(tokenizer.im_start_id)
555
+ # print(tokenizer.im_end_id)
556
+ # print("Slice Start Token:", tokenizer.slice_start)
557
+ # print("Slice End Token:", tokenizer.slice_end)
558
+ # print(tokenizer.slice_start_id)
559
+ # print(tokenizer.slice_end_id)
560
+
561
+ model.config.chunk_input = training_args.chunk_input
562
+ # model.llm.resize_token_embeddings(len(tokenizer))
563
+ # model.resize_token_embeddings(len(tokenizer))
564
+
565
+ model.llm.config.use_cache = False
566
+ model.config.max_length = training_args.model_max_length
567
+
568
+ # if not training_args.tune_vision and training_args.init_vision:
569
+ # model.vpm.requires_grad_(False)
570
+ # if not training_args.tune_speech and training_args.init_speech:
571
+ # model.apm.requires_grad_(False)
572
+ # if not training_args.tune_llm:
573
+ # model.llm.requires_grad_(False)
574
+ model.requires_grad_(False)
575
+
576
+
577
+ if training_args.tune_llm:
578
+ model.llm.requires_grad_(True)
579
+ print("Enabled training for LLM")
580
+ model.audio_projection_layer.requires_grad_(True)
581
+ print("Enabled training for audio_projection_layer")
582
+
583
+
584
+ if training_args.use_lora:
585
+ if training_args.use_lora and training_args.tune_llm:
586
+ raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
587
+
588
+ rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
589
+ for name, param in model.llm.named_parameters():
590
+ param.requires_grad = False
591
+ modules_to_save = ['embed_tokens','resampler']
592
+ if training_args.tune_vision:
593
+ modules_to_save.append('vpm')
594
+ lora_config = LoraConfig(
595
+ r=lora_args.lora_r,
596
+ lora_alpha=lora_args.lora_alpha,
597
+ target_modules=lora_args.lora_target_modules,
598
+ lora_dropout=lora_args.lora_dropout,
599
+ bias=lora_args.lora_bias,
600
+ layers_to_transform=lora_args.lora_layers_to_transform,
601
+ modules_to_save=modules_to_save,
602
+ )
603
+ if not hasattr(model, 'get_input_embeddings'):
604
+ def get_input_embeddings(self):
605
+ return self.llm.get_input_embeddings()
606
+ model.get_input_embeddings = MethodType(get_input_embeddings, model)
607
+ if lora_args.q_lora:
608
+ model = prepare_model_for_kbit_training(
609
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
610
+ )
611
+ model = get_peft_model(model, lora_config)
612
+ if training_args.gradient_checkpointing:
613
+ model.enable_input_require_grads()
614
+
615
+ rank0_print(get_parameter_number(model))
616
+
617
+
618
+ print_trainable_parameters_by_module(model)
619
+
620
+ llm_type = training_args.llm_type
621
+
622
+ rank0_print(f'llm_type={llm_type}')
623
+
624
+ # Load data
625
+ if hasattr(model.config, "slice_config"):
626
+ model.config.slice_config.max_slice_nums = training_args.max_slice_nums
627
+ slice_config = model.config.slice_config.to_dict()
628
+ else:
629
+ model.config.max_slice_nums = training_args.max_slice_nums
630
+ slice_config = model.config.to_dict()
631
+
632
+ if hasattr(model.config, "batch_vision_input"):
633
+ batch_vision = model.config.batch_vision_input
634
+ else:
635
+ batch_vision = False
636
+
637
+ transform_func = build_transform()
638
+
639
+ if model_args.tokenizer_path is not None:
640
+ processor = AutoProcessor.from_pretrained(model_args.tokenizer_path, trust_remote_code=True)
641
+ else:
642
+ processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
643
+ processor.tokenizer = tokenizer
644
+
645
+ raw_datasets = load_dataset(
646
+ "json",
647
+ data_files={
648
+ "train": data_args.data_path,
649
+ "validation": data_args.eval_data_path,
650
+ },
651
+ cache_dir=training_args.cache_dir,
652
+ )
653
+
654
+ train_ds = raw_datasets["train"]
655
+ if data_args.max_train_samples is not None:
656
+ train_ds = train_ds.select(range(data_args.max_train_samples))
657
+ eval_ds = raw_datasets["validation"]
658
+ if data_args.max_eval_samples is not None:
659
+ eval_ds = eval_ds.select(range(data_args.max_eval_samples))
660
+
661
+ # data_module = make_supervised_data_module(
662
+ # tokenizer=tokenizer,
663
+ # processor=processor,
664
+ # data_args=data_args,
665
+ # transform=transform_func,
666
+ # data_collator=data_collator,
667
+ # slice_config=slice_config,
668
+ # llm_type=llm_type,
669
+ # patch_size=model.config.patch_size,
670
+ # query_nums=model.config.query_num,
671
+ # batch_vision=batch_vision,
672
+ # max_length=training_args.model_max_length,
673
+ # )
674
+
675
+ init_prompt = None
676
+ if not training_args.tune_llm and training_args.tune_speech: # asr finetuning
677
+ init_prompt = "Please transcribe this audio into text."
678
+
679
+ transform = None
680
+ if data_args.augment_prob != 0.0 and data_args.augment_path is not None:
681
+ with open(data_args.augment_path, "r") as f:
682
+ augment_path_list = f.read().splitlines()
683
+ transform = AddBackgroundNoise(
684
+ sounds_path=augment_path_list,
685
+ min_snr_db=5.0,
686
+ max_snr_db=30.0,
687
+ noise_transform=PolarityInversion(),
688
+ p=data_args.augment_prob
689
+ )
690
+
691
+ custom_collate_fn = partial(collate_fn, processor = processor, chunk_input=training_args.chunk_input, max_len=training_args.model_max_length, prompt=init_prompt, transform=transform)
692
+
693
+ training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
694
+
695
+ # print("Training Layers:")
696
+ # for name, param in model.named_parameters():
697
+ # if param.requires_grad:
698
+ # print(name, param.grad)
699
+
700
+ # trainer = CPMTrainer(
701
+ # model=model,
702
+ # tokenizer=tokenizer,
703
+ # args=training_args,
704
+ # **data_module,
705
+ # )
706
+ trainer = Trainer(
707
+ model=model,
708
+ tokenizer=tokenizer,
709
+ args=training_args,
710
+ train_dataset=train_ds,
711
+ eval_dataset=eval_ds,
712
+ data_collator=custom_collate_fn
713
+ )
714
+
715
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
716
+ trainer.train(resume_from_checkpoint=True)
717
+ else:
718
+ trainer.train()
719
+
720
+ trainer.save_state()
721
+
722
+ safe_save_model_for_hf_trainer(
723
+ trainer=trainer,
724
+ output_dir=training_args.output_dir,
725
+ bias=lora_args.lora_bias)
726
+
727
+
728
+ if __name__ == "__main__":
729
+ train()
omni_speech/train/train_multiturn.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import copy
19
+ from dataclasses import dataclass, field
20
+ import json
21
+ import logging
22
+ import pathlib
23
+ from typing import Dict, Optional, Sequence, List
24
+
25
+ import torch
26
+
27
+ import transformers
28
+ import tokenizers
29
+
30
+ from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
31
+ from torch.utils.data import Dataset
32
+ from omni_speech.train.omni_trainer import OmniTrainer
33
+ from audiomentations import AddBackgroundNoise, PolarityInversion
34
+
35
+ from omni_speech import conversation as conversation_lib
36
+ from omni_speech.model import *
37
+ from omni_speech.utils import *
38
+ from omni_speech.datasets.preprocess import *
39
+ import whisper
40
+ import time
41
+
42
+ @dataclass
43
+ class ModelArguments:
44
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
45
+ version: Optional[str] = field(default="llama_3")
46
+ freeze_backbone: bool = field(default=False)
47
+ tune_speech_projector: bool = field(default=False)
48
+ tune_speech_encoder: bool = field(default=False)
49
+ tune_speech_generator_only: bool = field(default=False)
50
+ speech_encoder_type: Optional[str] = field(default=None)
51
+ speech_encoder: Optional[str] = field(default=None)
52
+ pretrain_speech_projector: Optional[str] = field(default=None)
53
+ speech_projector_type: Optional[str] = field(default='linear')
54
+ speech_generator_type: Optional[str] = field(default='ctc')
55
+ # ctc_decoder_config: str = "(2,4096,32,11008)" # num layers, hidden sizes, attn heads, ff dimensions of LLaMA
56
+ ctc_decoder_config: str = "(2,4096,32,22016)"
57
+ ctc_upsample_factor: int = 25
58
+ ctc_loss_weight: float = 1.0
59
+ unit_vocab_size: int = 1000
60
+ speech_encoder_ds_rate: int = 5
61
+ speech_encoder_hidden_size: int = 1280
62
+
63
+
64
+ @dataclass
65
+ class DataArguments:
66
+ data_path: str = field(default=None,
67
+ metadata={"help": "Path to the training data."})
68
+ dev_path: str = field(default=None,
69
+ metadata={"help": "Path to the dev data."})
70
+ is_multimodal: bool = False
71
+ input_type: str = field(default="mel")
72
+ speech_normalize: bool = False
73
+ mel_size: int = 128
74
+ has_tgt_units: bool = False
75
+ augment_prob: float = field(
76
+ default=0.0,
77
+ metadata={"help": "The probability of applying augmentation transform."}
78
+ )
79
+ augment_path: str = field(default=None,
80
+ metadata={"help": "Path to the augment data."})
81
+
82
+
83
+ @dataclass
84
+ class TrainingArguments(transformers.TrainingArguments):
85
+ cache_dir: Optional[str] = field(default=None)
86
+ optim: str = field(default="adamw_torch")
87
+ freeze_speech_projector: bool = field(default=False)
88
+ model_max_length: int = field(
89
+ default=512,
90
+ metadata={
91
+ "help":
92
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
93
+ },
94
+ )
95
+ double_quant: bool = field(
96
+ default=True,
97
+ metadata={"help": "Compress the quantization statistics through double quantization."}
98
+ )
99
+ quant_type: str = field(
100
+ default="nf4",
101
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
102
+ )
103
+ bits: int = field(
104
+ default=16,
105
+ metadata={"help": "How many bits to use."}
106
+ )
107
+ lora_enable: bool = False
108
+ lora_r: int = 64
109
+ lora_alpha: int = 16
110
+ lora_dropout: float = 0.05
111
+ lora_weight_path: str = ""
112
+ lora_bias: str = "none"
113
+ speech_projector_lr: Optional[float] = None
114
+ group_by_modality_length: bool = field(default=False)
115
+
116
+
117
+ class LazySupervisedDataset(Dataset):
118
+ """Dataset for supervised fine-tuning."""
119
+
120
+ def __init__(self, data_path: str,
121
+ tokenizer: transformers.PreTrainedTokenizer,
122
+ data_args: DataArguments):
123
+ super(LazySupervisedDataset, self).__init__()
124
+ list_data_dict = json.load(open(data_path, "r"))
125
+
126
+ self.tokenizer = tokenizer
127
+ self.list_data_dict = list_data_dict
128
+ self.data_args = data_args
129
+ if self.data_args.augment_prob != 0.0:
130
+ with open(self.data_args.augment_path, "r") as f:
131
+ augment_path_list = f.read().splitlines()
132
+ self.transform = AddBackgroundNoise(
133
+ sounds_path=augment_path_list,
134
+ min_snr_db=5.0,
135
+ max_snr_db=30.0,
136
+ noise_transform=PolarityInversion(),
137
+ p=self.data_args.augment_prob
138
+ )
139
+
140
+ def __len__(self):
141
+ return len(self.list_data_dict)
142
+
143
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
144
+ # TODO: define number of retries somewhere else
145
+ num_base_retries = 3
146
+ num_final_retries = 300
147
+
148
+ # try the current sample first
149
+ for attempt_idx in range(num_base_retries):
150
+ try:
151
+ sample = self._get_item(i)
152
+ return sample
153
+ except Exception as e:
154
+ # sleep 1s in case it is a cloud disk issue
155
+ print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
156
+ time.sleep(1)
157
+
158
+ # try other samples, in case it is file corruption issue
159
+ for attempt_idx in range(num_base_retries):
160
+ try:
161
+ next_index = min(i + 1, len(self.list_data_dict) - 1)
162
+ # sample_idx = random.choice(range(len(self)))
163
+ sample = self._get_item(next_index)
164
+ return sample
165
+ except Exception as e:
166
+ # no need to sleep
167
+ print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
168
+ pass
169
+
170
+ try:
171
+ sample = self._get_item(i)
172
+ return sample
173
+ except Exception as e:
174
+ raise e
175
+
176
+ def process_speech(self, speech_file):
177
+ speech = whisper.load_audio(speech_file)
178
+ if self.data_args.augment_prob != 0.0:
179
+ speech = self.transform(speech, sample_rate=16000)
180
+ if self.data_args.input_type == "raw":
181
+ speech = torch.from_numpy(speech)
182
+ if self.model_config.data_args.speech_normalize:
183
+ speech = torch.nn.functional.layer_norm(speech, speech.shape)
184
+ elif self.data_args.input_type == "mel":
185
+ speech = whisper.pad_or_trim(speech)
186
+ speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
187
+ speech_lengths = torch.LongTensor([speech.shape[0]])
188
+ return speech, speech_lengths
189
+
190
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
191
+ sources = self.list_data_dict[i]
192
+ if isinstance(i, int):
193
+ sources = [sources]
194
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
195
+ for item in sources:
196
+ if 'tools' in item:
197
+ tools_dict = {
198
+ "from": "tools",
199
+ "value": item["tools"]
200
+ }
201
+ item["conversations"].insert(0,tools_dict)
202
+
203
+ if self.data_args.has_tgt_units:
204
+ # pad_list = [0]
205
+ # tgt_units = [e["tgt_units"] if "tgt_units" in e else pad_list for e in sources]
206
+ tgt_units = [e["tgt_units"] for e in sources]
207
+ tgt_units = torch.tensor(tgt_units, dtype=torch.long)
208
+ else:
209
+ tgt_units = None
210
+
211
+ if 'speech' in sources[0]:
212
+ import numpy as np
213
+ speech_file = self.list_data_dict[i]['speech']
214
+ if type(speech_file) is list:
215
+ speech = [self.process_speech(f) for f in speech_file]
216
+ else:
217
+ speech = [self.process_speech(speech_file)]
218
+
219
+ sources = preprocess_multimodal(
220
+ copy.deepcopy([e["conversations"] for e in sources]),
221
+ self.data_args)
222
+ else:
223
+ sources = copy.deepcopy([e["conversations"] for e in sources])
224
+ data_dict = preprocess(
225
+ sources,
226
+ self.tokenizer,
227
+ has_speech=('speech' in self.list_data_dict[i]))
228
+ if isinstance(i, int):
229
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
230
+ labels=data_dict["labels"][0])
231
+
232
+ # speech exist in the data
233
+ if 'speech' in self.list_data_dict[i]:
234
+ data_dict['speech'] = speech
235
+
236
+ if tgt_units is not None:
237
+ data_dict['tgt_units'] = tgt_units[0]
238
+
239
+ data_dict["id"] = self.list_data_dict[i].get("id", i)
240
+
241
+ return data_dict
242
+
243
+
244
+ @dataclass
245
+ class DataCollatorForSupervisedDataset(object):
246
+ """Collate examples for supervised fine-tuning."""
247
+
248
+ tokenizer: transformers.PreTrainedTokenizer
249
+
250
+ def pad_sequence(self, input_ids, batch_first, padding_value):
251
+ if self.tokenizer.padding_side == "left":
252
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
253
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
254
+ if self.tokenizer.padding_side == "left":
255
+ input_ids = torch.flip(input_ids, [1])
256
+ return input_ids
257
+
258
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
259
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
260
+ # input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
261
+ input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
262
+ labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
263
+ if self.tokenizer.pad_token_id is None:
264
+ # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
265
+ self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
266
+ input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
267
+ labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
268
+ batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
269
+ # batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
270
+
271
+ if 'speech' in instances[0]:
272
+ speechs = [instance['speech'] for instance in instances]
273
+
274
+ speech = [sp[0] for sp_list in speechs for sp in sp_list]
275
+ speech_lengths = [sp[1] for sp_list in speechs for sp in sp_list]
276
+
277
+ batch["speech"] = speech
278
+ # print(len(speech)) # sum(len(audio) for audio in each batch)
279
+ # print(speech[0].shape) # seq_len, dim
280
+ batch['speech_lengths'] = speech_lengths
281
+ # print(speech_lengths[0].shape) # seq_len
282
+
283
+ if 'tgt_units' in instances[0]:
284
+ tgt_units = [instance['tgt_units'] for instance in instances]
285
+ tgt_units = self.pad_sequence(tgt_units, batch_first=True, padding_value=self.tokenizer.pad_token_id)
286
+ batch['tgt_units'] = tgt_units
287
+ # print(batch['tgt_units'])
288
+ # print("---------------")
289
+ # print(batch['input_ids'])
290
+
291
+ return batch
292
+
293
+
294
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
295
+ data_args) -> Dict:
296
+ """Make dataset and collator for supervised fine-tuning."""
297
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
298
+ data_path=data_args.data_path,
299
+ data_args=data_args)
300
+ if data_args.dev_path is not None:
301
+ dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
302
+ data_path=data_args.dev_path,
303
+ data_args=data_args)
304
+ else:
305
+ dev_dataset = None
306
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
307
+ return dict(train_dataset=train_dataset,
308
+ eval_dataset=dev_dataset,
309
+ data_collator=data_collator)
310
+
311
+
312
+ def train(attn_implementation="flash_attention_2"):
313
+
314
+ parser = transformers.HfArgumentParser(
315
+ (ModelArguments, DataArguments, TrainingArguments))
316
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
317
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
318
+
319
+ bnb_model_from_pretrained_args = {}
320
+ if training_args.bits in [4, 8]:
321
+ from transformers import BitsAndBytesConfig
322
+ bnb_model_from_pretrained_args.update(dict(
323
+ device_map={"": training_args.device},
324
+ load_in_4bit=training_args.bits == 4,
325
+ load_in_8bit=training_args.bits == 8,
326
+ quantization_config=BitsAndBytesConfig(
327
+ load_in_4bit=training_args.bits == 4,
328
+ load_in_8bit=training_args.bits == 8,
329
+ llm_int8_skip_modules=["speech_projector"],
330
+ llm_int8_threshold=6.0,
331
+ llm_int8_has_fp16_weight=False,
332
+ bnb_4bit_compute_dtype=compute_dtype,
333
+ bnb_4bit_use_double_quant=training_args.double_quant,
334
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
335
+ )
336
+ ))
337
+
338
+ if data_args.has_tgt_units:
339
+ if model_args.version == "llama_3":
340
+ model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
341
+ model_args.model_name_or_path,
342
+ cache_dir=training_args.cache_dir,
343
+ attn_implementation=attn_implementation,
344
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
345
+ **bnb_model_from_pretrained_args
346
+ )
347
+ elif model_args.version == "qwen":
348
+ model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
349
+ model_args.model_name_or_path,
350
+ cache_dir=training_args.cache_dir,
351
+ attn_implementation=attn_implementation,
352
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
353
+ **bnb_model_from_pretrained_args
354
+ )
355
+ else:
356
+ raise ValueError("--currently only support llama or qwen model!")
357
+ else:
358
+ if model_args.version == "llama_3":
359
+ model = OmniSpeechLlamaForCausalLM.from_pretrained(
360
+ model_args.model_name_or_path,
361
+ cache_dir=training_args.cache_dir,
362
+ attn_implementation=attn_implementation,
363
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
364
+ **bnb_model_from_pretrained_args
365
+ )
366
+ elif model_args.version == "qwen":
367
+ model = OmniSpeechQwen2ForCausalLM.from_pretrained(
368
+ model_args.model_name_or_path,
369
+ cache_dir=training_args.cache_dir,
370
+ attn_implementation=attn_implementation,
371
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
372
+ **bnb_model_from_pretrained_args
373
+ )
374
+ else:
375
+ raise ValueError("--currently only support llama or qwen model!")
376
+ model.config.use_cache = False
377
+
378
+ if model_args.freeze_backbone:
379
+ model.model.requires_grad_(False)
380
+
381
+ if training_args.bits in [4, 8]:
382
+ from peft import prepare_model_for_kbit_training
383
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
384
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
385
+
386
+ if training_args.gradient_checkpointing:
387
+ if hasattr(model, "enable_input_require_grads"):
388
+ model.enable_input_require_grads()
389
+ else:
390
+ def make_inputs_require_grad(module, input, output):
391
+ output.requires_grad_(True)
392
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
393
+
394
+ if training_args.lora_enable:
395
+ from peft import LoraConfig, get_peft_model
396
+ lora_config = LoraConfig(
397
+ r=training_args.lora_r,
398
+ lora_alpha=training_args.lora_alpha,
399
+ target_modules=find_all_linear_names(model),
400
+ lora_dropout=training_args.lora_dropout,
401
+ bias=training_args.lora_bias,
402
+ task_type="CAUSAL_LM",
403
+ )
404
+ if training_args.bits == 16:
405
+ if training_args.bf16:
406
+ model.to(torch.bfloat16)
407
+ if training_args.fp16:
408
+ model.to(torch.float16)
409
+ model = get_peft_model(model, lora_config)
410
+
411
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
412
+ model_args.model_name_or_path,
413
+ cache_dir=training_args.cache_dir,
414
+ model_max_length=training_args.model_max_length,
415
+ padding_side="right",
416
+ use_fast=False,
417
+ )
418
+
419
+ model.resize_token_embeddings(len(tokenizer))
420
+ model.config.max_length = training_args.model_max_length
421
+
422
+ if model_args.version in conversation_lib.conv_templates:
423
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
424
+ else:
425
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
426
+
427
+ if model_args.speech_encoder is not None:
428
+ model.get_model().initialize_speech_modules(
429
+ model_args=model_args,
430
+ fsdp=training_args.fsdp
431
+ )
432
+
433
+ data_args.is_multimodal = True
434
+
435
+ model.config.tokenizer_padding_side = tokenizer.padding_side
436
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
437
+
438
+ model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
439
+
440
+ model.config.speech_normalize = data_args.speech_normalize
441
+
442
+ for p in model.get_speech_encoder().parameters():
443
+ p.requires_grad = False
444
+
445
+ if model_args.tune_speech_projector:
446
+ model.requires_grad_(False)
447
+ for p in model.get_speech_projector().parameters():
448
+ p.requires_grad = True
449
+
450
+ model.config.freeze_speech_projector = training_args.freeze_speech_projector
451
+ if training_args.freeze_speech_projector:
452
+ for p in model.get_speech_projector().parameters():
453
+ p.requires_grad = False
454
+
455
+ if training_args.bits in [4, 8]:
456
+ model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
457
+
458
+ model.config.speech_projector_lr = training_args.speech_projector_lr
459
+
460
+ if data_args.has_tgt_units:
461
+ model.initialize_speech_generator(model_args=model_args)
462
+
463
+ if training_args.bits in [4, 8]:
464
+ from peft.tuners.lora import LoraLayer
465
+ for name, module in model.named_modules():
466
+ if isinstance(module, LoraLayer):
467
+ if training_args.bf16:
468
+ module = module.to(torch.bfloat16)
469
+ if 'norm' in name:
470
+ module = module.to(torch.float32)
471
+ if 'lm_head' in name or 'embed_tokens' in name:
472
+ if hasattr(module, 'weight'):
473
+ if training_args.bf16 and module.weight.dtype == torch.float32:
474
+ module = module.to(torch.bfloat16)
475
+
476
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
477
+ data_args=data_args)
478
+
479
+ print("Training Layers:")
480
+ for name, param in model.named_parameters():
481
+ if param.requires_grad:
482
+ print(name, param.grad)
483
+
484
+ trainer = OmniTrainer(model=model,
485
+ tokenizer=tokenizer,
486
+ args=training_args,
487
+ **data_module)
488
+
489
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
490
+ trainer.train(resume_from_checkpoint=True)
491
+ else:
492
+ trainer.train()
493
+ trainer.save_state()
494
+
495
+ model.config.use_cache = True
496
+
497
+ if training_args.lora_enable:
498
+ state_dict = get_peft_state_maybe_zero_3(
499
+ model.named_parameters(), training_args.lora_bias
500
+ )
501
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
502
+ model.named_parameters()
503
+ )
504
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
505
+ model.config.save_pretrained(training_args.output_dir)
506
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
507
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
508
+ else:
509
+ safe_save_model_for_hf_trainer(trainer=trainer,
510
+ output_dir=training_args.output_dir)
511
+
512
+
513
+ if __name__ == "__main__":
514
+ train()
515
+
omni_speech/train/trainer.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import deepspeed
4
+ from transformers import Trainer
5
+ from transformers.trainer_pt_utils import nested_detach
6
+ from transformers.utils import is_sagemaker_mp_enabled
7
+ from transformers.trainer import *
8
+ from transformers.integrations import is_deepspeed_zero3_enabled
9
+
10
+
11
+ class CPMTrainer(Trainer):
12
+ def compute_loss(self, model, inputs, return_outputs=False):
13
+ if "labels" in inputs:
14
+ labels = inputs.pop("labels")
15
+ else:
16
+ labels = None
17
+
18
+ if not self.args.use_lora:
19
+ outputs = self.model(data = inputs, use_cache=False)
20
+ else:
21
+ with self.model._enable_peft_forward_hooks(**inputs):
22
+ outputs = self.model.base_model(data = inputs, use_cache=False)
23
+
24
+ if labels is not None:
25
+ # Flatten the tokens
26
+ loss_fct = nn.CrossEntropyLoss()
27
+ logits = outputs.logits.view(-1,
28
+ self.model.config.vocab_size).contiguous()
29
+ labels = labels.view(-1).long().contiguous()
30
+ # Enable model parallelism
31
+ labels = labels.to(logits.device)
32
+ loss = loss_fct(logits, labels)
33
+ else:
34
+ if isinstance(outputs, dict) and "loss" not in outputs:
35
+ raise ValueError(
36
+ "The model did not return a loss from the inputs, only the following keys: "
37
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
38
+ )
39
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
40
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
41
+
42
+ return (loss, outputs) if return_outputs else loss
43
+
44
+ def prediction_step(
45
+ self,
46
+ model: nn.Module,
47
+ inputs: Dict[str, Union[torch.Tensor, Any]],
48
+ prediction_loss_only: bool,
49
+ ignore_keys: Optional[List[str]] = None,
50
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
51
+ """
52
+ Perform an evaluation step on `model` using `inputs`.
53
+
54
+ Subclass and override to inject custom behavior.
55
+
56
+ Args:
57
+ model (`nn.Module`):
58
+ The model to evaluate.
59
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
60
+ The inputs and targets of the model.
61
+
62
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
63
+ argument `labels`. Check your model's documentation for all accepted arguments.
64
+ prediction_loss_only (`bool`):
65
+ Whether or not to return the loss only.
66
+ ignore_keys (`List[str]`, *optional*):
67
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
68
+ gathering predictions.
69
+
70
+ Return:
71
+ Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
72
+ logits and labels (each being optional).
73
+ """
74
+ has_labels = (
75
+ False
76
+ if len(self.label_names) == 0
77
+ else all(inputs.get(k) is not None for k in self.label_names)
78
+ )
79
+ # For CLIP-like models capable of returning loss values.
80
+ # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
81
+ # is `True` in `model.forward`.
82
+ return_loss = inputs.get("return_loss", None)
83
+ if return_loss is None:
84
+ return_loss = self.can_return_loss
85
+ loss_without_labels = (
86
+ True if len(self.label_names) == 0 and return_loss else False
87
+ )
88
+
89
+ inputs = self._prepare_inputs(inputs)
90
+ if ignore_keys is None:
91
+ if hasattr(self.model, "config"):
92
+ ignore_keys = getattr(
93
+ self.model.config, "keys_to_ignore_at_inference", []
94
+ )
95
+ else:
96
+ ignore_keys = []
97
+
98
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
99
+ if has_labels or loss_without_labels:
100
+ labels = nested_detach(tuple(inputs.get(name)
101
+ for name in self.label_names))
102
+ if len(labels) == 1:
103
+ labels = labels[0]
104
+ else:
105
+ labels = None
106
+
107
+ with torch.no_grad():
108
+ if is_sagemaker_mp_enabled():
109
+ raw_outputs = smp_forward_only(model, inputs)
110
+ if has_labels or loss_without_labels:
111
+ if isinstance(raw_outputs, dict):
112
+ loss_mb = raw_outputs["loss"]
113
+ logits_mb = tuple(
114
+ v
115
+ for k, v in raw_outputs.items()
116
+ if k not in ignore_keys + ["loss"]
117
+ )
118
+ else:
119
+ loss_mb = raw_outputs[0]
120
+ logits_mb = raw_outputs[1:]
121
+
122
+ loss = loss_mb.reduce_mean().detach().cpu()
123
+ logits = smp_nested_concat(logits_mb)
124
+ else:
125
+ loss = None
126
+ if isinstance(raw_outputs, dict):
127
+ logits_mb = tuple(
128
+ v for k, v in raw_outputs.items() if k not in ignore_keys
129
+ )
130
+ else:
131
+ logits_mb = raw_outputs
132
+ logits = smp_nested_concat(logits_mb)
133
+ else:
134
+ if has_labels or loss_without_labels:
135
+ with self.compute_loss_context_manager():
136
+ loss, outputs = self.compute_loss(
137
+ model, inputs, return_outputs=True
138
+ )
139
+ loss = loss.mean().detach()
140
+
141
+ if isinstance(outputs, dict):
142
+ logits = tuple(
143
+ v
144
+ for k, v in outputs.items()
145
+ if k not in ignore_keys + ["loss"]
146
+ )
147
+ else:
148
+ logits = outputs[1:]
149
+ else:
150
+ loss = None
151
+ with self.compute_loss_context_manager():
152
+ outputs = model(**inputs)
153
+ if isinstance(outputs, dict):
154
+ logits = tuple(
155
+ v for k, v in outputs.items() if k not in ignore_keys
156
+ )
157
+ else:
158
+ logits = outputs
159
+ # TODO: this needs to be fixed and made cleaner later.
160
+ if self.args.past_index >= 0:
161
+ self._past = outputs[self.args.past_index - 1]
162
+
163
+ if prediction_loss_only:
164
+ return (loss, None, None)
165
+
166
+ logits = nested_detach(logits)
167
+ if len(logits) == 1:
168
+ logits = logits[0]
169
+
170
+ return (loss, logits, labels)
171
+
172
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
173
+ """
174
+ Perform a training step on a batch of inputs.
175
+
176
+ Subclass and override to inject custom behavior.
177
+
178
+ Args:
179
+ model (`nn.Module`):
180
+ The model to train.
181
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
182
+ The inputs and targets of the model.
183
+
184
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
185
+ argument `labels`. Check your model's documentation for all accepted arguments.
186
+
187
+ Return:
188
+ `torch.Tensor`: The tensor with training loss on this batch.
189
+ """
190
+ model.train()
191
+ inputs = self._prepare_inputs(inputs)
192
+
193
+ if is_sagemaker_mp_enabled():
194
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
195
+ return loss_mb.reduce_mean().detach().to(self.args.device)
196
+
197
+ with self.compute_loss_context_manager():
198
+ loss = self.compute_loss(model, inputs)
199
+
200
+ del inputs
201
+ torch.cuda.empty_cache()
202
+
203
+ if self.args.n_gpu > 1:
204
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
205
+
206
+ if self.use_apex:
207
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
208
+ scaled_loss.backward()
209
+ else:
210
+ self.accelerator.backward(loss)
211
+
212
+ return loss.detach() / self.args.gradient_accumulation_steps
213
+
214
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
215
+ # If we are executing this function, we are the process zero, so we don't check for that.
216
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
217
+ os.makedirs(output_dir, exist_ok=True)
218
+ logger.info(f"Saving model checkpoint to {output_dir}")
219
+
220
+ supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
221
+ # Save a trained model and configuration using `save_pretrained()`.
222
+ # They can then be reloaded using `from_pretrained()`
223
+ if not isinstance(self.model, supported_classes):
224
+ if state_dict is None:
225
+ state_dict = self.model.state_dict()
226
+
227
+ if isinstance(unwrap_model(self.model), supported_classes):
228
+ unwrap_model(self.model).save_pretrained(
229
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
230
+ )
231
+ else:
232
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
233
+ if self.args.save_safetensors:
234
+ safetensors.torch.save_file(
235
+ state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
236
+ )
237
+ else:
238
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
239
+ else:
240
+
241
+ self.model.save_pretrained(
242
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
243
+ )
244
+
245
+ if self.tokenizer is not None:
246
+ self.tokenizer.save_pretrained(output_dir)
247
+
248
+ # Good practice: save your training arguments together with the trained model
249
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
scripts/continue.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GPUS_PER_NODE=8
4
+ # NNODES=1
5
+ # NODE_RANK=0
6
+ # MASTER_ADDR=localhost
7
+ # MASTER_PORT=6001
8
+
9
+ MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/minicpmo_sft_asr"
10
+ TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
11
+ # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
12
+ # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
13
+ # See the section for finetuning in README for more information.
14
+ DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
15
+ EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
16
+
17
+ # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
18
+ # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
19
+ LLM_TYPE="qwen"
20
+ MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
21
+
22
+
23
+ # DISTRIBUTED_ARGS="
24
+ # --nproc_per_node $GPUS_PER_NODE \
25
+ # --nnodes $NNODES \
26
+ # --node_rank $NODE_RANK \
27
+ # --master_addr $MASTER_ADDR \
28
+ # --master_port $MASTER_PORT
29
+ # "
30
+
31
+ deepspeed ../omni_speech/train/train_minicpmo.py \
32
+ --deepspeed zero2.json \
33
+ --model_name_or_path $MODEL \
34
+ --tokenizer_path $TOKENIZER_PATH \
35
+ --llm_type $LLM_TYPE \
36
+ --data_path $DATA \
37
+ --eval_data_path $EVAL_DATA \
38
+ --remove_unused_columns false \
39
+ --label_names "labels" \
40
+ --prediction_loss_only false \
41
+ --bf16 true \
42
+ --do_train \
43
+ --do_eval \
44
+ --tune_speech true \
45
+ --tune_llm false \
46
+ --model_max_length $MODEL_MAX_Length \
47
+ --eval_steps 2000 \
48
+ --output_dir ../checkpoints/minicpmo_sft_asr \
49
+ --num_train_epochs 2 \
50
+ --logging_strategy "steps" \
51
+ --per_device_train_batch_size 1 \
52
+ --per_device_eval_batch_size 1 \
53
+ --gradient_accumulation_steps 4 \
54
+ --evaluation_strategy "steps" \
55
+ --save_strategy "steps" \
56
+ --save_steps 5000 \
57
+ --save_total_limit 1 \
58
+ --learning_rate 1e-5 \
59
+ --max_grad_norm 20. \
60
+ --weight_decay 0. \
61
+ --warmup_ratio 0.03 \
62
+ --lr_scheduler_type "cosine" \
63
+ --logging_steps 1 \
64
+ --tf32 True \
65
+ --gradient_checkpointing true
scripts/ds_config_zero2.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "bf16": {
12
+ "enabled": "auto"
13
+ },
14
+
15
+ "optimizer": {
16
+ "type": "AdamW",
17
+ "params": {
18
+ "lr": "auto",
19
+ "betas": "auto",
20
+ "eps": "auto",
21
+ "weight_decay": "auto"
22
+ }
23
+ },
24
+
25
+ "scheduler": {
26
+ "type": "WarmupLR",
27
+ "params": {
28
+ "warmup_min_lr": "auto",
29
+ "warmup_max_lr": "auto",
30
+ "warmup_num_steps": "auto"
31
+ }
32
+ },
33
+
34
+ "zero_optimization": {
35
+ "stage": 2,
36
+ "offload_optimizer": {
37
+ "device": "none",
38
+ "pin_memory": true
39
+ },
40
+ "allgather_partitions": true,
41
+ "allgather_bucket_size": 2e8,
42
+ "overlap_comm": true,
43
+ "reduce_scatter": true,
44
+ "reduce_bucket_size": 2e8,
45
+ "contiguous_gradients": true
46
+ },
47
+
48
+ "gradient_accumulation_steps": "auto",
49
+ "gradient_clipping": "auto",
50
+ "steps_per_print": 100,
51
+ "train_batch_size": "auto",
52
+ "train_micro_batch_size_per_gpu": "auto",
53
+ "wall_clock_breakdown": false
54
+ }
scripts/ds_config_zero3.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupLR",
25
+ "params": {
26
+ "warmup_min_lr": "auto",
27
+ "warmup_max_lr": "auto",
28
+ "warmup_num_steps": "auto"
29
+ }
30
+ },
31
+
32
+ "zero_optimization": {
33
+ "stage": 3,
34
+ "offload_optimizer": {
35
+ "device": "none",
36
+ "pin_memory": true
37
+ },
38
+ "offload_param": {
39
+ "device": "none",
40
+ "pin_memory": true
41
+ },
42
+ "overlap_comm": true,
43
+ "contiguous_gradients": true,
44
+ "sub_group_size": 1e9,
45
+ "reduce_bucket_size": "auto",
46
+ "stage3_prefetch_bucket_size": "auto",
47
+ "stage3_param_persistence_threshold": "auto",
48
+ "stage3_max_live_parameters": 1e9,
49
+ "stage3_max_reuse_distance": 1e9,
50
+ "stage3_gather_16bit_weights_on_model_save": true
51
+ },
52
+
53
+ "gradient_accumulation_steps": "auto",
54
+ "gradient_clipping": "auto",
55
+ "steps_per_print": 100,
56
+ "train_batch_size": "auto",
57
+ "train_micro_batch_size_per_gpu": "auto",
58
+ "wall_clock_breakdown": false
59
+ }
scripts/export.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_fixed_all/checkpoint-4000
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ PROMPT_VERSION=qwen
6
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/moss/moss_100K_phase3_tgt_units_processed.jsonl
7
+ # DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250103.jsonl
8
+ CACHE_DIR="../output/cached_sft_speech_decoder_20250103"
9
+
10
+ deepspeed --master_port 29501 ../omni_speech/train/export.py \
11
+ --deepspeed zero2.json \
12
+ --model_name_or_path $MODEL_PATH \
13
+ --version $PROMPT_VERSION \
14
+ --data_path $DATA_PATH \
15
+ --cache_dir $CACHE_DIR \
16
+ --speech_encoder $SPEECH_ENCODER \
17
+ --mel_size 80 \
18
+ --speech_encoder_hidden_size 1024 \
19
+ --speech_encoder_type whisper \
20
+ --tune_speech_generator_only True \
21
+ --bf16 True \
22
+ --output_dir ../checkpoints/tmp \
23
+ --num_train_epochs 8 \
24
+ --per_device_train_batch_size 1 \
25
+ --per_device_eval_batch_size 1 \
26
+ --gradient_accumulation_steps 2 \
27
+ --evaluation_strategy "no" \
28
+ --save_strategy "steps" \
29
+ --save_steps 2000 \
30
+ --save_total_limit 1 \
31
+ --learning_rate 1e-4 \
32
+ --weight_decay 0. \
33
+ --warmup_ratio 0.03 \
34
+ --logging_steps 10 \
35
+ --tf32 True \
36
+ --model_max_length 2048 \
37
+ --gradient_checkpointing True \
38
+ --dataloader_num_workers 8 \
39
+ --has_tgt_units True
scripts/finetune.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
6
+ PROMPT_VERSION=qwen
7
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_20250112_fc_mixed_vfva_text_fake_audios.jsonl
8
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250112_fc_mixed_vfva_text_fake_audios.jsonl
9
+ CACHE_DIR="../output/cached_sft_20250112"
10
+
11
+ deepspeed ../omni_speech/train/train_mem.py \
12
+ --deepspeed zero2.json \
13
+ --model_name_or_path $MODEL_PATH \
14
+ --version $PROMPT_VERSION \
15
+ --data_path $DATA_PATH \
16
+ --dev_path $DEV_PATH \
17
+ --cache_dir $CACHE_DIR \
18
+ --speech_encoder $SPEECH_ENCODER \
19
+ --mel_size 80 \
20
+ --speech_encoder_hidden_size 1024 \
21
+ --speech_encoder_type whisper \
22
+ --pretrain_speech_projector $SPEECH_ADAPTER \
23
+ --bf16 True \
24
+ --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text \
25
+ --num_train_epochs 2 \
26
+ --per_device_train_batch_size 1 \
27
+ --per_device_eval_batch_size 1 \
28
+ --gradient_accumulation_steps 4 \
29
+ --evaluation_strategy "steps" \
30
+ --save_strategy "steps" \
31
+ --eval_steps 2000 \
32
+ --save_steps 6000 \
33
+ --save_total_limit 1 \
34
+ --learning_rate 2e-5 \
35
+ --weight_decay 0. \
36
+ --warmup_ratio 0.03 \
37
+ --lr_scheduler_type "cosine" \
38
+ --logging_steps 1 \
39
+ --tf32 True \
40
+ --model_max_length 8192 \
41
+ --gradient_checkpointing True \
42
+ --dataloader_num_workers 8
scripts/finetune_llm_speech_decoder.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # it currently supports for batch = 1 only.
4
+
5
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
6
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
7
+ SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
8
+ PROMPT_VERSION=qwen
9
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_20250106_fc_mixed_tgt_units.jsonl
10
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
11
+ CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
12
+
13
+ deepspeed ../omni_speech/train/train_mem.py \
14
+ --deepspeed zero2.json \
15
+ --model_name_or_path $MODEL_PATH \
16
+ --version $PROMPT_VERSION \
17
+ --data_path $DATA_PATH \
18
+ --dev_path $DEV_PATH \
19
+ --cache_dir $CACHE_DIR \
20
+ --speech_encoder $SPEECH_ENCODER \
21
+ --mel_size 80 \
22
+ --speech_encoder_hidden_size 1024 \
23
+ --speech_encoder_type whisper \
24
+ --pretrain_speech_projector $SPEECH_ADAPTER \
25
+ --bf16 True \
26
+ --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_fixed_all \
27
+ --num_train_epochs 3 \
28
+ --per_device_train_batch_size 1 \
29
+ --per_device_eval_batch_size 1 \
30
+ --gradient_accumulation_steps 4 \
31
+ --evaluation_strategy "steps" \
32
+ --save_strategy "steps" \
33
+ --eval_steps 2000 \
34
+ --save_steps 2000 \
35
+ --save_total_limit 1 \
36
+ --learning_rate 1e-5 \
37
+ --weight_decay 0. \
38
+ --warmup_ratio 0.03 \
39
+ --lr_scheduler_type "cosine" \
40
+ --logging_steps 1 \
41
+ --tf32 True \
42
+ --model_max_length 1024 \
43
+ --gradient_checkpointing True \
44
+ --dataloader_num_workers 8 \
45
+ --has_tgt_units True \
46
+ --ctc_loss_weight 2.0
47
+
48
+
49
+ # MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc
50
+ # SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
51
+ # PROMPT_VERSION=qwen
52
+ # DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/moss/moss_100K_phase3_tgt_units_processed.jsonl
53
+ # # DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
54
+ # CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
55
+
56
+ # deepspeed ../omni_speech/train/train_mem.py \
57
+ # --deepspeed zero2.json \
58
+ # --model_name_or_path $MODEL_PATH \
59
+ # --version $PROMPT_VERSION \
60
+ # --data_path $DATA_PATH \
61
+ # --cache_dir $CACHE_DIR \
62
+ # --speech_encoder $SPEECH_ENCODER \
63
+ # --mel_size 80 \
64
+ # --speech_encoder_hidden_size 1024 \
65
+ # --speech_encoder_type whisper \
66
+ # --bf16 True \
67
+ # --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_all \
68
+ # --num_train_epochs 5 \
69
+ # --per_device_train_batch_size 1 \
70
+ # --per_device_eval_batch_size 1 \
71
+ # --gradient_accumulation_steps 4 \
72
+ # --evaluation_strategy "no" \
73
+ # --save_strategy "steps" \
74
+ # --save_steps 10000 \
75
+ # --save_total_limit 1 \
76
+ # --learning_rate 1e-4 \
77
+ # --weight_decay 0. \
78
+ # --warmup_ratio 0.03 \
79
+ # --logging_steps 1 \
80
+ # --tf32 True \
81
+ # --model_max_length 2048 \
82
+ # --gradient_checkpointing True \
83
+ # --dataloader_num_workers 8 \
84
+ # --has_tgt_units True \
85
+ # --ctc_loss_weight 10.0
scripts/finetune_lora.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
6
+ PROMPT_VERSION=qwen
7
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_tmp.jsonl
8
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_tmp.jsonl
9
+ CACHE_DIR="../output/cached_sft"
10
+
11
+ deepspeed ../omni_speech/train/train_mem.py \
12
+ --deepspeed zero2.json \
13
+ --lora_enable True \
14
+ --model_name_or_path $MODEL_PATH \
15
+ --version $PROMPT_VERSION \
16
+ --data_path $DATA_PATH \
17
+ --dev_path $DEV_PATH \
18
+ --cache_dir $CACHE_DIR \
19
+ --speech_encoder $SPEECH_ENCODER \
20
+ --mel_size 80 \
21
+ --speech_encoder_hidden_size 1024 \
22
+ --speech_encoder_type whisper \
23
+ --pretrain_speech_projector $SPEECH_ADAPTER \
24
+ --bf16 True \
25
+ --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-lora \
26
+ --num_train_epochs 18 \
27
+ --per_device_train_batch_size 2 \
28
+ --per_device_eval_batch_size 1 \
29
+ --gradient_accumulation_steps 4 \
30
+ --evaluation_strategy "steps" \
31
+ --save_strategy "steps" \
32
+ --eval_steps 1000 \
33
+ --save_steps 1000 \
34
+ --save_total_limit 1 \
35
+ --learning_rate 2e-5 \
36
+ --optim adamw_torch \
37
+ --weight_decay 0. \
38
+ --warmup_ratio 0.03 \
39
+ --logging_steps 1 \
40
+ --tf32 True \
41
+ --model_max_length 2048 \
42
+ --gradient_checkpointing True \
43
+ --dataloader_num_workers 8
scripts/finetune_minicpmo.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GPUS_PER_NODE=8
4
+ # NNODES=1
5
+ # NODE_RANK=0
6
+ # MASTER_ADDR=localhost
7
+ # MASTER_PORT=6001
8
+
9
+ MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/minicpmo_sft_asr"
10
+ TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
11
+ # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
12
+ # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
13
+ # See the section for finetuning in README for more information.
14
+ DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/sft/train_20250219_fc_mixed_text_filter_a_um.jsonl"
15
+ EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/sft/dev_20250219_fc_mixed_text_filter_a_um.jsonl"
16
+
17
+ # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
18
+ # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
19
+ LLM_TYPE="qwen"
20
+ MODEL_MAX_Length=8192 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
21
+
22
+
23
+ # DISTRIBUTED_ARGS="
24
+ # --nproc_per_node $GPUS_PER_NODE \
25
+ # --nnodes $NNODES \
26
+ # --node_rank $NODE_RANK \
27
+ # --master_addr $MASTER_ADDR \
28
+ # --master_port $MASTER_PORT
29
+ # "
30
+
31
+ deepspeed ../omni_speech/train/train_minicpmo.py \
32
+ --deepspeed zero2.json \
33
+ --model_name_or_path $MODEL \
34
+ --tokenizer_path $TOKENIZER_PATH \
35
+ --llm_type $LLM_TYPE \
36
+ --data_path $DATA \
37
+ --eval_data_path $EVAL_DATA \
38
+ --remove_unused_columns false \
39
+ --label_names "labels" \
40
+ --prediction_loss_only false \
41
+ --bf16 true \
42
+ --do_train \
43
+ --do_eval \
44
+ --tune_speech true \
45
+ --tune_llm true \
46
+ --model_max_length $MODEL_MAX_Length \
47
+ --eval_steps 1000 \
48
+ --output_dir ../checkpoints/minicpmo_sft_vi_fc_fixed \
49
+ --num_train_epochs 1 \
50
+ --logging_strategy "steps" \
51
+ --per_device_train_batch_size 1 \
52
+ --per_device_eval_batch_size 1 \
53
+ --gradient_accumulation_steps 4 \
54
+ --evaluation_strategy "steps" \
55
+ --save_strategy "no" \
56
+ --save_steps 4000 \
57
+ --save_total_limit 1 \
58
+ --learning_rate 1e-5 \
59
+ --max_grad_norm 20. \
60
+ --weight_decay 0. \
61
+ --warmup_ratio 0.03 \
62
+ --lr_scheduler_type "cosine" \
63
+ --logging_steps 1 \
64
+ --tf32 True \
65
+ --gradient_checkpointing true
scripts/finetune_minicpmo_asr.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GPUS_PER_NODE=8
4
+ # NNODES=1
5
+ # NODE_RANK=0
6
+ # MASTER_ADDR=localhost
7
+ # MASTER_PORT=6001
8
+
9
+ MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
10
+ # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
11
+ # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
12
+ # See the section for finetuning in README for more information.
13
+ DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
14
+ EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
15
+
16
+ # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
17
+ # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
18
+ LLM_TYPE="qwen"
19
+ MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
20
+
21
+
22
+ # DISTRIBUTED_ARGS="
23
+ # --nproc_per_node $GPUS_PER_NODE \
24
+ # --nnodes $NNODES \
25
+ # --node_rank $NODE_RANK \
26
+ # --master_addr $MASTER_ADDR \
27
+ # --master_port $MASTER_PORT
28
+ # "
29
+
30
+ deepspeed ../omni_speech/train/train_minicpmo.py \
31
+ --deepspeed zero2.json \
32
+ --model_name_or_path $MODEL \
33
+ --llm_type $LLM_TYPE \
34
+ --data_path $DATA \
35
+ --eval_data_path $EVAL_DATA \
36
+ --remove_unused_columns false \
37
+ --label_names "labels" \
38
+ --prediction_loss_only false \
39
+ --bf16 true \
40
+ --do_train \
41
+ --do_eval \
42
+ --tune_speech true \
43
+ --tune_llm false \
44
+ --model_max_length $MODEL_MAX_Length \
45
+ --eval_steps 4000 \
46
+ --output_dir ../checkpoints/minicpmo_sft_asr_new \
47
+ --num_train_epochs 1 \
48
+ --logging_strategy "steps" \
49
+ --per_device_train_batch_size 1 \
50
+ --per_device_eval_batch_size 1 \
51
+ --gradient_accumulation_steps 4 \
52
+ --evaluation_strategy "steps" \
53
+ --save_strategy "steps" \
54
+ --save_steps 10000 \
55
+ --save_total_limit 1 \
56
+ --learning_rate 2e-4 \
57
+ --max_grad_norm 20. \
58
+ --weight_decay 0. \
59
+ --warmup_ratio 0.03 \
60
+ --lr_scheduler_type "cosine" \
61
+ --logging_steps 1 \
62
+ --tf32 True \
63
+ --gradient_checkpointing true
scripts/finetune_speech_decoder.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # it currently supports for batch = 1 only.
4
+
5
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text
6
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
7
+ PROMPT_VERSION=qwen
8
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/20250114_tgt_unit_preprocessed_combined_mix_text_filtered.jsonl
9
+ # DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250103.jsonl
10
+ CACHE_DIR="../output/cached_sft_speech_decoder_20250114"
11
+
12
+ deepspeed ../omni_speech/train/train_mem.py \
13
+ --deepspeed zero2.json \
14
+ --model_name_or_path $MODEL_PATH \
15
+ --version $PROMPT_VERSION \
16
+ --data_path $DATA_PATH \
17
+ --cache_dir $CACHE_DIR \
18
+ --speech_encoder $SPEECH_ENCODER \
19
+ --mel_size 80 \
20
+ --speech_encoder_hidden_size 1024 \
21
+ --speech_encoder_type whisper \
22
+ --tune_speech_generator_only True \
23
+ --bf16 True \
24
+ --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text_speech-decoder \
25
+ --num_train_epochs 16 \
26
+ --per_device_train_batch_size 1 \
27
+ --per_device_eval_batch_size 1 \
28
+ --gradient_accumulation_steps 4 \
29
+ --evaluation_strategy "no" \
30
+ --save_strategy "no" \
31
+ --save_steps 3000 \
32
+ --save_total_limit 1 \
33
+ --learning_rate 2e-4 \
34
+ --max_grad_norm 200. \
35
+ --weight_decay 0. \
36
+ --warmup_ratio 0.03 \
37
+ --logging_steps 1 \
38
+ --tf32 True \
39
+ --model_max_length 4096 \
40
+ --gradient_checkpointing True \
41
+ --dataloader_num_workers 8 \
42
+ --has_tgt_units True
scripts/minicpmp_config.json ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_vision_input": true,
3
+ "drop_vision_last_layer": false,
4
+ "image_size": 448,
5
+ "audio_chunk_length": 1.0,
6
+ "audio_config": {
7
+ "_name_or_path": "openai/whisper-medium",
8
+ "architectures": [
9
+ "MiniCPMWhisperEncoder"
10
+ ],
11
+ "begin_suppress_tokens": [
12
+ 220,
13
+ 50257
14
+ ],
15
+ "bos_token_id": 50257,
16
+ "d_model": 1024,
17
+ "decoder_attention_heads": 16,
18
+ "decoder_ffn_dim": 4096,
19
+ "decoder_layers": 24,
20
+ "decoder_start_token_id": 50258,
21
+ "encoder_attention_heads": 16,
22
+ "encoder_ffn_dim": 4096,
23
+ "encoder_layers": 24,
24
+ "eos_token_id": 50257,
25
+ "forced_decoder_ids": [
26
+ [
27
+ 1,
28
+ 50259
29
+ ],
30
+ [
31
+ 2,
32
+ 50359
33
+ ],
34
+ [
35
+ 3,
36
+ 50363
37
+ ]
38
+ ],
39
+ "max_length": 448,
40
+ "model_type": "whisper",
41
+ "num_hidden_layers": 24,
42
+ "pad_token_id": 50257,
43
+ "suppress_tokens": [
44
+ 1,
45
+ 2,
46
+ 7,
47
+ 8,
48
+ 9,
49
+ 10,
50
+ 14,
51
+ 25,
52
+ 26,
53
+ 27,
54
+ 28,
55
+ 29,
56
+ 31,
57
+ 58,
58
+ 59,
59
+ 60,
60
+ 61,
61
+ 62,
62
+ 63,
63
+ 90,
64
+ 91,
65
+ 92,
66
+ 93,
67
+ 359,
68
+ 503,
69
+ 522,
70
+ 542,
71
+ 873,
72
+ 893,
73
+ 902,
74
+ 918,
75
+ 922,
76
+ 931,
77
+ 1350,
78
+ 1853,
79
+ 1982,
80
+ 2460,
81
+ 2627,
82
+ 3246,
83
+ 3253,
84
+ 3268,
85
+ 3536,
86
+ 3846,
87
+ 3961,
88
+ 4183,
89
+ 4667,
90
+ 6585,
91
+ 6647,
92
+ 7273,
93
+ 9061,
94
+ 9383,
95
+ 10428,
96
+ 10929,
97
+ 11938,
98
+ 12033,
99
+ 12331,
100
+ 12562,
101
+ 13793,
102
+ 14157,
103
+ 14635,
104
+ 15265,
105
+ 15618,
106
+ 16553,
107
+ 16604,
108
+ 18362,
109
+ 18956,
110
+ 20075,
111
+ 21675,
112
+ 22520,
113
+ 26130,
114
+ 26161,
115
+ 26435,
116
+ 28279,
117
+ 29464,
118
+ 31650,
119
+ 32302,
120
+ 32470,
121
+ 36865,
122
+ 42863,
123
+ 47425,
124
+ 49870,
125
+ 50254,
126
+ 50258,
127
+ 50358,
128
+ 50359,
129
+ 50360,
130
+ 50361,
131
+ 50362
132
+ ],
133
+ "torch_dtype": "float32"
134
+ },
135
+ "audio_pool_step": 2,
136
+ "chunk_input": true,
137
+ "model_type": "minicpmo",
138
+ "patch_size": 14,
139
+ "query_num": 64,
140
+ "slice_config": {
141
+ "max_slice_nums": 9,
142
+ "model_type": "minicpmv"
143
+ },
144
+ "slice_mode": true,
145
+ "torch_dtype": "bfloat16",
146
+ "transformers_version": "4.44.2",
147
+ "tts_config": {
148
+ "model_type": "conditional_chattts",
149
+ "llm_dim": 3584
150
+ },
151
+ "use_cache": false,
152
+ "use_image_id": true,
153
+ "vision_batch_size": 16,
154
+ "vision_config": {
155
+ "hidden_size": 1152,
156
+ "image_size": 980,
157
+ "intermediate_size": 4304,
158
+ "model_type": "siglip_vision_model",
159
+ "num_attention_heads": 16,
160
+ "num_hidden_layers": 27,
161
+ "patch_size": 14
162
+ }
163
+ }
scripts/pretrain_minicpmo_test.sh ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GPUS_PER_NODE=8
4
+ # NNODES=1
5
+ # NODE_RANK=0
6
+ # MASTER_ADDR=localhost
7
+ # MASTER_PORT=6001
8
+
9
+ # MODEL="/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct"
10
+ PRETRAINED_LLM="/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct"
11
+ MODEL="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
12
+ # PRETRAINED_LLM="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
13
+ TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
14
+ AUDIO_ENCODER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
15
+ # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
16
+ # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
17
+ # See the section for finetuning in README for more information.
18
+ # DATA="/data1/speech/anhnmt2/cuongnm/datasets/asr/train_asr_mixed_balanced_1M5_train.json "
19
+ # EVAL_DATA="/data1/speech/anhnmt2/cuongnm/datasets/asr/train_asr_mixed_balanced_1M5_dev.json "
20
+ # DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/train_asr_eng_100000_new_dataloader.jsonl"
21
+ # EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/dev_asr_eng_1000_new_dataloader.jsonl"
22
+ DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
23
+ EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
24
+ CONFIG_PATH="minicpmp_config.json"
25
+ AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
26
+
27
+ # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
28
+ # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
29
+ LLM_TYPE="qwen"
30
+ MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
31
+ CACHE_DIR="../output/cached_sft_20252502"
32
+
33
+
34
+ # DISTRIBUTED_ARGS="
35
+ # --nproc_per_node $GPUS_PER_NODE \
36
+ # --nnodes $NNODES \
37
+ # --node_rank $NODE_RANK \
38
+ # --master_addr $MASTER_ADDR \
39
+ # --master_port $MASTER_PORT
40
+ # "
41
+ DEEPSPEED_CMD="/home/anhnmt2/.local/bin/deepspeed"
42
+
43
+ # Kiểm tra file thực thi DeepSpeed
44
+ if [ ! -x "$DEEPSPEED_CMD" ]; then
45
+ echo "Error: DeepSpeed executable not found at $DEEPSPEED_CMD."
46
+ echo "Try reinstalling with: pip install deepspeed"
47
+ exit 1
48
+ fi
49
+
50
+
51
+ CUDA_LAUNCH_BLOCKING=1 "$DEEPSPEED_CMD" --master_port 29501 ../omni_speech/train/train_minicpmo_test.py \
52
+ --deepspeed zero2.json \
53
+ --model_name_or_path $MODEL \
54
+ --pretrained_llm_path $PRETRAINED_LLM \
55
+ --tokenizer_path $TOKENIZER_PATH \
56
+ --cache_dir $CACHE_DIR \
57
+ --audio_encoder_path $AUDIO_ENCODER_PATH \
58
+ --llm_type $LLM_TYPE \
59
+ --data_path $DATA \
60
+ --eval_data_path $EVAL_DATA \
61
+ --config_path $CONFIG_PATH \
62
+ --remove_unused_columns false \
63
+ --prediction_loss_only false \
64
+ --bf16 true \
65
+ --do_train \
66
+ --do_eval \
67
+ --tune_speech false \
68
+ --tune_llm false \
69
+ --model_max_length $MODEL_MAX_Length \
70
+ --eval_steps 3000 \
71
+ --output_dir ../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector \
72
+ --num_train_epochs 3 \
73
+ --logging_strategy "steps" \
74
+ --per_device_train_batch_size 8 \
75
+ --per_device_eval_batch_size 8 \
76
+ --gradient_accumulation_steps 4 \
77
+ --evaluation_strategy "steps" \
78
+ --save_strategy "steps" \
79
+ --save_steps 5000 \
80
+ --save_total_limit 1 \
81
+ --learning_rate 5e-5 \
82
+ --weight_decay 0. \
83
+ --warmup_ratio 0.03 \
84
+ --lr_scheduler_type "cosine" \
85
+ --logging_steps 1 \
86
+ --tf32 true \
87
+ --gradient_checkpointing true
88
+ # --augment_prob 0.2 \
89
+ # --augment_path $AUGMENT_PATH
scripts/pretrained.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ PROMPT_VERSION=qwen
6
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/train_asr_eng_5M.jsonl
7
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/dev_asr_libri_spgi.jsonl
8
+ CACHE_DIR="../output/cached_asr_full"
9
+ AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
10
+
11
+ deepspeed ../omni_speech/train/train_mem.py \
12
+ --deepspeed zero2.json \
13
+ --model_name_or_path $MODEL_PATH \
14
+ --version $PROMPT_VERSION \
15
+ --data_path $DATA_PATH \
16
+ --dev_path $DEV_PATH \
17
+ --cache_dir $CACHE_DIR \
18
+ --speech_encoder $SPEECH_ENCODER \
19
+ --mel_size 80 \
20
+ --speech_encoder_hidden_size 1024 \
21
+ --speech_encoder_type whisper \
22
+ --bf16 True \
23
+ --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr-5M \
24
+ --num_train_epochs 4 \
25
+ --tune_speech_projector True \
26
+ --per_device_train_batch_size 16 \
27
+ --per_device_eval_batch_size 4 \
28
+ --gradient_accumulation_steps 2 \
29
+ --evaluation_strategy "steps" \
30
+ --save_strategy "steps" \
31
+ --eval_steps 2000 \
32
+ --save_steps 2000 \
33
+ --save_total_limit 1 \
34
+ --learning_rate 1e-3 \
35
+ --weight_decay 0. \
36
+ --warmup_ratio 0.03 \
37
+ --lr_scheduler_type "cosine" \
38
+ --logging_steps 1 \
39
+ --tf32 True \
40
+ --model_max_length 4096 \
41
+ --gradient_checkpointing True \
42
+ --dataloader_num_workers 8
43
+ # --augment_prob 0.2 \
44
+ # --augment_path $AUGMENT_PATH \
scripts/pretrained_minicpmo.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # GPUS_PER_NODE=8
4
+ # NNODES=1
5
+ # NODE_RANK=0
6
+ # MASTER_ADDR=localhost
7
+ # MASTER_PORT=6001
8
+
9
+ MODEL="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
10
+ PRETRAINED_LLM="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
11
+ TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
12
+ AUDIO_ENCODER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
13
+ # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
14
+ # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
15
+ # See the section for finetuning in README for more information.
16
+ DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/train_asr_eng_100000_new_dataloader.jsonl"
17
+ EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/dev_asr_eng_1000_new_dataloader.jsonl"
18
+ CONFIG_PATH="minicpmp_config.json"
19
+ AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
20
+
21
+ # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
22
+ # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
23
+ LLM_TYPE="qwen"
24
+ MODEL_MAX_Length=4096 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
25
+ CACHE_DIR="../output/cached_sft_20252502"
26
+
27
+
28
+ # DISTRIBUTED_ARGS="
29
+ # --nproc_per_node $GPUS_PER_NODE \
30
+ # --nnodes $NNODES \
31
+ # --node_rank $NODE_RANK \
32
+ # --master_addr $MASTER_ADDR \
33
+ # --master_port $MASTER_PORT
34
+ # "
35
+
36
+ deepspeed --master_port 29501 ../omni_speech/train/train_minicpmo.py \
37
+ --deepspeed zero2.json \
38
+ --model_name_or_path $MODEL \
39
+ --pretrained_llm_path $PRETRAINED_LLM \
40
+ --tokenizer_path $TOKENIZER_PATH \
41
+ --cache_dir $CACHE_DIR \
42
+ --audio_encoder_path $AUDIO_ENCODER_PATH \
43
+ --llm_type $LLM_TYPE \
44
+ --data_path $DATA \
45
+ --eval_data_path $EVAL_DATA \
46
+ --config_path $CONFIG_PATH \
47
+ --remove_unused_columns false \
48
+ --prediction_loss_only false \
49
+ --bf16 true \
50
+ --do_train \
51
+ --do_eval \
52
+ --tune_speech true \
53
+ --tune_llm false \
54
+ --model_max_length $MODEL_MAX_Length \
55
+ --eval_steps 1000 \
56
+ --output_dir ../checkpoints/minicpmo_whisper-medium_Qwen2.5-3B_pretrained-asr \
57
+ --num_train_epochs 1 \
58
+ --logging_strategy "steps" \
59
+ --per_device_train_batch_size 1 \
60
+ --per_device_eval_batch_size 1 \
61
+ --gradient_accumulation_steps 4 \
62
+ --evaluation_strategy "steps" \
63
+ --save_strategy "no" \
64
+ --save_steps 2000 \
65
+ --save_total_limit 1 \
66
+ --learning_rate 2e-4 \
67
+ --weight_decay 0. \
68
+ --warmup_ratio 0.03 \
69
+ --lr_scheduler_type "cosine" \
70
+ --logging_steps 1 \
71
+ --tf32 true \
72
+ --gradient_checkpointing true
73
+ # --augment_prob 0.2 \
74
+ # --augment_path $AUGMENT_PATH
scripts/test_llama.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Llama-3.1-8B-Instruct
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ PROMPT_VERSION=llama_3
6
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/new/train_asr_eng_50000.jsonl
7
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000.jsonl
8
+ CACHE_DIR="../output/cached_asr"
9
+
10
+ deepspeed ../omni_speech/train/train.py \
11
+ --deepspeed zero2.json \
12
+ --model_name_or_path $MODEL_PATH \
13
+ --version $PROMPT_VERSION \
14
+ --data_path $DATA_PATH \
15
+ --dev_path $DEV_PATH \
16
+ --cache_dir $CACHE_DIR \
17
+ --speech_encoder $SPEECH_ENCODER \
18
+ --mel_size 80 \
19
+ --speech_encoder_hidden_size 1024 \
20
+ --speech_encoder_type whisper \
21
+ --bf16 True \
22
+ --output_dir ../checkpoints/llama-omni-pretrained-asr-test \
23
+ --num_train_epochs 10 \
24
+ --tune_speech_projector True \
25
+ --per_device_train_batch_size 4 \
26
+ --per_device_eval_batch_size 2 \
27
+ --gradient_accumulation_steps 4 \
28
+ --evaluation_strategy "steps" \
29
+ --save_strategy "steps" \
30
+ --eval_steps 2000 \
31
+ --save_steps 2000 \
32
+ --save_total_limit 1 \
33
+ --learning_rate 1e-3 \
34
+ --optim adamw_torch \
35
+ --weight_decay 0. \
36
+ --warmup_ratio 0.03 \
37
+ --logging_steps 1 \
38
+ --tf32 True \
39
+ --model_max_length 2048 \
40
+ --gradient_checkpointing True \
41
+ --dataloader_num_workers 8
scripts/test_qwen.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-1.5B-Instruct
4
+ SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
5
+ PROMPT_VERSION=qwen
6
+ DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000_multiturn.jsonl
7
+ DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000_multiturn.jsonl
8
+ CACHE_DIR="../output/cached_asr"
9
+
10
+ deepspeed ../omni_speech/train/train_multiturn.py \
11
+ --deepspeed zero2.json \
12
+ --model_name_or_path $MODEL_PATH \
13
+ --version $PROMPT_VERSION \
14
+ --data_path $DATA_PATH \
15
+ --dev_path $DEV_PATH \
16
+ --cache_dir $CACHE_DIR \
17
+ --speech_encoder $SPEECH_ENCODER \
18
+ --mel_size 80 \
19
+ --speech_encoder_hidden_size 1024 \
20
+ --speech_encoder_type whisper \
21
+ --bf16 True \
22
+ --output_dir ../checkpoints/llama-omni-pretrained-asr-qwen \
23
+ --num_train_epochs 10 \
24
+ --tune_speech_projector True \
25
+ --per_device_train_batch_size 4 \
26
+ --per_device_eval_batch_size 2 \
27
+ --gradient_accumulation_steps 4 \
28
+ --evaluation_strategy "steps" \
29
+ --save_strategy "steps" \
30
+ --eval_steps 2000 \
31
+ --save_steps 2000 \
32
+ --save_total_limit 1 \
33
+ --learning_rate 1e-3 \
34
+ --optim adamw_torch \
35
+ --weight_decay 0. \
36
+ --warmup_ratio 0.03 \
37
+ --logging_steps 1 \
38
+ --tf32 True \
39
+ --model_max_length 2048 \
40
+ --gradient_checkpointing True \
41
+ --dataloader_num_workers 8
scripts/wandb/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-04-10T17:19:28.842729448+07:00","level":"INFO","msg":"stream: starting","core version":"0.19.8","symlink path":"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-core.log"}
2
+ {"time":"2025-04-10T17:19:28.960322418+07:00","level":"INFO","msg":"created new stream","id":"pfaibe0c"}
3
+ {"time":"2025-04-10T17:19:28.960351593+07:00","level":"INFO","msg":"stream: started","id":"pfaibe0c"}
4
+ {"time":"2025-04-10T17:19:28.960375959+07:00","level":"INFO","msg":"writer: Do: started","stream_id":"pfaibe0c"}
5
+ {"time":"2025-04-10T17:19:28.960456552+07:00","level":"INFO","msg":"handler: started","stream_id":"pfaibe0c"}
6
+ {"time":"2025-04-10T17:19:28.961574927+07:00","level":"INFO","msg":"sender: started","stream_id":"pfaibe0c"}
7
+ {"time":"2025-04-10T17:19:29.497777718+07:00","level":"INFO","msg":"Starting system monitor"}
scripts/wandb/debug.log ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Current SDK version is 0.19.8
2
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Configure stats pid to 1734298
3
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from /home/anhnmt2/.config/wandb/settings
4
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/settings
5
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from environment variables
6
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:setup_run_log_directory():647] Logging user logs to /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug.log
7
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:setup_run_log_directory():648] Logging internal logs to /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-internal.log
8
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():761] calling init triggers
9
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():766] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():784] starting backend
12
+ 2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():788] sending inform_init request
13
+ 2025-04-10 17:19:28,834 INFO MainThread:1734298 [backend.py:_multiprocessing_setup():101] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
14
+ 2025-04-10 17:19:28,834 INFO MainThread:1734298 [wandb_init.py:init():798] backend started and connected
15
+ 2025-04-10 17:19:28,836 INFO MainThread:1734298 [wandb_init.py:init():891] updated telemetry
16
+ 2025-04-10 17:19:28,852 INFO MainThread:1734298 [wandb_init.py:init():915] communicating run to backend with 90.0 second timeout
17
+ 2025-04-10 17:19:29,493 INFO MainThread:1734298 [wandb_init.py:init():990] starting run threads in backend
18
+ 2025-04-10 17:19:29,890 INFO MainThread:1734298 [wandb_run.py:_console_start():2375] atexit reg
19
+ 2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2227] redirect: wrap_raw
20
+ 2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2292] Wrapping output streams.
21
+ 2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2315] Redirects installed.
22
+ 2025-04-10 17:19:29,895 INFO MainThread:1734298 [wandb_init.py:init():1032] run started, returning control to user process
23
+ 2025-04-10 17:19:29,898 INFO MainThread:1734298 [wandb_run.py:_config_callback():1261] config_cb None None {'use_cache': False, 'query_num': 64, 'image_size': 448, 'drop_vision_last_layer': False, 'batch_vision_input': True, 'use_image_id': True, 'vision_batch_size': 16, 'audio_pool_step': 2, 'audio_chunk_length': 1.0, 'stream_input': False, 'init_vision': False, 'init_audio': True, 'init_tts': False, 'processor_path': '/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6', 'pretrained_encoder_path': '/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6', 'pretrained_llm_path': '/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct', 'chunk_input': True, 'slice_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'minicpmv', 'patch_size': 14, 'max_slice_nums': 9, 'scale_resolution': 448}, 'slice_mode': True, 'vision_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'siglip_vision_model', 'hidden_size': 1152, 'intermediate_size': 4304, 'num_hidden_layers': 27, 'num_attention_heads': 16, 'num_channels': 3, 'patch_size': 14, 'image_size': 980, 'attention_dropout': 0.0, 'layer_norm_eps': 1e-06, 'hidden_act': 'gelu_pytorch_tanh'}, 'audio_config': {'vocab_size': 51865, 'num_mel_bins': 80, 'd_model': 1024, 'encoder_layers': 24, 'encoder_attention_heads': 16, 'decoder_layers': 24, 'decoder_attention_heads': 16, 'decoder_ffn_dim': 4096, 'encoder_ffn_dim': 4096, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.0, 'activation_function': 'gelu', 'init_std': 0.02, 'encoder_layerdrop': 0.0, 'decoder_layerdrop': 0.0, 'use_cache': True, 'num_hidden_layers': 24, 'scale_embedding': False, 'max_source_positions': 1500, 'max_target_positions': 448, 'classifier_proj_size': 256, 'use_weighted_layer_sum': False, 'apply_spec_augment': False, 'mask_time_prob': 0.05, 'mask_time_length': 10, 'mask_time_min_masks': 2, 'mask_feature_prob': 0.0, 'mask_feature_length': 10, 'mask_feature_min_masks': 0, 'median_filter_width': 7, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': True, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 448, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257], 'architectures': ['MiniCPMWhisperEncoder'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 50257, 'pad_token_id': 50257, 'eos_token_id': 50257, 'sep_token_id': None, 'decoder_start_token_id': 50258, 'task_specific_params': None, 'problem_type': None, '_name_or_path': 'openai/whisper-medium', 'forced_decoder_ids': [[1, 50259], [2, 50359], [3, 50363]], 'model_type': 'whisper'}, 'tts_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': True, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 20, 'top_p': 0.7, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'conditional_chattts', 'llm_dim': 3584, 'hidden_size': 768, 'intermediate_size': 3072, 'num_attention_heads': 12, 'num_hidden_layers': 20, 'max_position_embeddings': 4096, 'num_audio_tokens': 626, 'num_text_tokens': 21178, 'num_mel_bins': 100, 'num_vq': 4, 'use_speaker_embedding': True, 'use_llm_hidden_state': False, 'spk_emb_token_id': 21143, 'num_spk_embs': 1, 'audio_bos_token_id': 21132, 'text_eos_token_id': 21133, 'use_text': True, 'streaming': True, 'streaming_text_chunk_size': 10, 'streaming_text_reserved_len': 300, 'streaming_audio_chunk_size': 50, 'attn_implementation': 'sdpa', 'use_mlp': True, 'aug_loss_weight': True}, 'patch_size': 14, 'vocab_size': 152064, 'max_position_embeddings': 32768, 'hidden_size': 3584, 'intermediate_size': 18944, 'num_hidden_layers': 28, 'num_attention_heads': 28, 'use_sliding_window': False, 'sliding_window': None, 'max_window_layers': 28, 'num_key_value_heads': 4, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'rope_theta': 1000000.0, 'rope_scaling': None, 'attention_dropout': 0.0, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 2048, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['Qwen2ForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 151643, 'pad_token_id': None, 'eos_token_id': 151645, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct', 'transformers_version': '4.45.0', 'model_type': 'minicpmo', 'output_dir': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector', 'overwrite_output_dir': False, 'do_train': True, 'do_eval': True, 'do_predict': False, 'eval_strategy': 'steps', 'prediction_loss_only': False, 'per_device_train_batch_size': 8, 'per_device_eval_batch_size': 8, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 4, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 3.0, 'max_steps': -1, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.03, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector/runs/Apr10_17-18-52_dgx-a100-5', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 1.0, 'logging_nan_inf_filter': True, 'save_strategy': 'steps', 'save_steps': 5000, 'save_total_limit': 1, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': True, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': True, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 3000, 'dataloader_num_workers': 0, 'dataloader_prefetch_factor': None, 'past_index': -1, 'run_name': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': 'zero2.json', 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'ddp_bucket_cap_mb': None, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': True, 'dataloader_persistent_workers': False, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': False, 'hub_always_push': False, 'gradient_checkpointing': True, 'gradient_checkpointing_kwargs': {'use_reentrant': False}, 'include_inputs_for_metrics': False, 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'evaluation_strategy': 'steps', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'dispatch_batches': None, 'split_batches': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'cache_dir': '../output/cached_sft_20252502', 'model_max_length': 2048, 'tune_vision': True, 'tune_speech': False, 'tune_llm': False, 'llm_type': 'qwen', 'use_lora': False, 'max_slice_nums': 9, 'config_path': 'minicpmp_config.json', 'init_speech': True}
24
+ 2025-04-10 17:19:29,901 INFO MainThread:1734298 [wandb_config.py:__setitem__():154] config set model/num_parameters = 802971264 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x1553c5eda240>>
25
+ 2025-04-10 17:19:29,901 INFO MainThread:1734298 [wandb_run.py:_config_callback():1261] config_cb model/num_parameters 802971264 None
scripts/wandb/latest-run/files/output.log ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0%| | 0/43233 [00:00<?, ?it/s]DEBUG:numba.core.byteflow:bytecode dump:
2
+ > 0 NOP(arg=None, lineno=1141)
3
+ 2 RESUME(arg=0, lineno=1141)
4
+ 4 LOAD_FAST(arg=0, lineno=1144)
5
+ 6 LOAD_CONST(arg=1, lineno=1144)
6
+ 8 BINARY_SUBSCR(arg=None, lineno=1144)
7
+ 12 STORE_FAST(arg=3, lineno=1144)
8
+ 14 LOAD_FAST(arg=1, lineno=1145)
9
+ 16 UNARY_NEGATIVE(arg=None, lineno=1145)
10
+ 18 LOAD_FAST(arg=3, lineno=1145)
11
+ 20 SWAP(arg=2, lineno=1145)
12
+ 22 COPY(arg=2, lineno=1145)
13
+ 24 COMPARE_OP(arg=26, lineno=1145)
14
+ 28 POP_JUMP_IF_FALSE(arg=5, lineno=1145)
15
+ 30 LOAD_FAST(arg=1, lineno=1145)
16
+ 32 COMPARE_OP(arg=26, lineno=1145)
17
+ 36 POP_JUMP_IF_FALSE(arg=5, lineno=1145)
18
+ 38 JUMP_FORWARD(arg=2, lineno=1145)
19
+ > 40 POP_TOP(arg=None, lineno=1145)
20
+ 42 JUMP_FORWARD(arg=2, lineno=1145)
21
+ > 44 LOAD_CONST(arg=1, lineno=1146)
22
+ 46 STORE_FAST(arg=3, lineno=1146)
23
+ > 48 LOAD_FAST(arg=0, lineno=1148)
24
+ 50 LOAD_CONST(arg=2, lineno=1148)
25
+ 52 BINARY_SUBSCR(arg=None, lineno=1148)
26
+ 56 STORE_FAST(arg=4, lineno=1148)
27
+ 58 LOAD_FAST(arg=1, lineno=1149)
28
+ 60 UNARY_NEGATIVE(arg=None, lineno=1149)
29
+ 62 LOAD_FAST(arg=4, lineno=1149)
30
+ 64 SWAP(arg=2, lineno=1149)
31
+ 66 COPY(arg=2, lineno=1149)
32
+ 68 COMPARE_OP(arg=26, lineno=1149)
33
+ 72 POP_JUMP_IF_FALSE(arg=5, lineno=1149)
34
+ 74 LOAD_FAST(arg=1, lineno=1149)
35
+ 76 COMPARE_OP(arg=26, lineno=1149)
36
+ 80 POP_JUMP_IF_FALSE(arg=5, lineno=1149)
37
+ 82 JUMP_FORWARD(arg=2, lineno=1149)
38
+ > 84 POP_TOP(arg=None, lineno=1149)
39
+ 86 JUMP_FORWARD(arg=2, lineno=1149)
40
+ > 88 LOAD_CONST(arg=1, lineno=1150)
41
+ 90 STORE_FAST(arg=4, lineno=1150)
42
+ > 92 LOAD_FAST(arg=2, lineno=1152)
43
+ 94 POP_JUMP_IF_FALSE(arg=43, lineno=1152)
44
+ 96 LOAD_GLOBAL(arg=1, lineno=1153)
45
+ 106 LOAD_ATTR(arg=2, lineno=1153)
46
+ 126 LOAD_FAST(arg=3, lineno=1153)
47
+ 128 CALL(arg=1, lineno=1153)
48
+ 136 LOAD_GLOBAL(arg=1, lineno=1153)
49
+ 146 LOAD_ATTR(arg=2, lineno=1153)
50
+ 166 LOAD_FAST(arg=4, lineno=1153)
51
+ 168 CALL(arg=1, lineno=1153)
52
+ 176 COMPARE_OP(arg=55, lineno=1153)
53
+ 180 RETURN_VALUE(arg=None, lineno=1153)
54
+ > 182 LOAD_GLOBAL(arg=1, lineno=1155)
55
+ 192 LOAD_ATTR(arg=4, lineno=1155)
56
+ 212 LOAD_FAST(arg=3, lineno=1155)
57
+ 214 CALL(arg=1, lineno=1155)
58
+ 222 LOAD_GLOBAL(arg=1, lineno=1155)
59
+ 232 LOAD_ATTR(arg=4, lineno=1155)
60
+ 252 LOAD_FAST(arg=4, lineno=1155)
61
+ 254 CALL(arg=1, lineno=1155)
62
+ 262 COMPARE_OP(arg=55, lineno=1155)
63
+ 266 RETURN_VALUE(arg=None, lineno=1155)
64
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
65
+ DEBUG:numba.core.byteflow:stack: []
66
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
67
+ DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1141)
68
+ DEBUG:numba.core.byteflow:stack []
69
+ DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1141)
70
+ DEBUG:numba.core.byteflow:stack []
71
+ DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1144)
72
+ DEBUG:numba.core.byteflow:stack []
73
+ DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1144)
74
+ DEBUG:numba.core.byteflow:stack ['$x4.0']
75
+ DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1144)
76
+ DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
77
+ DEBUG:numba.core.byteflow:dispatch pc=12, inst=STORE_FAST(arg=3, lineno=1144)
78
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
79
+ DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_FAST(arg=1, lineno=1145)
80
+ DEBUG:numba.core.byteflow:stack []
81
+ DEBUG:numba.core.byteflow:dispatch pc=16, inst=UNARY_NEGATIVE(arg=None, lineno=1145)
82
+ DEBUG:numba.core.byteflow:stack ['$threshold14.3']
83
+ DEBUG:numba.core.byteflow:dispatch pc=18, inst=LOAD_FAST(arg=3, lineno=1145)
84
+ DEBUG:numba.core.byteflow:stack ['$16unary_negative.4']
85
+ DEBUG:numba.core.byteflow:dispatch pc=20, inst=SWAP(arg=2, lineno=1145)
86
+ DEBUG:numba.core.byteflow:stack ['$16unary_negative.4', '$x018.5']
87
+ DEBUG:numba.core.byteflow:dispatch pc=22, inst=COPY(arg=2, lineno=1145)
88
+ DEBUG:numba.core.byteflow:stack ['$x018.5', '$16unary_negative.4']
89
+ DEBUG:numba.core.byteflow:dispatch pc=24, inst=COMPARE_OP(arg=26, lineno=1145)
90
+ DEBUG:numba.core.byteflow:stack ['$x018.5', '$16unary_negative.4', '$x018.5']
91
+ DEBUG:numba.core.byteflow:dispatch pc=28, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1145)
92
+ DEBUG:numba.core.byteflow:stack ['$x018.5', '$24compare_op.6']
93
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=30, stack=('$x018.5',), blockstack=(), npush=0), Edge(pc=40, stack=('$x018.5',), blockstack=(), npush=0)]
94
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=30 nstack_initial=1), State(pc_initial=40 nstack_initial=1)])
95
+ DEBUG:numba.core.byteflow:stack: ['$phi30.0']
96
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=30 nstack_initial=1)
97
+ DEBUG:numba.core.byteflow:dispatch pc=30, inst=LOAD_FAST(arg=1, lineno=1145)
98
+ DEBUG:numba.core.byteflow:stack ['$phi30.0']
99
+ DEBUG:numba.core.byteflow:dispatch pc=32, inst=COMPARE_OP(arg=26, lineno=1145)
100
+ DEBUG:numba.core.byteflow:stack ['$phi30.0', '$threshold30.1']
101
+ DEBUG:numba.core.byteflow:dispatch pc=36, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1145)
102
+ DEBUG:numba.core.byteflow:stack ['$32compare_op.2']
103
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=38, stack=(), blockstack=(), npush=0), Edge(pc=48, stack=(), blockstack=(), npush=0)]
104
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=40 nstack_initial=1), State(pc_initial=38 nstack_initial=0), State(pc_initial=48 nstack_initial=0)])
105
+ DEBUG:numba.core.byteflow:stack: ['$phi40.0']
106
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=40 nstack_initial=1)
107
+ DEBUG:numba.core.byteflow:dispatch pc=40, inst=POP_TOP(arg=None, lineno=1145)
108
+ DEBUG:numba.core.byteflow:stack ['$phi40.0']
109
+ DEBUG:numba.core.byteflow:dispatch pc=42, inst=JUMP_FORWARD(arg=2, lineno=1145)
110
+ DEBUG:numba.core.byteflow:stack []
111
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=48, stack=(), blockstack=(), npush=0)]
112
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=38 nstack_initial=0), State(pc_initial=48 nstack_initial=0), State(pc_initial=48 nstack_initial=0)])
113
+ DEBUG:numba.core.byteflow:stack: []
114
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=38 nstack_initial=0)
115
+ DEBUG:numba.core.byteflow:dispatch pc=38, inst=JUMP_FORWARD(arg=2, lineno=1145)
116
+ DEBUG:numba.core.byteflow:stack []
117
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=44, stack=(), blockstack=(), npush=0)]
118
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=48 nstack_initial=0), State(pc_initial=44 nstack_initial=0)])
119
+ DEBUG:numba.core.byteflow:stack: []
120
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=48 nstack_initial=0)
121
+ DEBUG:numba.core.byteflow:dispatch pc=48, inst=LOAD_FAST(arg=0, lineno=1148)
122
+ DEBUG:numba.core.byteflow:stack []
123
+ DEBUG:numba.core.byteflow:dispatch pc=50, inst=LOAD_CONST(arg=2, lineno=1148)
124
+ DEBUG:numba.core.byteflow:stack ['$x48.0']
125
+ DEBUG:numba.core.byteflow:dispatch pc=52, inst=BINARY_SUBSCR(arg=None, lineno=1148)
126
+ DEBUG:numba.core.byteflow:stack ['$x48.0', '$const50.1']
127
+ DEBUG:numba.core.byteflow:dispatch pc=56, inst=STORE_FAST(arg=4, lineno=1148)
128
+ DEBUG:numba.core.byteflow:stack ['$52binary_subscr.2']
129
+ DEBUG:numba.core.byteflow:dispatch pc=58, inst=LOAD_FAST(arg=1, lineno=1149)
130
+ DEBUG:numba.core.byteflow:stack []
131
+ DEBUG:numba.core.byteflow:dispatch pc=60, inst=UNARY_NEGATIVE(arg=None, lineno=1149)
132
+ DEBUG:numba.core.byteflow:stack ['$threshold58.3']
133
+ DEBUG:numba.core.byteflow:dispatch pc=62, inst=LOAD_FAST(arg=4, lineno=1149)
134
+ DEBUG:numba.core.byteflow:stack ['$60unary_negative.4']
135
+ DEBUG:numba.core.byteflow:dispatch pc=64, inst=SWAP(arg=2, lineno=1149)
136
+ DEBUG:numba.core.byteflow:stack ['$60unary_negative.4', '$x162.5']
137
+ DEBUG:numba.core.byteflow:dispatch pc=66, inst=COPY(arg=2, lineno=1149)
138
+ DEBUG:numba.core.byteflow:stack ['$x162.5', '$60unary_negative.4']
139
+ DEBUG:numba.core.byteflow:dispatch pc=68, inst=COMPARE_OP(arg=26, lineno=1149)
140
+ DEBUG:numba.core.byteflow:stack ['$x162.5', '$60unary_negative.4', '$x162.5']
141
+ DEBUG:numba.core.byteflow:dispatch pc=72, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1149)
142
+ DEBUG:numba.core.byteflow:stack ['$x162.5', '$68compare_op.6']
143
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=74, stack=('$x162.5',), blockstack=(), npush=0), Edge(pc=84, stack=('$x162.5',), blockstack=(), npush=0)]
144
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=44 nstack_initial=0), State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1)])
145
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=44 nstack_initial=0), State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1)])
146
+ DEBUG:numba.core.byteflow:stack: []
147
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=44 nstack_initial=0)
148
+ DEBUG:numba.core.byteflow:dispatch pc=44, inst=LOAD_CONST(arg=1, lineno=1146)
149
+ DEBUG:numba.core.byteflow:stack []
150
+ DEBUG:numba.core.byteflow:dispatch pc=46, inst=STORE_FAST(arg=3, lineno=1146)
151
+ DEBUG:numba.core.byteflow:stack ['$const44.0']
152
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=48, stack=(), blockstack=(), npush=0)]
153
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1), State(pc_initial=48 nstack_initial=0)])
154
+ DEBUG:numba.core.byteflow:stack: ['$phi74.0']
155
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=74 nstack_initial=1)
156
+ DEBUG:numba.core.byteflow:dispatch pc=74, inst=LOAD_FAST(arg=1, lineno=1149)
157
+ DEBUG:numba.core.byteflow:stack ['$phi74.0']
158
+ DEBUG:numba.core.byteflow:dispatch pc=76, inst=COMPARE_OP(arg=26, lineno=1149)
159
+ DEBUG:numba.core.byteflow:stack ['$phi74.0', '$threshold74.1']
160
+ DEBUG:numba.core.byteflow:dispatch pc=80, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1149)
161
+ DEBUG:numba.core.byteflow:stack ['$76compare_op.2']
162
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=82, stack=(), blockstack=(), npush=0), Edge(pc=92, stack=(), blockstack=(), npush=0)]
163
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=84 nstack_initial=1), State(pc_initial=48 nstack_initial=0), State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
164
+ DEBUG:numba.core.byteflow:stack: ['$phi84.0']
165
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=84 nstack_initial=1)
166
+ DEBUG:numba.core.byteflow:dispatch pc=84, inst=POP_TOP(arg=None, lineno=1149)
167
+ DEBUG:numba.core.byteflow:stack ['$phi84.0']
168
+ DEBUG:numba.core.byteflow:dispatch pc=86, inst=JUMP_FORWARD(arg=2, lineno=1149)
169
+ DEBUG:numba.core.byteflow:stack []
170
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=92, stack=(), blockstack=(), npush=0)]
171
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
172
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
173
+ DEBUG:numba.core.byteflow:stack: []
174
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=82 nstack_initial=0)
175
+ DEBUG:numba.core.byteflow:dispatch pc=82, inst=JUMP_FORWARD(arg=2, lineno=1149)
176
+ DEBUG:numba.core.byteflow:stack []
177
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=88, stack=(), blockstack=(), npush=0)]
178
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=88 nstack_initial=0)])
179
+ DEBUG:numba.core.byteflow:stack: []
180
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=92 nstack_initial=0)
181
+ DEBUG:numba.core.byteflow:dispatch pc=92, inst=LOAD_FAST(arg=2, lineno=1152)
182
+ DEBUG:numba.core.byteflow:stack []
183
+ DEBUG:numba.core.byteflow:dispatch pc=94, inst=POP_JUMP_IF_FALSE(arg=43, lineno=1152)
184
+ DEBUG:numba.core.byteflow:stack ['$zero_pos92.0']
185
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=96, stack=(), blockstack=(), npush=0), Edge(pc=182, stack=(), blockstack=(), npush=0)]
186
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0), State(pc_initial=88 nstack_initial=0), State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0)])
187
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=88 nstack_initial=0), State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0)])
188
+ DEBUG:numba.core.byteflow:stack: []
189
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=88 nstack_initial=0)
190
+ DEBUG:numba.core.byteflow:dispatch pc=88, inst=LOAD_CONST(arg=1, lineno=1150)
191
+ DEBUG:numba.core.byteflow:stack []
192
+ DEBUG:numba.core.byteflow:dispatch pc=90, inst=STORE_FAST(arg=4, lineno=1150)
193
+ DEBUG:numba.core.byteflow:stack ['$const88.0']
194
+ DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=92, stack=(), blockstack=(), npush=0)]
195
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
196
+ DEBUG:numba.core.byteflow:stack: []
197
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=96 nstack_initial=0)
198
+ DEBUG:numba.core.byteflow:dispatch pc=96, inst=LOAD_GLOBAL(arg=1, lineno=1153)
199
+ DEBUG:numba.core.byteflow:stack []
200
+ DEBUG:numba.core.byteflow:dispatch pc=106, inst=LOAD_ATTR(arg=2, lineno=1153)
201
+ DEBUG:numba.core.byteflow:stack ['$null$96.1', '$96load_global.0']
202
+ DEBUG:numba.core.byteflow:dispatch pc=126, inst=LOAD_FAST(arg=3, lineno=1153)
203
+ DEBUG:numba.core.byteflow:stack ['$null$96.1', '$106load_attr.2']
204
+ DEBUG:numba.core.byteflow:dispatch pc=128, inst=CALL(arg=1, lineno=1153)
205
+ DEBUG:numba.core.byteflow:stack ['$null$96.1', '$106load_attr.2', '$x0126.3']
206
+ DEBUG:numba.core.byteflow:dispatch pc=136, inst=LOAD_GLOBAL(arg=1, lineno=1153)
207
+ DEBUG:numba.core.byteflow:stack ['$128call.4']
208
+ DEBUG:numba.core.byteflow:dispatch pc=146, inst=LOAD_ATTR(arg=2, lineno=1153)
209
+ DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$136load_global.5']
210
+ DEBUG:numba.core.byteflow:dispatch pc=166, inst=LOAD_FAST(arg=4, lineno=1153)
211
+ DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$146load_attr.7']
212
+ DEBUG:numba.core.byteflow:dispatch pc=168, inst=CALL(arg=1, lineno=1153)
213
+ DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$146load_attr.7', '$x1166.8']
214
+ DEBUG:numba.core.byteflow:dispatch pc=176, inst=COMPARE_OP(arg=55, lineno=1153)
215
+ DEBUG:numba.core.byteflow:stack ['$128call.4', '$168call.9']
216
+ DEBUG:numba.core.byteflow:dispatch pc=180, inst=RETURN_VALUE(arg=None, lineno=1153)
217
+ DEBUG:numba.core.byteflow:stack ['$176compare_op.10']
218
+ DEBUG:numba.core.byteflow:end state. edges=[]
219
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=182 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
220
+ DEBUG:numba.core.byteflow:stack: []
221
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=182 nstack_initial=0)
222
+ DEBUG:numba.core.byteflow:dispatch pc=182, inst=LOAD_GLOBAL(arg=1, lineno=1155)
223
+ DEBUG:numba.core.byteflow:stack []
224
+ DEBUG:numba.core.byteflow:dispatch pc=192, inst=LOAD_ATTR(arg=4, lineno=1155)
225
+ DEBUG:numba.core.byteflow:stack ['$null$182.1', '$182load_global.0']
226
+ DEBUG:numba.core.byteflow:dispatch pc=212, inst=LOAD_FAST(arg=3, lineno=1155)
227
+ DEBUG:numba.core.byteflow:stack ['$null$182.1', '$192load_attr.2']
228
+ DEBUG:numba.core.byteflow:dispatch pc=214, inst=CALL(arg=1, lineno=1155)
229
+ DEBUG:numba.core.byteflow:stack ['$null$182.1', '$192load_attr.2', '$x0212.3']
230
+ DEBUG:numba.core.byteflow:dispatch pc=222, inst=LOAD_GLOBAL(arg=1, lineno=1155)
231
+ DEBUG:numba.core.byteflow:stack ['$214call.4']
232
+ DEBUG:numba.core.byteflow:dispatch pc=232, inst=LOAD_ATTR(arg=4, lineno=1155)
233
+ DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$222load_global.5']
234
+ DEBUG:numba.core.byteflow:dispatch pc=252, inst=LOAD_FAST(arg=4, lineno=1155)
235
+ DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$232load_attr.7']
236
+ DEBUG:numba.core.byteflow:dispatch pc=254, inst=CALL(arg=1, lineno=1155)
237
+ DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$232load_attr.7', '$x1252.8']
238
+ DEBUG:numba.core.byteflow:dispatch pc=262, inst=COMPARE_OP(arg=55, lineno=1155)
239
+ DEBUG:numba.core.byteflow:stack ['$214call.4', '$254call.9']
240
+ DEBUG:numba.core.byteflow:dispatch pc=266, inst=RETURN_VALUE(arg=None, lineno=1155)
241
+ DEBUG:numba.core.byteflow:stack ['$262compare_op.10']
242
+ DEBUG:numba.core.byteflow:end state. edges=[]
243
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0)])
244
+ DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
245
+ DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>,
246
+ {State(pc_initial=0 nstack_initial=0): set(),
247
+ State(pc_initial=30 nstack_initial=1): {'$phi30.0'},
248
+ State(pc_initial=38 nstack_initial=0): set(),
249
+ State(pc_initial=40 nstack_initial=1): set(),
250
+ State(pc_initial=44 nstack_initial=0): set(),
251
+ State(pc_initial=48 nstack_initial=0): set(),
252
+ State(pc_initial=74 nstack_initial=1): {'$phi74.0'},
253
+ State(pc_initial=82 nstack_initial=0): set(),
254
+ State(pc_initial=84 nstack_initial=1): set(),
255
+ State(pc_initial=88 nstack_initial=0): set(),
256
+ State(pc_initial=92 nstack_initial=0): set(),
257
+ State(pc_initial=96 nstack_initial=0): set(),
258
+ State(pc_initial=182 nstack_initial=0): set()})
259
+ DEBUG:numba.core.byteflow:defmap: {'$phi30.0': State(pc_initial=0 nstack_initial=0),
260
+ '$phi40.0': State(pc_initial=0 nstack_initial=0),
261
+ '$phi74.0': State(pc_initial=48 nstack_initial=0),
262
+ '$phi84.0': State(pc_initial=48 nstack_initial=0)}
263
+ DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>,
264
+ {'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
265
+ '$phi40.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
266
+ '$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))},
267
+ '$phi84.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}})
268
+ DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>,
269
+ {'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
270
+ '$phi40.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
271
+ '$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))},
272
+ '$phi84.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}})
273
+ DEBUG:numba.core.byteflow:keep phismap: {'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
274
+ '$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}}
275
+ DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>,
276
+ {State(pc_initial=0 nstack_initial=0): {'$phi30.0': '$x018.5'},
277
+ State(pc_initial=48 nstack_initial=0): {'$phi74.0': '$x162.5'}})
278
+ DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
279
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
280
+ AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'value': '$8binary_subscr.2'}), (14, {'res': '$threshold14.3'}), (16, {'value': '$threshold14.3', 'res': '$16unary_negative.4'}), (18, {'res': '$x018.5'}), (24, {'lhs': '$16unary_negative.4', 'rhs': '$x018.5', 'res': '$24compare_op.6'}), (28, {'pred': '$24compare_op.6'})), outgoing_phis={'$phi30.0': '$x018.5'}, blockstack=(), active_try_block=None, outgoing_edgepushed={30: ('$x018.5',), 40: ('$x018.5',)})
281
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=30 nstack_initial=1):
282
+ AdaptBlockInfo(insts=((30, {'res': '$threshold30.1'}), (32, {'lhs': '$phi30.0', 'rhs': '$threshold30.1', 'res': '$32compare_op.2'}), (36, {'pred': '$32compare_op.2'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={38: (), 48: ()})
283
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=38 nstack_initial=0):
284
+ AdaptBlockInfo(insts=((38, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={44: ()})
285
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=40 nstack_initial=1):
286
+ AdaptBlockInfo(insts=((42, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={48: ()})
287
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=44 nstack_initial=0):
288
+ AdaptBlockInfo(insts=((44, {'res': '$const44.0'}), (46, {'value': '$const44.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={48: ()})
289
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=48 nstack_initial=0):
290
+ AdaptBlockInfo(insts=((48, {'res': '$x48.0'}), (50, {'res': '$const50.1'}), (52, {'index': '$const50.1', 'target': '$x48.0', 'res': '$52binary_subscr.2'}), (56, {'value': '$52binary_subscr.2'}), (58, {'res': '$threshold58.3'}), (60, {'value': '$threshold58.3', 'res': '$60unary_negative.4'}), (62, {'res': '$x162.5'}), (68, {'lhs': '$60unary_negative.4', 'rhs': '$x162.5', 'res': '$68compare_op.6'}), (72, {'pred': '$68compare_op.6'})), outgoing_phis={'$phi74.0': '$x162.5'}, blockstack=(), active_try_block=None, outgoing_edgepushed={74: ('$x162.5',), 84: ('$x162.5',)})
291
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=74 nstack_initial=1):
292
+ AdaptBlockInfo(insts=((74, {'res': '$threshold74.1'}), (76, {'lhs': '$phi74.0', 'rhs': '$threshold74.1', 'res': '$76compare_op.2'}), (80, {'pred': '$76compare_op.2'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={82: (), 92: ()})
293
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=82 nstack_initial=0):
294
+ AdaptBlockInfo(insts=((82, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={88: ()})
295
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=84 nstack_initial=1):
296
+ AdaptBlockInfo(insts=((86, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={92: ()})
297
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=88 nstack_initial=0):
298
+ AdaptBlockInfo(insts=((88, {'res': '$const88.0'}), (90, {'value': '$const88.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={92: ()})
299
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=92 nstack_initial=0):
300
+ AdaptBlockInfo(insts=((92, {'res': '$zero_pos92.0'}), (94, {'pred': '$zero_pos92.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={96: (), 182: ()})
301
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=96 nstack_initial=0):
302
+ AdaptBlockInfo(insts=((96, {'idx': 0, 'res': '$96load_global.0'}), (106, {'item': '$96load_global.0', 'res': '$106load_attr.2'}), (126, {'res': '$x0126.3'}), (128, {'func': '$106load_attr.2', 'args': ['$x0126.3'], 'kw_names': None, 'res': '$128call.4'}), (136, {'idx': 0, 'res': '$136load_global.5'}), (146, {'item': '$136load_global.5', 'res': '$146load_attr.7'}), (166, {'res': '$x1166.8'}), (168, {'func': '$146load_attr.7', 'args': ['$x1166.8'], 'kw_names': None, 'res': '$168call.9'}), (176, {'lhs': '$128call.4', 'rhs': '$168call.9', 'res': '$176compare_op.10'}), (180, {'retval': '$176compare_op.10', 'castval': '$180return_value.11'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
303
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=182 nstack_initial=0):
304
+ AdaptBlockInfo(insts=((182, {'idx': 0, 'res': '$182load_global.0'}), (192, {'item': '$182load_global.0', 'res': '$192load_attr.2'}), (212, {'res': '$x0212.3'}), (214, {'func': '$192load_attr.2', 'args': ['$x0212.3'], 'kw_names': None, 'res': '$214call.4'}), (222, {'idx': 0, 'res': '$222load_global.5'}), (232, {'item': '$222load_global.5', 'res': '$232load_attr.7'}), (252, {'res': '$x1252.8'}), (254, {'func': '$232load_attr.7', 'args': ['$x1252.8'], 'kw_names': None, 'res': '$254call.9'}), (262, {'lhs': '$214call.4', 'rhs': '$254call.9', 'res': '$262compare_op.10'}), (266, {'retval': '$262compare_op.10', 'castval': '$266return_value.11'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
305
+ DEBUG:numba.core.interpreter:label 0:
306
+ x = arg(0, name=x) ['x']
307
+ threshold = arg(1, name=threshold) ['threshold']
308
+ zero_pos = arg(2, name=zero_pos) ['zero_pos']
309
+ $const6.1 = const(int, 0) ['$const6.1']
310
+ x0 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$const6.1', 'x', 'x0']
311
+ $16unary_negative.4 = unary(fn=<built-in function neg>, value=threshold) ['$16unary_negative.4', 'threshold']
312
+ $24compare_op.6 = $16unary_negative.4 <= x0 ['$16unary_negative.4', '$24compare_op.6', 'x0']
313
+ bool28 = global(bool: <class 'bool'>) ['bool28']
314
+ $28pred = call bool28($24compare_op.6, func=bool28, args=(Var($24compare_op.6, audio.py:1145),), kws=(), vararg=None, varkwarg=None, target=None) ['$24compare_op.6', '$28pred', 'bool28']
315
+ $phi30.0 = x0 ['$phi30.0', 'x0']
316
+ branch $28pred, 30, 40 ['$28pred']
317
+ label 30:
318
+ $32compare_op.2 = $phi30.0 <= threshold ['$32compare_op.2', '$phi30.0', 'threshold']
319
+ bool36 = global(bool: <class 'bool'>) ['bool36']
320
+ $36pred = call bool36($32compare_op.2, func=bool36, args=(Var($32compare_op.2, audio.py:1145),), kws=(), vararg=None, varkwarg=None, target=None) ['$32compare_op.2', '$36pred', 'bool36']
321
+ branch $36pred, 38, 48 ['$36pred']
322
+ label 38:
323
+ jump 44 []
324
+ label 40:
325
+ jump 48 []
326
+ label 44:
327
+ x0 = const(int, 0) ['x0']
328
+ jump 48 []
329
+ label 48:
330
+ $const50.1 = const(int, -1) ['$const50.1']
331
+ x1 = getitem(value=x, index=$const50.1, fn=<built-in function getitem>) ['$const50.1', 'x', 'x1']
332
+ $60unary_negative.4 = unary(fn=<built-in function neg>, value=threshold) ['$60unary_negative.4', 'threshold']
333
+ $68compare_op.6 = $60unary_negative.4 <= x1 ['$60unary_negative.4', '$68compare_op.6', 'x1']
334
+ bool72 = global(bool: <class 'bool'>) ['bool72']
335
+ $72pred = call bool72($68compare_op.6, func=bool72, args=(Var($68compare_op.6, audio.py:1149),), kws=(), vararg=None, varkwarg=None, target=None) ['$68compare_op.6', '$72pred', 'bool72']
336
+ $phi74.0 = x1 ['$phi74.0', 'x1']
337
+ branch $72pred, 74, 84 ['$72pred']
338
+ label 74:
339
+ $76compare_op.2 = $phi74.0 <= threshold ['$76compare_op.2', '$phi74.0', 'threshold']
340
+ bool80 = global(bool: <class 'bool'>) ['bool80']
341
+ $80pred = call bool80($76compare_op.2, func=bool80, args=(Var($76compare_op.2, audio.py:1149),), kws=(), vararg=None, varkwarg=None, target=None) ['$76compare_op.2', '$80pred', 'bool80']
342
+ branch $80pred, 82, 92 ['$80pred']
343
+ label 82:
344
+ jump 88 []
345
+ label 84:
346
+ jump 92 []
347
+ label 88:
348
+ x1 = const(int, 0) ['x1']
349
+ jump 92 []
350
+ label 92:
351
+ bool94 = global(bool: <class 'bool'>) ['bool94']
352
+ $94pred = call bool94(zero_pos, func=bool94, args=(Var(zero_pos, audio.py:1141),), kws=(), vararg=None, varkwarg=None, target=None) ['$94pred', 'bool94', 'zero_pos']
353
+ branch $94pred, 96, 182 ['$94pred']
354
+ label 96:
355
+ $96load_global.0 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$96load_global.0']
356
+ $106load_attr.2 = getattr(value=$96load_global.0, attr=signbit) ['$106load_attr.2', '$96load_global.0']
357
+ $128call.4 = call $106load_attr.2(x0, func=$106load_attr.2, args=[Var(x0, audio.py:1144)], kws=(), vararg=None, varkwarg=None, target=None) ['$106load_attr.2', '$128call.4', 'x0']
358
+ $136load_global.5 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$136load_global.5']
359
+ $146load_attr.7 = getattr(value=$136load_global.5, attr=signbit) ['$136load_global.5', '$146load_attr.7']
360
+ $168call.9 = call $146load_attr.7(x1, func=$146load_attr.7, args=[Var(x1, audio.py:1148)], kws=(), vararg=None, varkwarg=None, target=None) ['$146load_attr.7', '$168call.9', 'x1']
361
+ $176compare_op.10 = $128call.4 != $168call.9 ['$128call.4', '$168call.9', '$176compare_op.10']
362
+ $180return_value.11 = cast(value=$176compare_op.10) ['$176compare_op.10', '$180return_value.11']
363
+ return $180return_value.11 ['$180return_value.11']
364
+ label 182:
365
+ $182load_global.0 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$182load_global.0']
366
+ $192load_attr.2 = getattr(value=$182load_global.0, attr=sign) ['$182load_global.0', '$192load_attr.2']
367
+ $214call.4 = call $192load_attr.2(x0, func=$192load_attr.2, args=[Var(x0, audio.py:1144)], kws=(), vararg=None, varkwarg=None, target=None) ['$192load_attr.2', '$214call.4', 'x0']
368
+ $222load_global.5 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$222load_global.5']
369
+ $232load_attr.7 = getattr(value=$222load_global.5, attr=sign) ['$222load_global.5', '$232load_attr.7']
370
+ $254call.9 = call $232load_attr.7(x1, func=$232load_attr.7, args=[Var(x1, audio.py:1148)], kws=(), vararg=None, varkwarg=None, target=None) ['$232load_attr.7', '$254call.9', 'x1']
371
+ $262compare_op.10 = $214call.4 != $254call.9 ['$214call.4', '$254call.9', '$262compare_op.10']
372
+ $266return_value.11 = cast(value=$262compare_op.10) ['$262compare_op.10', '$266return_value.11']
373
+ return $266return_value.11 ['$266return_value.11']
374
+
375
+ DEBUG:numba.core.byteflow:bytecode dump:
376
+ > 0 NOP(arg=None, lineno=1039)
377
+ 2 RESUME(arg=0, lineno=1039)
378
+ 4 LOAD_FAST(arg=0, lineno=1042)
379
+ 6 LOAD_CONST(arg=1, lineno=1042)
380
+ 8 BINARY_SUBSCR(arg=None, lineno=1042)
381
+ 12 LOAD_FAST(arg=0, lineno=1042)
382
+ 14 LOAD_CONST(arg=2, lineno=1042)
383
+ 16 BINARY_SUBSCR(arg=None, lineno=1042)
384
+ 20 COMPARE_OP(arg=68, lineno=1042)
385
+ 24 LOAD_FAST(arg=0, lineno=1042)
386
+ 26 LOAD_CONST(arg=1, lineno=1042)
387
+ 28 BINARY_SUBSCR(arg=None, lineno=1042)
388
+ 32 LOAD_FAST(arg=0, lineno=1042)
389
+ 34 LOAD_CONST(arg=3, lineno=1042)
390
+ 36 BINARY_SUBSCR(arg=None, lineno=1042)
391
+ 40 COMPARE_OP(arg=92, lineno=1042)
392
+ 44 BINARY_OP(arg=1, lineno=1042)
393
+ 48 RETURN_VALUE(arg=None, lineno=1042)
394
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
395
+ DEBUG:numba.core.byteflow:stack: []
396
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
397
+ DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1039)
398
+ DEBUG:numba.core.byteflow:stack []
399
+ DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1039)
400
+ DEBUG:numba.core.byteflow:stack []
401
+ DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1042)
402
+ DEBUG:numba.core.byteflow:stack []
403
+ DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1042)
404
+ DEBUG:numba.core.byteflow:stack ['$x4.0']
405
+ DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1042)
406
+ DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
407
+ DEBUG:numba.core.byteflow:dispatch pc=12, inst=LOAD_FAST(arg=0, lineno=1042)
408
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
409
+ DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_CONST(arg=2, lineno=1042)
410
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3']
411
+ DEBUG:numba.core.byteflow:dispatch pc=16, inst=BINARY_SUBSCR(arg=None, lineno=1042)
412
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3', '$const14.4']
413
+ DEBUG:numba.core.byteflow:dispatch pc=20, inst=COMPARE_OP(arg=68, lineno=1042)
414
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$16binary_subscr.5']
415
+ DEBUG:numba.core.byteflow:dispatch pc=24, inst=LOAD_FAST(arg=0, lineno=1042)
416
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6']
417
+ DEBUG:numba.core.byteflow:dispatch pc=26, inst=LOAD_CONST(arg=1, lineno=1042)
418
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7']
419
+ DEBUG:numba.core.byteflow:dispatch pc=28, inst=BINARY_SUBSCR(arg=None, lineno=1042)
420
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7', '$const26.8']
421
+ DEBUG:numba.core.byteflow:dispatch pc=32, inst=LOAD_FAST(arg=0, lineno=1042)
422
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9']
423
+ DEBUG:numba.core.byteflow:dispatch pc=34, inst=LOAD_CONST(arg=3, lineno=1042)
424
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10']
425
+ DEBUG:numba.core.byteflow:dispatch pc=36, inst=BINARY_SUBSCR(arg=None, lineno=1042)
426
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10', '$const34.11']
427
+ DEBUG:numba.core.byteflow:dispatch pc=40, inst=COMPARE_OP(arg=92, lineno=1042)
428
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$36binary_subscr.12']
429
+ DEBUG:numba.core.byteflow:dispatch pc=44, inst=BINARY_OP(arg=1, lineno=1042)
430
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$40compare_op.13']
431
+ DEBUG:numba.core.byteflow:dispatch pc=48, inst=RETURN_VALUE(arg=None, lineno=1042)
432
+ DEBUG:numba.core.byteflow:stack ['$binop_and_44.14']
433
+ DEBUG:numba.core.byteflow:end state. edges=[]
434
+ DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
435
+ DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>, {State(pc_initial=0 nstack_initial=0): set()})
436
+ DEBUG:numba.core.byteflow:defmap: {}
437
+ DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>, {})
438
+ DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>, {})
439
+ DEBUG:numba.core.byteflow:keep phismap: {}
440
+ DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>, {})
441
+ DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
442
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
443
+ AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'res': '$x12.3'}), (14, {'res': '$const14.4'}), (16, {'index': '$const14.4', 'target': '$x12.3', 'res': '$16binary_subscr.5'}), (20, {'lhs': '$8binary_subscr.2', 'rhs': '$16binary_subscr.5', 'res': '$20compare_op.6'}), (24, {'res': '$x24.7'}), (26, {'res': '$const26.8'}), (28, {'index': '$const26.8', 'target': '$x24.7', 'res': '$28binary_subscr.9'}), (32, {'res': '$x32.10'}), (34, {'res': '$const34.11'}), (36, {'index': '$const34.11', 'target': '$x32.10', 'res': '$36binary_subscr.12'}), (40, {'lhs': '$28binary_subscr.9', 'rhs': '$36binary_subscr.12', 'res': '$40compare_op.13'}), (44, {'op': '&', 'lhs': '$20compare_op.6', 'rhs': '$40compare_op.13', 'res': '$binop_and_44.14'}), (48, {'retval': '$binop_and_44.14', 'castval': '$48return_value.15'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
444
+ DEBUG:numba.core.interpreter:label 0:
445
+ x = arg(0, name=x) ['x']
446
+ $const6.1 = const(int, 0) ['$const6.1']
447
+ $8binary_subscr.2 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$8binary_subscr.2', '$const6.1', 'x']
448
+ $const14.4 = const(int, -1) ['$const14.4']
449
+ $16binary_subscr.5 = getitem(value=x, index=$const14.4, fn=<built-in function getitem>) ['$16binary_subscr.5', '$const14.4', 'x']
450
+ $20compare_op.6 = $8binary_subscr.2 > $16binary_subscr.5 ['$16binary_subscr.5', '$20compare_op.6', '$8binary_subscr.2']
451
+ $const26.8 = const(int, 0) ['$const26.8']
452
+ $28binary_subscr.9 = getitem(value=x, index=$const26.8, fn=<built-in function getitem>) ['$28binary_subscr.9', '$const26.8', 'x']
453
+ $const34.11 = const(int, 1) ['$const34.11']
454
+ $36binary_subscr.12 = getitem(value=x, index=$const34.11, fn=<built-in function getitem>) ['$36binary_subscr.12', '$const34.11', 'x']
455
+ $40compare_op.13 = $28binary_subscr.9 >= $36binary_subscr.12 ['$28binary_subscr.9', '$36binary_subscr.12', '$40compare_op.13']
456
+ $binop_and_44.14 = $20compare_op.6 & $40compare_op.13 ['$20compare_op.6', '$40compare_op.13', '$binop_and_44.14']
457
+ $48return_value.15 = cast(value=$binop_and_44.14) ['$48return_value.15', '$binop_and_44.14']
458
+ return $48return_value.15 ['$48return_value.15']
459
+
460
+ DEBUG:numba.core.byteflow:bytecode dump:
461
+ > 0 NOP(arg=None, lineno=1045)
462
+ 2 RESUME(arg=0, lineno=1045)
463
+ 4 LOAD_FAST(arg=0, lineno=1048)
464
+ 6 LOAD_CONST(arg=1, lineno=1048)
465
+ 8 BINARY_SUBSCR(arg=None, lineno=1048)
466
+ 12 LOAD_FAST(arg=0, lineno=1048)
467
+ 14 LOAD_CONST(arg=2, lineno=1048)
468
+ 16 BINARY_SUBSCR(arg=None, lineno=1048)
469
+ 20 COMPARE_OP(arg=2, lineno=1048)
470
+ 24 LOAD_FAST(arg=0, lineno=1048)
471
+ 26 LOAD_CONST(arg=1, lineno=1048)
472
+ 28 BINARY_SUBSCR(arg=None, lineno=1048)
473
+ 32 LOAD_FAST(arg=0, lineno=1048)
474
+ 34 LOAD_CONST(arg=3, lineno=1048)
475
+ 36 BINARY_SUBSCR(arg=None, lineno=1048)
476
+ 40 COMPARE_OP(arg=26, lineno=1048)
477
+ 44 BINARY_OP(arg=1, lineno=1048)
478
+ 48 RETURN_VALUE(arg=None, lineno=1048)
479
+ DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
480
+ DEBUG:numba.core.byteflow:stack: []
481
+ DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
482
+ DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1045)
483
+ DEBUG:numba.core.byteflow:stack []
484
+ DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1045)
485
+ DEBUG:numba.core.byteflow:stack []
486
+ DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1048)
487
+ DEBUG:numba.core.byteflow:stack []
488
+ DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1048)
489
+ DEBUG:numba.core.byteflow:stack ['$x4.0']
490
+ DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1048)
491
+ DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
492
+ DEBUG:numba.core.byteflow:dispatch pc=12, inst=LOAD_FAST(arg=0, lineno=1048)
493
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
494
+ DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_CONST(arg=2, lineno=1048)
495
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3']
496
+ DEBUG:numba.core.byteflow:dispatch pc=16, inst=BINARY_SUBSCR(arg=None, lineno=1048)
497
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3', '$const14.4']
498
+ DEBUG:numba.core.byteflow:dispatch pc=20, inst=COMPARE_OP(arg=2, lineno=1048)
499
+ DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$16binary_subscr.5']
500
+ DEBUG:numba.core.byteflow:dispatch pc=24, inst=LOAD_FAST(arg=0, lineno=1048)
501
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6']
502
+ DEBUG:numba.core.byteflow:dispatch pc=26, inst=LOAD_CONST(arg=1, lineno=1048)
503
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7']
504
+ DEBUG:numba.core.byteflow:dispatch pc=28, inst=BINARY_SUBSCR(arg=None, lineno=1048)
505
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7', '$const26.8']
506
+ DEBUG:numba.core.byteflow:dispatch pc=32, inst=LOAD_FAST(arg=0, lineno=1048)
507
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9']
508
+ DEBUG:numba.core.byteflow:dispatch pc=34, inst=LOAD_CONST(arg=3, lineno=1048)
509
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10']
510
+ DEBUG:numba.core.byteflow:dispatch pc=36, inst=BINARY_SUBSCR(arg=None, lineno=1048)
511
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10', '$const34.11']
512
+ DEBUG:numba.core.byteflow:dispatch pc=40, inst=COMPARE_OP(arg=26, lineno=1048)
513
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$36binary_subscr.12']
514
+ DEBUG:numba.core.byteflow:dispatch pc=44, inst=BINARY_OP(arg=1, lineno=1048)
515
+ DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$40compare_op.13']
516
+ DEBUG:numba.core.byteflow:dispatch pc=48, inst=RETURN_VALUE(arg=None, lineno=1048)
517
+ DEBUG:numba.core.byteflow:stack ['$binop_and_44.14']
518
+ DEBUG:numba.core.byteflow:end state. edges=[]
519
+ DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
520
+ DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>, {State(pc_initial=0 nstack_initial=0): set()})
521
+ DEBUG:numba.core.byteflow:defmap: {}
522
+ DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>, {})
523
+ DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>, {})
524
+ DEBUG:numba.core.byteflow:keep phismap: {}
525
+ DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>, {})
526
+ DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
527
+ DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
528
+ AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'res': '$x12.3'}), (14, {'res': '$const14.4'}), (16, {'index': '$const14.4', 'target': '$x12.3', 'res': '$16binary_subscr.5'}), (20, {'lhs': '$8binary_subscr.2', 'rhs': '$16binary_subscr.5', 'res': '$20compare_op.6'}), (24, {'res': '$x24.7'}), (26, {'res': '$const26.8'}), (28, {'index': '$const26.8', 'target': '$x24.7', 'res': '$28binary_subscr.9'}), (32, {'res': '$x32.10'}), (34, {'res': '$const34.11'}), (36, {'index': '$const34.11', 'target': '$x32.10', 'res': '$36binary_subscr.12'}), (40, {'lhs': '$28binary_subscr.9', 'rhs': '$36binary_subscr.12', 'res': '$40compare_op.13'}), (44, {'op': '&', 'lhs': '$20compare_op.6', 'rhs': '$40compare_op.13', 'res': '$binop_and_44.14'}), (48, {'retval': '$binop_and_44.14', 'castval': '$48return_value.15'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
529
+ DEBUG:numba.core.interpreter:label 0:
530
+ x = arg(0, name=x) ['x']
531
+ $const6.1 = const(int, 0) ['$const6.1']
532
+ $8binary_subscr.2 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$8binary_subscr.2', '$const6.1', 'x']
533
+ $const14.4 = const(int, -1) ['$const14.4']
534
+ $16binary_subscr.5 = getitem(value=x, index=$const14.4, fn=<built-in function getitem>) ['$16binary_subscr.5', '$const14.4', 'x']
535
+ $20compare_op.6 = $8binary_subscr.2 < $16binary_subscr.5 ['$16binary_subscr.5', '$20compare_op.6', '$8binary_subscr.2']
536
+ $const26.8 = const(int, 0) ['$const26.8']
537
+ $28binary_subscr.9 = getitem(value=x, index=$const26.8, fn=<built-in function getitem>) ['$28binary_subscr.9', '$const26.8', 'x']
538
+ $const34.11 = const(int, 1) ['$const34.11']
539
+ $36binary_subscr.12 = getitem(value=x, index=$const34.11, fn=<built-in function getitem>) ['$36binary_subscr.12', '$const34.11', 'x']
540
+ $40compare_op.13 = $28binary_subscr.9 <= $36binary_subscr.12 ['$28binary_subscr.9', '$36binary_subscr.12', '$40compare_op.13']
541
+ $binop_and_44.14 = $20compare_op.6 & $40compare_op.13 ['$20compare_op.6', '$40compare_op.13', '$binop_and_44.14']
542
+ $48return_value.15 = cast(value=$binop_and_44.14) ['$48return_value.15', '$binop_and_44.14']
543
+ return $48return_value.15 ['$48return_value.15']
544
+
545
+ 0%| | 14/43233 [00:59<27:10:42, 2.26s/it]Traceback (most recent call last):
546
+ {'loss': 6.6172, 'grad_norm': 10.827251434326172, 'learning_rate': 3.8550501156515035e-08, 'epoch': 0.0}
547
+ {'loss': 5.794, 'grad_norm': 14.017024040222168, 'learning_rate': 7.710100231303007e-08, 'epoch': 0.0}
548
+ {'loss': 6.8788, 'grad_norm': 13.020977020263672, 'learning_rate': 1.1565150346954511e-07, 'epoch': 0.0}
549
+ {'loss': 6.6162, 'grad_norm': 18.2950439453125, 'learning_rate': 1.5420200462606014e-07, 'epoch': 0.0}
550
+ {'loss': 6.9646, 'grad_norm': 14.263402938842773, 'learning_rate': 1.9275250578257518e-07, 'epoch': 0.0}
551
+ {'loss': 6.631, 'grad_norm': 13.121792793273926, 'learning_rate': 2.3130300693909022e-07, 'epoch': 0.0}
552
+ {'loss': 7.2093, 'grad_norm': 18.358381271362305, 'learning_rate': 2.6985350809560526e-07, 'epoch': 0.0}
553
+ {'loss': 7.7485, 'grad_norm': 24.542631149291992, 'learning_rate': 3.084040092521203e-07, 'epoch': 0.0}
554
+ {'loss': 6.2489, 'grad_norm': 12.371420860290527, 'learning_rate': 3.469545104086353e-07, 'epoch': 0.0}
555
+ {'loss': 6.6981, 'grad_norm': 22.148744583129883, 'learning_rate': 3.8550501156515036e-07, 'epoch': 0.0}
556
+ {'loss': 6.9253, 'grad_norm': 25.32149314880371, 'learning_rate': 4.240555127216654e-07, 'epoch': 0.0}
557
+ {'loss': 6.7832, 'grad_norm': 36.084407806396484, 'learning_rate': 4.6260601387818044e-07, 'epoch': 0.0}
558
+ {'loss': 6.9372, 'grad_norm': 12.946453094482422, 'learning_rate': 5.011565150346955e-07, 'epoch': 0.0}
559
+ {'loss': 6.6468, 'grad_norm': 14.272549629211426, 'learning_rate': 5.397070161912105e-07, 'epoch': 0.0}
scripts/wandb/latest-run/files/requirements.txt ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tomlkit==0.12.0
2
+ python-dotenv==1.0.1
3
+ SQLAlchemy==2.0.36
4
+ psutil==6.1.0
5
+ anyio==4.8.0
6
+ onnxruntime==1.20.1
7
+ antlr4-python3-runtime==4.9.3
8
+ httpx-sse==0.4.0
9
+ annotated-types==0.7.0
10
+ tqdm==4.66.5
11
+ simplejson==3.19.3
12
+ csvw==3.5.1
13
+ pooch==1.8.2
14
+ trl==0.9.6
15
+ more-itertools==10.5.0
16
+ jiter==0.6.1
17
+ markdown2==2.5.1
18
+ segments==2.2.1
19
+ opentelemetry-instrumentation-asgi==0.50b0
20
+ Deprecated==1.2.15
21
+ pyasn1_modules==0.4.1
22
+ bcrypt==4.2.1
23
+ opentelemetry-util-http==0.50b0
24
+ intervaltree==3.1.0
25
+ hjson==3.1.0
26
+ modelscope==1.18.1
27
+ fastapi==0.112.4
28
+ pyarrow==17.0.0
29
+ sounddevice==0.5.1
30
+ modelscope_studio==0.4.0.9
31
+ build==1.2.2.post1
32
+ oauthlib==3.2.2
33
+ gunicorn==23.0.0
34
+ pyasn1==0.6.1
35
+ matplotlib==3.9.2
36
+ speechbrain==0.5.16
37
+ joblib==1.4.2
38
+ tyro==0.8.13
39
+ rsa==4.9
40
+ numba==0.60.0
41
+ fastprogress==1.0.3
42
+ wrapt==1.17.0
43
+ PyPika==0.48.9
44
+ dacite==1.8.1
45
+ googleapis-common-protos==1.66.0
46
+ openai==1.68.0
47
+ tabulate==0.9.0
48
+ monotonic==1.6
49
+ lazy_loader==0.4
50
+ google-auth==2.37.0
51
+ fairseq==0.12.3
52
+ opentelemetry-semantic-conventions==0.50b0
53
+ sacrebleu==2.4.3
54
+ requests-toolbelt==1.0.0
55
+ ruff==0.7.0
56
+ bitsandbytes==0.43.1
57
+ tenacity==9.0.0
58
+ uvloop==0.21.0
59
+ Pygments==2.18.0
60
+ langchain==0.3.18
61
+ typer==0.12.5
62
+ uritemplate==4.1.1
63
+ rich==13.9.3
64
+ lion-pytorch==0.2.3
65
+ pydub==0.25.1
66
+ fastcore==1.7.28
67
+ encodec==0.1.1
68
+ cytoolz==1.0.1
69
+ huggingface-hub==0.26.1
70
+ python-dateutil==2.9.0.post0
71
+ duckduckgo_search==7.3.2
72
+ rfc3986==1.5.0
73
+ wavedrom==2.0.3.post3
74
+ sentence-transformers==3.3.1
75
+ httpx==0.28.1
76
+ colorlog==6.9.0
77
+ xxhash==3.5.0
78
+ termcolor==2.5.0
79
+ importlib_resources==6.4.5
80
+ lilcom==1.8.1
81
+ llamafactory==0.9.2.dev0
82
+ lhotse==1.31.0
83
+ kiwisolver==1.4.7
84
+ watchfiles==1.0.3
85
+ marshmallow==3.23.1
86
+ overrides==7.7.0
87
+ langchain-text-splitters==0.3.6
88
+ lxml==5.3.0
89
+ blinker==1.8.2
90
+ whisper==1.1.10
91
+ triton==3.1.0
92
+ python-multipart==0.0.12
93
+ isodate==0.7.2
94
+ wandb==0.19.8
95
+ nvidia-ml-py==12.560.30
96
+ h11==0.14.0
97
+ zipp==3.20.2
98
+ transformers==4.45.0
99
+ websocket-client==1.8.0
100
+ opentelemetry-instrumentation==0.50b0
101
+ pydantic==2.9.2
102
+ latex2mathml==3.77.0
103
+ numpy-rms==0.4.2
104
+ opentelemetry-exporter-otlp-proto-grpc==1.29.0
105
+ humanfriendly==10.0
106
+ decorator==5.1.1
107
+ fonttools==4.54.1
108
+ fire==0.7.0
109
+ ninja==1.11.1.1
110
+ shortuuid==1.0.13
111
+ tiktoken==0.8.0
112
+ aliyun-python-sdk-kms==2.16.5
113
+ einops==0.8.0
114
+ threadpoolctl==3.5.0
115
+ docker-pycreds==0.4.0
116
+ Flask==3.0.3
117
+ opentelemetry-sdk==1.29.0
118
+ opentelemetry-exporter-otlp-proto-common==1.29.0
119
+ pylatexenc==2.10
120
+ orjson==3.10.10
121
+ durationpy==0.9
122
+ addict==2.4.0
123
+ py-cpuinfo==9.0.0
124
+ contourpy==1.3.0
125
+ crcmod==1.7
126
+ pydantic-settings==2.6.1
127
+ pyproject_hooks==1.2.0
128
+ future==1.0.0
129
+ jsonschema-specifications==2024.10.1
130
+ coloredlogs==15.0.1
131
+ timm==0.6.13
132
+ deepspeed==0.14.5
133
+ referencing==0.35.1
134
+ binpacking==1.5.2
135
+ peft==0.12.0
136
+ language-tags==1.2.0
137
+ speechtokenizer==1.0.1
138
+ shellingham==1.5.4
139
+ primp==0.12.1
140
+ tavily-python==0.5.1
141
+ uvicorn==0.32.0
142
+ opentelemetry-proto==1.29.0
143
+ typing-inspect==0.9.0
144
+ backoff==2.2.1
145
+ sortedcontainers==2.4.0
146
+ gitdb==4.0.12
147
+ aiofiles==23.2.1
148
+ jsonschema==4.23.0
149
+ svgwrite==1.4.3
150
+ protobuf==5.29.1
151
+ starlette==0.38.6
152
+ transformers-stream-generator==0.0.5
153
+ sentry-sdk==2.22.0
154
+ toolz==1.0.0
155
+ einops-exts==0.0.4
156
+ WhisperSpeech==0.8
157
+ hydra-core==1.3.2
158
+ portalocker==2.10.1
159
+ jieba==0.42.1
160
+ pandas==2.2.3
161
+ requests==2.32.3
162
+ flash-attn==2.6.3
163
+ msgpack==1.1.0
164
+ chroma-hnswlib==0.7.6
165
+ librosa==0.10.2.post1
166
+ sniffio==1.3.1
167
+ smmap==5.0.2
168
+ opentelemetry-api==1.29.0
169
+ websockets==14.2
170
+ kubernetes==31.0.0
171
+ audioread==3.0.1
172
+ docstring_parser==0.16
173
+ scipy==1.12.0
174
+ aliyun-python-sdk-core==2.16.0
175
+ accelerate==1.0.0
176
+ dill==0.3.8
177
+ llama-omni==1.0.0
178
+ mdurl==0.1.2
179
+ chromadb==0.5.23
180
+ oss2==2.19.0
181
+ rdflib==7.1.1
182
+ bibtexparser==2.0.0b8
183
+ rpds-py==0.22.3
184
+ soundfile==0.12.1
185
+ langdetect==1.0.9
186
+ duckdb==1.2.0
187
+ numpy==1.26.3
188
+ dataclasses-json==0.6.7
189
+ tokenizers==0.20.3
190
+ cpm-kernels==1.0.11
191
+ einx==0.3.0
192
+ langchain-core==0.3.34
193
+ clldutils==3.24.0
194
+ openai-whisper==20240930
195
+ setuptools==69.5.1
196
+ requests-oauthlib==2.0.0
197
+ langchain-community==0.3.17
198
+ langsmith==0.2.3
199
+ colorama==0.4.6
200
+ omegaconf==2.3.0
201
+ asgiref==3.8.1
202
+ pydantic_core==2.23.4
203
+ ffmpy==0.4.0
204
+ multiprocess==0.70.16
205
+ mmh3==5.0.1
206
+ babel==2.16.0
207
+ phonemizer==3.3.0
208
+ pycryptodome==3.21.0
209
+ gradio==4.44.1
210
+ google-genai==1.5.0
211
+ tzdata==2024.2
212
+ llvmlite==0.43.0
213
+ cachetools==5.5.0
214
+ seaborn==0.13.2
215
+ httptools==0.6.4
216
+ GitPython==3.1.44
217
+ markdown-it-py==3.0.0
218
+ beartype==0.20.2
219
+ whisper_normalizer==0.0.10
220
+ dlinfo==1.2.1
221
+ vocos==0.1.0
222
+ itsdangerous==2.2.0
223
+ bitarray==3.0.0
224
+ opentelemetry-instrumentation-fastapi==0.50b0
225
+ setproctitle==1.3.5
226
+ cycler==0.12.1
227
+ vector-quantize-pytorch==1.18.5
228
+ jmespath==0.10.0
229
+ mypy-extensions==1.0.0
230
+ flatbuffers==24.3.25
231
+ scikit-learn==1.5.2
232
+ pytz==2024.2
233
+ pyparsing==3.2.0
234
+ posthog==3.7.4
235
+ rouge==1.0.1
236
+ semantic-version==2.10.0
237
+ httpcore==1.0.6
238
+ soxr==0.5.0.post1
239
+ importlib_metadata==8.5.0
240
+ audiomentations==0.36.1
241
+ shtab==1.7.1
242
+ Unidecode==1.3.8
243
+ click==8.1.8
244
+ tensorboardX==2.6.2.2
245
+ greenlet==3.1.1
246
+ nltk==3.9.1
247
+ gradio_client==1.3.0
248
+ datasets==2.21.0
249
+ attrdict==2.0.1
250
+ llamafactory==0.9.2.dev0
251
+ ms-swift==2.6.0.dev0
252
+ Brotli==1.0.9
253
+ Cython==3.0.10
254
+ HyperPyYAML==1.2.2
255
+ Markdown==3.6
256
+ MarkupSafe==2.1.3
257
+ PySocks==1.7.1
258
+ PyYAML==6.0.1
259
+ absl-py==2.1.0
260
+ aiohttp==3.9.5
261
+ aiosignal==1.3.1
262
+ anaconda-anon-usage==0.4.4
263
+ archspec==0.2.3
264
+ attrs==23.2.0
265
+ boltons==23.0.0
266
+ certifi==2024.6.2
267
+ cffi==1.16.0
268
+ charset-normalizer==2.0.4
269
+ click==8.1.7
270
+ conda==24.5.0
271
+ conda-content-trust==0.2.0
272
+ conda-libmamba-solver==24.1.0
273
+ conda-package-handling==2.2.0
274
+ conda_package_streaming==0.9.0
275
+ cryptography==42.0.5
276
+ distro==1.9.0
277
+ filelock==3.13.1
278
+ frozendict==2.4.2
279
+ frozenlist==1.4.1
280
+ fsspec==2024.6.0
281
+ grpcio==1.64.1
282
+ huggingface-hub==0.23.3
283
+ idna==3.7
284
+ Jinja2==3.1.4
285
+ jiwer==3.0.4
286
+ jsonargparse==4.29.0
287
+ jsonpatch==1.33
288
+ jsonpointer==2.1
289
+ kaldialign==0.9.1
290
+ libmambapy==1.5.8
291
+ lightning==2.2.5
292
+ lightning-utilities==0.11.2
293
+ llvmlite==0.42.0
294
+ menuinst==2.0.2
295
+ mkl-fft==1.3.8
296
+ mkl-random==1.2.4
297
+ mkl-service==2.4.0
298
+ mpmath==1.3.0
299
+ multidict==6.0.5
300
+ networkx==3.2.1
301
+ numba==0.59.1
302
+ numpy==1.26.4
303
+ packaging==23.2
304
+ pillow==10.3.0
305
+ pip==24.0
306
+ platformdirs==3.10.0
307
+ pluggy==1.0.0
308
+ protobuf==4.25.3
309
+ pycosat==0.6.6
310
+ pycparser==2.21
311
+ pytorch-lightning==2.2.5
312
+ rapidfuzz==3.9.3
313
+ regex==2024.5.15
314
+ requests==2.31.0
315
+ ruamel.yaml==0.18.6
316
+ ruamel.yaml.clib==0.2.8
317
+ safetensors==0.4.3
318
+ scipy==1.13.1
319
+ sentencepiece==0.2.0
320
+ setuptools==69.5.1
321
+ six==1.16.0
322
+ sympy==1.12
323
+ tensorboard==2.17.0
324
+ tensorboard-data-server==0.7.2
325
+ tokenizers==0.19.1
326
+ torch==2.2.1
327
+ torch-complex==0.4.3
328
+ torchaudio==2.2.1
329
+ torchmetrics==1.4.0.post0
330
+ torchvision==0.17.1
331
+ tqdm==4.66.2
332
+ transformers==4.41.2
333
+ truststore==0.8.0
334
+ typeguard==2.13.3
335
+ typing_extensions==4.11.0
336
+ urllib3==2.1.0
337
+ Werkzeug==3.0.3
338
+ wheel==0.43.0
339
+ yarl==1.9.4
340
+ zstandard==0.22.0
341
+ warprnnt_pytorch==0.1
scripts/wandb/latest-run/files/wandb-metadata.json ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-1029-nvidia-x86_64-with-glibc2.31",
3
+ "python": "CPython 3.12.3",
4
+ "startedAt": "2025-04-10T10:19:28.834922Z",
5
+ "args": [
6
+ "--local_rank=0",
7
+ "--deepspeed",
8
+ "zero2.json",
9
+ "--model_name_or_path",
10
+ "/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct",
11
+ "--pretrained_llm_path",
12
+ "/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct",
13
+ "--tokenizer_path",
14
+ "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6",
15
+ "--cache_dir",
16
+ "../output/cached_sft_20252502",
17
+ "--audio_encoder_path",
18
+ "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6",
19
+ "--llm_type",
20
+ "qwen",
21
+ "--data_path",
22
+ "/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl",
23
+ "--eval_data_path",
24
+ "/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl",
25
+ "--config_path",
26
+ "minicpmp_config.json",
27
+ "--remove_unused_columns",
28
+ "false",
29
+ "--prediction_loss_only",
30
+ "false",
31
+ "--bf16",
32
+ "true",
33
+ "--do_train",
34
+ "--do_eval",
35
+ "--tune_speech",
36
+ "false",
37
+ "--tune_llm",
38
+ "false",
39
+ "--model_max_length",
40
+ "2048",
41
+ "--eval_steps",
42
+ "3000",
43
+ "--output_dir",
44
+ "../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector",
45
+ "--num_train_epochs",
46
+ "3",
47
+ "--logging_strategy",
48
+ "steps",
49
+ "--per_device_train_batch_size",
50
+ "8",
51
+ "--per_device_eval_batch_size",
52
+ "8",
53
+ "--gradient_accumulation_steps",
54
+ "4",
55
+ "--evaluation_strategy",
56
+ "steps",
57
+ "--save_strategy",
58
+ "steps",
59
+ "--save_steps",
60
+ "5000",
61
+ "--save_total_limit",
62
+ "1",
63
+ "--learning_rate",
64
+ "5e-5",
65
+ "--weight_decay",
66
+ "0.",
67
+ "--warmup_ratio",
68
+ "0.03",
69
+ "--lr_scheduler_type",
70
+ "cosine",
71
+ "--logging_steps",
72
+ "1",
73
+ "--tf32",
74
+ "true",
75
+ "--gradient_checkpointing",
76
+ "true"
77
+ ],
78
+ "program": "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/../omni_speech/train/train_minicpmo_test.py",
79
+ "codePath": "omni_speech/train/train_minicpmo_test.py",
80
+ "git": {
81
+ "remote": "https://bitbucket.org/vinbdi-slp/half-streaming-speech-nlp.git",
82
+ "commit": "3876ef3c080c3ca44ad5ea0bd316241f0323ada6"
83
+ },
84
+ "email": "cuong220103@gmail.com",
85
+ "root": "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts",
86
+ "host": "dgx-a100-5",
87
+ "executable": "/opt/conda/bin/python3",
88
+ "cpu_count": 128,
89
+ "cpu_count_logical": 256,
90
+ "gpu": "NVIDIA A100-SXM4-40GB",
91
+ "gpu_count": 1,
92
+ "disk": {
93
+ "/": {
94
+ "total": "1900954378240",
95
+ "used": "286067507200"
96
+ }
97
+ },
98
+ "memory": {
99
+ "total": "1081975545856"
100
+ },
101
+ "cpu": {
102
+ "count": 128,
103
+ "countLogical": 256
104
+ },
105
+ "gpu_nvidia": [
106
+ {
107
+ "name": "NVIDIA A100-SXM4-40GB",
108
+ "memoryTotal": "42949672960",
109
+ "cudaCores": 6912,
110
+ "architecture": "Ampere"
111
+ }
112
+ ],
113
+ "slurm": {
114
+ "cluster_name": "slurm",
115
+ "conf": "/cm/shared/apps/slurm/var/etc/slurm/slurm.conf",
116
+ "cpus_on_node": "24",
117
+ "cpus_per_task": "24",
118
+ "gpus_on_node": "1",
119
+ "gpus_per_node": "1",
120
+ "gtids": "0",
121
+ "job_cpus_per_node": "24",
122
+ "job_end_time": "1775042326",
123
+ "job_gid": "1400",
124
+ "job_group": "speech",
125
+ "job_id": "5154",
126
+ "job_name": "bash",
127
+ "job_nodelist": "dgx-a100-5",
128
+ "job_num_nodes": "1",
129
+ "job_partition": "defq",
130
+ "job_qos": "normal",
131
+ "job_start_time": "1743506326",
132
+ "job_uid": "1407",
133
+ "job_user": "anhnmt2",
134
+ "jobid": "5154",
135
+ "launch_node_ipaddr": "192.168.100.102",
136
+ "localid": "0",
137
+ "mpi_type": "pmix",
138
+ "nnodes": "1",
139
+ "nodeid": "0",
140
+ "nodelist": "dgx-a100-5",
141
+ "nprocs": "1",
142
+ "ntasks": "1",
143
+ "ntasks_per_node": "1",
144
+ "pmix_mapping_serv": "(vector,(0,1,1))",
145
+ "pmixp_abort_agent_port": "37119",
146
+ "prio_process": "0",
147
+ "procid": "0",
148
+ "pty_port": "45373",
149
+ "pty_win_col": "137",
150
+ "pty_win_row": "10",
151
+ "srun_comm_host": "192.168.100.102",
152
+ "srun_comm_port": "43475",
153
+ "step_gpus": "4",
154
+ "step_id": "0",
155
+ "step_launcher_port": "43475",
156
+ "step_nodelist": "dgx-a100-5",
157
+ "step_num_nodes": "1",
158
+ "step_num_tasks": "1",
159
+ "step_tasks_per_node": "1",
160
+ "stepid": "0",
161
+ "submit_dir": "/data1/speech/anhnmt2/ASR/speechgpt/slurm/submit",
162
+ "submit_host": "login-1",
163
+ "task_pid": "268175",
164
+ "tasks_per_node": "1",
165
+ "topology_addr": "dgx-a100-5",
166
+ "topology_addr_pattern": "node",
167
+ "umask": "0022",
168
+ "working_cluster": "slurm:bcm10-headnode:6817:9984:109"
169
+ },
170
+ "cudaVersion": "12.2"
171
+ }
scripts/wandb/latest-run/logs/debug-core.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-04-10T17:19:28.173097267+07:00","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpm4_vxj8m/port-1734298.txt","pid":1734298,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false}
2
+ {"time":"2025-04-10T17:19:28.173483898+07:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":44091,"Zone":""}}
3
+ {"time":"2025-04-10T17:19:28.173583196+07:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1734298}
4
+ {"time":"2025-04-10T17:19:28.338675346+07:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:60304"}
5
+ {"time":"2025-04-10T17:19:28.838813222+07:00","level":"INFO","msg":"handleInformInit: received","streamId":"pfaibe0c","id":"127.0.0.1:60304"}
6
+ {"time":"2025-04-10T17:19:28.960357084+07:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"pfaibe0c","id":"127.0.0.1:60304"}
7
+ {"time":"2025-04-10T17:20:36.908864225+07:00","level":"INFO","msg":"received shutdown signal","signal":15}
scripts/wandb/latest-run/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-04-10T17:19:28.842729448+07:00","level":"INFO","msg":"stream: starting","core version":"0.19.8","symlink path":"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-core.log"}
2
+ {"time":"2025-04-10T17:19:28.960322418+07:00","level":"INFO","msg":"created new stream","id":"pfaibe0c"}
3
+ {"time":"2025-04-10T17:19:28.960351593+07:00","level":"INFO","msg":"stream: started","id":"pfaibe0c"}
4
+ {"time":"2025-04-10T17:19:28.960375959+07:00","level":"INFO","msg":"writer: Do: started","stream_id":"pfaibe0c"}
5
+ {"time":"2025-04-10T17:19:28.960456552+07:00","level":"INFO","msg":"handler: started","stream_id":"pfaibe0c"}
6
+ {"time":"2025-04-10T17:19:28.961574927+07:00","level":"INFO","msg":"sender: started","stream_id":"pfaibe0c"}
7
+ {"time":"2025-04-10T17:19:29.497777718+07:00","level":"INFO","msg":"Starting system monitor"}