Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc +0 -0
- omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc +0 -0
- omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-39.pyc +0 -0
- omni_speech/serve/__init__.py +0 -0
- omni_speech/serve/controller.py +298 -0
- omni_speech/serve/gradio_web_server.py +348 -0
- omni_speech/serve/model_worker.py +292 -0
- omni_speech/train/__pycache__/omni_trainer.cpython-310.pyc +0 -0
- omni_speech/train/__pycache__/omni_trainer.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/run_train.cpython-310.pyc +0 -0
- omni_speech/train/__pycache__/run_train.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/run_train.cpython-38.pyc +0 -0
- omni_speech/train/__pycache__/train.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/train_mem.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/train_multiturn.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/train_raw.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/train_test.cpython-312.pyc +0 -0
- omni_speech/train/__pycache__/trainer.cpython-310.pyc +0 -0
- omni_speech/train/__pycache__/trainer.cpython-312.pyc +0 -0
- omni_speech/train/export.py +512 -0
- omni_speech/train/omni_trainer.py +345 -0
- omni_speech/train/train.py +420 -0
- omni_speech/train/train_mem.py +4 -0
- omni_speech/train/train_minicpmo.py +660 -0
- omni_speech/train/train_minicpmo_test.py +729 -0
- omni_speech/train/train_multiturn.py +515 -0
- omni_speech/train/trainer.py +249 -0
- scripts/continue.sh +65 -0
- scripts/ds_config_zero2.json +54 -0
- scripts/ds_config_zero3.json +59 -0
- scripts/export.sh +39 -0
- scripts/finetune.sh +42 -0
- scripts/finetune_llm_speech_decoder.sh +85 -0
- scripts/finetune_lora.sh +43 -0
- scripts/finetune_minicpmo.sh +65 -0
- scripts/finetune_minicpmo_asr.sh +63 -0
- scripts/finetune_speech_decoder.sh +42 -0
- scripts/minicpmp_config.json +163 -0
- scripts/pretrain_minicpmo_test.sh +89 -0
- scripts/pretrained.sh +44 -0
- scripts/pretrained_minicpmo.sh +74 -0
- scripts/test_llama.sh +41 -0
- scripts/test_qwen.sh +41 -0
- scripts/wandb/debug-internal.log +7 -0
- scripts/wandb/debug.log +25 -0
- scripts/wandb/latest-run/files/output.log +559 -0
- scripts/wandb/latest-run/files/requirements.txt +341 -0
- scripts/wandb/latest-run/files/wandb-metadata.json +171 -0
- scripts/wandb/latest-run/logs/debug-core.log +7 -0
- scripts/wandb/latest-run/logs/debug-internal.log +7 -0
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
omni_speech/model/speech_projector/__pycache__/speech_projector.cpython-39.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
omni_speech/serve/__init__.py
ADDED
|
File without changes
|
omni_speech/serve/controller.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A controller manages distributed workers.
|
| 3 |
+
It sends worker addresses to clients.
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import asyncio
|
| 7 |
+
import dataclasses
|
| 8 |
+
from enum import Enum, auto
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import time
|
| 12 |
+
from typing import List, Union
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
from fastapi import FastAPI, Request
|
| 16 |
+
from fastapi.responses import StreamingResponse
|
| 17 |
+
import numpy as np
|
| 18 |
+
import requests
|
| 19 |
+
import uvicorn
|
| 20 |
+
|
| 21 |
+
from omni_speech.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
| 22 |
+
from omni_speech.utils import build_logger, server_error_msg
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = build_logger("controller", "controller.log")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DispatchMethod(Enum):
|
| 29 |
+
LOTTERY = auto()
|
| 30 |
+
SHORTEST_QUEUE = auto()
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def from_str(cls, name):
|
| 34 |
+
if name == "lottery":
|
| 35 |
+
return cls.LOTTERY
|
| 36 |
+
elif name == "shortest_queue":
|
| 37 |
+
return cls.SHORTEST_QUEUE
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Invalid dispatch method")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclasses.dataclass
|
| 43 |
+
class WorkerInfo:
|
| 44 |
+
model_names: List[str]
|
| 45 |
+
speed: int
|
| 46 |
+
queue_length: int
|
| 47 |
+
check_heart_beat: bool
|
| 48 |
+
last_heart_beat: str
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def heart_beat_controller(controller):
|
| 52 |
+
while True:
|
| 53 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
| 54 |
+
controller.remove_stable_workers_by_expiration()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Controller:
|
| 58 |
+
def __init__(self, dispatch_method: str):
|
| 59 |
+
# Dict[str -> WorkerInfo]
|
| 60 |
+
self.worker_info = {}
|
| 61 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
| 62 |
+
|
| 63 |
+
self.heart_beat_thread = threading.Thread(
|
| 64 |
+
target=heart_beat_controller, args=(self,), daemon=True)
|
| 65 |
+
self.heart_beat_thread.start()
|
| 66 |
+
|
| 67 |
+
logger.info("Init controller")
|
| 68 |
+
|
| 69 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
| 70 |
+
worker_status: dict):
|
| 71 |
+
if worker_name not in self.worker_info:
|
| 72 |
+
logger.info(f"Register a new worker: {worker_name}")
|
| 73 |
+
else:
|
| 74 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
| 75 |
+
|
| 76 |
+
if not worker_status:
|
| 77 |
+
worker_status = self.get_worker_status(worker_name)
|
| 78 |
+
if not worker_status:
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
self.worker_info[worker_name] = WorkerInfo(
|
| 82 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
| 83 |
+
check_heart_beat, time.time())
|
| 84 |
+
|
| 85 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
def get_worker_status(self, worker_name: str):
|
| 89 |
+
try:
|
| 90 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
| 91 |
+
except requests.exceptions.RequestException as e:
|
| 92 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
if r.status_code != 200:
|
| 96 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
return r.json()
|
| 100 |
+
|
| 101 |
+
def remove_worker(self, worker_name: str):
|
| 102 |
+
del self.worker_info[worker_name]
|
| 103 |
+
|
| 104 |
+
def refresh_all_workers(self):
|
| 105 |
+
old_info = dict(self.worker_info)
|
| 106 |
+
self.worker_info = {}
|
| 107 |
+
|
| 108 |
+
for w_name, w_info in old_info.items():
|
| 109 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
| 110 |
+
logger.info(f"Remove stale worker: {w_name}")
|
| 111 |
+
|
| 112 |
+
def list_models(self):
|
| 113 |
+
model_names = set()
|
| 114 |
+
|
| 115 |
+
for w_name, w_info in self.worker_info.items():
|
| 116 |
+
model_names.update(w_info.model_names)
|
| 117 |
+
|
| 118 |
+
return list(model_names)
|
| 119 |
+
|
| 120 |
+
def get_worker_address(self, model_name: str):
|
| 121 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
| 122 |
+
worker_names = []
|
| 123 |
+
worker_speeds = []
|
| 124 |
+
for w_name, w_info in self.worker_info.items():
|
| 125 |
+
if model_name in w_info.model_names:
|
| 126 |
+
worker_names.append(w_name)
|
| 127 |
+
worker_speeds.append(w_info.speed)
|
| 128 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
| 129 |
+
norm = np.sum(worker_speeds)
|
| 130 |
+
if norm < 1e-4:
|
| 131 |
+
return ""
|
| 132 |
+
worker_speeds = worker_speeds / norm
|
| 133 |
+
if True: # Directly return address
|
| 134 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
| 135 |
+
p=worker_speeds)
|
| 136 |
+
worker_name = worker_names[pt]
|
| 137 |
+
return worker_name
|
| 138 |
+
|
| 139 |
+
# Check status before returning
|
| 140 |
+
while True:
|
| 141 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
| 142 |
+
p=worker_speeds)
|
| 143 |
+
worker_name = worker_names[pt]
|
| 144 |
+
|
| 145 |
+
if self.get_worker_status(worker_name):
|
| 146 |
+
break
|
| 147 |
+
else:
|
| 148 |
+
self.remove_worker(worker_name)
|
| 149 |
+
worker_speeds[pt] = 0
|
| 150 |
+
norm = np.sum(worker_speeds)
|
| 151 |
+
if norm < 1e-4:
|
| 152 |
+
return ""
|
| 153 |
+
worker_speeds = worker_speeds / norm
|
| 154 |
+
continue
|
| 155 |
+
return worker_name
|
| 156 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
| 157 |
+
worker_names = []
|
| 158 |
+
worker_qlen = []
|
| 159 |
+
for w_name, w_info in self.worker_info.items():
|
| 160 |
+
if model_name in w_info.model_names:
|
| 161 |
+
worker_names.append(w_name)
|
| 162 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
| 163 |
+
if len(worker_names) == 0:
|
| 164 |
+
return ""
|
| 165 |
+
min_index = np.argmin(worker_qlen)
|
| 166 |
+
w_name = worker_names[min_index]
|
| 167 |
+
self.worker_info[w_name].queue_length += 1
|
| 168 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
| 169 |
+
return w_name
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
| 172 |
+
|
| 173 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
| 174 |
+
if worker_name not in self.worker_info:
|
| 175 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
self.worker_info[worker_name].queue_length = queue_length
|
| 179 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
| 180 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
def remove_stable_workers_by_expiration(self):
|
| 184 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
| 185 |
+
to_delete = []
|
| 186 |
+
for worker_name, w_info in self.worker_info.items():
|
| 187 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
| 188 |
+
to_delete.append(worker_name)
|
| 189 |
+
|
| 190 |
+
for worker_name in to_delete:
|
| 191 |
+
self.remove_worker(worker_name)
|
| 192 |
+
|
| 193 |
+
def worker_api_generate_stream(self, params):
|
| 194 |
+
worker_addr = self.get_worker_address(params["model"])
|
| 195 |
+
if not worker_addr:
|
| 196 |
+
logger.info(f"no worker: {params['model']}")
|
| 197 |
+
ret = {
|
| 198 |
+
"text": server_error_msg,
|
| 199 |
+
"error_code": 2,
|
| 200 |
+
}
|
| 201 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
| 205 |
+
json=params, stream=True, timeout=5)
|
| 206 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 207 |
+
if chunk:
|
| 208 |
+
yield chunk + b"\0"
|
| 209 |
+
except requests.exceptions.RequestException as e:
|
| 210 |
+
logger.info(f"worker timeout: {worker_addr}")
|
| 211 |
+
ret = {
|
| 212 |
+
"text": server_error_msg,
|
| 213 |
+
"error_code": 3,
|
| 214 |
+
}
|
| 215 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Let the controller act as a worker to achieve hierarchical
|
| 219 |
+
# management. This can be used to connect isolated sub networks.
|
| 220 |
+
def worker_api_get_status(self):
|
| 221 |
+
model_names = set()
|
| 222 |
+
speed = 0
|
| 223 |
+
queue_length = 0
|
| 224 |
+
|
| 225 |
+
for w_name in self.worker_info:
|
| 226 |
+
worker_status = self.get_worker_status(w_name)
|
| 227 |
+
if worker_status is not None:
|
| 228 |
+
model_names.update(worker_status["model_names"])
|
| 229 |
+
speed += worker_status["speed"]
|
| 230 |
+
queue_length += worker_status["queue_length"]
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"model_names": list(model_names),
|
| 234 |
+
"speed": speed,
|
| 235 |
+
"queue_length": queue_length,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
app = FastAPI()
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@app.post("/register_worker")
|
| 243 |
+
async def register_worker(request: Request):
|
| 244 |
+
data = await request.json()
|
| 245 |
+
controller.register_worker(
|
| 246 |
+
data["worker_name"], data["check_heart_beat"],
|
| 247 |
+
data.get("worker_status", None))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@app.post("/refresh_all_workers")
|
| 251 |
+
async def refresh_all_workers():
|
| 252 |
+
models = controller.refresh_all_workers()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@app.post("/list_models")
|
| 256 |
+
async def list_models():
|
| 257 |
+
models = controller.list_models()
|
| 258 |
+
return {"models": models}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@app.post("/get_worker_address")
|
| 262 |
+
async def get_worker_address(request: Request):
|
| 263 |
+
data = await request.json()
|
| 264 |
+
addr = controller.get_worker_address(data["model"])
|
| 265 |
+
return {"address": addr}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@app.post("/receive_heart_beat")
|
| 269 |
+
async def receive_heart_beat(request: Request):
|
| 270 |
+
data = await request.json()
|
| 271 |
+
exist = controller.receive_heart_beat(
|
| 272 |
+
data["worker_name"], data["queue_length"])
|
| 273 |
+
return {"exist": exist}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@app.post("/worker_generate_stream")
|
| 277 |
+
async def worker_api_generate_stream(request: Request):
|
| 278 |
+
params = await request.json()
|
| 279 |
+
generator = controller.worker_api_generate_stream(params)
|
| 280 |
+
return StreamingResponse(generator)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@app.post("/worker_get_status")
|
| 284 |
+
async def worker_api_get_status(request: Request):
|
| 285 |
+
return controller.worker_api_get_status()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
parser = argparse.ArgumentParser()
|
| 290 |
+
parser.add_argument("--host", type=str, default="localhost")
|
| 291 |
+
parser.add_argument("--port", type=int, default=21001)
|
| 292 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
| 293 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
| 294 |
+
args = parser.parse_args()
|
| 295 |
+
logger.info(f"args: {args}")
|
| 296 |
+
|
| 297 |
+
controller = Controller(args.dispatch_method)
|
| 298 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
omni_speech/serve/gradio_web_server.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
|
| 14 |
+
from omni_speech.conversation import default_conversation, conv_templates
|
| 15 |
+
from omni_speech.constants import LOGDIR
|
| 16 |
+
from omni_speech.utils import build_logger, server_error_msg
|
| 17 |
+
from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
| 21 |
+
|
| 22 |
+
vocoder = None
|
| 23 |
+
|
| 24 |
+
headers = {"User-Agent": "LLaMA-Omni Client"}
|
| 25 |
+
|
| 26 |
+
no_change_btn = gr.Button()
|
| 27 |
+
enable_btn = gr.Button(interactive=True)
|
| 28 |
+
disable_btn = gr.Button(interactive=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_conv_log_filename():
|
| 32 |
+
t = datetime.datetime.now()
|
| 33 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
| 34 |
+
return name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_model_list():
|
| 38 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
| 39 |
+
assert ret.status_code == 200
|
| 40 |
+
ret = requests.post(args.controller_url + "/list_models")
|
| 41 |
+
models = ret.json()["models"]
|
| 42 |
+
logger.info(f"Models: {models}")
|
| 43 |
+
return models
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
get_window_url_params = """
|
| 47 |
+
function() {
|
| 48 |
+
const params = new URLSearchParams(window.location.search);
|
| 49 |
+
url_params = Object.fromEntries(params);
|
| 50 |
+
console.log(url_params);
|
| 51 |
+
return url_params;
|
| 52 |
+
}
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_demo(url_params, request: gr.Request):
|
| 57 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
| 58 |
+
|
| 59 |
+
dropdown_update = gr.Dropdown(visible=True)
|
| 60 |
+
if "model" in url_params:
|
| 61 |
+
model = url_params["model"]
|
| 62 |
+
if model in models:
|
| 63 |
+
dropdown_update = gr.Dropdown(value=model, visible=True)
|
| 64 |
+
|
| 65 |
+
state = default_conversation.copy()
|
| 66 |
+
return state, dropdown_update
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
| 70 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
| 71 |
+
models = get_model_list()
|
| 72 |
+
state = default_conversation.copy()
|
| 73 |
+
dropdown_update = gr.Dropdown(
|
| 74 |
+
choices=models,
|
| 75 |
+
value=models[0] if len(models) > 0 else ""
|
| 76 |
+
)
|
| 77 |
+
return state, dropdown_update
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def clear_history(request: gr.Request):
|
| 81 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
| 82 |
+
state = default_conversation.copy()
|
| 83 |
+
return (state, None, "", "", None)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def add_speech(state, speech, request: gr.Request):
|
| 87 |
+
text = "Please directly answer the questions in the user's speech."
|
| 88 |
+
text = '<speech>\n' + text
|
| 89 |
+
text = (text, speech)
|
| 90 |
+
state = default_conversation.copy()
|
| 91 |
+
state.append_message(state.roles[0], text)
|
| 92 |
+
state.append_message(state.roles[1], None)
|
| 93 |
+
state.skip_next = False
|
| 94 |
+
return (state)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, chunk_size, request: gr.Request):
|
| 98 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
| 99 |
+
start_tstamp = time.time()
|
| 100 |
+
model_name = model_selector
|
| 101 |
+
|
| 102 |
+
if state.skip_next:
|
| 103 |
+
# This generate call is skipped due to invalid inputs
|
| 104 |
+
yield (state, "", "", None)
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
if len(state.messages) == state.offset + 2:
|
| 108 |
+
# First round of conversation
|
| 109 |
+
template_name = "llama_3"
|
| 110 |
+
new_state = conv_templates[template_name].copy()
|
| 111 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
| 112 |
+
new_state.append_message(new_state.roles[1], None)
|
| 113 |
+
state = new_state
|
| 114 |
+
|
| 115 |
+
# Query worker address
|
| 116 |
+
controller_url = args.controller_url
|
| 117 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
| 118 |
+
json={"model": model_name})
|
| 119 |
+
worker_addr = ret.json()["address"]
|
| 120 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
| 121 |
+
|
| 122 |
+
# No available worker
|
| 123 |
+
if worker_addr == "":
|
| 124 |
+
state.messages[-1][-1] = server_error_msg
|
| 125 |
+
yield (state, "", "", None)
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
# Construct prompt
|
| 129 |
+
prompt = state.get_prompt()
|
| 130 |
+
|
| 131 |
+
sr, audio = state.messages[0][1][1]
|
| 132 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
|
| 133 |
+
audio = torch.tensor(audio.astype(np.float32)).unsqueeze(0)
|
| 134 |
+
audio = resampler(audio).squeeze(0).numpy()
|
| 135 |
+
audio /= 32768.0
|
| 136 |
+
audio = audio.tolist()
|
| 137 |
+
# Make requests
|
| 138 |
+
pload = {
|
| 139 |
+
"model": model_name,
|
| 140 |
+
"prompt": prompt,
|
| 141 |
+
"temperature": float(temperature),
|
| 142 |
+
"top_p": float(top_p),
|
| 143 |
+
"max_new_tokens": min(int(max_new_tokens), 1500),
|
| 144 |
+
"stop": state.sep2,
|
| 145 |
+
"audio": audio,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
yield (state, "", "", None)
|
| 149 |
+
|
| 150 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
# Stream output
|
| 154 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
| 155 |
+
headers=headers, json=pload, stream=True, timeout=10)
|
| 156 |
+
num_generated_units = 0
|
| 157 |
+
wav_list = []
|
| 158 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 159 |
+
if chunk:
|
| 160 |
+
data = json.loads(chunk.decode())
|
| 161 |
+
if data["error_code"] == 0:
|
| 162 |
+
output = data["text"][len(prompt):].strip()
|
| 163 |
+
output_unit = list(map(int, data["unit"].strip().split()))
|
| 164 |
+
state.messages[-1][-1] = (output, data["unit"].strip())
|
| 165 |
+
|
| 166 |
+
# vocoder
|
| 167 |
+
new_units = output_unit[num_generated_units:]
|
| 168 |
+
if len(new_units) >= chunk_size:
|
| 169 |
+
num_generated_units = len(output_unit)
|
| 170 |
+
x = {"code": torch.LongTensor(new_units).view(1, -1).cuda()}
|
| 171 |
+
wav = vocoder(x, True)
|
| 172 |
+
wav_list.append(wav.detach().cpu().numpy())
|
| 173 |
+
|
| 174 |
+
if len(wav_list) > 0:
|
| 175 |
+
wav_full = np.concatenate(wav_list)
|
| 176 |
+
return_value = (16000, wav_full)
|
| 177 |
+
else:
|
| 178 |
+
return_value = None
|
| 179 |
+
|
| 180 |
+
yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value)
|
| 181 |
+
else:
|
| 182 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
| 183 |
+
state.messages[-1][-1] = output
|
| 184 |
+
yield (state, "", "", None)
|
| 185 |
+
return
|
| 186 |
+
time.sleep(0.03)
|
| 187 |
+
except requests.exceptions.RequestException as e:
|
| 188 |
+
state.messages[-1][-1] = server_error_msg
|
| 189 |
+
yield (state, "", "", None)
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
if num_generated_units < len(output_unit):
|
| 193 |
+
new_units = output_unit[num_generated_units:]
|
| 194 |
+
num_generated_units = len(output_unit)
|
| 195 |
+
x = {
|
| 196 |
+
"code": torch.LongTensor(new_units).view(1, -1).cuda()
|
| 197 |
+
}
|
| 198 |
+
wav = vocoder(x, True)
|
| 199 |
+
wav_list.append(wav.detach().cpu().numpy())
|
| 200 |
+
|
| 201 |
+
if len(wav_list) > 0:
|
| 202 |
+
wav_full = np.concatenate(wav_list)
|
| 203 |
+
return_value = (16000, wav_full)
|
| 204 |
+
else:
|
| 205 |
+
return_value = None
|
| 206 |
+
|
| 207 |
+
yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value)
|
| 208 |
+
|
| 209 |
+
finish_tstamp = time.time()
|
| 210 |
+
logger.info(f"{output}")
|
| 211 |
+
logger.info(f"{output_unit}")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
title_markdown = ("""
|
| 215 |
+
# 🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models
|
| 216 |
+
""")
|
| 217 |
+
|
| 218 |
+
block_css = """
|
| 219 |
+
|
| 220 |
+
#buttons button {
|
| 221 |
+
min-width: min(120px,100%);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def build_demo(embed_mode, vocoder, cur_dir=None, concurrency_count=10):
|
| 227 |
+
with gr.Blocks(title="LLaMA-Omni Speech Chatbot", theme=gr.themes.Default(), css=block_css) as demo:
|
| 228 |
+
state = gr.State()
|
| 229 |
+
|
| 230 |
+
if not embed_mode:
|
| 231 |
+
gr.Markdown(title_markdown)
|
| 232 |
+
|
| 233 |
+
with gr.Row(elem_id="model_selector_row"):
|
| 234 |
+
model_selector = gr.Dropdown(
|
| 235 |
+
choices=models,
|
| 236 |
+
value=models[0] if len(models) > 0 else "",
|
| 237 |
+
interactive=True,
|
| 238 |
+
show_label=False,
|
| 239 |
+
container=False)
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
audio_input_box = gr.Audio(sources=["upload", "microphone"], label="Speech Input")
|
| 243 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
| 244 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",)
|
| 245 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
| 246 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max Output Tokens",)
|
| 247 |
+
chunk_size = gr.Slider(minimum=10, maximum=500, value=40, step=10, interactive=True, label="Chunk Size",)
|
| 248 |
+
|
| 249 |
+
if cur_dir is None:
|
| 250 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 251 |
+
gr.Examples(examples=[
|
| 252 |
+
[f"{cur_dir}/examples/vicuna_1.wav"],
|
| 253 |
+
[f"{cur_dir}/examples/vicuna_2.wav"],
|
| 254 |
+
[f"{cur_dir}/examples/vicuna_3.wav"],
|
| 255 |
+
[f"{cur_dir}/examples/vicuna_4.wav"],
|
| 256 |
+
[f"{cur_dir}/examples/vicuna_5.wav"],
|
| 257 |
+
[f"{cur_dir}/examples/helpful_base_1.wav"],
|
| 258 |
+
[f"{cur_dir}/examples/helpful_base_2.wav"],
|
| 259 |
+
[f"{cur_dir}/examples/helpful_base_3.wav"],
|
| 260 |
+
[f"{cur_dir}/examples/helpful_base_4.wav"],
|
| 261 |
+
[f"{cur_dir}/examples/helpful_base_5.wav"],
|
| 262 |
+
], inputs=[audio_input_box])
|
| 263 |
+
|
| 264 |
+
with gr.Row():
|
| 265 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
| 266 |
+
clear_btn = gr.Button(value="Clear")
|
| 267 |
+
|
| 268 |
+
text_output_box = gr.Textbox(label="Text Output", type="text")
|
| 269 |
+
unit_output_box = gr.Textbox(label="Unit Output", type="text")
|
| 270 |
+
audio_output_box = gr.Audio(label="Speech Output")
|
| 271 |
+
|
| 272 |
+
url_params = gr.JSON(visible=False)
|
| 273 |
+
|
| 274 |
+
submit_btn.click(
|
| 275 |
+
add_speech,
|
| 276 |
+
[state, audio_input_box],
|
| 277 |
+
[state]
|
| 278 |
+
).then(
|
| 279 |
+
http_bot,
|
| 280 |
+
[state, model_selector, temperature, top_p, max_output_tokens, chunk_size],
|
| 281 |
+
[state, text_output_box, unit_output_box, audio_output_box],
|
| 282 |
+
concurrency_limit=concurrency_count
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
clear_btn.click(
|
| 286 |
+
clear_history,
|
| 287 |
+
None,
|
| 288 |
+
[state, audio_input_box, text_output_box, unit_output_box, audio_output_box],
|
| 289 |
+
queue=False
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if args.model_list_mode == "once":
|
| 293 |
+
demo.load(
|
| 294 |
+
load_demo,
|
| 295 |
+
[url_params],
|
| 296 |
+
[state, model_selector],
|
| 297 |
+
js=get_window_url_params
|
| 298 |
+
)
|
| 299 |
+
elif args.model_list_mode == "reload":
|
| 300 |
+
demo.load(
|
| 301 |
+
load_demo_refresh_model_list,
|
| 302 |
+
None,
|
| 303 |
+
[state, model_selector],
|
| 304 |
+
queue=False
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
| 308 |
+
|
| 309 |
+
return demo
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def build_vocoder(args):
|
| 313 |
+
global vocoder
|
| 314 |
+
if args.vocoder is None:
|
| 315 |
+
return None
|
| 316 |
+
with open(args.vocoder_cfg) as f:
|
| 317 |
+
vocoder_cfg = json.load(f)
|
| 318 |
+
vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).cuda()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
parser = argparse.ArgumentParser()
|
| 323 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 324 |
+
parser.add_argument("--port", type=int)
|
| 325 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
| 326 |
+
parser.add_argument("--concurrency-count", type=int, default=16)
|
| 327 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
| 328 |
+
choices=["once", "reload"])
|
| 329 |
+
parser.add_argument("--share", action="store_true")
|
| 330 |
+
parser.add_argument("--moderate", action="store_true")
|
| 331 |
+
parser.add_argument("--embed", action="store_true")
|
| 332 |
+
parser.add_argument("--vocoder", type=str)
|
| 333 |
+
parser.add_argument("--vocoder-cfg", type=str)
|
| 334 |
+
args = parser.parse_args()
|
| 335 |
+
logger.info(f"args: {args}")
|
| 336 |
+
|
| 337 |
+
models = get_model_list()
|
| 338 |
+
build_vocoder(args)
|
| 339 |
+
|
| 340 |
+
logger.info(args)
|
| 341 |
+
demo = build_demo(args.embed, vocoder, concurrency_count=args.concurrency_count)
|
| 342 |
+
demo.queue(
|
| 343 |
+
api_open=False
|
| 344 |
+
).launch(
|
| 345 |
+
server_name=args.host,
|
| 346 |
+
server_port=args.port,
|
| 347 |
+
share=args.share
|
| 348 |
+
)
|
omni_speech/serve/model_worker.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A model worker executes the model.
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
import threading
|
| 9 |
+
import uuid
|
| 10 |
+
|
| 11 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
| 12 |
+
from fastapi.responses import StreamingResponse
|
| 13 |
+
import requests
|
| 14 |
+
import torch
|
| 15 |
+
import uvicorn
|
| 16 |
+
import whisper
|
| 17 |
+
import numpy as np
|
| 18 |
+
from functools import partial
|
| 19 |
+
|
| 20 |
+
from transformers import PreTrainedTokenizer
|
| 21 |
+
|
| 22 |
+
from omni_speech.constants import WORKER_HEART_BEAT_INTERVAL
|
| 23 |
+
from omni_speech.utils import (build_logger, server_error_msg,
|
| 24 |
+
pretty_print_semaphore)
|
| 25 |
+
from omni_speech.model.builder import load_pretrained_model
|
| 26 |
+
from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
|
| 27 |
+
from omni_speech.datasets.preprocess import tokenizer_speech_token
|
| 28 |
+
from transformers import TextIteratorStreamer
|
| 29 |
+
from threading import Thread
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
GB = 1 << 30
|
| 33 |
+
|
| 34 |
+
worker_id = str(uuid.uuid4())[:6]
|
| 35 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
| 36 |
+
global_counter = 0
|
| 37 |
+
|
| 38 |
+
model_semaphore = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def heart_beat_worker(controller):
|
| 42 |
+
|
| 43 |
+
while True:
|
| 44 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
| 45 |
+
controller.send_heart_beat()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_speech(audio, input_type, mel_size, speech_normalize):
|
| 49 |
+
speech = np.array(audio, dtype=np.float32)
|
| 50 |
+
if input_type == "raw":
|
| 51 |
+
speech = torch.from_numpy(speech)
|
| 52 |
+
if speech_normalize:
|
| 53 |
+
speech = torch.nn.functional.layer_norm(speech, speech.shape)
|
| 54 |
+
elif input_type == "mel":
|
| 55 |
+
speech = whisper.pad_or_trim(speech)
|
| 56 |
+
speech = whisper.log_mel_spectrogram(speech, n_mels=mel_size).permute(1, 0)
|
| 57 |
+
return speech
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_unit_tokenizer(vocab_size):
|
| 61 |
+
import os
|
| 62 |
+
from transformers import BertTokenizer
|
| 63 |
+
with open("unit_vocab.txt", "w") as f:
|
| 64 |
+
for i in range(vocab_size + 1):
|
| 65 |
+
f.write(str(i) + "\n")
|
| 66 |
+
tokenizer = BertTokenizer(vocab_file="unit_vocab.txt")
|
| 67 |
+
os.remove("unit_vocab.txt")
|
| 68 |
+
return tokenizer
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ModelWorker:
|
| 72 |
+
def __init__(self, controller_addr, worker_addr,
|
| 73 |
+
worker_id, no_register,
|
| 74 |
+
model_path, model_base, model_name,
|
| 75 |
+
load_8bit, load_4bit, device, input_type, mel_size, s2s, is_lora, use_flash_attn=False):
|
| 76 |
+
self.controller_addr = controller_addr
|
| 77 |
+
self.worker_addr = worker_addr
|
| 78 |
+
self.worker_id = worker_id
|
| 79 |
+
self.device = device
|
| 80 |
+
self.model_name = model_name
|
| 81 |
+
self.input_type = input_type
|
| 82 |
+
self.mel_size = mel_size
|
| 83 |
+
self.tokenizer, self.model, self.context_len = load_pretrained_model(
|
| 84 |
+
model_path, model_base, is_lora=is_lora, s2s=s2s, load_8bit=load_8bit, load_4bit=load_4bit, device=self.device, use_flash_attn=use_flash_attn)
|
| 85 |
+
self.unit_tokenizer = build_unit_tokenizer(self.model.config.unit_vocab_size)
|
| 86 |
+
|
| 87 |
+
if not no_register:
|
| 88 |
+
self.register_to_controller()
|
| 89 |
+
self.heart_beat_thread = threading.Thread(
|
| 90 |
+
target=heart_beat_worker, args=(self,), daemon=True)
|
| 91 |
+
self.heart_beat_thread.start()
|
| 92 |
+
|
| 93 |
+
def register_to_controller(self):
|
| 94 |
+
logger.info("Register to controller")
|
| 95 |
+
|
| 96 |
+
url = self.controller_addr + "/register_worker"
|
| 97 |
+
data = {
|
| 98 |
+
"worker_name": self.worker_addr,
|
| 99 |
+
"check_heart_beat": True,
|
| 100 |
+
"worker_status": self.get_status()
|
| 101 |
+
}
|
| 102 |
+
r = requests.post(url, json=data)
|
| 103 |
+
assert r.status_code == 200
|
| 104 |
+
|
| 105 |
+
def send_heart_beat(self):
|
| 106 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
| 107 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
| 108 |
+
f"global_counter: {global_counter}")
|
| 109 |
+
|
| 110 |
+
url = self.controller_addr + "/receive_heart_beat"
|
| 111 |
+
|
| 112 |
+
while True:
|
| 113 |
+
try:
|
| 114 |
+
ret = requests.post(url, json={
|
| 115 |
+
"worker_name": self.worker_addr,
|
| 116 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
| 117 |
+
exist = ret.json()["exist"]
|
| 118 |
+
break
|
| 119 |
+
except requests.exceptions.RequestException as e:
|
| 120 |
+
logger.error(f"heart beat error: {e}")
|
| 121 |
+
time.sleep(5)
|
| 122 |
+
|
| 123 |
+
if not exist:
|
| 124 |
+
self.register_to_controller()
|
| 125 |
+
|
| 126 |
+
def get_queue_length(self):
|
| 127 |
+
if model_semaphore is None:
|
| 128 |
+
return 0
|
| 129 |
+
else:
|
| 130 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
| 131 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
| 132 |
+
|
| 133 |
+
def get_status(self):
|
| 134 |
+
return {
|
| 135 |
+
"model_names": [self.model_name],
|
| 136 |
+
"speed": 1,
|
| 137 |
+
"queue_length": self.get_queue_length(),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
@torch.inference_mode()
|
| 141 |
+
def generate_stream(self, params):
|
| 142 |
+
tokenizer, model = self.tokenizer, self.model
|
| 143 |
+
|
| 144 |
+
prompt = params["prompt"]
|
| 145 |
+
ori_prompt = prompt
|
| 146 |
+
audio = params.get("audio", None)
|
| 147 |
+
if audio is not None and len(audio) > 0:
|
| 148 |
+
speech = load_speech(audio, self.input_type, self.mel_size, self.model.config.speech_normalize)
|
| 149 |
+
speech_length = torch.LongTensor([speech.shape[0]]).unsqueeze(0).to(self.device)
|
| 150 |
+
speech_tensor = speech.unsqueeze(0).to(self.device, dtype=torch.float16)
|
| 151 |
+
speech_args = {"speech": speech_tensor, "speech_lengths": speech_length}
|
| 152 |
+
else:
|
| 153 |
+
speech = None
|
| 154 |
+
speech_args = {}
|
| 155 |
+
|
| 156 |
+
temperature = float(params.get("temperature", 1.0))
|
| 157 |
+
top_p = float(params.get("top_p", 1.0))
|
| 158 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
| 159 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
| 160 |
+
stop_str = params.get("stop", None)
|
| 161 |
+
do_sample = True if temperature > 0.001 else False
|
| 162 |
+
|
| 163 |
+
input_ids = tokenizer_speech_token(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device)
|
| 164 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
| 165 |
+
streamer_unit = TextIteratorStreamer(self.unit_tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
|
| 166 |
+
|
| 167 |
+
# max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
| 168 |
+
|
| 169 |
+
if max_new_tokens < 1:
|
| 170 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
| 174 |
+
inputs=input_ids,
|
| 175 |
+
do_sample=do_sample,
|
| 176 |
+
temperature=temperature,
|
| 177 |
+
top_p=top_p,
|
| 178 |
+
max_new_tokens=max_new_tokens,
|
| 179 |
+
streamer=streamer,
|
| 180 |
+
streamer_unit=streamer_unit,
|
| 181 |
+
streaming_unit_gen=True,
|
| 182 |
+
use_cache=True,
|
| 183 |
+
**speech_args
|
| 184 |
+
))
|
| 185 |
+
thread.start()
|
| 186 |
+
|
| 187 |
+
generated_text = ori_prompt
|
| 188 |
+
for new_text in streamer:
|
| 189 |
+
generated_text += new_text
|
| 190 |
+
generated_unit = " ".join(map(str, streamer_unit.token_cache))
|
| 191 |
+
if generated_text.endswith(stop_str):
|
| 192 |
+
generated_text = generated_text[:-len(stop_str)]
|
| 193 |
+
yield json.dumps({"text": generated_text, "unit": generated_unit, "error_code": 0}).encode() + b"\0"
|
| 194 |
+
|
| 195 |
+
def generate_stream_gate(self, params):
|
| 196 |
+
try:
|
| 197 |
+
for x in self.generate_stream(params):
|
| 198 |
+
yield x
|
| 199 |
+
except ValueError as e:
|
| 200 |
+
print("Caught ValueError:", e)
|
| 201 |
+
ret = {
|
| 202 |
+
"text": server_error_msg,
|
| 203 |
+
"error_code": 1,
|
| 204 |
+
}
|
| 205 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 206 |
+
except torch.cuda.CudaError as e:
|
| 207 |
+
print("Caught torch.cuda.CudaError:", e)
|
| 208 |
+
ret = {
|
| 209 |
+
"text": server_error_msg,
|
| 210 |
+
"error_code": 1,
|
| 211 |
+
}
|
| 212 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print("Caught Unknown Error", e)
|
| 215 |
+
ret = {
|
| 216 |
+
"text": server_error_msg,
|
| 217 |
+
"error_code": 1,
|
| 218 |
+
}
|
| 219 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
app = FastAPI()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def release_model_semaphore(fn=None):
|
| 226 |
+
model_semaphore.release()
|
| 227 |
+
if fn is not None:
|
| 228 |
+
fn()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@app.post("/worker_generate_stream")
|
| 232 |
+
async def generate_stream(request: Request):
|
| 233 |
+
global model_semaphore, global_counter
|
| 234 |
+
global_counter += 1
|
| 235 |
+
params = await request.json()
|
| 236 |
+
|
| 237 |
+
if model_semaphore is None:
|
| 238 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
| 239 |
+
await model_semaphore.acquire()
|
| 240 |
+
worker.send_heart_beat()
|
| 241 |
+
generator = worker.generate_stream_gate(params)
|
| 242 |
+
background_tasks = BackgroundTasks()
|
| 243 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
| 244 |
+
return StreamingResponse(generator, background=background_tasks)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@app.post("/worker_get_status")
|
| 248 |
+
async def get_status(request: Request):
|
| 249 |
+
return worker.get_status()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
parser = argparse.ArgumentParser()
|
| 254 |
+
parser.add_argument("--host", type=str, default="localhost")
|
| 255 |
+
parser.add_argument("--port", type=int, default=21002)
|
| 256 |
+
parser.add_argument("--worker-address", type=str,
|
| 257 |
+
default="http://localhost:21002")
|
| 258 |
+
parser.add_argument("--controller-address", type=str,
|
| 259 |
+
default="http://localhost:21001")
|
| 260 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
| 261 |
+
parser.add_argument("--model-base", type=str, default=None)
|
| 262 |
+
parser.add_argument("--model-name", type=str)
|
| 263 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 264 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
| 265 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
| 266 |
+
parser.add_argument("--no-register", action="store_true")
|
| 267 |
+
parser.add_argument("--load-8bit", action="store_true")
|
| 268 |
+
parser.add_argument("--load-4bit", action="store_true")
|
| 269 |
+
parser.add_argument("--use-flash-attn", action="store_true")
|
| 270 |
+
parser.add_argument("--input-type", type=str, default="mel")
|
| 271 |
+
parser.add_argument("--mel-size", type=int, default=128)
|
| 272 |
+
parser.add_argument("--s2s", action="store_true", default=False)
|
| 273 |
+
parser.add_argument("--is-lora", action="store_true", default=False)
|
| 274 |
+
args = parser.parse_args()
|
| 275 |
+
logger.info(f"args: {args}")
|
| 276 |
+
|
| 277 |
+
worker = ModelWorker(args.controller_address,
|
| 278 |
+
args.worker_address,
|
| 279 |
+
worker_id,
|
| 280 |
+
args.no_register,
|
| 281 |
+
args.model_path,
|
| 282 |
+
args.model_base,
|
| 283 |
+
args.model_name,
|
| 284 |
+
args.load_8bit,
|
| 285 |
+
args.load_4bit,
|
| 286 |
+
args.device,
|
| 287 |
+
args.input_type,
|
| 288 |
+
args.mel_size,
|
| 289 |
+
args.s2s,
|
| 290 |
+
args.is_lora,
|
| 291 |
+
use_flash_attn=args.use_flash_attn)
|
| 292 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
omni_speech/train/__pycache__/omni_trainer.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
omni_speech/train/__pycache__/omni_trainer.cpython-312.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
omni_speech/train/__pycache__/run_train.cpython-310.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
omni_speech/train/__pycache__/run_train.cpython-312.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
omni_speech/train/__pycache__/run_train.cpython-38.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
omni_speech/train/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
omni_speech/train/__pycache__/train_mem.cpython-312.pyc
ADDED
|
Binary file (348 Bytes). View file
|
|
|
omni_speech/train/__pycache__/train_multiturn.cpython-312.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
omni_speech/train/__pycache__/train_raw.cpython-312.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
omni_speech/train/__pycache__/train_test.cpython-312.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
omni_speech/train/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|
omni_speech/train/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
omni_speech/train/export.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 3 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import copy
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import pathlib
|
| 23 |
+
from typing import Dict, Optional, Sequence, List
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
import transformers
|
| 28 |
+
import tokenizers
|
| 29 |
+
|
| 30 |
+
from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
|
| 31 |
+
from torch.utils.data import Dataset
|
| 32 |
+
from omni_speech.train.omni_trainer import OmniTrainer
|
| 33 |
+
from audiomentations import AddBackgroundNoise, PolarityInversion
|
| 34 |
+
|
| 35 |
+
from omni_speech import conversation as conversation_lib
|
| 36 |
+
from omni_speech.model import *
|
| 37 |
+
from omni_speech.utils import *
|
| 38 |
+
from omni_speech.datasets.preprocess import *
|
| 39 |
+
import whisper
|
| 40 |
+
import time
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ModelArguments:
|
| 44 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 45 |
+
version: Optional[str] = field(default="llama_3")
|
| 46 |
+
freeze_backbone: bool = field(default=False)
|
| 47 |
+
tune_speech_projector: bool = field(default=False)
|
| 48 |
+
tune_speech_encoder: bool = field(default=False)
|
| 49 |
+
tune_speech_generator_only: bool = field(default=False)
|
| 50 |
+
speech_encoder_type: Optional[str] = field(default=None)
|
| 51 |
+
speech_encoder: Optional[str] = field(default=None)
|
| 52 |
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
| 53 |
+
speech_projector_type: Optional[str] = field(default='linear')
|
| 54 |
+
speech_generator_type: Optional[str] = field(default='ctc')
|
| 55 |
+
ctc_decoder_config: str = "(2,4096,32,11008)"
|
| 56 |
+
ctc_upsample_factor: int = 25
|
| 57 |
+
ctc_loss_weight: float = 1.0
|
| 58 |
+
unit_vocab_size: int = 1000
|
| 59 |
+
speech_encoder_ds_rate: int = 5
|
| 60 |
+
speech_encoder_hidden_size: int = 1280
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class DataArguments:
|
| 65 |
+
data_path: str = field(default=None,
|
| 66 |
+
metadata={"help": "Path to the training data."})
|
| 67 |
+
dev_path: str = field(default=None,
|
| 68 |
+
metadata={"help": "Path to the dev data."})
|
| 69 |
+
is_multimodal: bool = False
|
| 70 |
+
input_type: str = field(default="mel")
|
| 71 |
+
speech_normalize: bool = False
|
| 72 |
+
mel_size: int = 128
|
| 73 |
+
has_tgt_units: bool = False
|
| 74 |
+
augment_prob: float = field(
|
| 75 |
+
default=0.0,
|
| 76 |
+
metadata={"help": "The probability of applying augmentation transform."}
|
| 77 |
+
)
|
| 78 |
+
augment_path: str = field(default=None,
|
| 79 |
+
metadata={"help": "Path to the augment data."})
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 84 |
+
cache_dir: Optional[str] = field(default=None)
|
| 85 |
+
optim: str = field(default="adamw_torch")
|
| 86 |
+
freeze_speech_projector: bool = field(default=False)
|
| 87 |
+
model_max_length: int = field(
|
| 88 |
+
default=512,
|
| 89 |
+
metadata={
|
| 90 |
+
"help":
|
| 91 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 92 |
+
},
|
| 93 |
+
)
|
| 94 |
+
double_quant: bool = field(
|
| 95 |
+
default=True,
|
| 96 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 97 |
+
)
|
| 98 |
+
quant_type: str = field(
|
| 99 |
+
default="nf4",
|
| 100 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 101 |
+
)
|
| 102 |
+
bits: int = field(
|
| 103 |
+
default=16,
|
| 104 |
+
metadata={"help": "How many bits to use."}
|
| 105 |
+
)
|
| 106 |
+
lora_enable: bool = False
|
| 107 |
+
lora_r: int = 64
|
| 108 |
+
lora_alpha: int = 16
|
| 109 |
+
lora_dropout: float = 0.05
|
| 110 |
+
lora_weight_path: str = ""
|
| 111 |
+
lora_bias: str = "none"
|
| 112 |
+
speech_projector_lr: Optional[float] = None
|
| 113 |
+
group_by_modality_length: bool = field(default=False)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class LazySupervisedDataset(Dataset):
|
| 117 |
+
"""Dataset for supervised fine-tuning."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, data_path: str,
|
| 120 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 121 |
+
data_args: DataArguments):
|
| 122 |
+
super(LazySupervisedDataset, self).__init__()
|
| 123 |
+
list_data_dict = json.load(open(data_path, "r"))
|
| 124 |
+
|
| 125 |
+
self.tokenizer = tokenizer
|
| 126 |
+
self.list_data_dict = list_data_dict
|
| 127 |
+
self.data_args = data_args
|
| 128 |
+
if self.data_args.augment_prob != 0.0:
|
| 129 |
+
with open(self.data_args.augment_path, "r") as f:
|
| 130 |
+
augment_path_list = f.read().splitlines()
|
| 131 |
+
self.transform = AddBackgroundNoise(
|
| 132 |
+
sounds_path=augment_path_list,
|
| 133 |
+
min_snr_db=5.0,
|
| 134 |
+
max_snr_db=30.0,
|
| 135 |
+
noise_transform=PolarityInversion(),
|
| 136 |
+
p=self.data_args.augment_prob
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.list_data_dict)
|
| 141 |
+
|
| 142 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 143 |
+
# TODO: define number of retries somewhere else
|
| 144 |
+
num_base_retries = 3
|
| 145 |
+
num_final_retries = 300
|
| 146 |
+
|
| 147 |
+
# try the current sample first
|
| 148 |
+
for attempt_idx in range(num_base_retries):
|
| 149 |
+
try:
|
| 150 |
+
sample = self._get_item(i)
|
| 151 |
+
return sample
|
| 152 |
+
except Exception as e:
|
| 153 |
+
# sleep 1s in case it is a cloud disk issue
|
| 154 |
+
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
|
| 155 |
+
time.sleep(1)
|
| 156 |
+
|
| 157 |
+
# try other samples, in case it is file corruption issue
|
| 158 |
+
for attempt_idx in range(num_base_retries):
|
| 159 |
+
try:
|
| 160 |
+
next_index = min(i + 1, len(self.list_data_dict) - 1)
|
| 161 |
+
# sample_idx = random.choice(range(len(self)))
|
| 162 |
+
sample = self._get_item(next_index)
|
| 163 |
+
return sample
|
| 164 |
+
except Exception as e:
|
| 165 |
+
# no need to sleep
|
| 166 |
+
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
sample = self._get_item(i)
|
| 171 |
+
return sample
|
| 172 |
+
except Exception as e:
|
| 173 |
+
raise e
|
| 174 |
+
|
| 175 |
+
def process_speech(self, speech_file):
|
| 176 |
+
speech = whisper.load_audio(speech_file)
|
| 177 |
+
if self.data_args.augment_prob != 0.0:
|
| 178 |
+
speech = self.transform(speech, sample_rate=16000)
|
| 179 |
+
if self.data_args.input_type == "raw":
|
| 180 |
+
speech = torch.from_numpy(speech)
|
| 181 |
+
if self.model_config.data_args.speech_normalize:
|
| 182 |
+
speech = torch.nn.functional.layer_norm(speech, speech.shape)
|
| 183 |
+
elif self.data_args.input_type == "mel":
|
| 184 |
+
speech = whisper.pad_or_trim(speech)
|
| 185 |
+
speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
|
| 186 |
+
speech_lengths = torch.LongTensor([speech.shape[0]])
|
| 187 |
+
return speech, speech_lengths
|
| 188 |
+
|
| 189 |
+
def _get_item(self, i) -> Dict[str, torch.Tensor]:
|
| 190 |
+
sources = self.list_data_dict[i]
|
| 191 |
+
if isinstance(i, int):
|
| 192 |
+
sources = [sources]
|
| 193 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
| 194 |
+
for item in sources:
|
| 195 |
+
if 'tools' in item:
|
| 196 |
+
tools_dict = {
|
| 197 |
+
"from": "tools",
|
| 198 |
+
"value": item["tools"]
|
| 199 |
+
}
|
| 200 |
+
item["conversations"].insert(0,tools_dict)
|
| 201 |
+
|
| 202 |
+
if self.data_args.has_tgt_units:
|
| 203 |
+
tgt_units = [e["tgt_units"] for e in sources]
|
| 204 |
+
tgt_units = torch.tensor(tgt_units, dtype=torch.long)
|
| 205 |
+
else:
|
| 206 |
+
tgt_units = None
|
| 207 |
+
|
| 208 |
+
if 'speech' in sources[0]:
|
| 209 |
+
import numpy as np
|
| 210 |
+
speech_file = self.list_data_dict[i]['speech']
|
| 211 |
+
if type(speech_file) is list:
|
| 212 |
+
speech = [self.process_speech(f) for f in speech_file]
|
| 213 |
+
else:
|
| 214 |
+
speech = [self.process_speech(speech_file)]
|
| 215 |
+
|
| 216 |
+
sources = preprocess_multimodal(
|
| 217 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 218 |
+
self.data_args)
|
| 219 |
+
else:
|
| 220 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 221 |
+
data_dict = preprocess(
|
| 222 |
+
sources,
|
| 223 |
+
self.tokenizer,
|
| 224 |
+
has_speech=('speech' in self.list_data_dict[i]))
|
| 225 |
+
if isinstance(i, int):
|
| 226 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
| 227 |
+
labels=data_dict["labels"][0])
|
| 228 |
+
|
| 229 |
+
# speech exist in the data
|
| 230 |
+
if 'speech' in self.list_data_dict[i]:
|
| 231 |
+
data_dict['speech'] = speech
|
| 232 |
+
|
| 233 |
+
if tgt_units is not None:
|
| 234 |
+
data_dict['tgt_units'] = tgt_units[0]
|
| 235 |
+
|
| 236 |
+
data_dict["id"] = self.list_data_dict[i].get("id", i)
|
| 237 |
+
|
| 238 |
+
return data_dict
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@dataclass
|
| 242 |
+
class DataCollatorForSupervisedDataset(object):
|
| 243 |
+
"""Collate examples for supervised fine-tuning."""
|
| 244 |
+
|
| 245 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 246 |
+
|
| 247 |
+
def pad_sequence(self, input_ids, batch_first, padding_value):
|
| 248 |
+
if self.tokenizer.padding_side == "left":
|
| 249 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
| 250 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
| 251 |
+
if self.tokenizer.padding_side == "left":
|
| 252 |
+
input_ids = torch.flip(input_ids, [1])
|
| 253 |
+
return input_ids
|
| 254 |
+
|
| 255 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 256 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 257 |
+
# input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
|
| 258 |
+
input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
|
| 259 |
+
labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
|
| 260 |
+
if self.tokenizer.pad_token_id is None:
|
| 261 |
+
# self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
|
| 262 |
+
self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
|
| 263 |
+
input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 264 |
+
labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 265 |
+
batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
|
| 266 |
+
# batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
|
| 267 |
+
|
| 268 |
+
if 'speech' in instances[0]:
|
| 269 |
+
speechs = [instance['speech'] for instance in instances]
|
| 270 |
+
|
| 271 |
+
speech = [sp[0] for sp_list in speechs for sp in sp_list]
|
| 272 |
+
speech_lengths = [sp[1] for sp_list in speechs for sp in sp_list]
|
| 273 |
+
|
| 274 |
+
batch["speech"] = speech
|
| 275 |
+
# print(len(speech)) # sum(len(audio) for audio in each batch)
|
| 276 |
+
# print(speech[0].shape) # seq_len, dim
|
| 277 |
+
batch['speech_lengths'] = speech_lengths
|
| 278 |
+
# print(speech_lengths[0].shape) # seq_len
|
| 279 |
+
|
| 280 |
+
if 'tgt_units' in instances[0]:
|
| 281 |
+
tgt_units = [instance['tgt_units'] for instance in instances]
|
| 282 |
+
tgt_units = self.pad_sequence(tgt_units, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 283 |
+
batch['tgt_units'] = tgt_units
|
| 284 |
+
# print(batch['tgt_units'])
|
| 285 |
+
# print("---------------")
|
| 286 |
+
# print(batch['input_ids'])
|
| 287 |
+
|
| 288 |
+
return batch
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
| 292 |
+
data_args) -> Dict:
|
| 293 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 294 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 295 |
+
data_path=data_args.data_path,
|
| 296 |
+
data_args=data_args)
|
| 297 |
+
if data_args.dev_path is not None:
|
| 298 |
+
dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 299 |
+
data_path=data_args.dev_path,
|
| 300 |
+
data_args=data_args)
|
| 301 |
+
else:
|
| 302 |
+
dev_dataset = None
|
| 303 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 304 |
+
return dict(train_dataset=train_dataset,
|
| 305 |
+
eval_dataset=dev_dataset,
|
| 306 |
+
data_collator=data_collator)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def train(attn_implementation="flash_attention_2"):
|
| 310 |
+
|
| 311 |
+
parser = transformers.HfArgumentParser(
|
| 312 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
| 313 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 314 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 315 |
+
|
| 316 |
+
bnb_model_from_pretrained_args = {}
|
| 317 |
+
if training_args.bits in [4, 8]:
|
| 318 |
+
from transformers import BitsAndBytesConfig
|
| 319 |
+
bnb_model_from_pretrained_args.update(dict(
|
| 320 |
+
device_map={"": training_args.device},
|
| 321 |
+
load_in_4bit=training_args.bits == 4,
|
| 322 |
+
load_in_8bit=training_args.bits == 8,
|
| 323 |
+
quantization_config=BitsAndBytesConfig(
|
| 324 |
+
load_in_4bit=training_args.bits == 4,
|
| 325 |
+
load_in_8bit=training_args.bits == 8,
|
| 326 |
+
llm_int8_skip_modules=["speech_projector"],
|
| 327 |
+
llm_int8_threshold=6.0,
|
| 328 |
+
llm_int8_has_fp16_weight=False,
|
| 329 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 330 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
| 331 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
| 332 |
+
)
|
| 333 |
+
))
|
| 334 |
+
|
| 335 |
+
if data_args.has_tgt_units:
|
| 336 |
+
if model_args.version == "llama_3":
|
| 337 |
+
model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
|
| 338 |
+
model_args.model_name_or_path,
|
| 339 |
+
cache_dir=training_args.cache_dir,
|
| 340 |
+
attn_implementation=attn_implementation,
|
| 341 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 342 |
+
**bnb_model_from_pretrained_args
|
| 343 |
+
)
|
| 344 |
+
elif model_args.version == "qwen":
|
| 345 |
+
model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
|
| 346 |
+
model_args.model_name_or_path,
|
| 347 |
+
cache_dir=training_args.cache_dir,
|
| 348 |
+
attn_implementation=attn_implementation,
|
| 349 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 350 |
+
**bnb_model_from_pretrained_args
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 354 |
+
else:
|
| 355 |
+
if model_args.version == "llama_3":
|
| 356 |
+
model = OmniSpeechLlamaForCausalLM.from_pretrained(
|
| 357 |
+
model_args.model_name_or_path,
|
| 358 |
+
cache_dir=training_args.cache_dir,
|
| 359 |
+
attn_implementation=attn_implementation,
|
| 360 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 361 |
+
**bnb_model_from_pretrained_args
|
| 362 |
+
)
|
| 363 |
+
elif model_args.version == "qwen":
|
| 364 |
+
model = OmniSpeechQwen2ForCausalLM.from_pretrained(
|
| 365 |
+
model_args.model_name_or_path,
|
| 366 |
+
cache_dir=training_args.cache_dir,
|
| 367 |
+
attn_implementation=attn_implementation,
|
| 368 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 369 |
+
**bnb_model_from_pretrained_args
|
| 370 |
+
)
|
| 371 |
+
else:
|
| 372 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 373 |
+
model.config.use_cache = False
|
| 374 |
+
|
| 375 |
+
if model_args.freeze_backbone:
|
| 376 |
+
model.model.requires_grad_(False)
|
| 377 |
+
|
| 378 |
+
if training_args.bits in [4, 8]:
|
| 379 |
+
from peft import prepare_model_for_kbit_training
|
| 380 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 381 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
| 382 |
+
|
| 383 |
+
if training_args.gradient_checkpointing:
|
| 384 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 385 |
+
model.enable_input_require_grads()
|
| 386 |
+
else:
|
| 387 |
+
def make_inputs_require_grad(module, input, output):
|
| 388 |
+
output.requires_grad_(True)
|
| 389 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 390 |
+
|
| 391 |
+
if training_args.lora_enable:
|
| 392 |
+
from peft import LoraConfig, get_peft_model
|
| 393 |
+
lora_config = LoraConfig(
|
| 394 |
+
r=training_args.lora_r,
|
| 395 |
+
lora_alpha=training_args.lora_alpha,
|
| 396 |
+
target_modules=find_all_linear_names(model),
|
| 397 |
+
lora_dropout=training_args.lora_dropout,
|
| 398 |
+
bias=training_args.lora_bias,
|
| 399 |
+
task_type="CAUSAL_LM",
|
| 400 |
+
)
|
| 401 |
+
if training_args.bits == 16:
|
| 402 |
+
if training_args.bf16:
|
| 403 |
+
model.to(torch.bfloat16)
|
| 404 |
+
if training_args.fp16:
|
| 405 |
+
model.to(torch.float16)
|
| 406 |
+
model = get_peft_model(model, lora_config)
|
| 407 |
+
|
| 408 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 409 |
+
model_args.model_name_or_path,
|
| 410 |
+
cache_dir=training_args.cache_dir,
|
| 411 |
+
model_max_length=training_args.model_max_length,
|
| 412 |
+
padding_side="right",
|
| 413 |
+
use_fast=False,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 417 |
+
model.config.max_length = training_args.model_max_length
|
| 418 |
+
|
| 419 |
+
if model_args.version in conversation_lib.conv_templates:
|
| 420 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
| 421 |
+
else:
|
| 422 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
|
| 423 |
+
|
| 424 |
+
if model_args.speech_encoder is not None:
|
| 425 |
+
model.get_model().initialize_speech_modules(
|
| 426 |
+
model_args=model_args,
|
| 427 |
+
fsdp=training_args.fsdp
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
data_args.is_multimodal = True
|
| 431 |
+
|
| 432 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
| 433 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
| 434 |
+
|
| 435 |
+
model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
|
| 436 |
+
|
| 437 |
+
model.config.speech_normalize = data_args.speech_normalize
|
| 438 |
+
|
| 439 |
+
for p in model.get_speech_encoder().parameters():
|
| 440 |
+
p.requires_grad = False
|
| 441 |
+
|
| 442 |
+
if model_args.tune_speech_projector:
|
| 443 |
+
model.requires_grad_(False)
|
| 444 |
+
for p in model.get_speech_projector().parameters():
|
| 445 |
+
p.requires_grad = True
|
| 446 |
+
|
| 447 |
+
model.config.freeze_speech_projector = training_args.freeze_speech_projector
|
| 448 |
+
if training_args.freeze_speech_projector:
|
| 449 |
+
for p in model.get_speech_projector().parameters():
|
| 450 |
+
p.requires_grad = False
|
| 451 |
+
|
| 452 |
+
if training_args.bits in [4, 8]:
|
| 453 |
+
model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
|
| 454 |
+
|
| 455 |
+
model.config.speech_projector_lr = training_args.speech_projector_lr
|
| 456 |
+
|
| 457 |
+
if data_args.has_tgt_units:
|
| 458 |
+
model.initialize_speech_generator(model_args=model_args)
|
| 459 |
+
|
| 460 |
+
if training_args.bits in [4, 8]:
|
| 461 |
+
from peft.tuners.lora import LoraLayer
|
| 462 |
+
for name, module in model.named_modules():
|
| 463 |
+
if isinstance(module, LoraLayer):
|
| 464 |
+
if training_args.bf16:
|
| 465 |
+
module = module.to(torch.bfloat16)
|
| 466 |
+
if 'norm' in name:
|
| 467 |
+
module = module.to(torch.float32)
|
| 468 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
| 469 |
+
if hasattr(module, 'weight'):
|
| 470 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
| 471 |
+
module = module.to(torch.bfloat16)
|
| 472 |
+
|
| 473 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
| 474 |
+
data_args=data_args)
|
| 475 |
+
|
| 476 |
+
print("Training Layers:")
|
| 477 |
+
for name, param in model.named_parameters():
|
| 478 |
+
if param.requires_grad:
|
| 479 |
+
print(name, param.grad)
|
| 480 |
+
|
| 481 |
+
trainer = OmniTrainer(model=model,
|
| 482 |
+
tokenizer=tokenizer,
|
| 483 |
+
args=training_args,
|
| 484 |
+
**data_module)
|
| 485 |
+
|
| 486 |
+
# if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 487 |
+
# trainer.train(resume_from_checkpoint=True)
|
| 488 |
+
# else:
|
| 489 |
+
# trainer.train()
|
| 490 |
+
# trainer.save_state()
|
| 491 |
+
|
| 492 |
+
model.config.use_cache = True
|
| 493 |
+
|
| 494 |
+
if training_args.lora_enable:
|
| 495 |
+
state_dict = get_peft_state_maybe_zero_3(
|
| 496 |
+
model.named_parameters(), training_args.lora_bias
|
| 497 |
+
)
|
| 498 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
| 499 |
+
model.named_parameters()
|
| 500 |
+
)
|
| 501 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
| 502 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 503 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
| 504 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
| 505 |
+
else:
|
| 506 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
| 507 |
+
output_dir=training_args.output_dir)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
train()
|
| 512 |
+
|
omni_speech/train/omni_trainer.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Sampler
|
| 6 |
+
|
| 7 |
+
from transformers import Trainer
|
| 8 |
+
from transformers.trainer import (
|
| 9 |
+
is_sagemaker_mp_enabled,
|
| 10 |
+
get_parameter_names,
|
| 11 |
+
has_length,
|
| 12 |
+
ALL_LAYERNORM_LAYERS,
|
| 13 |
+
logger,
|
| 14 |
+
)
|
| 15 |
+
from typing import List, Optional
|
| 16 |
+
from omni_speech.utils import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
| 20 |
+
"""
|
| 21 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
if len(indices) % num_chunks != 0:
|
| 25 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
| 26 |
+
|
| 27 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
| 28 |
+
|
| 29 |
+
chunks = [[] for _ in range(num_chunks)]
|
| 30 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
| 31 |
+
for index in indices:
|
| 32 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
| 33 |
+
chunks[shortest_chunk].append(index)
|
| 34 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
| 35 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
| 36 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
| 37 |
+
|
| 38 |
+
return chunks
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
| 42 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 43 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
| 44 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
| 45 |
+
# all samples are in the same modality
|
| 46 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
| 47 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
| 48 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
| 49 |
+
|
| 50 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
| 51 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
| 52 |
+
megabatch_size = world_size * batch_size
|
| 53 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
| 54 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
| 55 |
+
|
| 56 |
+
last_mm = mm_megabatches[-1]
|
| 57 |
+
last_lang = lang_megabatches[-1]
|
| 58 |
+
additional_batch = last_mm + last_lang
|
| 59 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
| 60 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
| 61 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
| 62 |
+
|
| 63 |
+
if len(additional_batch) > 0:
|
| 64 |
+
megabatches.append(sorted(additional_batch))
|
| 65 |
+
|
| 66 |
+
return [i for megabatch in megabatches for i in megabatch]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
| 70 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 71 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
| 72 |
+
megabatch_size = world_size * batch_size
|
| 73 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
| 74 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
| 75 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
| 76 |
+
|
| 77 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class LengthGroupedSampler(Sampler):
|
| 81 |
+
r"""
|
| 82 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
| 83 |
+
keeping a bit of randomness.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
batch_size: int,
|
| 89 |
+
world_size: int,
|
| 90 |
+
lengths: Optional[List[int]] = None,
|
| 91 |
+
generator=None,
|
| 92 |
+
group_by_modality: bool = False,
|
| 93 |
+
):
|
| 94 |
+
if lengths is None:
|
| 95 |
+
raise ValueError("Lengths must be provided.")
|
| 96 |
+
|
| 97 |
+
self.batch_size = batch_size
|
| 98 |
+
self.world_size = world_size
|
| 99 |
+
self.lengths = lengths
|
| 100 |
+
self.generator = generator
|
| 101 |
+
self.group_by_modality = group_by_modality
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return len(self.lengths)
|
| 105 |
+
|
| 106 |
+
def __iter__(self):
|
| 107 |
+
if self.group_by_modality:
|
| 108 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
| 109 |
+
else:
|
| 110 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
| 111 |
+
return iter(indices)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class OmniTrainer(Trainer):
|
| 115 |
+
|
| 116 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 117 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
if self.args.group_by_modality_length:
|
| 121 |
+
lengths = self.train_dataset.modality_lengths
|
| 122 |
+
return LengthGroupedSampler(
|
| 123 |
+
self.args.train_batch_size,
|
| 124 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
|
| 125 |
+
lengths=lengths,
|
| 126 |
+
group_by_modality=True,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
return super()._get_train_sampler()
|
| 130 |
+
|
| 131 |
+
# def create_optimizer(self):
|
| 132 |
+
# from transformers.utils import (
|
| 133 |
+
# is_sagemaker_mp_enabled,
|
| 134 |
+
# )
|
| 135 |
+
# import torch.nn as nn
|
| 136 |
+
# if is_sagemaker_mp_enabled():
|
| 137 |
+
# import smdistributed.modelparallel.torch as smp
|
| 138 |
+
|
| 139 |
+
# """
|
| 140 |
+
# Setup the optimizer.
|
| 141 |
+
|
| 142 |
+
# We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 143 |
+
# Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 144 |
+
# """
|
| 145 |
+
# opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
| 146 |
+
|
| 147 |
+
# if self.optimizer is None:
|
| 148 |
+
# decay_parameters = self.get_decay_parameter_names(opt_model)
|
| 149 |
+
|
| 150 |
+
# optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
| 151 |
+
|
| 152 |
+
# optimizer_grouped_parameters = [
|
| 153 |
+
# # speech projector
|
| 154 |
+
# {
|
| 155 |
+
# "params": [
|
| 156 |
+
# p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "speech_projector" in n)
|
| 157 |
+
# ],
|
| 158 |
+
# "weight_decay": self.args.weight_decay,
|
| 159 |
+
# "learning_rate": optimizer_kwargs["lr"] * 20,
|
| 160 |
+
# },
|
| 161 |
+
# {
|
| 162 |
+
# "params": [
|
| 163 |
+
# p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and "speech_projector" in n)
|
| 164 |
+
# ],
|
| 165 |
+
# "weight_decay": 0.0,
|
| 166 |
+
# "learning_rate": optimizer_kwargs["lr"] * 20,
|
| 167 |
+
# },
|
| 168 |
+
|
| 169 |
+
# # non speech project
|
| 170 |
+
# {
|
| 171 |
+
# "params": [
|
| 172 |
+
# p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "speech_projector" not in n)
|
| 173 |
+
# ],
|
| 174 |
+
# "weight_decay": self.args.weight_decay,
|
| 175 |
+
# },
|
| 176 |
+
# {
|
| 177 |
+
# "params": [
|
| 178 |
+
# p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and "speech_projector" not in n)
|
| 179 |
+
# ],
|
| 180 |
+
# "weight_decay": 0.0,
|
| 181 |
+
# },
|
| 182 |
+
# ]
|
| 183 |
+
|
| 184 |
+
# # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
| 185 |
+
# # e.g. for GaLore optimizer.
|
| 186 |
+
# if "params" in optimizer_kwargs:
|
| 187 |
+
# optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
| 188 |
+
|
| 189 |
+
# # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
| 190 |
+
# # e.g. for LOMO optimizer.
|
| 191 |
+
# if "model" in optimizer_kwargs:
|
| 192 |
+
# optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
| 193 |
+
|
| 194 |
+
# # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
| 195 |
+
# # to avoid arguments conflicts.
|
| 196 |
+
# if "optimizer_dict" in optimizer_kwargs:
|
| 197 |
+
# optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
| 198 |
+
|
| 199 |
+
# self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 200 |
+
|
| 201 |
+
# if optimizer_cls.__name__ == "Adam8bit":
|
| 202 |
+
# import bitsandbytes
|
| 203 |
+
|
| 204 |
+
# manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
| 205 |
+
|
| 206 |
+
# skipped = 0
|
| 207 |
+
# for module in opt_model.modules():
|
| 208 |
+
# if isinstance(module, nn.Embedding):
|
| 209 |
+
# skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
| 210 |
+
# logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
| 211 |
+
# manager.register_module_override(module, "weight", {"optim_bits": 32})
|
| 212 |
+
# logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
| 213 |
+
# logger.info(f"skipped: {skipped / 2 ** 20}M params")
|
| 214 |
+
|
| 215 |
+
# if is_sagemaker_mp_enabled():
|
| 216 |
+
# self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
| 217 |
+
|
| 218 |
+
# return self.optimizer
|
| 219 |
+
|
| 220 |
+
def create_optimizer(self):
|
| 221 |
+
"""
|
| 222 |
+
Setup the optimizer.
|
| 223 |
+
|
| 224 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 225 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 226 |
+
"""
|
| 227 |
+
if is_sagemaker_mp_enabled():
|
| 228 |
+
return super().create_optimizer()
|
| 229 |
+
|
| 230 |
+
opt_model = self.model
|
| 231 |
+
|
| 232 |
+
if self.optimizer is None:
|
| 233 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
| 234 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 235 |
+
if self.args.speech_projector_lr is not None:
|
| 236 |
+
projector_parameters = [name for name, _ in opt_model.named_parameters() if "speech_projector" in name]
|
| 237 |
+
optimizer_grouped_parameters = [
|
| 238 |
+
{
|
| 239 |
+
"params": [
|
| 240 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
|
| 241 |
+
],
|
| 242 |
+
"weight_decay": self.args.weight_decay,
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"params": [
|
| 246 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
|
| 247 |
+
],
|
| 248 |
+
"weight_decay": 0.0,
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"params": [
|
| 252 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
|
| 253 |
+
],
|
| 254 |
+
"weight_decay": self.args.weight_decay,
|
| 255 |
+
"lr": self.args.speech_projector_lr,
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"params": [
|
| 259 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
|
| 260 |
+
],
|
| 261 |
+
"weight_decay": 0.0,
|
| 262 |
+
"lr": self.args.speech_projector_lr,
|
| 263 |
+
},
|
| 264 |
+
]
|
| 265 |
+
else:
|
| 266 |
+
optimizer_grouped_parameters = [
|
| 267 |
+
{
|
| 268 |
+
"params": [
|
| 269 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
| 270 |
+
],
|
| 271 |
+
"weight_decay": self.args.weight_decay,
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"params": [
|
| 275 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
| 276 |
+
],
|
| 277 |
+
"weight_decay": 0.0,
|
| 278 |
+
},
|
| 279 |
+
]
|
| 280 |
+
|
| 281 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
| 282 |
+
|
| 283 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 284 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
| 285 |
+
import bitsandbytes
|
| 286 |
+
|
| 287 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
| 288 |
+
|
| 289 |
+
skipped = 0
|
| 290 |
+
for module in opt_model.modules():
|
| 291 |
+
if isinstance(module, nn.Embedding):
|
| 292 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
| 293 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
| 294 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
| 295 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
| 296 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
| 297 |
+
|
| 298 |
+
return self.optimizer
|
| 299 |
+
|
| 300 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
| 301 |
+
if getattr(self.args, 'tune_speech_projector', False):
|
| 302 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 303 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 304 |
+
|
| 305 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 306 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 307 |
+
|
| 308 |
+
# Only save Adapter
|
| 309 |
+
keys_to_match = ['speech_projector']
|
| 310 |
+
|
| 311 |
+
weight_to_save = get_speech_projector_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
|
| 312 |
+
|
| 313 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
| 314 |
+
self.model.config.save_pretrained(output_dir)
|
| 315 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'speech_projector.bin'))
|
| 316 |
+
else:
|
| 317 |
+
super(OmniTrainer, self)._save_checkpoint(model, trial, metrics)
|
| 318 |
+
|
| 319 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 320 |
+
if getattr(self.args, 'tune_speech_projector', False):
|
| 321 |
+
pass
|
| 322 |
+
else:
|
| 323 |
+
super(OmniTrainer, self)._save(output_dir, state_dict)
|
| 324 |
+
|
| 325 |
+
# def training_step(self, model, inputs):
|
| 326 |
+
# # Move inputs to device
|
| 327 |
+
# inputs = self._prepare_inputs(inputs)
|
| 328 |
+
|
| 329 |
+
# # Forward pass
|
| 330 |
+
# outputs = model(**inputs)
|
| 331 |
+
# loss = outputs.loss
|
| 332 |
+
|
| 333 |
+
# # Backward pass
|
| 334 |
+
# loss.backward()
|
| 335 |
+
|
| 336 |
+
# # Check gradients
|
| 337 |
+
# for name, param in model.module.named_parameters():
|
| 338 |
+
# if param.requires_grad:
|
| 339 |
+
# if param.grad is None:
|
| 340 |
+
# print(f"Gradients for {name} are None.")
|
| 341 |
+
# else:
|
| 342 |
+
# print(f"Gradients for {name}: {param.grad.norm()}") # Check norm of the gradients
|
| 343 |
+
|
| 344 |
+
# # Return loss for optimization
|
| 345 |
+
# return loss.detach()
|
omni_speech/train/train.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 3 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import copy
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import pathlib
|
| 23 |
+
from typing import Dict, Optional, Sequence, List
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
import transformers
|
| 28 |
+
import tokenizers
|
| 29 |
+
|
| 30 |
+
from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
|
| 31 |
+
from torch.utils.data import Dataset
|
| 32 |
+
from omni_speech.train.omni_trainer import OmniTrainer
|
| 33 |
+
|
| 34 |
+
from omni_speech import conversation as conversation_lib
|
| 35 |
+
from omni_speech.model import *
|
| 36 |
+
from omni_speech.utils import *
|
| 37 |
+
from omni_speech.datasets.preprocess import *
|
| 38 |
+
import whisper
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ModelArguments:
|
| 42 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 43 |
+
version: Optional[str] = field(default="llama_3")
|
| 44 |
+
freeze_backbone: bool = field(default=False)
|
| 45 |
+
tune_speech_projector: bool = field(default=False)
|
| 46 |
+
tune_speech_encoder: bool = field(default=False)
|
| 47 |
+
tune_speech_generator_only: bool = field(default=False)
|
| 48 |
+
speech_encoder_type: Optional[str] = field(default=None)
|
| 49 |
+
speech_encoder: Optional[str] = field(default=None)
|
| 50 |
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
| 51 |
+
speech_projector_type: Optional[str] = field(default='linear')
|
| 52 |
+
speech_generator_type: Optional[str] = field(default='ctc')
|
| 53 |
+
ctc_decoder_config: str = "(2,4096,32,11008)"
|
| 54 |
+
ctc_upsample_factor: int = 1
|
| 55 |
+
ctc_loss_weight: float = 1.0
|
| 56 |
+
unit_vocab_size: int = 1000
|
| 57 |
+
speech_encoder_ds_rate: int = 5
|
| 58 |
+
speech_encoder_hidden_size: int = 1280
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class DataArguments:
|
| 63 |
+
data_path: str = field(default=None,
|
| 64 |
+
metadata={"help": "Path to the training data."})
|
| 65 |
+
dev_path: str = field(default=None,
|
| 66 |
+
metadata={"help": "Path to the dev data."})
|
| 67 |
+
is_multimodal: bool = False
|
| 68 |
+
input_type: str = field(default="mel")
|
| 69 |
+
speech_normalize: bool = False
|
| 70 |
+
mel_size: int = 128
|
| 71 |
+
has_tgt_units: bool = False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 76 |
+
cache_dir: Optional[str] = field(default=None)
|
| 77 |
+
optim: str = field(default="adamw_torch")
|
| 78 |
+
freeze_speech_projector: bool = field(default=False)
|
| 79 |
+
model_max_length: int = field(
|
| 80 |
+
default=512,
|
| 81 |
+
metadata={
|
| 82 |
+
"help":
|
| 83 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
double_quant: bool = field(
|
| 87 |
+
default=True,
|
| 88 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 89 |
+
)
|
| 90 |
+
quant_type: str = field(
|
| 91 |
+
default="nf4",
|
| 92 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 93 |
+
)
|
| 94 |
+
bits: int = field(
|
| 95 |
+
default=16,
|
| 96 |
+
metadata={"help": "How many bits to use."}
|
| 97 |
+
)
|
| 98 |
+
lora_enable: bool = False
|
| 99 |
+
lora_r: int = 64
|
| 100 |
+
lora_alpha: int = 16
|
| 101 |
+
lora_dropout: float = 0.05
|
| 102 |
+
lora_weight_path: str = ""
|
| 103 |
+
lora_bias: str = "none"
|
| 104 |
+
speech_projector_lr: Optional[float] = None
|
| 105 |
+
group_by_modality_length: bool = field(default=False)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class LazySupervisedDataset(Dataset):
|
| 109 |
+
"""Dataset for supervised fine-tuning."""
|
| 110 |
+
|
| 111 |
+
def __init__(self, data_path: str,
|
| 112 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 113 |
+
data_args: DataArguments):
|
| 114 |
+
super(LazySupervisedDataset, self).__init__()
|
| 115 |
+
list_data_dict = json.load(open(data_path, "r"))
|
| 116 |
+
|
| 117 |
+
self.tokenizer = tokenizer
|
| 118 |
+
self.list_data_dict = list_data_dict
|
| 119 |
+
self.data_args = data_args
|
| 120 |
+
|
| 121 |
+
def __len__(self):
|
| 122 |
+
return len(self.list_data_dict)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 125 |
+
sources = self.list_data_dict[i]
|
| 126 |
+
if isinstance(i, int):
|
| 127 |
+
sources = [sources]
|
| 128 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
| 129 |
+
if 'speech' in sources[0]:
|
| 130 |
+
import numpy as np
|
| 131 |
+
speech_file = self.list_data_dict[i]['speech']
|
| 132 |
+
speech = whisper.load_audio(speech_file)
|
| 133 |
+
# speech = np.random.uniform(low=-1.0, high=1.0, size=speech.shape[0]).astype(speech.dtype)
|
| 134 |
+
|
| 135 |
+
if self.data_args.input_type == "raw":
|
| 136 |
+
speech = torch.from_numpy(speech)
|
| 137 |
+
if self.model_config.data_args.speech_normalize:
|
| 138 |
+
speech = torch.nn.functional.layer_norm(speech, speech.shape)
|
| 139 |
+
elif self.data_args.input_type == "mel":
|
| 140 |
+
speech = whisper.pad_or_trim(speech)
|
| 141 |
+
speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
|
| 142 |
+
speech_lengths = torch.LongTensor([speech.shape[0]])
|
| 143 |
+
|
| 144 |
+
sources = preprocess_multimodal(
|
| 145 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 146 |
+
self.data_args)
|
| 147 |
+
else:
|
| 148 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 149 |
+
data_dict = preprocess(
|
| 150 |
+
sources,
|
| 151 |
+
self.tokenizer,
|
| 152 |
+
has_speech=('speech' in self.list_data_dict[i]))
|
| 153 |
+
if isinstance(i, int):
|
| 154 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
| 155 |
+
labels=data_dict["labels"][0])
|
| 156 |
+
|
| 157 |
+
# speech exist in the data
|
| 158 |
+
if 'speech' in self.list_data_dict[i]:
|
| 159 |
+
data_dict['speech'] = speech
|
| 160 |
+
data_dict['speech_lengths'] = speech_lengths
|
| 161 |
+
return data_dict
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dataclass
|
| 165 |
+
class DataCollatorForSupervisedDataset(object):
|
| 166 |
+
"""Collate examples for supervised fine-tuning."""
|
| 167 |
+
|
| 168 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 169 |
+
|
| 170 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 171 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
| 172 |
+
for key in ("input_ids", "labels"))
|
| 173 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 174 |
+
input_ids,
|
| 175 |
+
batch_first=True,
|
| 176 |
+
padding_value=self.tokenizer.pad_token_id)
|
| 177 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| 178 |
+
batch_first=True,
|
| 179 |
+
padding_value=IGNORE_INDEX)
|
| 180 |
+
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
| 181 |
+
labels = labels[:, :self.tokenizer.model_max_length]
|
| 182 |
+
batch = dict(
|
| 183 |
+
input_ids=input_ids,
|
| 184 |
+
labels=labels,
|
| 185 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if 'speech' in instances[0]:
|
| 189 |
+
speech = [instance['speech'] for instance in instances]
|
| 190 |
+
speech_lengths = [instance['speech_lengths'] for instance in instances]
|
| 191 |
+
if all(x is not None and x.shape == speech[0].shape for x in speech):
|
| 192 |
+
batch['speech'] = torch.stack(speech)
|
| 193 |
+
batch['speech_lengths'] = torch.stack(speech_lengths)
|
| 194 |
+
else:
|
| 195 |
+
batch['speech'] = speech
|
| 196 |
+
batch['speech_lengths'] = speech_lengths
|
| 197 |
+
|
| 198 |
+
return batch
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
| 202 |
+
data_args) -> Dict:
|
| 203 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 204 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 205 |
+
data_path=data_args.data_path,
|
| 206 |
+
data_args=data_args)
|
| 207 |
+
if data_args.dev_path is not None:
|
| 208 |
+
dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 209 |
+
data_path=data_args.dev_path,
|
| 210 |
+
data_args=data_args)
|
| 211 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 212 |
+
return dict(train_dataset=train_dataset,
|
| 213 |
+
eval_dataset=dev_dataset,
|
| 214 |
+
data_collator=data_collator)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def train(attn_implementation="flash_attention_2"):
|
| 218 |
+
|
| 219 |
+
parser = transformers.HfArgumentParser(
|
| 220 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
| 221 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 222 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 223 |
+
|
| 224 |
+
bnb_model_from_pretrained_args = {}
|
| 225 |
+
if training_args.bits in [4, 8]:
|
| 226 |
+
from transformers import BitsAndBytesConfig
|
| 227 |
+
bnb_model_from_pretrained_args.update(dict(
|
| 228 |
+
device_map={"": training_args.device},
|
| 229 |
+
load_in_4bit=training_args.bits == 4,
|
| 230 |
+
load_in_8bit=training_args.bits == 8,
|
| 231 |
+
quantization_config=BitsAndBytesConfig(
|
| 232 |
+
load_in_4bit=training_args.bits == 4,
|
| 233 |
+
load_in_8bit=training_args.bits == 8,
|
| 234 |
+
llm_int8_skip_modules=["speech_projector"],
|
| 235 |
+
llm_int8_threshold=6.0,
|
| 236 |
+
llm_int8_has_fp16_weight=False,
|
| 237 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 238 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
| 239 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
| 240 |
+
)
|
| 241 |
+
))
|
| 242 |
+
|
| 243 |
+
if data_args.has_tgt_units:
|
| 244 |
+
if model_args.version == "llama_3":
|
| 245 |
+
model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
|
| 246 |
+
model_args.model_name_or_path,
|
| 247 |
+
cache_dir=training_args.cache_dir,
|
| 248 |
+
attn_implementation=attn_implementation,
|
| 249 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 250 |
+
**bnb_model_from_pretrained_args
|
| 251 |
+
)
|
| 252 |
+
elif model_args.version == "qwen":
|
| 253 |
+
model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
|
| 254 |
+
model_args.model_name_or_path,
|
| 255 |
+
cache_dir=training_args.cache_dir,
|
| 256 |
+
attn_implementation=attn_implementation,
|
| 257 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 258 |
+
**bnb_model_from_pretrained_args
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 262 |
+
else:
|
| 263 |
+
if model_args.version == "llama_3":
|
| 264 |
+
model = OmniSpeechLlamaForCausalLM.from_pretrained(
|
| 265 |
+
model_args.model_name_or_path,
|
| 266 |
+
cache_dir=training_args.cache_dir,
|
| 267 |
+
attn_implementation=attn_implementation,
|
| 268 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 269 |
+
**bnb_model_from_pretrained_args
|
| 270 |
+
)
|
| 271 |
+
elif model_args.version == "qwen":
|
| 272 |
+
model = OmniSpeechQwen2ForCausalLM.from_pretrained(
|
| 273 |
+
model_args.model_name_or_path,
|
| 274 |
+
cache_dir=training_args.cache_dir,
|
| 275 |
+
attn_implementation=attn_implementation,
|
| 276 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 277 |
+
**bnb_model_from_pretrained_args
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 281 |
+
model.config.use_cache = False
|
| 282 |
+
|
| 283 |
+
if model_args.freeze_backbone:
|
| 284 |
+
model.model.requires_grad_(False)
|
| 285 |
+
|
| 286 |
+
if training_args.bits in [4, 8]:
|
| 287 |
+
from peft import prepare_model_for_kbit_training
|
| 288 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 289 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
| 290 |
+
|
| 291 |
+
if training_args.gradient_checkpointing:
|
| 292 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 293 |
+
model.enable_input_require_grads()
|
| 294 |
+
else:
|
| 295 |
+
def make_inputs_require_grad(module, input, output):
|
| 296 |
+
output.requires_grad_(True)
|
| 297 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 298 |
+
|
| 299 |
+
if training_args.lora_enable:
|
| 300 |
+
from peft import LoraConfig, get_peft_model
|
| 301 |
+
lora_config = LoraConfig(
|
| 302 |
+
r=training_args.lora_r,
|
| 303 |
+
lora_alpha=training_args.lora_alpha,
|
| 304 |
+
target_modules=find_all_linear_names(model),
|
| 305 |
+
lora_dropout=training_args.lora_dropout,
|
| 306 |
+
bias=training_args.lora_bias,
|
| 307 |
+
task_type="CAUSAL_LM",
|
| 308 |
+
)
|
| 309 |
+
if training_args.bits == 16:
|
| 310 |
+
if training_args.bf16:
|
| 311 |
+
model.to(torch.bfloat16)
|
| 312 |
+
if training_args.fp16:
|
| 313 |
+
model.to(torch.float16)
|
| 314 |
+
model = get_peft_model(model, lora_config)
|
| 315 |
+
|
| 316 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 317 |
+
model_args.model_name_or_path,
|
| 318 |
+
cache_dir=training_args.cache_dir,
|
| 319 |
+
model_max_length=training_args.model_max_length,
|
| 320 |
+
padding_side="right",
|
| 321 |
+
use_fast=False,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 325 |
+
model.config.max_length = training_args.model_max_length
|
| 326 |
+
|
| 327 |
+
if model_args.version in conversation_lib.conv_templates:
|
| 328 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
| 329 |
+
else:
|
| 330 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
|
| 331 |
+
|
| 332 |
+
if model_args.speech_encoder is not None:
|
| 333 |
+
model.get_model().initialize_speech_modules(
|
| 334 |
+
model_args=model_args,
|
| 335 |
+
fsdp=training_args.fsdp
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
data_args.is_multimodal = True
|
| 339 |
+
|
| 340 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
| 341 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
| 342 |
+
|
| 343 |
+
model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
|
| 344 |
+
|
| 345 |
+
model.config.speech_normalize = data_args.speech_normalize
|
| 346 |
+
|
| 347 |
+
for p in model.get_speech_encoder().parameters():
|
| 348 |
+
p.requires_grad = False
|
| 349 |
+
|
| 350 |
+
if model_args.tune_speech_projector:
|
| 351 |
+
model.requires_grad_(False)
|
| 352 |
+
for p in model.get_speech_projector().parameters():
|
| 353 |
+
p.requires_grad = True
|
| 354 |
+
|
| 355 |
+
model.config.freeze_speech_projector = training_args.freeze_speech_projector
|
| 356 |
+
if training_args.freeze_speech_projector:
|
| 357 |
+
for p in model.get_speech_projector().parameters():
|
| 358 |
+
p.requires_grad = False
|
| 359 |
+
|
| 360 |
+
if training_args.bits in [4, 8]:
|
| 361 |
+
model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
|
| 362 |
+
|
| 363 |
+
model.config.speech_projector_lr = training_args.speech_projector_lr
|
| 364 |
+
|
| 365 |
+
if data_args.has_tgt_units:
|
| 366 |
+
model.initialize_speech_generator(model_args=model_args)
|
| 367 |
+
|
| 368 |
+
if training_args.bits in [4, 8]:
|
| 369 |
+
from peft.tuners.lora import LoraLayer
|
| 370 |
+
for name, module in model.named_modules():
|
| 371 |
+
if isinstance(module, LoraLayer):
|
| 372 |
+
if training_args.bf16:
|
| 373 |
+
module = module.to(torch.bfloat16)
|
| 374 |
+
if 'norm' in name:
|
| 375 |
+
module = module.to(torch.float32)
|
| 376 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
| 377 |
+
if hasattr(module, 'weight'):
|
| 378 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
| 379 |
+
module = module.to(torch.bfloat16)
|
| 380 |
+
|
| 381 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
| 382 |
+
data_args=data_args)
|
| 383 |
+
|
| 384 |
+
print("Training Layers:")
|
| 385 |
+
for name, param in model.named_parameters():
|
| 386 |
+
if param.requires_grad:
|
| 387 |
+
print(name, param.grad)
|
| 388 |
+
|
| 389 |
+
trainer = OmniTrainer(model=model,
|
| 390 |
+
tokenizer=tokenizer,
|
| 391 |
+
args=training_args,
|
| 392 |
+
**data_module)
|
| 393 |
+
|
| 394 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 395 |
+
trainer.train(resume_from_checkpoint=True)
|
| 396 |
+
else:
|
| 397 |
+
trainer.train()
|
| 398 |
+
trainer.save_state()
|
| 399 |
+
|
| 400 |
+
model.config.use_cache = True
|
| 401 |
+
|
| 402 |
+
if training_args.lora_enable:
|
| 403 |
+
state_dict = get_peft_state_maybe_zero_3(
|
| 404 |
+
model.named_parameters(), training_args.lora_bias
|
| 405 |
+
)
|
| 406 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
| 407 |
+
model.named_parameters()
|
| 408 |
+
)
|
| 409 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
| 410 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 411 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
| 412 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
| 413 |
+
else:
|
| 414 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
| 415 |
+
output_dir=training_args.output_dir)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
train()
|
| 420 |
+
|
omni_speech/train/train_mem.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omni_speech.train.train_multiturn import train
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
train(attn_implementation="flash_attention_2")
|
omni_speech/train/train_minicpmo.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Dict, List, Optional, Union, Literal, Tuple
|
| 8 |
+
from types import MethodType
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import transformers
|
| 14 |
+
from accelerate.utils import DistributedType
|
| 15 |
+
from deepspeed import zero
|
| 16 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 17 |
+
import pathlib
|
| 18 |
+
|
| 19 |
+
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
| 20 |
+
from transformers.integrations import deepspeed
|
| 21 |
+
|
| 22 |
+
from omni_speech.datasets.dataset import SupervisedDataset, data_collator
|
| 23 |
+
from omni_speech.model import *
|
| 24 |
+
from trainer import CPMTrainer
|
| 25 |
+
from transformers import Trainer
|
| 26 |
+
import librosa
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
import numpy as np
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from functools import partial
|
| 31 |
+
from audiomentations import AddBackgroundNoise, PolarityInversion
|
| 32 |
+
|
| 33 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ModelArguments:
|
| 37 |
+
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-o-2_6")
|
| 38 |
+
tokenizer_path: Optional[str] = field(default=None)
|
| 39 |
+
audio_encoder_path: Optional[str] = field(default=None)
|
| 40 |
+
pretrained_llm_path: Optional[str] = field(default=None)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class DataArguments:
|
| 45 |
+
data_path: str = field(
|
| 46 |
+
default=None, metadata={"help": "Path to the training data."}
|
| 47 |
+
)
|
| 48 |
+
eval_data_path: str = field(
|
| 49 |
+
default=None, metadata={"help": "Path to the evaluation data."}
|
| 50 |
+
)
|
| 51 |
+
max_train_samples: Optional[int] = field(
|
| 52 |
+
default=None,
|
| 53 |
+
metadata={
|
| 54 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 55 |
+
"value if set."
|
| 56 |
+
},
|
| 57 |
+
)
|
| 58 |
+
max_eval_samples: Optional[int] = field(
|
| 59 |
+
default=None,
|
| 60 |
+
metadata={
|
| 61 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 62 |
+
"value if set."
|
| 63 |
+
},
|
| 64 |
+
)
|
| 65 |
+
augment_prob: float = field(
|
| 66 |
+
default=0.0,
|
| 67 |
+
metadata={"help": "The probability of applying augmentation transform."}
|
| 68 |
+
)
|
| 69 |
+
augment_path: str = field(default=None,
|
| 70 |
+
metadata={"help": "Path to the augment data."})
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 75 |
+
cache_dir: Optional[str] = field(default=None)
|
| 76 |
+
optim: str = field(default="adamw_torch")
|
| 77 |
+
model_max_length: int = field(
|
| 78 |
+
default=2048,
|
| 79 |
+
metadata={
|
| 80 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
tune_vision: Optional[bool] = field(default=True)
|
| 84 |
+
tune_speech: Optional[bool] = field(default=True)
|
| 85 |
+
tune_llm: Optional[bool] = field(default=True)
|
| 86 |
+
llm_type: str = field(default="qwen")
|
| 87 |
+
use_lora: Optional[bool] = field(default=False)
|
| 88 |
+
max_slice_nums: Optional[int] = field(default=9)
|
| 89 |
+
config_path: Optional[str] = field(default=None)
|
| 90 |
+
chunk_input: Optional[bool] = field(default=True)
|
| 91 |
+
init_vision: Optional[bool] = field(default=False)
|
| 92 |
+
init_speech: Optional[bool] = field(default=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class LoraArguments:
|
| 97 |
+
lora_r: int = 64
|
| 98 |
+
lora_alpha: int = 64
|
| 99 |
+
lora_dropout: float = 0.05
|
| 100 |
+
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
|
| 101 |
+
lora_weight_path: str = ""
|
| 102 |
+
lora_bias: str = "none"
|
| 103 |
+
q_lora: bool = False
|
| 104 |
+
lora_modules_to_save: str = ""
|
| 105 |
+
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
|
| 106 |
+
lora_layers_to_transform: Optional[List[int]] = None
|
| 107 |
+
lora_layers_pattern: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
local_rank = None
|
| 110 |
+
def rank0_print(*args):
|
| 111 |
+
if local_rank == 0:
|
| 112 |
+
print(*args)
|
| 113 |
+
|
| 114 |
+
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
|
| 115 |
+
"""Collects the state dict and dump to disk."""
|
| 116 |
+
if trainer.args.should_save and trainer.args.local_rank == 0:
|
| 117 |
+
trainer.save_model(output_dir,)
|
| 118 |
+
|
| 119 |
+
# class CollateFn:
|
| 120 |
+
# def __init__(self, processor, prompt="Please transcribe this audio into text.", system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."):
|
| 121 |
+
# self.prompt = prompt
|
| 122 |
+
# self.system_prompt = system_prompt
|
| 123 |
+
# self.processor = processor
|
| 124 |
+
|
| 125 |
+
# def __call__(self, examples):
|
| 126 |
+
# prompts_lists = []
|
| 127 |
+
# input_images_list = []
|
| 128 |
+
# input_audios_list = []
|
| 129 |
+
# audio_parts_list = []
|
| 130 |
+
|
| 131 |
+
# for msgs in examples:
|
| 132 |
+
# msgs = msgs["conversations"]
|
| 133 |
+
# if isinstance(msgs, str):
|
| 134 |
+
# msgs = json.loads(msgs)
|
| 135 |
+
# copy_msgs = deepcopy(msgs)
|
| 136 |
+
|
| 137 |
+
# assert len(msgs) > 0, "msgs is empty"
|
| 138 |
+
|
| 139 |
+
# system_turn = {'role': 'system', 'content': self.system_prompt}
|
| 140 |
+
# if copy_msgs[0]["role"] != 'system':
|
| 141 |
+
# copy_msgs.insert(0, system_turn)
|
| 142 |
+
|
| 143 |
+
# images = []
|
| 144 |
+
# audios = []
|
| 145 |
+
# audio_parts = []
|
| 146 |
+
# for i, msg in enumerate(copy_msgs):
|
| 147 |
+
# role = msg["role"]
|
| 148 |
+
# content = msg["content"]
|
| 149 |
+
# assert role in ["system", "user", "assistant"]
|
| 150 |
+
# if i == 0:
|
| 151 |
+
# assert role in ["user", "system"], "The role of first msg should be user"
|
| 152 |
+
# content = [content, self.prompt]
|
| 153 |
+
# cur_msgs = []
|
| 154 |
+
|
| 155 |
+
# for c in content:
|
| 156 |
+
# if os.path.exists(c):
|
| 157 |
+
# c, _ = librosa.load(c, sr=16000, mono=True)
|
| 158 |
+
|
| 159 |
+
# if isinstance(c, Image.Image):
|
| 160 |
+
# images.append(c)
|
| 161 |
+
# cur_msgs.append("(<image>./</image>)")
|
| 162 |
+
# elif isinstance(c, np.ndarray): # audio
|
| 163 |
+
# audios.append(c)
|
| 164 |
+
# audio_parts.append(i)
|
| 165 |
+
# cur_msgs.append("(<audio>./</audio>)")
|
| 166 |
+
# elif isinstance(c, str):
|
| 167 |
+
# cur_msgs.append(c)
|
| 168 |
+
# else:
|
| 169 |
+
# msg["content"] = "\n".join(cur_msgs)
|
| 170 |
+
|
| 171 |
+
# prompts_lists.append(
|
| 172 |
+
# self.processor.tokenizer.apply_chat_template(
|
| 173 |
+
# copy_msgs,
|
| 174 |
+
# tokenize=False,
|
| 175 |
+
# add_generation_prompt=False,
|
| 176 |
+
# )
|
| 177 |
+
# )
|
| 178 |
+
# input_images_list.append(images)
|
| 179 |
+
# input_audios_list.append(audios)
|
| 180 |
+
# audio_parts_list.append(audio_parts)
|
| 181 |
+
|
| 182 |
+
# inputs = self.processor(
|
| 183 |
+
# prompts_lists,
|
| 184 |
+
# input_images_list,
|
| 185 |
+
# input_audios_list,
|
| 186 |
+
# audio_parts_list,
|
| 187 |
+
# return_tensors="pt",
|
| 188 |
+
# max_length=32768,
|
| 189 |
+
# return_labels=True,
|
| 190 |
+
# )
|
| 191 |
+
|
| 192 |
+
# return inputs
|
| 193 |
+
|
| 194 |
+
def collate_fn(examples, processor, chunk_input, max_len, prompt=None, system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", transform=None):
|
| 195 |
+
|
| 196 |
+
prompts_lists = []
|
| 197 |
+
input_images_list = []
|
| 198 |
+
input_audios_list = []
|
| 199 |
+
audio_parts_list = []
|
| 200 |
+
|
| 201 |
+
for msgs in examples:
|
| 202 |
+
raw_msgs = deepcopy(msgs)
|
| 203 |
+
msgs = msgs["conversations"]
|
| 204 |
+
if isinstance(msgs, str):
|
| 205 |
+
msgs = json.loads(msgs)
|
| 206 |
+
copy_msgs = deepcopy(msgs)
|
| 207 |
+
|
| 208 |
+
assert len(msgs) > 0, "msgs is empty"
|
| 209 |
+
|
| 210 |
+
system_turn = {'role': 'system', 'content': system_prompt}
|
| 211 |
+
if copy_msgs[0]["role"] != 'system':
|
| 212 |
+
copy_msgs.insert(0, system_turn)
|
| 213 |
+
|
| 214 |
+
fc = None
|
| 215 |
+
if "tools" in raw_msgs:
|
| 216 |
+
# if raw_msgs["tools"] != "":
|
| 217 |
+
# json_objects = raw_msgs["tools"].split("\n\n")
|
| 218 |
+
# try:
|
| 219 |
+
# fc = [json.loads(obj) for obj in json_objects]
|
| 220 |
+
# except:
|
| 221 |
+
# if len(json_objects) > 1:
|
| 222 |
+
# json_objects = json_objects[:-1]
|
| 223 |
+
# fc = [json.loads(obj) for obj in json_objects]
|
| 224 |
+
if raw_msgs["tools"] != "":
|
| 225 |
+
fc = json.loads(raw_msgs["tools"])
|
| 226 |
+
|
| 227 |
+
# print(fc)
|
| 228 |
+
# print("-----------")
|
| 229 |
+
|
| 230 |
+
images = []
|
| 231 |
+
audios = []
|
| 232 |
+
audio_parts = []
|
| 233 |
+
for i, msg in enumerate(copy_msgs):
|
| 234 |
+
role = msg["role"]
|
| 235 |
+
content = msg["content"]
|
| 236 |
+
assert role in ["system", "user", "assistant", "tool"]
|
| 237 |
+
if i == 0:
|
| 238 |
+
assert role in ["user", "system"], "The role of first msg should be user or system"
|
| 239 |
+
|
| 240 |
+
if role == "user":
|
| 241 |
+
if prompt is not None:
|
| 242 |
+
content = [content, prompt]
|
| 243 |
+
else:
|
| 244 |
+
content = [content]
|
| 245 |
+
cur_msgs = []
|
| 246 |
+
for c in content:
|
| 247 |
+
if os.path.exists(c):
|
| 248 |
+
c, _ = librosa.load(c, sr=16000, mono=True)
|
| 249 |
+
if transform is not None:
|
| 250 |
+
c = transform(c, sample_rate=16000)
|
| 251 |
+
|
| 252 |
+
if isinstance(c, Image.Image):
|
| 253 |
+
images.append(c)
|
| 254 |
+
cur_msgs.append("(<image>./</image>)")
|
| 255 |
+
elif isinstance(c, np.ndarray): # audio
|
| 256 |
+
audios.append(c)
|
| 257 |
+
audio_parts.append(i)
|
| 258 |
+
cur_msgs.append("(<audio>./</audio>)")
|
| 259 |
+
elif isinstance(c, str):
|
| 260 |
+
cur_msgs.append(c)
|
| 261 |
+
|
| 262 |
+
msg["content"] = "\n".join(cur_msgs)
|
| 263 |
+
|
| 264 |
+
if "tool_calls" in msg:
|
| 265 |
+
if msg["tool_calls"] is not None:
|
| 266 |
+
assert isinstance(msg["tool_calls"], str), f"Invalid type: {type(msg['tool_calls'])}"
|
| 267 |
+
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
| 268 |
+
if type(msg["tool_calls"]) != list:
|
| 269 |
+
msg["tool_calls"] = [msg["tool_calls"]]
|
| 270 |
+
|
| 271 |
+
# print(copy_msgs)
|
| 272 |
+
# print("--------")
|
| 273 |
+
|
| 274 |
+
qwen_template = processor.tokenizer.apply_chat_template(
|
| 275 |
+
copy_msgs,
|
| 276 |
+
tokenize=False,
|
| 277 |
+
add_generation_prompt=False,
|
| 278 |
+
tools = fc,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# print(qwen_template)
|
| 282 |
+
# print("---------------")
|
| 283 |
+
|
| 284 |
+
prompts_lists.append(qwen_template)
|
| 285 |
+
input_images_list.append(images)
|
| 286 |
+
input_audios_list.append(audios)
|
| 287 |
+
audio_parts_list.append(audio_parts)
|
| 288 |
+
|
| 289 |
+
inputs = processor(
|
| 290 |
+
prompts_lists,
|
| 291 |
+
input_images_list,
|
| 292 |
+
input_audios_list,
|
| 293 |
+
audio_parts_list,
|
| 294 |
+
chunk_input=chunk_input,
|
| 295 |
+
return_tensors="pt",
|
| 296 |
+
# max_length=max_len,
|
| 297 |
+
return_labels=True,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
return inputs
|
| 301 |
+
|
| 302 |
+
def make_supervised_data_module(
|
| 303 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 304 |
+
processor: transformers.ProcessorMixin,
|
| 305 |
+
data_args,
|
| 306 |
+
transform,
|
| 307 |
+
data_collator=None,
|
| 308 |
+
llm_type="qwen",
|
| 309 |
+
slice_config=None,
|
| 310 |
+
patch_size=14,
|
| 311 |
+
query_nums=64,
|
| 312 |
+
batch_vision=False,
|
| 313 |
+
max_length=2048,
|
| 314 |
+
) -> Dict:
|
| 315 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 316 |
+
dataset_cls = SupervisedDataset
|
| 317 |
+
|
| 318 |
+
rank0_print("Loading data...")
|
| 319 |
+
|
| 320 |
+
train_json = json.load(open(data_args.data_path, "r"))
|
| 321 |
+
train_dataset = dataset_cls(
|
| 322 |
+
train_json,
|
| 323 |
+
transform,
|
| 324 |
+
tokenizer,
|
| 325 |
+
processor,
|
| 326 |
+
slice_config=slice_config,
|
| 327 |
+
llm_type=llm_type,
|
| 328 |
+
patch_size=patch_size,
|
| 329 |
+
query_nums=query_nums,
|
| 330 |
+
batch_vision=batch_vision,
|
| 331 |
+
max_length=max_length,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if data_args.eval_data_path:
|
| 335 |
+
eval_json = json.load(open(data_args.eval_data_path, "r"))
|
| 336 |
+
eval_dataset = dataset_cls(
|
| 337 |
+
eval_json,
|
| 338 |
+
transform,
|
| 339 |
+
tokenizer,
|
| 340 |
+
processor,
|
| 341 |
+
slice_config=slice_config,
|
| 342 |
+
llm_type=llm_type,
|
| 343 |
+
patch_size=patch_size,
|
| 344 |
+
query_nums=query_nums,
|
| 345 |
+
batch_vision=batch_vision,
|
| 346 |
+
max_length=max_length,
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
eval_dataset = None
|
| 350 |
+
|
| 351 |
+
return dict(
|
| 352 |
+
train_dataset=train_dataset,
|
| 353 |
+
eval_dataset=eval_dataset,
|
| 354 |
+
data_collator= partial(data_collator, max_length=max_length),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def build_transform():
|
| 359 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
| 360 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
| 361 |
+
return transforms.Compose(
|
| 362 |
+
[
|
| 363 |
+
transforms.ToTensor(),
|
| 364 |
+
transforms.Normalize(
|
| 365 |
+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
|
| 366 |
+
),
|
| 367 |
+
]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def get_parameter_number(model):
|
| 371 |
+
trainable_params, all_param = 0, 0
|
| 372 |
+
for param in model.parameters():
|
| 373 |
+
num_params = param.numel()
|
| 374 |
+
# if using DS Zero 3 and the weights are initialized empty
|
| 375 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
| 376 |
+
num_params = param.ds_numel
|
| 377 |
+
|
| 378 |
+
all_param += num_params
|
| 379 |
+
if param.requires_grad:
|
| 380 |
+
trainable_params += num_params
|
| 381 |
+
|
| 382 |
+
return {'Total': all_param, 'Trainable': trainable_params}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
local_rank = 0
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def train(attn_implementation="flash_attention_2"):
|
| 389 |
+
global local_rank
|
| 390 |
+
parser = transformers.HfArgumentParser(
|
| 391 |
+
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
(
|
| 395 |
+
model_args,
|
| 396 |
+
data_args,
|
| 397 |
+
training_args,
|
| 398 |
+
lora_args,
|
| 399 |
+
) = parser.parse_args_into_dataclasses()
|
| 400 |
+
|
| 401 |
+
if getattr(training_args, "deepspeed", None) :
|
| 402 |
+
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
| 403 |
+
|
| 404 |
+
compute_dtype = (
|
| 405 |
+
torch.float16
|
| 406 |
+
if training_args.fp16
|
| 407 |
+
else (torch.bfloat16 if training_args.bf16 else torch.float32)
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
local_rank = training_args.local_rank
|
| 411 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 412 |
+
ddp = world_size != 1
|
| 413 |
+
device_map = None
|
| 414 |
+
if lora_args.q_lora:
|
| 415 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
| 416 |
+
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
| 417 |
+
logging.warning(
|
| 418 |
+
"FSDP or ZeRO3 are not incompatible with QLoRA."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
minipcmo_config = {}
|
| 422 |
+
if training_args.config_path is not None:
|
| 423 |
+
minipcmo_config = json.load(open(training_args.config_path, "r"))
|
| 424 |
+
|
| 425 |
+
# if model_args.tokenizer_path is not None:
|
| 426 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
| 427 |
+
# model_args.tokenizer_path, trust_remote_code=True
|
| 428 |
+
# )
|
| 429 |
+
# else:
|
| 430 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
| 431 |
+
# model_args.model_name_or_path, trust_remote_code=True
|
| 432 |
+
# )
|
| 433 |
+
|
| 434 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 435 |
+
model_args.model_name_or_path, trust_remote_code=True
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if model_args.pretrained_llm_path is None:
|
| 439 |
+
print("Finetuning model!!!")
|
| 440 |
+
model = MiniCPMO.from_pretrained(
|
| 441 |
+
model_args.model_name_or_path,
|
| 442 |
+
torch_dtype=compute_dtype,
|
| 443 |
+
device_map=device_map,
|
| 444 |
+
attn_implementation=attn_implementation,
|
| 445 |
+
init_vision=training_args.init_vision,
|
| 446 |
+
init_audio=training_args.init_speech,
|
| 447 |
+
init_tts=False,
|
| 448 |
+
processor_path=model_args.tokenizer_path,
|
| 449 |
+
**minipcmo_config
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
print("Load pretrained LLM from scratch!!!")
|
| 453 |
+
# # Create the config object as needed
|
| 454 |
+
# config = MiniCPMOConfig(
|
| 455 |
+
# model_name_or_path=model_args.model_name_or_path,
|
| 456 |
+
# pretrained_llm_path=model_args.pretrained_llm_path,
|
| 457 |
+
# init_vision=training_args.init_vision,
|
| 458 |
+
# init_audio=training_args.init_speech,
|
| 459 |
+
# pretrained_encoder_path=model_args.audio_encoder_path,
|
| 460 |
+
# processor_path=model_args.tokenizer_path,
|
| 461 |
+
# **minipcmo_config
|
| 462 |
+
# )
|
| 463 |
+
|
| 464 |
+
# # Initialize model
|
| 465 |
+
# model = MiniCPMO(config)
|
| 466 |
+
|
| 467 |
+
model = MiniCPMO.from_pretrained(
|
| 468 |
+
model_args.model_name_or_path,
|
| 469 |
+
pretrained_llm_path=model_args.pretrained_llm_path,
|
| 470 |
+
init_vision=training_args.init_vision,
|
| 471 |
+
init_audio=training_args.init_speech,
|
| 472 |
+
pretrained_encoder_path=model_args.audio_encoder_path,
|
| 473 |
+
processor_path=model_args.tokenizer_path,
|
| 474 |
+
**minipcmo_config
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# tokenizer.audio_start_id = tokenizer.convert_tokens_to_ids("<|box_start|>")
|
| 478 |
+
# tokenizer.audio_end_id = tokenizer.convert_tokens_to_ids("<|box_end|>")
|
| 479 |
+
# tokenizer.audio_start = "<|box_start|>"
|
| 480 |
+
# tokenizer.audio_end = "<|box_end|>"
|
| 481 |
+
# tokenizer.im_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
| 482 |
+
# tokenizer.im_end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>")
|
| 483 |
+
# tokenizer.im_start = "<|vision_start|>"
|
| 484 |
+
# tokenizer.im_end = "<|vision_end|>"
|
| 485 |
+
# tokenizer.slice_start_id = tokenizer.convert_tokens_to_ids("<|quad_start|>")
|
| 486 |
+
# tokenizer.slice_end_id = tokenizer.convert_tokens_to_ids("<|quad_end|>")
|
| 487 |
+
# tokenizer.slice_start = "<|quad_start|>"
|
| 488 |
+
# tokenizer.slice_end = "<|quad_end|>"
|
| 489 |
+
# tokenizer.unk_token = "<unk>"
|
| 490 |
+
|
| 491 |
+
# print("Audio Start Token:", tokenizer.audio_start)
|
| 492 |
+
# print("Audio End Token:", tokenizer.audio_end)
|
| 493 |
+
# print(tokenizer.audio_start_id)
|
| 494 |
+
# print(tokenizer.audio_end_id)
|
| 495 |
+
# print("Start Token:", tokenizer.im_start)
|
| 496 |
+
# print("End Token:", tokenizer.im_end)
|
| 497 |
+
# print(tokenizer.im_start_id)
|
| 498 |
+
# print(tokenizer.im_end_id)
|
| 499 |
+
# print("Slice Start Token:", tokenizer.slice_start)
|
| 500 |
+
# print("Slice End Token:", tokenizer.slice_end)
|
| 501 |
+
# print(tokenizer.slice_start_id)
|
| 502 |
+
# print(tokenizer.slice_end_id)
|
| 503 |
+
|
| 504 |
+
model.config.chunk_input = training_args.chunk_input
|
| 505 |
+
# model.llm.resize_token_embeddings(len(tokenizer))
|
| 506 |
+
# model.resize_token_embeddings(len(tokenizer))
|
| 507 |
+
|
| 508 |
+
model.llm.config.use_cache = False
|
| 509 |
+
model.config.max_length = training_args.model_max_length
|
| 510 |
+
|
| 511 |
+
if not training_args.tune_vision and training_args.init_vision:
|
| 512 |
+
model.vpm.requires_grad_(False)
|
| 513 |
+
if not training_args.tune_speech and training_args.init_speech:
|
| 514 |
+
model.apm.requires_grad_(False)
|
| 515 |
+
if not training_args.tune_llm:
|
| 516 |
+
model.llm.requires_grad_(False)
|
| 517 |
+
|
| 518 |
+
if training_args.use_lora:
|
| 519 |
+
if training_args.use_lora and training_args.tune_llm:
|
| 520 |
+
raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
|
| 521 |
+
|
| 522 |
+
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
|
| 523 |
+
for name, param in model.llm.named_parameters():
|
| 524 |
+
param.requires_grad = False
|
| 525 |
+
modules_to_save = ['embed_tokens','resampler']
|
| 526 |
+
if training_args.tune_vision:
|
| 527 |
+
modules_to_save.append('vpm')
|
| 528 |
+
lora_config = LoraConfig(
|
| 529 |
+
r=lora_args.lora_r,
|
| 530 |
+
lora_alpha=lora_args.lora_alpha,
|
| 531 |
+
target_modules=lora_args.lora_target_modules,
|
| 532 |
+
lora_dropout=lora_args.lora_dropout,
|
| 533 |
+
bias=lora_args.lora_bias,
|
| 534 |
+
layers_to_transform=lora_args.lora_layers_to_transform,
|
| 535 |
+
modules_to_save=modules_to_save,
|
| 536 |
+
)
|
| 537 |
+
if not hasattr(model, 'get_input_embeddings'):
|
| 538 |
+
def get_input_embeddings(self):
|
| 539 |
+
return self.llm.get_input_embeddings()
|
| 540 |
+
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
| 541 |
+
if lora_args.q_lora:
|
| 542 |
+
model = prepare_model_for_kbit_training(
|
| 543 |
+
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
| 544 |
+
)
|
| 545 |
+
model = get_peft_model(model, lora_config)
|
| 546 |
+
if training_args.gradient_checkpointing:
|
| 547 |
+
model.enable_input_require_grads()
|
| 548 |
+
|
| 549 |
+
rank0_print(get_parameter_number(model))
|
| 550 |
+
|
| 551 |
+
llm_type = training_args.llm_type
|
| 552 |
+
|
| 553 |
+
rank0_print(f'llm_type={llm_type}')
|
| 554 |
+
|
| 555 |
+
# Load data
|
| 556 |
+
if hasattr(model.config, "slice_config"):
|
| 557 |
+
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
|
| 558 |
+
slice_config = model.config.slice_config.to_dict()
|
| 559 |
+
else:
|
| 560 |
+
model.config.max_slice_nums = training_args.max_slice_nums
|
| 561 |
+
slice_config = model.config.to_dict()
|
| 562 |
+
|
| 563 |
+
if hasattr(model.config, "batch_vision_input"):
|
| 564 |
+
batch_vision = model.config.batch_vision_input
|
| 565 |
+
else:
|
| 566 |
+
batch_vision = False
|
| 567 |
+
|
| 568 |
+
transform_func = build_transform()
|
| 569 |
+
|
| 570 |
+
if model_args.tokenizer_path is not None:
|
| 571 |
+
processor = AutoProcessor.from_pretrained(model_args.tokenizer_path, trust_remote_code=True)
|
| 572 |
+
else:
|
| 573 |
+
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
| 574 |
+
processor.tokenizer = tokenizer
|
| 575 |
+
|
| 576 |
+
raw_datasets = load_dataset(
|
| 577 |
+
"json",
|
| 578 |
+
data_files={
|
| 579 |
+
"train": data_args.data_path,
|
| 580 |
+
"validation": data_args.eval_data_path,
|
| 581 |
+
},
|
| 582 |
+
cache_dir=training_args.cache_dir,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
train_ds = raw_datasets["train"]
|
| 586 |
+
if data_args.max_train_samples is not None:
|
| 587 |
+
train_ds = train_ds.select(range(data_args.max_train_samples))
|
| 588 |
+
eval_ds = raw_datasets["validation"]
|
| 589 |
+
if data_args.max_eval_samples is not None:
|
| 590 |
+
eval_ds = eval_ds.select(range(data_args.max_eval_samples))
|
| 591 |
+
|
| 592 |
+
# data_module = make_supervised_data_module(
|
| 593 |
+
# tokenizer=tokenizer,
|
| 594 |
+
# processor=processor,
|
| 595 |
+
# data_args=data_args,
|
| 596 |
+
# transform=transform_func,
|
| 597 |
+
# data_collator=data_collator,
|
| 598 |
+
# slice_config=slice_config,
|
| 599 |
+
# llm_type=llm_type,
|
| 600 |
+
# patch_size=model.config.patch_size,
|
| 601 |
+
# query_nums=model.config.query_num,
|
| 602 |
+
# batch_vision=batch_vision,
|
| 603 |
+
# max_length=training_args.model_max_length,
|
| 604 |
+
# )
|
| 605 |
+
|
| 606 |
+
init_prompt = None
|
| 607 |
+
if not training_args.tune_llm and training_args.tune_speech: # asr finetuning
|
| 608 |
+
init_prompt = "Please transcribe this audio into text."
|
| 609 |
+
|
| 610 |
+
transform = None
|
| 611 |
+
if data_args.augment_prob != 0.0 and data_args.augment_path is not None:
|
| 612 |
+
with open(data_args.augment_path, "r") as f:
|
| 613 |
+
augment_path_list = f.read().splitlines()
|
| 614 |
+
transform = AddBackgroundNoise(
|
| 615 |
+
sounds_path=augment_path_list,
|
| 616 |
+
min_snr_db=5.0,
|
| 617 |
+
max_snr_db=30.0,
|
| 618 |
+
noise_transform=PolarityInversion(),
|
| 619 |
+
p=data_args.augment_prob
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
custom_collate_fn = partial(collate_fn, processor = processor, chunk_input=training_args.chunk_input, max_len=training_args.model_max_length, prompt=init_prompt, transform=transform)
|
| 623 |
+
|
| 624 |
+
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
| 625 |
+
|
| 626 |
+
print("Training Layers:")
|
| 627 |
+
for name, param in model.named_parameters():
|
| 628 |
+
if param.requires_grad:
|
| 629 |
+
print(name, param.grad)
|
| 630 |
+
|
| 631 |
+
# trainer = CPMTrainer(
|
| 632 |
+
# model=model,
|
| 633 |
+
# tokenizer=tokenizer,
|
| 634 |
+
# args=training_args,
|
| 635 |
+
# **data_module,
|
| 636 |
+
# )
|
| 637 |
+
trainer = Trainer(
|
| 638 |
+
model=model,
|
| 639 |
+
tokenizer=tokenizer,
|
| 640 |
+
args=training_args,
|
| 641 |
+
train_dataset=train_ds,
|
| 642 |
+
eval_dataset=eval_ds,
|
| 643 |
+
data_collator=custom_collate_fn
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 647 |
+
trainer.train(resume_from_checkpoint=True)
|
| 648 |
+
else:
|
| 649 |
+
trainer.train()
|
| 650 |
+
|
| 651 |
+
trainer.save_state()
|
| 652 |
+
|
| 653 |
+
safe_save_model_for_hf_trainer(
|
| 654 |
+
trainer=trainer,
|
| 655 |
+
output_dir=training_args.output_dir,
|
| 656 |
+
bias=lora_args.lora_bias)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
if __name__ == "__main__":
|
| 660 |
+
train()
|
omni_speech/train/train_minicpmo_test.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Dict, List, Optional, Union, Literal, Tuple
|
| 8 |
+
from types import MethodType
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import transformers
|
| 14 |
+
from accelerate.utils import DistributedType
|
| 15 |
+
from deepspeed import zero
|
| 16 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 17 |
+
import pathlib
|
| 18 |
+
|
| 19 |
+
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
| 20 |
+
from transformers.integrations import deepspeed
|
| 21 |
+
|
| 22 |
+
from omni_speech.datasets.dataset import SupervisedDataset, data_collator
|
| 23 |
+
from omni_speech.model import *
|
| 24 |
+
from trainer import CPMTrainer
|
| 25 |
+
from transformers import Trainer
|
| 26 |
+
import librosa
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
import numpy as np
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from functools import partial
|
| 31 |
+
from audiomentations import AddBackgroundNoise, PolarityInversion
|
| 32 |
+
|
| 33 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ModelArguments:
|
| 37 |
+
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-o-2_6")
|
| 38 |
+
tokenizer_path: Optional[str] = field(default=None)
|
| 39 |
+
audio_encoder_path: Optional[str] = field(default=None)
|
| 40 |
+
pretrained_llm_path: Optional[str] = field(default=None)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class DataArguments:
|
| 45 |
+
data_path: str = field(
|
| 46 |
+
default=None, metadata={"help": "Path to the training data."}
|
| 47 |
+
)
|
| 48 |
+
eval_data_path: str = field(
|
| 49 |
+
default=None, metadata={"help": "Path to the evaluation data."}
|
| 50 |
+
)
|
| 51 |
+
max_train_samples: Optional[int] = field(
|
| 52 |
+
default=None,
|
| 53 |
+
metadata={
|
| 54 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 55 |
+
"value if set."
|
| 56 |
+
},
|
| 57 |
+
)
|
| 58 |
+
max_eval_samples: Optional[int] = field(
|
| 59 |
+
default=None,
|
| 60 |
+
metadata={
|
| 61 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 62 |
+
"value if set."
|
| 63 |
+
},
|
| 64 |
+
)
|
| 65 |
+
augment_prob: float = field(
|
| 66 |
+
default=0.0,
|
| 67 |
+
metadata={"help": "The probability of applying augmentation transform."}
|
| 68 |
+
)
|
| 69 |
+
augment_path: str = field(default=None,
|
| 70 |
+
metadata={"help": "Path to the augment data."})
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 75 |
+
cache_dir: Optional[str] = field(default=None)
|
| 76 |
+
optim: str = field(default="adamw_torch")
|
| 77 |
+
model_max_length: int = field(
|
| 78 |
+
default=2048,
|
| 79 |
+
metadata={
|
| 80 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
tune_vision: Optional[bool] = field(default=True)
|
| 84 |
+
tune_speech: Optional[bool] = field(default=True)
|
| 85 |
+
tune_llm: Optional[bool] = field(default=True)
|
| 86 |
+
llm_type: str = field(default="qwen")
|
| 87 |
+
use_lora: Optional[bool] = field(default=False)
|
| 88 |
+
max_slice_nums: Optional[int] = field(default=9)
|
| 89 |
+
config_path: Optional[str] = field(default=None)
|
| 90 |
+
chunk_input: Optional[bool] = field(default=True)
|
| 91 |
+
init_vision: Optional[bool] = field(default=False)
|
| 92 |
+
init_speech: Optional[bool] = field(default=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class LoraArguments:
|
| 97 |
+
lora_r: int = 64
|
| 98 |
+
lora_alpha: int = 64
|
| 99 |
+
lora_dropout: float = 0.05
|
| 100 |
+
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
|
| 101 |
+
lora_weight_path: str = ""
|
| 102 |
+
lora_bias: str = "none"
|
| 103 |
+
q_lora: bool = False
|
| 104 |
+
lora_modules_to_save: str = ""
|
| 105 |
+
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
|
| 106 |
+
lora_layers_to_transform: Optional[List[int]] = None
|
| 107 |
+
lora_layers_pattern: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
local_rank = None
|
| 110 |
+
def rank0_print(*args):
|
| 111 |
+
if local_rank == 0:
|
| 112 |
+
print(*args)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def print_trainable_parameters_by_module(model):
|
| 116 |
+
"""
|
| 117 |
+
In ra chi tiết các tham số trainable theo module và số lượng tham số
|
| 118 |
+
"""
|
| 119 |
+
print("\n" + "="*50)
|
| 120 |
+
print("TRAINABLE PARAMETERS BY MODULE")
|
| 121 |
+
print("="*50)
|
| 122 |
+
|
| 123 |
+
# Lưu trữ tham số theo module cấp 2
|
| 124 |
+
module_params = {}
|
| 125 |
+
all_params = 0
|
| 126 |
+
trainable_params = 0
|
| 127 |
+
|
| 128 |
+
for name, param in model.named_parameters():
|
| 129 |
+
all_params += param.numel()
|
| 130 |
+
|
| 131 |
+
# Lấy module cấp 2
|
| 132 |
+
parts = name.split('.')
|
| 133 |
+
if len(parts) >= 2:
|
| 134 |
+
module_name = f"{parts[0]}.{parts[1]}"
|
| 135 |
+
else:
|
| 136 |
+
module_name = parts[0]
|
| 137 |
+
|
| 138 |
+
if param.requires_grad:
|
| 139 |
+
trainable_params += param.numel()
|
| 140 |
+
|
| 141 |
+
if module_name not in module_params:
|
| 142 |
+
module_params[module_name] = {
|
| 143 |
+
'count': 0,
|
| 144 |
+
'names': []
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
module_params[module_name]['count'] += param.numel()
|
| 148 |
+
module_params[module_name]['names'].append(name)
|
| 149 |
+
|
| 150 |
+
# Sắp xếp và in kết quả
|
| 151 |
+
sorted_modules = sorted(module_params.items(), key=lambda x: x[1]['count'], reverse=True)
|
| 152 |
+
|
| 153 |
+
for module_name, info in sorted_modules:
|
| 154 |
+
param_count = info['count']
|
| 155 |
+
percentage = 100 * param_count / trainable_params
|
| 156 |
+
print(f"{module_name:<30} {param_count:,} params ({percentage:.2f}%)")
|
| 157 |
+
|
| 158 |
+
# In ra 3 tham số đầu tiên của module này
|
| 159 |
+
for i, param_name in enumerate(info['names'][:3]):
|
| 160 |
+
print(f" - {param_name}")
|
| 161 |
+
|
| 162 |
+
if len(info['names']) > 3:
|
| 163 |
+
print(f" ... and {len(info['names']) - 3} more parameters")
|
| 164 |
+
|
| 165 |
+
print("\n" + "-"*50)
|
| 166 |
+
print(f"Total trainable parameters: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
|
| 167 |
+
print("="*50 + "\n")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
|
| 171 |
+
"""Collects the state dict and dump to disk."""
|
| 172 |
+
if trainer.args.should_save and trainer.args.local_rank == 0:
|
| 173 |
+
trainer.save_model(output_dir,)
|
| 174 |
+
|
| 175 |
+
# class CollateFn:
|
| 176 |
+
# def __init__(self, processor, prompt="Please transcribe this audio into text.", system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."):
|
| 177 |
+
# self.prompt = prompt
|
| 178 |
+
# self.system_prompt = system_prompt
|
| 179 |
+
# self.processor = processor
|
| 180 |
+
|
| 181 |
+
# def __call__(self, examples):
|
| 182 |
+
# prompts_lists = []
|
| 183 |
+
# input_images_list = []
|
| 184 |
+
# input_audios_list = []
|
| 185 |
+
# audio_parts_list = []
|
| 186 |
+
|
| 187 |
+
# for msgs in examples:
|
| 188 |
+
# msgs = msgs["conversations"]
|
| 189 |
+
# if isinstance(msgs, str):
|
| 190 |
+
# msgs = json.loads(msgs)
|
| 191 |
+
# copy_msgs = deepcopy(msgs)
|
| 192 |
+
|
| 193 |
+
# assert len(msgs) > 0, "msgs is empty"
|
| 194 |
+
|
| 195 |
+
# system_turn = {'role': 'system', 'content': self.system_prompt}
|
| 196 |
+
# if copy_msgs[0]["role"] != 'system':
|
| 197 |
+
# copy_msgs.insert(0, system_turn)
|
| 198 |
+
|
| 199 |
+
# images = []
|
| 200 |
+
# audios = []
|
| 201 |
+
# audio_parts = []
|
| 202 |
+
# for i, msg in enumerate(copy_msgs):
|
| 203 |
+
# role = msg["role"]
|
| 204 |
+
# content = msg["content"]
|
| 205 |
+
# assert role in ["system", "user", "assistant"]
|
| 206 |
+
# if i == 0:
|
| 207 |
+
# assert role in ["user", "system"], "The role of first msg should be user"
|
| 208 |
+
# content = [content, self.prompt]
|
| 209 |
+
# cur_msgs = []
|
| 210 |
+
|
| 211 |
+
# for c in content:
|
| 212 |
+
# if os.path.exists(c):
|
| 213 |
+
# c, _ = librosa.load(c, sr=16000, mono=True)
|
| 214 |
+
|
| 215 |
+
# if isinstance(c, Image.Image):
|
| 216 |
+
# images.append(c)
|
| 217 |
+
# cur_msgs.append("(<image>./</image>)")
|
| 218 |
+
# elif isinstance(c, np.ndarray): # audio
|
| 219 |
+
# audios.append(c)
|
| 220 |
+
# audio_parts.append(i)
|
| 221 |
+
# cur_msgs.append("(<audio>./</audio>)")
|
| 222 |
+
# elif isinstance(c, str):
|
| 223 |
+
# cur_msgs.append(c)
|
| 224 |
+
# else:
|
| 225 |
+
# msg["content"] = "\n".join(cur_msgs)
|
| 226 |
+
|
| 227 |
+
# prompts_lists.append(
|
| 228 |
+
# self.processor.tokenizer.apply_chat_template(
|
| 229 |
+
# copy_msgs,
|
| 230 |
+
# tokenize=False,
|
| 231 |
+
# add_generation_prompt=False,
|
| 232 |
+
# )
|
| 233 |
+
# )
|
| 234 |
+
# input_images_list.append(images)
|
| 235 |
+
# input_audios_list.append(audios)
|
| 236 |
+
# audio_parts_list.append(audio_parts)
|
| 237 |
+
|
| 238 |
+
# inputs = self.processor(
|
| 239 |
+
# prompts_lists,
|
| 240 |
+
# input_images_list,
|
| 241 |
+
# input_audios_list,
|
| 242 |
+
# audio_parts_list,
|
| 243 |
+
# return_tensors="pt",
|
| 244 |
+
# max_length=32768,
|
| 245 |
+
# return_labels=True,
|
| 246 |
+
# )
|
| 247 |
+
|
| 248 |
+
# return inputs
|
| 249 |
+
|
| 250 |
+
def collate_fn(examples, processor, chunk_input, max_len, prompt=None, system_prompt="You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", transform=None):
|
| 251 |
+
|
| 252 |
+
prompts_lists = []
|
| 253 |
+
input_images_list = []
|
| 254 |
+
input_audios_list = []
|
| 255 |
+
audio_parts_list = []
|
| 256 |
+
|
| 257 |
+
for msgs in examples:
|
| 258 |
+
raw_msgs = deepcopy(msgs)
|
| 259 |
+
msgs = msgs["conversations"]
|
| 260 |
+
if isinstance(msgs, str):
|
| 261 |
+
msgs = json.loads(msgs)
|
| 262 |
+
copy_msgs = deepcopy(msgs)
|
| 263 |
+
|
| 264 |
+
assert len(msgs) > 0, "msgs is empty"
|
| 265 |
+
|
| 266 |
+
system_turn = {'role': 'system', 'content': system_prompt}
|
| 267 |
+
if copy_msgs[0]["role"] != 'system':
|
| 268 |
+
copy_msgs.insert(0, system_turn)
|
| 269 |
+
|
| 270 |
+
fc = None
|
| 271 |
+
if "tools" in raw_msgs:
|
| 272 |
+
# if raw_msgs["tools"] != "":
|
| 273 |
+
# json_objects = raw_msgs["tools"].split("\n\n")
|
| 274 |
+
# try:
|
| 275 |
+
# fc = [json.loads(obj) for obj in json_objects]
|
| 276 |
+
# except:
|
| 277 |
+
# if len(json_objects) > 1:
|
| 278 |
+
# json_objects = json_objects[:-1]
|
| 279 |
+
# fc = [json.loads(obj) for obj in json_objects]
|
| 280 |
+
if raw_msgs["tools"] != "":
|
| 281 |
+
fc = json.loads(raw_msgs["tools"])
|
| 282 |
+
|
| 283 |
+
# print(fc)
|
| 284 |
+
# print("-----------")
|
| 285 |
+
|
| 286 |
+
images = []
|
| 287 |
+
audios = []
|
| 288 |
+
audio_parts = []
|
| 289 |
+
for i, msg in enumerate(copy_msgs):
|
| 290 |
+
role = msg["role"]
|
| 291 |
+
content = msg["content"]
|
| 292 |
+
assert role in ["system", "user", "assistant", "tool"]
|
| 293 |
+
if i == 0:
|
| 294 |
+
assert role in ["user", "system"], "The role of first msg should be user or system"
|
| 295 |
+
|
| 296 |
+
if role == "user":
|
| 297 |
+
if prompt is not None:
|
| 298 |
+
content = [content, prompt]
|
| 299 |
+
else:
|
| 300 |
+
content = [content]
|
| 301 |
+
cur_msgs = []
|
| 302 |
+
for c in content:
|
| 303 |
+
if os.path.exists(c):
|
| 304 |
+
c, _ = librosa.load(c, sr=16000, mono=True)
|
| 305 |
+
if transform is not None:
|
| 306 |
+
c = transform(c, sample_rate=16000)
|
| 307 |
+
|
| 308 |
+
if isinstance(c, Image.Image):
|
| 309 |
+
images.append(c)
|
| 310 |
+
cur_msgs.append("(<image>./</image>)")
|
| 311 |
+
elif isinstance(c, np.ndarray): # audio
|
| 312 |
+
audios.append(c)
|
| 313 |
+
audio_parts.append(i)
|
| 314 |
+
cur_msgs.append("(<audio>./</audio>)")
|
| 315 |
+
elif isinstance(c, str):
|
| 316 |
+
cur_msgs.append(c)
|
| 317 |
+
|
| 318 |
+
msg["content"] = "\n".join(cur_msgs)
|
| 319 |
+
|
| 320 |
+
if "tool_calls" in msg:
|
| 321 |
+
if msg["tool_calls"] is not None:
|
| 322 |
+
assert isinstance(msg["tool_calls"], str), f"Invalid type: {type(msg['tool_calls'])}"
|
| 323 |
+
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
| 324 |
+
if type(msg["tool_calls"]) != list:
|
| 325 |
+
msg["tool_calls"] = [msg["tool_calls"]]
|
| 326 |
+
|
| 327 |
+
# print(copy_msgs)
|
| 328 |
+
# print("--------")
|
| 329 |
+
|
| 330 |
+
qwen_template = processor.tokenizer.apply_chat_template(
|
| 331 |
+
copy_msgs,
|
| 332 |
+
tokenize=False,
|
| 333 |
+
add_generation_prompt=False,
|
| 334 |
+
tools = fc,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# print(qwen_template)
|
| 338 |
+
# print("---------------")
|
| 339 |
+
|
| 340 |
+
prompts_lists.append(qwen_template)
|
| 341 |
+
input_images_list.append(images)
|
| 342 |
+
input_audios_list.append(audios)
|
| 343 |
+
audio_parts_list.append(audio_parts)
|
| 344 |
+
|
| 345 |
+
inputs = processor(
|
| 346 |
+
prompts_lists,
|
| 347 |
+
input_images_list,
|
| 348 |
+
input_audios_list,
|
| 349 |
+
audio_parts_list,
|
| 350 |
+
chunk_input=chunk_input,
|
| 351 |
+
return_tensors="pt",
|
| 352 |
+
# max_length=max_len,
|
| 353 |
+
return_labels=True,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return inputs
|
| 357 |
+
|
| 358 |
+
def make_supervised_data_module(
|
| 359 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 360 |
+
processor: transformers.ProcessorMixin,
|
| 361 |
+
data_args,
|
| 362 |
+
transform,
|
| 363 |
+
data_collator=None,
|
| 364 |
+
llm_type="qwen",
|
| 365 |
+
slice_config=None,
|
| 366 |
+
patch_size=14,
|
| 367 |
+
query_nums=64,
|
| 368 |
+
batch_vision=False,
|
| 369 |
+
max_length=2048,
|
| 370 |
+
) -> Dict:
|
| 371 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 372 |
+
dataset_cls = SupervisedDataset
|
| 373 |
+
|
| 374 |
+
rank0_print("Loading data...")
|
| 375 |
+
|
| 376 |
+
train_json = json.load(open(data_args.data_path, "r"))
|
| 377 |
+
train_dataset = dataset_cls(
|
| 378 |
+
train_json,
|
| 379 |
+
transform,
|
| 380 |
+
tokenizer,
|
| 381 |
+
processor,
|
| 382 |
+
slice_config=slice_config,
|
| 383 |
+
llm_type=llm_type,
|
| 384 |
+
patch_size=patch_size,
|
| 385 |
+
query_nums=query_nums,
|
| 386 |
+
batch_vision=batch_vision,
|
| 387 |
+
max_length=max_length,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if data_args.eval_data_path:
|
| 391 |
+
eval_json = json.load(open(data_args.eval_data_path, "r"))
|
| 392 |
+
eval_dataset = dataset_cls(
|
| 393 |
+
eval_json,
|
| 394 |
+
transform,
|
| 395 |
+
tokenizer,
|
| 396 |
+
processor,
|
| 397 |
+
slice_config=slice_config,
|
| 398 |
+
llm_type=llm_type,
|
| 399 |
+
patch_size=patch_size,
|
| 400 |
+
query_nums=query_nums,
|
| 401 |
+
batch_vision=batch_vision,
|
| 402 |
+
max_length=max_length,
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
eval_dataset = None
|
| 406 |
+
|
| 407 |
+
return dict(
|
| 408 |
+
train_dataset=train_dataset,
|
| 409 |
+
eval_dataset=eval_dataset,
|
| 410 |
+
data_collator= partial(data_collator, max_length=max_length),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def build_transform():
|
| 415 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
| 416 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
| 417 |
+
return transforms.Compose(
|
| 418 |
+
[
|
| 419 |
+
transforms.ToTensor(),
|
| 420 |
+
transforms.Normalize(
|
| 421 |
+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
|
| 422 |
+
),
|
| 423 |
+
]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
def get_parameter_number(model):
|
| 427 |
+
trainable_params, all_param = 0, 0
|
| 428 |
+
for param in model.parameters():
|
| 429 |
+
num_params = param.numel()
|
| 430 |
+
# if using DS Zero 3 and the weights are initialized empty
|
| 431 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
| 432 |
+
num_params = param.ds_numel
|
| 433 |
+
|
| 434 |
+
all_param += num_params
|
| 435 |
+
if param.requires_grad:
|
| 436 |
+
trainable_params += num_params
|
| 437 |
+
|
| 438 |
+
return {'Total': all_param, 'Trainable': trainable_params}
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
local_rank = 0
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def train(attn_implementation="flash_attention_2"):
|
| 445 |
+
global local_rank
|
| 446 |
+
parser = transformers.HfArgumentParser(
|
| 447 |
+
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
(
|
| 451 |
+
model_args,
|
| 452 |
+
data_args,
|
| 453 |
+
training_args,
|
| 454 |
+
lora_args,
|
| 455 |
+
) = parser.parse_args_into_dataclasses()
|
| 456 |
+
|
| 457 |
+
if getattr(training_args, "deepspeed", None) :
|
| 458 |
+
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
| 459 |
+
|
| 460 |
+
compute_dtype = (
|
| 461 |
+
torch.float16
|
| 462 |
+
if training_args.fp16
|
| 463 |
+
else (torch.bfloat16 if training_args.bf16 else torch.float32)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
local_rank = training_args.local_rank
|
| 467 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 468 |
+
ddp = world_size != 1
|
| 469 |
+
device_map = None
|
| 470 |
+
if lora_args.q_lora:
|
| 471 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
| 472 |
+
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
| 473 |
+
logging.warning(
|
| 474 |
+
"FSDP or ZeRO3 are not incompatible with QLoRA."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
minipcmo_config = {}
|
| 478 |
+
if training_args.config_path is not None:
|
| 479 |
+
minipcmo_config = json.load(open(training_args.config_path, "r"))
|
| 480 |
+
|
| 481 |
+
# if model_args.tokenizer_path is not None:
|
| 482 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
| 483 |
+
# model_args.tokenizer_path, trust_remote_code=True
|
| 484 |
+
# )
|
| 485 |
+
# else:
|
| 486 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
| 487 |
+
# model_args.model_name_or_path, trust_remote_code=True
|
| 488 |
+
# )
|
| 489 |
+
|
| 490 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 491 |
+
model_args.model_name_or_path, trust_remote_code=True
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if model_args.pretrained_llm_path is None:
|
| 495 |
+
print("Finetuning model!!!")
|
| 496 |
+
model = MiniCPMO.from_pretrained(
|
| 497 |
+
model_args.model_name_or_path,
|
| 498 |
+
torch_dtype=compute_dtype,
|
| 499 |
+
device_map=device_map,
|
| 500 |
+
attn_implementation=attn_implementation,
|
| 501 |
+
init_vision=training_args.init_vision,
|
| 502 |
+
init_audio=training_args.init_speech,
|
| 503 |
+
init_tts=False,
|
| 504 |
+
processor_path=model_args.tokenizer_path,
|
| 505 |
+
**minipcmo_config
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
print("Load pretrained LLM from scratch!!!")
|
| 509 |
+
# # Create the config object as needed
|
| 510 |
+
# config = MiniCPMOConfig(
|
| 511 |
+
# model_name_or_path=model_args.model_name_or_path,
|
| 512 |
+
# pretrained_llm_path=model_args.pretrained_llm_path,
|
| 513 |
+
# init_vision=training_args.init_vision,
|
| 514 |
+
# init_audio=training_args.init_speech,
|
| 515 |
+
# pretrained_encoder_path=model_args.audio_encoder_path,
|
| 516 |
+
# processor_path=model_args.tokenizer_path,
|
| 517 |
+
# **minipcmo_config
|
| 518 |
+
# )
|
| 519 |
+
|
| 520 |
+
# # Initialize model
|
| 521 |
+
# model = MiniCPMO(config)
|
| 522 |
+
|
| 523 |
+
model = MiniCPMO.from_pretrained(
|
| 524 |
+
model_args.model_name_or_path,
|
| 525 |
+
pretrained_llm_path=model_args.pretrained_llm_path,
|
| 526 |
+
init_vision=training_args.init_vision,
|
| 527 |
+
init_audio=training_args.init_speech,
|
| 528 |
+
init_tts=False,
|
| 529 |
+
pretrained_encoder_path=model_args.audio_encoder_path,
|
| 530 |
+
processor_path=model_args.tokenizer_path,
|
| 531 |
+
**minipcmo_config
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# tokenizer.audio_start_id = tokenizer.convert_tokens_to_ids("<|box_start|>")
|
| 535 |
+
# tokenizer.audio_end_id = tokenizer.convert_tokens_to_ids("<|box_end|>")
|
| 536 |
+
# tokenizer.audio_start = "<|box_start|>"
|
| 537 |
+
# tokenizer.audio_end = "<|box_end|>"
|
| 538 |
+
# tokenizer.im_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
| 539 |
+
# tokenizer.im_end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>")
|
| 540 |
+
# tokenizer.im_start = "<|vision_start|>"
|
| 541 |
+
# tokenizer.im_end = "<|vision_end|>"
|
| 542 |
+
# tokenizer.slice_start_id = tokenizer.convert_tokens_to_ids("<|quad_start|>")
|
| 543 |
+
# tokenizer.slice_end_id = tokenizer.convert_tokens_to_ids("<|quad_end|>")
|
| 544 |
+
# tokenizer.slice_start = "<|quad_start|>"
|
| 545 |
+
# tokenizer.slice_end = "<|quad_end|>"
|
| 546 |
+
# tokenizer.unk_token = "<unk>"
|
| 547 |
+
|
| 548 |
+
# print("Audio Start Token:", tokenizer.audio_start)
|
| 549 |
+
# print("Audio End Token:", tokenizer.audio_end)
|
| 550 |
+
# print(tokenizer.audio_start_id)
|
| 551 |
+
# print(tokenizer.audio_end_id)
|
| 552 |
+
# print("Start Token:", tokenizer.im_start)
|
| 553 |
+
# print("End Token:", tokenizer.im_end)
|
| 554 |
+
# print(tokenizer.im_start_id)
|
| 555 |
+
# print(tokenizer.im_end_id)
|
| 556 |
+
# print("Slice Start Token:", tokenizer.slice_start)
|
| 557 |
+
# print("Slice End Token:", tokenizer.slice_end)
|
| 558 |
+
# print(tokenizer.slice_start_id)
|
| 559 |
+
# print(tokenizer.slice_end_id)
|
| 560 |
+
|
| 561 |
+
model.config.chunk_input = training_args.chunk_input
|
| 562 |
+
# model.llm.resize_token_embeddings(len(tokenizer))
|
| 563 |
+
# model.resize_token_embeddings(len(tokenizer))
|
| 564 |
+
|
| 565 |
+
model.llm.config.use_cache = False
|
| 566 |
+
model.config.max_length = training_args.model_max_length
|
| 567 |
+
|
| 568 |
+
# if not training_args.tune_vision and training_args.init_vision:
|
| 569 |
+
# model.vpm.requires_grad_(False)
|
| 570 |
+
# if not training_args.tune_speech and training_args.init_speech:
|
| 571 |
+
# model.apm.requires_grad_(False)
|
| 572 |
+
# if not training_args.tune_llm:
|
| 573 |
+
# model.llm.requires_grad_(False)
|
| 574 |
+
model.requires_grad_(False)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if training_args.tune_llm:
|
| 578 |
+
model.llm.requires_grad_(True)
|
| 579 |
+
print("Enabled training for LLM")
|
| 580 |
+
model.audio_projection_layer.requires_grad_(True)
|
| 581 |
+
print("Enabled training for audio_projection_layer")
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
if training_args.use_lora:
|
| 585 |
+
if training_args.use_lora and training_args.tune_llm:
|
| 586 |
+
raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
|
| 587 |
+
|
| 588 |
+
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
|
| 589 |
+
for name, param in model.llm.named_parameters():
|
| 590 |
+
param.requires_grad = False
|
| 591 |
+
modules_to_save = ['embed_tokens','resampler']
|
| 592 |
+
if training_args.tune_vision:
|
| 593 |
+
modules_to_save.append('vpm')
|
| 594 |
+
lora_config = LoraConfig(
|
| 595 |
+
r=lora_args.lora_r,
|
| 596 |
+
lora_alpha=lora_args.lora_alpha,
|
| 597 |
+
target_modules=lora_args.lora_target_modules,
|
| 598 |
+
lora_dropout=lora_args.lora_dropout,
|
| 599 |
+
bias=lora_args.lora_bias,
|
| 600 |
+
layers_to_transform=lora_args.lora_layers_to_transform,
|
| 601 |
+
modules_to_save=modules_to_save,
|
| 602 |
+
)
|
| 603 |
+
if not hasattr(model, 'get_input_embeddings'):
|
| 604 |
+
def get_input_embeddings(self):
|
| 605 |
+
return self.llm.get_input_embeddings()
|
| 606 |
+
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
| 607 |
+
if lora_args.q_lora:
|
| 608 |
+
model = prepare_model_for_kbit_training(
|
| 609 |
+
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
| 610 |
+
)
|
| 611 |
+
model = get_peft_model(model, lora_config)
|
| 612 |
+
if training_args.gradient_checkpointing:
|
| 613 |
+
model.enable_input_require_grads()
|
| 614 |
+
|
| 615 |
+
rank0_print(get_parameter_number(model))
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
print_trainable_parameters_by_module(model)
|
| 619 |
+
|
| 620 |
+
llm_type = training_args.llm_type
|
| 621 |
+
|
| 622 |
+
rank0_print(f'llm_type={llm_type}')
|
| 623 |
+
|
| 624 |
+
# Load data
|
| 625 |
+
if hasattr(model.config, "slice_config"):
|
| 626 |
+
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
|
| 627 |
+
slice_config = model.config.slice_config.to_dict()
|
| 628 |
+
else:
|
| 629 |
+
model.config.max_slice_nums = training_args.max_slice_nums
|
| 630 |
+
slice_config = model.config.to_dict()
|
| 631 |
+
|
| 632 |
+
if hasattr(model.config, "batch_vision_input"):
|
| 633 |
+
batch_vision = model.config.batch_vision_input
|
| 634 |
+
else:
|
| 635 |
+
batch_vision = False
|
| 636 |
+
|
| 637 |
+
transform_func = build_transform()
|
| 638 |
+
|
| 639 |
+
if model_args.tokenizer_path is not None:
|
| 640 |
+
processor = AutoProcessor.from_pretrained(model_args.tokenizer_path, trust_remote_code=True)
|
| 641 |
+
else:
|
| 642 |
+
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
| 643 |
+
processor.tokenizer = tokenizer
|
| 644 |
+
|
| 645 |
+
raw_datasets = load_dataset(
|
| 646 |
+
"json",
|
| 647 |
+
data_files={
|
| 648 |
+
"train": data_args.data_path,
|
| 649 |
+
"validation": data_args.eval_data_path,
|
| 650 |
+
},
|
| 651 |
+
cache_dir=training_args.cache_dir,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
train_ds = raw_datasets["train"]
|
| 655 |
+
if data_args.max_train_samples is not None:
|
| 656 |
+
train_ds = train_ds.select(range(data_args.max_train_samples))
|
| 657 |
+
eval_ds = raw_datasets["validation"]
|
| 658 |
+
if data_args.max_eval_samples is not None:
|
| 659 |
+
eval_ds = eval_ds.select(range(data_args.max_eval_samples))
|
| 660 |
+
|
| 661 |
+
# data_module = make_supervised_data_module(
|
| 662 |
+
# tokenizer=tokenizer,
|
| 663 |
+
# processor=processor,
|
| 664 |
+
# data_args=data_args,
|
| 665 |
+
# transform=transform_func,
|
| 666 |
+
# data_collator=data_collator,
|
| 667 |
+
# slice_config=slice_config,
|
| 668 |
+
# llm_type=llm_type,
|
| 669 |
+
# patch_size=model.config.patch_size,
|
| 670 |
+
# query_nums=model.config.query_num,
|
| 671 |
+
# batch_vision=batch_vision,
|
| 672 |
+
# max_length=training_args.model_max_length,
|
| 673 |
+
# )
|
| 674 |
+
|
| 675 |
+
init_prompt = None
|
| 676 |
+
if not training_args.tune_llm and training_args.tune_speech: # asr finetuning
|
| 677 |
+
init_prompt = "Please transcribe this audio into text."
|
| 678 |
+
|
| 679 |
+
transform = None
|
| 680 |
+
if data_args.augment_prob != 0.0 and data_args.augment_path is not None:
|
| 681 |
+
with open(data_args.augment_path, "r") as f:
|
| 682 |
+
augment_path_list = f.read().splitlines()
|
| 683 |
+
transform = AddBackgroundNoise(
|
| 684 |
+
sounds_path=augment_path_list,
|
| 685 |
+
min_snr_db=5.0,
|
| 686 |
+
max_snr_db=30.0,
|
| 687 |
+
noise_transform=PolarityInversion(),
|
| 688 |
+
p=data_args.augment_prob
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
custom_collate_fn = partial(collate_fn, processor = processor, chunk_input=training_args.chunk_input, max_len=training_args.model_max_length, prompt=init_prompt, transform=transform)
|
| 692 |
+
|
| 693 |
+
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
| 694 |
+
|
| 695 |
+
# print("Training Layers:")
|
| 696 |
+
# for name, param in model.named_parameters():
|
| 697 |
+
# if param.requires_grad:
|
| 698 |
+
# print(name, param.grad)
|
| 699 |
+
|
| 700 |
+
# trainer = CPMTrainer(
|
| 701 |
+
# model=model,
|
| 702 |
+
# tokenizer=tokenizer,
|
| 703 |
+
# args=training_args,
|
| 704 |
+
# **data_module,
|
| 705 |
+
# )
|
| 706 |
+
trainer = Trainer(
|
| 707 |
+
model=model,
|
| 708 |
+
tokenizer=tokenizer,
|
| 709 |
+
args=training_args,
|
| 710 |
+
train_dataset=train_ds,
|
| 711 |
+
eval_dataset=eval_ds,
|
| 712 |
+
data_collator=custom_collate_fn
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 716 |
+
trainer.train(resume_from_checkpoint=True)
|
| 717 |
+
else:
|
| 718 |
+
trainer.train()
|
| 719 |
+
|
| 720 |
+
trainer.save_state()
|
| 721 |
+
|
| 722 |
+
safe_save_model_for_hf_trainer(
|
| 723 |
+
trainer=trainer,
|
| 724 |
+
output_dir=training_args.output_dir,
|
| 725 |
+
bias=lora_args.lora_bias)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
if __name__ == "__main__":
|
| 729 |
+
train()
|
omni_speech/train/train_multiturn.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 3 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import copy
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import pathlib
|
| 23 |
+
from typing import Dict, Optional, Sequence, List
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
import transformers
|
| 28 |
+
import tokenizers
|
| 29 |
+
|
| 30 |
+
from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
|
| 31 |
+
from torch.utils.data import Dataset
|
| 32 |
+
from omni_speech.train.omni_trainer import OmniTrainer
|
| 33 |
+
from audiomentations import AddBackgroundNoise, PolarityInversion
|
| 34 |
+
|
| 35 |
+
from omni_speech import conversation as conversation_lib
|
| 36 |
+
from omni_speech.model import *
|
| 37 |
+
from omni_speech.utils import *
|
| 38 |
+
from omni_speech.datasets.preprocess import *
|
| 39 |
+
import whisper
|
| 40 |
+
import time
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ModelArguments:
|
| 44 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 45 |
+
version: Optional[str] = field(default="llama_3")
|
| 46 |
+
freeze_backbone: bool = field(default=False)
|
| 47 |
+
tune_speech_projector: bool = field(default=False)
|
| 48 |
+
tune_speech_encoder: bool = field(default=False)
|
| 49 |
+
tune_speech_generator_only: bool = field(default=False)
|
| 50 |
+
speech_encoder_type: Optional[str] = field(default=None)
|
| 51 |
+
speech_encoder: Optional[str] = field(default=None)
|
| 52 |
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
| 53 |
+
speech_projector_type: Optional[str] = field(default='linear')
|
| 54 |
+
speech_generator_type: Optional[str] = field(default='ctc')
|
| 55 |
+
# ctc_decoder_config: str = "(2,4096,32,11008)" # num layers, hidden sizes, attn heads, ff dimensions of LLaMA
|
| 56 |
+
ctc_decoder_config: str = "(2,4096,32,22016)"
|
| 57 |
+
ctc_upsample_factor: int = 25
|
| 58 |
+
ctc_loss_weight: float = 1.0
|
| 59 |
+
unit_vocab_size: int = 1000
|
| 60 |
+
speech_encoder_ds_rate: int = 5
|
| 61 |
+
speech_encoder_hidden_size: int = 1280
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class DataArguments:
|
| 66 |
+
data_path: str = field(default=None,
|
| 67 |
+
metadata={"help": "Path to the training data."})
|
| 68 |
+
dev_path: str = field(default=None,
|
| 69 |
+
metadata={"help": "Path to the dev data."})
|
| 70 |
+
is_multimodal: bool = False
|
| 71 |
+
input_type: str = field(default="mel")
|
| 72 |
+
speech_normalize: bool = False
|
| 73 |
+
mel_size: int = 128
|
| 74 |
+
has_tgt_units: bool = False
|
| 75 |
+
augment_prob: float = field(
|
| 76 |
+
default=0.0,
|
| 77 |
+
metadata={"help": "The probability of applying augmentation transform."}
|
| 78 |
+
)
|
| 79 |
+
augment_path: str = field(default=None,
|
| 80 |
+
metadata={"help": "Path to the augment data."})
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 85 |
+
cache_dir: Optional[str] = field(default=None)
|
| 86 |
+
optim: str = field(default="adamw_torch")
|
| 87 |
+
freeze_speech_projector: bool = field(default=False)
|
| 88 |
+
model_max_length: int = field(
|
| 89 |
+
default=512,
|
| 90 |
+
metadata={
|
| 91 |
+
"help":
|
| 92 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 93 |
+
},
|
| 94 |
+
)
|
| 95 |
+
double_quant: bool = field(
|
| 96 |
+
default=True,
|
| 97 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 98 |
+
)
|
| 99 |
+
quant_type: str = field(
|
| 100 |
+
default="nf4",
|
| 101 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 102 |
+
)
|
| 103 |
+
bits: int = field(
|
| 104 |
+
default=16,
|
| 105 |
+
metadata={"help": "How many bits to use."}
|
| 106 |
+
)
|
| 107 |
+
lora_enable: bool = False
|
| 108 |
+
lora_r: int = 64
|
| 109 |
+
lora_alpha: int = 16
|
| 110 |
+
lora_dropout: float = 0.05
|
| 111 |
+
lora_weight_path: str = ""
|
| 112 |
+
lora_bias: str = "none"
|
| 113 |
+
speech_projector_lr: Optional[float] = None
|
| 114 |
+
group_by_modality_length: bool = field(default=False)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LazySupervisedDataset(Dataset):
|
| 118 |
+
"""Dataset for supervised fine-tuning."""
|
| 119 |
+
|
| 120 |
+
def __init__(self, data_path: str,
|
| 121 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 122 |
+
data_args: DataArguments):
|
| 123 |
+
super(LazySupervisedDataset, self).__init__()
|
| 124 |
+
list_data_dict = json.load(open(data_path, "r"))
|
| 125 |
+
|
| 126 |
+
self.tokenizer = tokenizer
|
| 127 |
+
self.list_data_dict = list_data_dict
|
| 128 |
+
self.data_args = data_args
|
| 129 |
+
if self.data_args.augment_prob != 0.0:
|
| 130 |
+
with open(self.data_args.augment_path, "r") as f:
|
| 131 |
+
augment_path_list = f.read().splitlines()
|
| 132 |
+
self.transform = AddBackgroundNoise(
|
| 133 |
+
sounds_path=augment_path_list,
|
| 134 |
+
min_snr_db=5.0,
|
| 135 |
+
max_snr_db=30.0,
|
| 136 |
+
noise_transform=PolarityInversion(),
|
| 137 |
+
p=self.data_args.augment_prob
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def __len__(self):
|
| 141 |
+
return len(self.list_data_dict)
|
| 142 |
+
|
| 143 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 144 |
+
# TODO: define number of retries somewhere else
|
| 145 |
+
num_base_retries = 3
|
| 146 |
+
num_final_retries = 300
|
| 147 |
+
|
| 148 |
+
# try the current sample first
|
| 149 |
+
for attempt_idx in range(num_base_retries):
|
| 150 |
+
try:
|
| 151 |
+
sample = self._get_item(i)
|
| 152 |
+
return sample
|
| 153 |
+
except Exception as e:
|
| 154 |
+
# sleep 1s in case it is a cloud disk issue
|
| 155 |
+
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
|
| 156 |
+
time.sleep(1)
|
| 157 |
+
|
| 158 |
+
# try other samples, in case it is file corruption issue
|
| 159 |
+
for attempt_idx in range(num_base_retries):
|
| 160 |
+
try:
|
| 161 |
+
next_index = min(i + 1, len(self.list_data_dict) - 1)
|
| 162 |
+
# sample_idx = random.choice(range(len(self)))
|
| 163 |
+
sample = self._get_item(next_index)
|
| 164 |
+
return sample
|
| 165 |
+
except Exception as e:
|
| 166 |
+
# no need to sleep
|
| 167 |
+
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
sample = self._get_item(i)
|
| 172 |
+
return sample
|
| 173 |
+
except Exception as e:
|
| 174 |
+
raise e
|
| 175 |
+
|
| 176 |
+
def process_speech(self, speech_file):
|
| 177 |
+
speech = whisper.load_audio(speech_file)
|
| 178 |
+
if self.data_args.augment_prob != 0.0:
|
| 179 |
+
speech = self.transform(speech, sample_rate=16000)
|
| 180 |
+
if self.data_args.input_type == "raw":
|
| 181 |
+
speech = torch.from_numpy(speech)
|
| 182 |
+
if self.model_config.data_args.speech_normalize:
|
| 183 |
+
speech = torch.nn.functional.layer_norm(speech, speech.shape)
|
| 184 |
+
elif self.data_args.input_type == "mel":
|
| 185 |
+
speech = whisper.pad_or_trim(speech)
|
| 186 |
+
speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
|
| 187 |
+
speech_lengths = torch.LongTensor([speech.shape[0]])
|
| 188 |
+
return speech, speech_lengths
|
| 189 |
+
|
| 190 |
+
def _get_item(self, i) -> Dict[str, torch.Tensor]:
|
| 191 |
+
sources = self.list_data_dict[i]
|
| 192 |
+
if isinstance(i, int):
|
| 193 |
+
sources = [sources]
|
| 194 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
| 195 |
+
for item in sources:
|
| 196 |
+
if 'tools' in item:
|
| 197 |
+
tools_dict = {
|
| 198 |
+
"from": "tools",
|
| 199 |
+
"value": item["tools"]
|
| 200 |
+
}
|
| 201 |
+
item["conversations"].insert(0,tools_dict)
|
| 202 |
+
|
| 203 |
+
if self.data_args.has_tgt_units:
|
| 204 |
+
# pad_list = [0]
|
| 205 |
+
# tgt_units = [e["tgt_units"] if "tgt_units" in e else pad_list for e in sources]
|
| 206 |
+
tgt_units = [e["tgt_units"] for e in sources]
|
| 207 |
+
tgt_units = torch.tensor(tgt_units, dtype=torch.long)
|
| 208 |
+
else:
|
| 209 |
+
tgt_units = None
|
| 210 |
+
|
| 211 |
+
if 'speech' in sources[0]:
|
| 212 |
+
import numpy as np
|
| 213 |
+
speech_file = self.list_data_dict[i]['speech']
|
| 214 |
+
if type(speech_file) is list:
|
| 215 |
+
speech = [self.process_speech(f) for f in speech_file]
|
| 216 |
+
else:
|
| 217 |
+
speech = [self.process_speech(speech_file)]
|
| 218 |
+
|
| 219 |
+
sources = preprocess_multimodal(
|
| 220 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 221 |
+
self.data_args)
|
| 222 |
+
else:
|
| 223 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 224 |
+
data_dict = preprocess(
|
| 225 |
+
sources,
|
| 226 |
+
self.tokenizer,
|
| 227 |
+
has_speech=('speech' in self.list_data_dict[i]))
|
| 228 |
+
if isinstance(i, int):
|
| 229 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
| 230 |
+
labels=data_dict["labels"][0])
|
| 231 |
+
|
| 232 |
+
# speech exist in the data
|
| 233 |
+
if 'speech' in self.list_data_dict[i]:
|
| 234 |
+
data_dict['speech'] = speech
|
| 235 |
+
|
| 236 |
+
if tgt_units is not None:
|
| 237 |
+
data_dict['tgt_units'] = tgt_units[0]
|
| 238 |
+
|
| 239 |
+
data_dict["id"] = self.list_data_dict[i].get("id", i)
|
| 240 |
+
|
| 241 |
+
return data_dict
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@dataclass
|
| 245 |
+
class DataCollatorForSupervisedDataset(object):
|
| 246 |
+
"""Collate examples for supervised fine-tuning."""
|
| 247 |
+
|
| 248 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 249 |
+
|
| 250 |
+
def pad_sequence(self, input_ids, batch_first, padding_value):
|
| 251 |
+
if self.tokenizer.padding_side == "left":
|
| 252 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
| 253 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
| 254 |
+
if self.tokenizer.padding_side == "left":
|
| 255 |
+
input_ids = torch.flip(input_ids, [1])
|
| 256 |
+
return input_ids
|
| 257 |
+
|
| 258 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 259 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 260 |
+
# input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
|
| 261 |
+
input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
|
| 262 |
+
labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
|
| 263 |
+
if self.tokenizer.pad_token_id is None:
|
| 264 |
+
# self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
|
| 265 |
+
self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
|
| 266 |
+
input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 267 |
+
labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 268 |
+
batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
|
| 269 |
+
# batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
|
| 270 |
+
|
| 271 |
+
if 'speech' in instances[0]:
|
| 272 |
+
speechs = [instance['speech'] for instance in instances]
|
| 273 |
+
|
| 274 |
+
speech = [sp[0] for sp_list in speechs for sp in sp_list]
|
| 275 |
+
speech_lengths = [sp[1] for sp_list in speechs for sp in sp_list]
|
| 276 |
+
|
| 277 |
+
batch["speech"] = speech
|
| 278 |
+
# print(len(speech)) # sum(len(audio) for audio in each batch)
|
| 279 |
+
# print(speech[0].shape) # seq_len, dim
|
| 280 |
+
batch['speech_lengths'] = speech_lengths
|
| 281 |
+
# print(speech_lengths[0].shape) # seq_len
|
| 282 |
+
|
| 283 |
+
if 'tgt_units' in instances[0]:
|
| 284 |
+
tgt_units = [instance['tgt_units'] for instance in instances]
|
| 285 |
+
tgt_units = self.pad_sequence(tgt_units, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 286 |
+
batch['tgt_units'] = tgt_units
|
| 287 |
+
# print(batch['tgt_units'])
|
| 288 |
+
# print("---------------")
|
| 289 |
+
# print(batch['input_ids'])
|
| 290 |
+
|
| 291 |
+
return batch
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
| 295 |
+
data_args) -> Dict:
|
| 296 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 297 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 298 |
+
data_path=data_args.data_path,
|
| 299 |
+
data_args=data_args)
|
| 300 |
+
if data_args.dev_path is not None:
|
| 301 |
+
dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
| 302 |
+
data_path=data_args.dev_path,
|
| 303 |
+
data_args=data_args)
|
| 304 |
+
else:
|
| 305 |
+
dev_dataset = None
|
| 306 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 307 |
+
return dict(train_dataset=train_dataset,
|
| 308 |
+
eval_dataset=dev_dataset,
|
| 309 |
+
data_collator=data_collator)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def train(attn_implementation="flash_attention_2"):
|
| 313 |
+
|
| 314 |
+
parser = transformers.HfArgumentParser(
|
| 315 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
| 316 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 317 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 318 |
+
|
| 319 |
+
bnb_model_from_pretrained_args = {}
|
| 320 |
+
if training_args.bits in [4, 8]:
|
| 321 |
+
from transformers import BitsAndBytesConfig
|
| 322 |
+
bnb_model_from_pretrained_args.update(dict(
|
| 323 |
+
device_map={"": training_args.device},
|
| 324 |
+
load_in_4bit=training_args.bits == 4,
|
| 325 |
+
load_in_8bit=training_args.bits == 8,
|
| 326 |
+
quantization_config=BitsAndBytesConfig(
|
| 327 |
+
load_in_4bit=training_args.bits == 4,
|
| 328 |
+
load_in_8bit=training_args.bits == 8,
|
| 329 |
+
llm_int8_skip_modules=["speech_projector"],
|
| 330 |
+
llm_int8_threshold=6.0,
|
| 331 |
+
llm_int8_has_fp16_weight=False,
|
| 332 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 333 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
| 334 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
| 335 |
+
)
|
| 336 |
+
))
|
| 337 |
+
|
| 338 |
+
if data_args.has_tgt_units:
|
| 339 |
+
if model_args.version == "llama_3":
|
| 340 |
+
model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
|
| 341 |
+
model_args.model_name_or_path,
|
| 342 |
+
cache_dir=training_args.cache_dir,
|
| 343 |
+
attn_implementation=attn_implementation,
|
| 344 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 345 |
+
**bnb_model_from_pretrained_args
|
| 346 |
+
)
|
| 347 |
+
elif model_args.version == "qwen":
|
| 348 |
+
model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
|
| 349 |
+
model_args.model_name_or_path,
|
| 350 |
+
cache_dir=training_args.cache_dir,
|
| 351 |
+
attn_implementation=attn_implementation,
|
| 352 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 353 |
+
**bnb_model_from_pretrained_args
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 357 |
+
else:
|
| 358 |
+
if model_args.version == "llama_3":
|
| 359 |
+
model = OmniSpeechLlamaForCausalLM.from_pretrained(
|
| 360 |
+
model_args.model_name_or_path,
|
| 361 |
+
cache_dir=training_args.cache_dir,
|
| 362 |
+
attn_implementation=attn_implementation,
|
| 363 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 364 |
+
**bnb_model_from_pretrained_args
|
| 365 |
+
)
|
| 366 |
+
elif model_args.version == "qwen":
|
| 367 |
+
model = OmniSpeechQwen2ForCausalLM.from_pretrained(
|
| 368 |
+
model_args.model_name_or_path,
|
| 369 |
+
cache_dir=training_args.cache_dir,
|
| 370 |
+
attn_implementation=attn_implementation,
|
| 371 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
| 372 |
+
**bnb_model_from_pretrained_args
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
raise ValueError("--currently only support llama or qwen model!")
|
| 376 |
+
model.config.use_cache = False
|
| 377 |
+
|
| 378 |
+
if model_args.freeze_backbone:
|
| 379 |
+
model.model.requires_grad_(False)
|
| 380 |
+
|
| 381 |
+
if training_args.bits in [4, 8]:
|
| 382 |
+
from peft import prepare_model_for_kbit_training
|
| 383 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 384 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
| 385 |
+
|
| 386 |
+
if training_args.gradient_checkpointing:
|
| 387 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 388 |
+
model.enable_input_require_grads()
|
| 389 |
+
else:
|
| 390 |
+
def make_inputs_require_grad(module, input, output):
|
| 391 |
+
output.requires_grad_(True)
|
| 392 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 393 |
+
|
| 394 |
+
if training_args.lora_enable:
|
| 395 |
+
from peft import LoraConfig, get_peft_model
|
| 396 |
+
lora_config = LoraConfig(
|
| 397 |
+
r=training_args.lora_r,
|
| 398 |
+
lora_alpha=training_args.lora_alpha,
|
| 399 |
+
target_modules=find_all_linear_names(model),
|
| 400 |
+
lora_dropout=training_args.lora_dropout,
|
| 401 |
+
bias=training_args.lora_bias,
|
| 402 |
+
task_type="CAUSAL_LM",
|
| 403 |
+
)
|
| 404 |
+
if training_args.bits == 16:
|
| 405 |
+
if training_args.bf16:
|
| 406 |
+
model.to(torch.bfloat16)
|
| 407 |
+
if training_args.fp16:
|
| 408 |
+
model.to(torch.float16)
|
| 409 |
+
model = get_peft_model(model, lora_config)
|
| 410 |
+
|
| 411 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 412 |
+
model_args.model_name_or_path,
|
| 413 |
+
cache_dir=training_args.cache_dir,
|
| 414 |
+
model_max_length=training_args.model_max_length,
|
| 415 |
+
padding_side="right",
|
| 416 |
+
use_fast=False,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 420 |
+
model.config.max_length = training_args.model_max_length
|
| 421 |
+
|
| 422 |
+
if model_args.version in conversation_lib.conv_templates:
|
| 423 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
| 424 |
+
else:
|
| 425 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
|
| 426 |
+
|
| 427 |
+
if model_args.speech_encoder is not None:
|
| 428 |
+
model.get_model().initialize_speech_modules(
|
| 429 |
+
model_args=model_args,
|
| 430 |
+
fsdp=training_args.fsdp
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
data_args.is_multimodal = True
|
| 434 |
+
|
| 435 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
| 436 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
| 437 |
+
|
| 438 |
+
model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
|
| 439 |
+
|
| 440 |
+
model.config.speech_normalize = data_args.speech_normalize
|
| 441 |
+
|
| 442 |
+
for p in model.get_speech_encoder().parameters():
|
| 443 |
+
p.requires_grad = False
|
| 444 |
+
|
| 445 |
+
if model_args.tune_speech_projector:
|
| 446 |
+
model.requires_grad_(False)
|
| 447 |
+
for p in model.get_speech_projector().parameters():
|
| 448 |
+
p.requires_grad = True
|
| 449 |
+
|
| 450 |
+
model.config.freeze_speech_projector = training_args.freeze_speech_projector
|
| 451 |
+
if training_args.freeze_speech_projector:
|
| 452 |
+
for p in model.get_speech_projector().parameters():
|
| 453 |
+
p.requires_grad = False
|
| 454 |
+
|
| 455 |
+
if training_args.bits in [4, 8]:
|
| 456 |
+
model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
|
| 457 |
+
|
| 458 |
+
model.config.speech_projector_lr = training_args.speech_projector_lr
|
| 459 |
+
|
| 460 |
+
if data_args.has_tgt_units:
|
| 461 |
+
model.initialize_speech_generator(model_args=model_args)
|
| 462 |
+
|
| 463 |
+
if training_args.bits in [4, 8]:
|
| 464 |
+
from peft.tuners.lora import LoraLayer
|
| 465 |
+
for name, module in model.named_modules():
|
| 466 |
+
if isinstance(module, LoraLayer):
|
| 467 |
+
if training_args.bf16:
|
| 468 |
+
module = module.to(torch.bfloat16)
|
| 469 |
+
if 'norm' in name:
|
| 470 |
+
module = module.to(torch.float32)
|
| 471 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
| 472 |
+
if hasattr(module, 'weight'):
|
| 473 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
| 474 |
+
module = module.to(torch.bfloat16)
|
| 475 |
+
|
| 476 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
| 477 |
+
data_args=data_args)
|
| 478 |
+
|
| 479 |
+
print("Training Layers:")
|
| 480 |
+
for name, param in model.named_parameters():
|
| 481 |
+
if param.requires_grad:
|
| 482 |
+
print(name, param.grad)
|
| 483 |
+
|
| 484 |
+
trainer = OmniTrainer(model=model,
|
| 485 |
+
tokenizer=tokenizer,
|
| 486 |
+
args=training_args,
|
| 487 |
+
**data_module)
|
| 488 |
+
|
| 489 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 490 |
+
trainer.train(resume_from_checkpoint=True)
|
| 491 |
+
else:
|
| 492 |
+
trainer.train()
|
| 493 |
+
trainer.save_state()
|
| 494 |
+
|
| 495 |
+
model.config.use_cache = True
|
| 496 |
+
|
| 497 |
+
if training_args.lora_enable:
|
| 498 |
+
state_dict = get_peft_state_maybe_zero_3(
|
| 499 |
+
model.named_parameters(), training_args.lora_bias
|
| 500 |
+
)
|
| 501 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
| 502 |
+
model.named_parameters()
|
| 503 |
+
)
|
| 504 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
| 505 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 506 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
| 507 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
| 508 |
+
else:
|
| 509 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
| 510 |
+
output_dir=training_args.output_dir)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
if __name__ == "__main__":
|
| 514 |
+
train()
|
| 515 |
+
|
omni_speech/train/trainer.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import deepspeed
|
| 4 |
+
from transformers import Trainer
|
| 5 |
+
from transformers.trainer_pt_utils import nested_detach
|
| 6 |
+
from transformers.utils import is_sagemaker_mp_enabled
|
| 7 |
+
from transformers.trainer import *
|
| 8 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CPMTrainer(Trainer):
|
| 12 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
| 13 |
+
if "labels" in inputs:
|
| 14 |
+
labels = inputs.pop("labels")
|
| 15 |
+
else:
|
| 16 |
+
labels = None
|
| 17 |
+
|
| 18 |
+
if not self.args.use_lora:
|
| 19 |
+
outputs = self.model(data = inputs, use_cache=False)
|
| 20 |
+
else:
|
| 21 |
+
with self.model._enable_peft_forward_hooks(**inputs):
|
| 22 |
+
outputs = self.model.base_model(data = inputs, use_cache=False)
|
| 23 |
+
|
| 24 |
+
if labels is not None:
|
| 25 |
+
# Flatten the tokens
|
| 26 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 27 |
+
logits = outputs.logits.view(-1,
|
| 28 |
+
self.model.config.vocab_size).contiguous()
|
| 29 |
+
labels = labels.view(-1).long().contiguous()
|
| 30 |
+
# Enable model parallelism
|
| 31 |
+
labels = labels.to(logits.device)
|
| 32 |
+
loss = loss_fct(logits, labels)
|
| 33 |
+
else:
|
| 34 |
+
if isinstance(outputs, dict) and "loss" not in outputs:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
"The model did not return a loss from the inputs, only the following keys: "
|
| 37 |
+
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 38 |
+
)
|
| 39 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 40 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 41 |
+
|
| 42 |
+
return (loss, outputs) if return_outputs else loss
|
| 43 |
+
|
| 44 |
+
def prediction_step(
|
| 45 |
+
self,
|
| 46 |
+
model: nn.Module,
|
| 47 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
| 48 |
+
prediction_loss_only: bool,
|
| 49 |
+
ignore_keys: Optional[List[str]] = None,
|
| 50 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 51 |
+
"""
|
| 52 |
+
Perform an evaluation step on `model` using `inputs`.
|
| 53 |
+
|
| 54 |
+
Subclass and override to inject custom behavior.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model (`nn.Module`):
|
| 58 |
+
The model to evaluate.
|
| 59 |
+
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
| 60 |
+
The inputs and targets of the model.
|
| 61 |
+
|
| 62 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
| 63 |
+
argument `labels`. Check your model's documentation for all accepted arguments.
|
| 64 |
+
prediction_loss_only (`bool`):
|
| 65 |
+
Whether or not to return the loss only.
|
| 66 |
+
ignore_keys (`List[str]`, *optional*):
|
| 67 |
+
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
| 68 |
+
gathering predictions.
|
| 69 |
+
|
| 70 |
+
Return:
|
| 71 |
+
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
| 72 |
+
logits and labels (each being optional).
|
| 73 |
+
"""
|
| 74 |
+
has_labels = (
|
| 75 |
+
False
|
| 76 |
+
if len(self.label_names) == 0
|
| 77 |
+
else all(inputs.get(k) is not None for k in self.label_names)
|
| 78 |
+
)
|
| 79 |
+
# For CLIP-like models capable of returning loss values.
|
| 80 |
+
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
|
| 81 |
+
# is `True` in `model.forward`.
|
| 82 |
+
return_loss = inputs.get("return_loss", None)
|
| 83 |
+
if return_loss is None:
|
| 84 |
+
return_loss = self.can_return_loss
|
| 85 |
+
loss_without_labels = (
|
| 86 |
+
True if len(self.label_names) == 0 and return_loss else False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
inputs = self._prepare_inputs(inputs)
|
| 90 |
+
if ignore_keys is None:
|
| 91 |
+
if hasattr(self.model, "config"):
|
| 92 |
+
ignore_keys = getattr(
|
| 93 |
+
self.model.config, "keys_to_ignore_at_inference", []
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
ignore_keys = []
|
| 97 |
+
|
| 98 |
+
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
| 99 |
+
if has_labels or loss_without_labels:
|
| 100 |
+
labels = nested_detach(tuple(inputs.get(name)
|
| 101 |
+
for name in self.label_names))
|
| 102 |
+
if len(labels) == 1:
|
| 103 |
+
labels = labels[0]
|
| 104 |
+
else:
|
| 105 |
+
labels = None
|
| 106 |
+
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
if is_sagemaker_mp_enabled():
|
| 109 |
+
raw_outputs = smp_forward_only(model, inputs)
|
| 110 |
+
if has_labels or loss_without_labels:
|
| 111 |
+
if isinstance(raw_outputs, dict):
|
| 112 |
+
loss_mb = raw_outputs["loss"]
|
| 113 |
+
logits_mb = tuple(
|
| 114 |
+
v
|
| 115 |
+
for k, v in raw_outputs.items()
|
| 116 |
+
if k not in ignore_keys + ["loss"]
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
loss_mb = raw_outputs[0]
|
| 120 |
+
logits_mb = raw_outputs[1:]
|
| 121 |
+
|
| 122 |
+
loss = loss_mb.reduce_mean().detach().cpu()
|
| 123 |
+
logits = smp_nested_concat(logits_mb)
|
| 124 |
+
else:
|
| 125 |
+
loss = None
|
| 126 |
+
if isinstance(raw_outputs, dict):
|
| 127 |
+
logits_mb = tuple(
|
| 128 |
+
v for k, v in raw_outputs.items() if k not in ignore_keys
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
logits_mb = raw_outputs
|
| 132 |
+
logits = smp_nested_concat(logits_mb)
|
| 133 |
+
else:
|
| 134 |
+
if has_labels or loss_without_labels:
|
| 135 |
+
with self.compute_loss_context_manager():
|
| 136 |
+
loss, outputs = self.compute_loss(
|
| 137 |
+
model, inputs, return_outputs=True
|
| 138 |
+
)
|
| 139 |
+
loss = loss.mean().detach()
|
| 140 |
+
|
| 141 |
+
if isinstance(outputs, dict):
|
| 142 |
+
logits = tuple(
|
| 143 |
+
v
|
| 144 |
+
for k, v in outputs.items()
|
| 145 |
+
if k not in ignore_keys + ["loss"]
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
logits = outputs[1:]
|
| 149 |
+
else:
|
| 150 |
+
loss = None
|
| 151 |
+
with self.compute_loss_context_manager():
|
| 152 |
+
outputs = model(**inputs)
|
| 153 |
+
if isinstance(outputs, dict):
|
| 154 |
+
logits = tuple(
|
| 155 |
+
v for k, v in outputs.items() if k not in ignore_keys
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
logits = outputs
|
| 159 |
+
# TODO: this needs to be fixed and made cleaner later.
|
| 160 |
+
if self.args.past_index >= 0:
|
| 161 |
+
self._past = outputs[self.args.past_index - 1]
|
| 162 |
+
|
| 163 |
+
if prediction_loss_only:
|
| 164 |
+
return (loss, None, None)
|
| 165 |
+
|
| 166 |
+
logits = nested_detach(logits)
|
| 167 |
+
if len(logits) == 1:
|
| 168 |
+
logits = logits[0]
|
| 169 |
+
|
| 170 |
+
return (loss, logits, labels)
|
| 171 |
+
|
| 172 |
+
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Perform a training step on a batch of inputs.
|
| 175 |
+
|
| 176 |
+
Subclass and override to inject custom behavior.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
model (`nn.Module`):
|
| 180 |
+
The model to train.
|
| 181 |
+
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
| 182 |
+
The inputs and targets of the model.
|
| 183 |
+
|
| 184 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
| 185 |
+
argument `labels`. Check your model's documentation for all accepted arguments.
|
| 186 |
+
|
| 187 |
+
Return:
|
| 188 |
+
`torch.Tensor`: The tensor with training loss on this batch.
|
| 189 |
+
"""
|
| 190 |
+
model.train()
|
| 191 |
+
inputs = self._prepare_inputs(inputs)
|
| 192 |
+
|
| 193 |
+
if is_sagemaker_mp_enabled():
|
| 194 |
+
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
| 195 |
+
return loss_mb.reduce_mean().detach().to(self.args.device)
|
| 196 |
+
|
| 197 |
+
with self.compute_loss_context_manager():
|
| 198 |
+
loss = self.compute_loss(model, inputs)
|
| 199 |
+
|
| 200 |
+
del inputs
|
| 201 |
+
torch.cuda.empty_cache()
|
| 202 |
+
|
| 203 |
+
if self.args.n_gpu > 1:
|
| 204 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 205 |
+
|
| 206 |
+
if self.use_apex:
|
| 207 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 208 |
+
scaled_loss.backward()
|
| 209 |
+
else:
|
| 210 |
+
self.accelerator.backward(loss)
|
| 211 |
+
|
| 212 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 213 |
+
|
| 214 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 215 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 216 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 217 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 218 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 219 |
+
|
| 220 |
+
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
| 221 |
+
# Save a trained model and configuration using `save_pretrained()`.
|
| 222 |
+
# They can then be reloaded using `from_pretrained()`
|
| 223 |
+
if not isinstance(self.model, supported_classes):
|
| 224 |
+
if state_dict is None:
|
| 225 |
+
state_dict = self.model.state_dict()
|
| 226 |
+
|
| 227 |
+
if isinstance(unwrap_model(self.model), supported_classes):
|
| 228 |
+
unwrap_model(self.model).save_pretrained(
|
| 229 |
+
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
| 233 |
+
if self.args.save_safetensors:
|
| 234 |
+
safetensors.torch.save_file(
|
| 235 |
+
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
| 239 |
+
else:
|
| 240 |
+
|
| 241 |
+
self.model.save_pretrained(
|
| 242 |
+
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if self.tokenizer is not None:
|
| 246 |
+
self.tokenizer.save_pretrained(output_dir)
|
| 247 |
+
|
| 248 |
+
# Good practice: save your training arguments together with the trained model
|
| 249 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
scripts/continue.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPUS_PER_NODE=8
|
| 4 |
+
# NNODES=1
|
| 5 |
+
# NODE_RANK=0
|
| 6 |
+
# MASTER_ADDR=localhost
|
| 7 |
+
# MASTER_PORT=6001
|
| 8 |
+
|
| 9 |
+
MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/minicpmo_sft_asr"
|
| 10 |
+
TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 11 |
+
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
| 12 |
+
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
| 13 |
+
# See the section for finetuning in README for more information.
|
| 14 |
+
DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
|
| 15 |
+
EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
|
| 16 |
+
|
| 17 |
+
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
| 18 |
+
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
| 19 |
+
LLM_TYPE="qwen"
|
| 20 |
+
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# DISTRIBUTED_ARGS="
|
| 24 |
+
# --nproc_per_node $GPUS_PER_NODE \
|
| 25 |
+
# --nnodes $NNODES \
|
| 26 |
+
# --node_rank $NODE_RANK \
|
| 27 |
+
# --master_addr $MASTER_ADDR \
|
| 28 |
+
# --master_port $MASTER_PORT
|
| 29 |
+
# "
|
| 30 |
+
|
| 31 |
+
deepspeed ../omni_speech/train/train_minicpmo.py \
|
| 32 |
+
--deepspeed zero2.json \
|
| 33 |
+
--model_name_or_path $MODEL \
|
| 34 |
+
--tokenizer_path $TOKENIZER_PATH \
|
| 35 |
+
--llm_type $LLM_TYPE \
|
| 36 |
+
--data_path $DATA \
|
| 37 |
+
--eval_data_path $EVAL_DATA \
|
| 38 |
+
--remove_unused_columns false \
|
| 39 |
+
--label_names "labels" \
|
| 40 |
+
--prediction_loss_only false \
|
| 41 |
+
--bf16 true \
|
| 42 |
+
--do_train \
|
| 43 |
+
--do_eval \
|
| 44 |
+
--tune_speech true \
|
| 45 |
+
--tune_llm false \
|
| 46 |
+
--model_max_length $MODEL_MAX_Length \
|
| 47 |
+
--eval_steps 2000 \
|
| 48 |
+
--output_dir ../checkpoints/minicpmo_sft_asr \
|
| 49 |
+
--num_train_epochs 2 \
|
| 50 |
+
--logging_strategy "steps" \
|
| 51 |
+
--per_device_train_batch_size 1 \
|
| 52 |
+
--per_device_eval_batch_size 1 \
|
| 53 |
+
--gradient_accumulation_steps 4 \
|
| 54 |
+
--evaluation_strategy "steps" \
|
| 55 |
+
--save_strategy "steps" \
|
| 56 |
+
--save_steps 5000 \
|
| 57 |
+
--save_total_limit 1 \
|
| 58 |
+
--learning_rate 1e-5 \
|
| 59 |
+
--max_grad_norm 20. \
|
| 60 |
+
--weight_decay 0. \
|
| 61 |
+
--warmup_ratio 0.03 \
|
| 62 |
+
--lr_scheduler_type "cosine" \
|
| 63 |
+
--logging_steps 1 \
|
| 64 |
+
--tf32 True \
|
| 65 |
+
--gradient_checkpointing true
|
scripts/ds_config_zero2.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
|
| 11 |
+
"bf16": {
|
| 12 |
+
"enabled": "auto"
|
| 13 |
+
},
|
| 14 |
+
|
| 15 |
+
"optimizer": {
|
| 16 |
+
"type": "AdamW",
|
| 17 |
+
"params": {
|
| 18 |
+
"lr": "auto",
|
| 19 |
+
"betas": "auto",
|
| 20 |
+
"eps": "auto",
|
| 21 |
+
"weight_decay": "auto"
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
|
| 25 |
+
"scheduler": {
|
| 26 |
+
"type": "WarmupLR",
|
| 27 |
+
"params": {
|
| 28 |
+
"warmup_min_lr": "auto",
|
| 29 |
+
"warmup_max_lr": "auto",
|
| 30 |
+
"warmup_num_steps": "auto"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
|
| 34 |
+
"zero_optimization": {
|
| 35 |
+
"stage": 2,
|
| 36 |
+
"offload_optimizer": {
|
| 37 |
+
"device": "none",
|
| 38 |
+
"pin_memory": true
|
| 39 |
+
},
|
| 40 |
+
"allgather_partitions": true,
|
| 41 |
+
"allgather_bucket_size": 2e8,
|
| 42 |
+
"overlap_comm": true,
|
| 43 |
+
"reduce_scatter": true,
|
| 44 |
+
"reduce_bucket_size": 2e8,
|
| 45 |
+
"contiguous_gradients": true
|
| 46 |
+
},
|
| 47 |
+
|
| 48 |
+
"gradient_accumulation_steps": "auto",
|
| 49 |
+
"gradient_clipping": "auto",
|
| 50 |
+
"steps_per_print": 100,
|
| 51 |
+
"train_batch_size": "auto",
|
| 52 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 53 |
+
"wall_clock_breakdown": false
|
| 54 |
+
}
|
scripts/ds_config_zero3.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
|
| 23 |
+
"scheduler": {
|
| 24 |
+
"type": "WarmupLR",
|
| 25 |
+
"params": {
|
| 26 |
+
"warmup_min_lr": "auto",
|
| 27 |
+
"warmup_max_lr": "auto",
|
| 28 |
+
"warmup_num_steps": "auto"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
|
| 32 |
+
"zero_optimization": {
|
| 33 |
+
"stage": 3,
|
| 34 |
+
"offload_optimizer": {
|
| 35 |
+
"device": "none",
|
| 36 |
+
"pin_memory": true
|
| 37 |
+
},
|
| 38 |
+
"offload_param": {
|
| 39 |
+
"device": "none",
|
| 40 |
+
"pin_memory": true
|
| 41 |
+
},
|
| 42 |
+
"overlap_comm": true,
|
| 43 |
+
"contiguous_gradients": true,
|
| 44 |
+
"sub_group_size": 1e9,
|
| 45 |
+
"reduce_bucket_size": "auto",
|
| 46 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 47 |
+
"stage3_param_persistence_threshold": "auto",
|
| 48 |
+
"stage3_max_live_parameters": 1e9,
|
| 49 |
+
"stage3_max_reuse_distance": 1e9,
|
| 50 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 51 |
+
},
|
| 52 |
+
|
| 53 |
+
"gradient_accumulation_steps": "auto",
|
| 54 |
+
"gradient_clipping": "auto",
|
| 55 |
+
"steps_per_print": 100,
|
| 56 |
+
"train_batch_size": "auto",
|
| 57 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 58 |
+
"wall_clock_breakdown": false
|
| 59 |
+
}
|
scripts/export.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_fixed_all/checkpoint-4000
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
PROMPT_VERSION=qwen
|
| 6 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/moss/moss_100K_phase3_tgt_units_processed.jsonl
|
| 7 |
+
# DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250103.jsonl
|
| 8 |
+
CACHE_DIR="../output/cached_sft_speech_decoder_20250103"
|
| 9 |
+
|
| 10 |
+
deepspeed --master_port 29501 ../omni_speech/train/export.py \
|
| 11 |
+
--deepspeed zero2.json \
|
| 12 |
+
--model_name_or_path $MODEL_PATH \
|
| 13 |
+
--version $PROMPT_VERSION \
|
| 14 |
+
--data_path $DATA_PATH \
|
| 15 |
+
--cache_dir $CACHE_DIR \
|
| 16 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 17 |
+
--mel_size 80 \
|
| 18 |
+
--speech_encoder_hidden_size 1024 \
|
| 19 |
+
--speech_encoder_type whisper \
|
| 20 |
+
--tune_speech_generator_only True \
|
| 21 |
+
--bf16 True \
|
| 22 |
+
--output_dir ../checkpoints/tmp \
|
| 23 |
+
--num_train_epochs 8 \
|
| 24 |
+
--per_device_train_batch_size 1 \
|
| 25 |
+
--per_device_eval_batch_size 1 \
|
| 26 |
+
--gradient_accumulation_steps 2 \
|
| 27 |
+
--evaluation_strategy "no" \
|
| 28 |
+
--save_strategy "steps" \
|
| 29 |
+
--save_steps 2000 \
|
| 30 |
+
--save_total_limit 1 \
|
| 31 |
+
--learning_rate 1e-4 \
|
| 32 |
+
--weight_decay 0. \
|
| 33 |
+
--warmup_ratio 0.03 \
|
| 34 |
+
--logging_steps 10 \
|
| 35 |
+
--tf32 True \
|
| 36 |
+
--model_max_length 2048 \
|
| 37 |
+
--gradient_checkpointing True \
|
| 38 |
+
--dataloader_num_workers 8 \
|
| 39 |
+
--has_tgt_units True
|
scripts/finetune.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
|
| 6 |
+
PROMPT_VERSION=qwen
|
| 7 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_20250112_fc_mixed_vfva_text_fake_audios.jsonl
|
| 8 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250112_fc_mixed_vfva_text_fake_audios.jsonl
|
| 9 |
+
CACHE_DIR="../output/cached_sft_20250112"
|
| 10 |
+
|
| 11 |
+
deepspeed ../omni_speech/train/train_mem.py \
|
| 12 |
+
--deepspeed zero2.json \
|
| 13 |
+
--model_name_or_path $MODEL_PATH \
|
| 14 |
+
--version $PROMPT_VERSION \
|
| 15 |
+
--data_path $DATA_PATH \
|
| 16 |
+
--dev_path $DEV_PATH \
|
| 17 |
+
--cache_dir $CACHE_DIR \
|
| 18 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 19 |
+
--mel_size 80 \
|
| 20 |
+
--speech_encoder_hidden_size 1024 \
|
| 21 |
+
--speech_encoder_type whisper \
|
| 22 |
+
--pretrain_speech_projector $SPEECH_ADAPTER \
|
| 23 |
+
--bf16 True \
|
| 24 |
+
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text \
|
| 25 |
+
--num_train_epochs 2 \
|
| 26 |
+
--per_device_train_batch_size 1 \
|
| 27 |
+
--per_device_eval_batch_size 1 \
|
| 28 |
+
--gradient_accumulation_steps 4 \
|
| 29 |
+
--evaluation_strategy "steps" \
|
| 30 |
+
--save_strategy "steps" \
|
| 31 |
+
--eval_steps 2000 \
|
| 32 |
+
--save_steps 6000 \
|
| 33 |
+
--save_total_limit 1 \
|
| 34 |
+
--learning_rate 2e-5 \
|
| 35 |
+
--weight_decay 0. \
|
| 36 |
+
--warmup_ratio 0.03 \
|
| 37 |
+
--lr_scheduler_type "cosine" \
|
| 38 |
+
--logging_steps 1 \
|
| 39 |
+
--tf32 True \
|
| 40 |
+
--model_max_length 8192 \
|
| 41 |
+
--gradient_checkpointing True \
|
| 42 |
+
--dataloader_num_workers 8
|
scripts/finetune_llm_speech_decoder.sh
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# it currently supports for batch = 1 only.
|
| 4 |
+
|
| 5 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
|
| 6 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 7 |
+
SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
|
| 8 |
+
PROMPT_VERSION=qwen
|
| 9 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_20250106_fc_mixed_tgt_units.jsonl
|
| 10 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
|
| 11 |
+
CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
|
| 12 |
+
|
| 13 |
+
deepspeed ../omni_speech/train/train_mem.py \
|
| 14 |
+
--deepspeed zero2.json \
|
| 15 |
+
--model_name_or_path $MODEL_PATH \
|
| 16 |
+
--version $PROMPT_VERSION \
|
| 17 |
+
--data_path $DATA_PATH \
|
| 18 |
+
--dev_path $DEV_PATH \
|
| 19 |
+
--cache_dir $CACHE_DIR \
|
| 20 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 21 |
+
--mel_size 80 \
|
| 22 |
+
--speech_encoder_hidden_size 1024 \
|
| 23 |
+
--speech_encoder_type whisper \
|
| 24 |
+
--pretrain_speech_projector $SPEECH_ADAPTER \
|
| 25 |
+
--bf16 True \
|
| 26 |
+
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_fixed_all \
|
| 27 |
+
--num_train_epochs 3 \
|
| 28 |
+
--per_device_train_batch_size 1 \
|
| 29 |
+
--per_device_eval_batch_size 1 \
|
| 30 |
+
--gradient_accumulation_steps 4 \
|
| 31 |
+
--evaluation_strategy "steps" \
|
| 32 |
+
--save_strategy "steps" \
|
| 33 |
+
--eval_steps 2000 \
|
| 34 |
+
--save_steps 2000 \
|
| 35 |
+
--save_total_limit 1 \
|
| 36 |
+
--learning_rate 1e-5 \
|
| 37 |
+
--weight_decay 0. \
|
| 38 |
+
--warmup_ratio 0.03 \
|
| 39 |
+
--lr_scheduler_type "cosine" \
|
| 40 |
+
--logging_steps 1 \
|
| 41 |
+
--tf32 True \
|
| 42 |
+
--model_max_length 1024 \
|
| 43 |
+
--gradient_checkpointing True \
|
| 44 |
+
--dataloader_num_workers 8 \
|
| 45 |
+
--has_tgt_units True \
|
| 46 |
+
--ctc_loss_weight 2.0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc
|
| 50 |
+
# SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 51 |
+
# PROMPT_VERSION=qwen
|
| 52 |
+
# DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/moss/moss_100K_phase3_tgt_units_processed.jsonl
|
| 53 |
+
# # DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
|
| 54 |
+
# CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
|
| 55 |
+
|
| 56 |
+
# deepspeed ../omni_speech/train/train_mem.py \
|
| 57 |
+
# --deepspeed zero2.json \
|
| 58 |
+
# --model_name_or_path $MODEL_PATH \
|
| 59 |
+
# --version $PROMPT_VERSION \
|
| 60 |
+
# --data_path $DATA_PATH \
|
| 61 |
+
# --cache_dir $CACHE_DIR \
|
| 62 |
+
# --speech_encoder $SPEECH_ENCODER \
|
| 63 |
+
# --mel_size 80 \
|
| 64 |
+
# --speech_encoder_hidden_size 1024 \
|
| 65 |
+
# --speech_encoder_type whisper \
|
| 66 |
+
# --bf16 True \
|
| 67 |
+
# --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_all \
|
| 68 |
+
# --num_train_epochs 5 \
|
| 69 |
+
# --per_device_train_batch_size 1 \
|
| 70 |
+
# --per_device_eval_batch_size 1 \
|
| 71 |
+
# --gradient_accumulation_steps 4 \
|
| 72 |
+
# --evaluation_strategy "no" \
|
| 73 |
+
# --save_strategy "steps" \
|
| 74 |
+
# --save_steps 10000 \
|
| 75 |
+
# --save_total_limit 1 \
|
| 76 |
+
# --learning_rate 1e-4 \
|
| 77 |
+
# --weight_decay 0. \
|
| 78 |
+
# --warmup_ratio 0.03 \
|
| 79 |
+
# --logging_steps 1 \
|
| 80 |
+
# --tf32 True \
|
| 81 |
+
# --model_max_length 2048 \
|
| 82 |
+
# --gradient_checkpointing True \
|
| 83 |
+
# --dataloader_num_workers 8 \
|
| 84 |
+
# --has_tgt_units True \
|
| 85 |
+
# --ctc_loss_weight 10.0
|
scripts/finetune_lora.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
|
| 6 |
+
PROMPT_VERSION=qwen
|
| 7 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_tmp.jsonl
|
| 8 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_tmp.jsonl
|
| 9 |
+
CACHE_DIR="../output/cached_sft"
|
| 10 |
+
|
| 11 |
+
deepspeed ../omni_speech/train/train_mem.py \
|
| 12 |
+
--deepspeed zero2.json \
|
| 13 |
+
--lora_enable True \
|
| 14 |
+
--model_name_or_path $MODEL_PATH \
|
| 15 |
+
--version $PROMPT_VERSION \
|
| 16 |
+
--data_path $DATA_PATH \
|
| 17 |
+
--dev_path $DEV_PATH \
|
| 18 |
+
--cache_dir $CACHE_DIR \
|
| 19 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 20 |
+
--mel_size 80 \
|
| 21 |
+
--speech_encoder_hidden_size 1024 \
|
| 22 |
+
--speech_encoder_type whisper \
|
| 23 |
+
--pretrain_speech_projector $SPEECH_ADAPTER \
|
| 24 |
+
--bf16 True \
|
| 25 |
+
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-lora \
|
| 26 |
+
--num_train_epochs 18 \
|
| 27 |
+
--per_device_train_batch_size 2 \
|
| 28 |
+
--per_device_eval_batch_size 1 \
|
| 29 |
+
--gradient_accumulation_steps 4 \
|
| 30 |
+
--evaluation_strategy "steps" \
|
| 31 |
+
--save_strategy "steps" \
|
| 32 |
+
--eval_steps 1000 \
|
| 33 |
+
--save_steps 1000 \
|
| 34 |
+
--save_total_limit 1 \
|
| 35 |
+
--learning_rate 2e-5 \
|
| 36 |
+
--optim adamw_torch \
|
| 37 |
+
--weight_decay 0. \
|
| 38 |
+
--warmup_ratio 0.03 \
|
| 39 |
+
--logging_steps 1 \
|
| 40 |
+
--tf32 True \
|
| 41 |
+
--model_max_length 2048 \
|
| 42 |
+
--gradient_checkpointing True \
|
| 43 |
+
--dataloader_num_workers 8
|
scripts/finetune_minicpmo.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPUS_PER_NODE=8
|
| 4 |
+
# NNODES=1
|
| 5 |
+
# NODE_RANK=0
|
| 6 |
+
# MASTER_ADDR=localhost
|
| 7 |
+
# MASTER_PORT=6001
|
| 8 |
+
|
| 9 |
+
MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/minicpmo_sft_asr"
|
| 10 |
+
TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 11 |
+
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
| 12 |
+
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
| 13 |
+
# See the section for finetuning in README for more information.
|
| 14 |
+
DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/sft/train_20250219_fc_mixed_text_filter_a_um.jsonl"
|
| 15 |
+
EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/sft/dev_20250219_fc_mixed_text_filter_a_um.jsonl"
|
| 16 |
+
|
| 17 |
+
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
| 18 |
+
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
| 19 |
+
LLM_TYPE="qwen"
|
| 20 |
+
MODEL_MAX_Length=8192 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# DISTRIBUTED_ARGS="
|
| 24 |
+
# --nproc_per_node $GPUS_PER_NODE \
|
| 25 |
+
# --nnodes $NNODES \
|
| 26 |
+
# --node_rank $NODE_RANK \
|
| 27 |
+
# --master_addr $MASTER_ADDR \
|
| 28 |
+
# --master_port $MASTER_PORT
|
| 29 |
+
# "
|
| 30 |
+
|
| 31 |
+
deepspeed ../omni_speech/train/train_minicpmo.py \
|
| 32 |
+
--deepspeed zero2.json \
|
| 33 |
+
--model_name_or_path $MODEL \
|
| 34 |
+
--tokenizer_path $TOKENIZER_PATH \
|
| 35 |
+
--llm_type $LLM_TYPE \
|
| 36 |
+
--data_path $DATA \
|
| 37 |
+
--eval_data_path $EVAL_DATA \
|
| 38 |
+
--remove_unused_columns false \
|
| 39 |
+
--label_names "labels" \
|
| 40 |
+
--prediction_loss_only false \
|
| 41 |
+
--bf16 true \
|
| 42 |
+
--do_train \
|
| 43 |
+
--do_eval \
|
| 44 |
+
--tune_speech true \
|
| 45 |
+
--tune_llm true \
|
| 46 |
+
--model_max_length $MODEL_MAX_Length \
|
| 47 |
+
--eval_steps 1000 \
|
| 48 |
+
--output_dir ../checkpoints/minicpmo_sft_vi_fc_fixed \
|
| 49 |
+
--num_train_epochs 1 \
|
| 50 |
+
--logging_strategy "steps" \
|
| 51 |
+
--per_device_train_batch_size 1 \
|
| 52 |
+
--per_device_eval_batch_size 1 \
|
| 53 |
+
--gradient_accumulation_steps 4 \
|
| 54 |
+
--evaluation_strategy "steps" \
|
| 55 |
+
--save_strategy "no" \
|
| 56 |
+
--save_steps 4000 \
|
| 57 |
+
--save_total_limit 1 \
|
| 58 |
+
--learning_rate 1e-5 \
|
| 59 |
+
--max_grad_norm 20. \
|
| 60 |
+
--weight_decay 0. \
|
| 61 |
+
--warmup_ratio 0.03 \
|
| 62 |
+
--lr_scheduler_type "cosine" \
|
| 63 |
+
--logging_steps 1 \
|
| 64 |
+
--tf32 True \
|
| 65 |
+
--gradient_checkpointing true
|
scripts/finetune_minicpmo_asr.sh
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPUS_PER_NODE=8
|
| 4 |
+
# NNODES=1
|
| 5 |
+
# NODE_RANK=0
|
| 6 |
+
# MASTER_ADDR=localhost
|
| 7 |
+
# MASTER_PORT=6001
|
| 8 |
+
|
| 9 |
+
MODEL="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 10 |
+
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
| 11 |
+
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
| 12 |
+
# See the section for finetuning in README for more information.
|
| 13 |
+
DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
|
| 14 |
+
EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
|
| 15 |
+
|
| 16 |
+
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
| 17 |
+
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
| 18 |
+
LLM_TYPE="qwen"
|
| 19 |
+
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# DISTRIBUTED_ARGS="
|
| 23 |
+
# --nproc_per_node $GPUS_PER_NODE \
|
| 24 |
+
# --nnodes $NNODES \
|
| 25 |
+
# --node_rank $NODE_RANK \
|
| 26 |
+
# --master_addr $MASTER_ADDR \
|
| 27 |
+
# --master_port $MASTER_PORT
|
| 28 |
+
# "
|
| 29 |
+
|
| 30 |
+
deepspeed ../omni_speech/train/train_minicpmo.py \
|
| 31 |
+
--deepspeed zero2.json \
|
| 32 |
+
--model_name_or_path $MODEL \
|
| 33 |
+
--llm_type $LLM_TYPE \
|
| 34 |
+
--data_path $DATA \
|
| 35 |
+
--eval_data_path $EVAL_DATA \
|
| 36 |
+
--remove_unused_columns false \
|
| 37 |
+
--label_names "labels" \
|
| 38 |
+
--prediction_loss_only false \
|
| 39 |
+
--bf16 true \
|
| 40 |
+
--do_train \
|
| 41 |
+
--do_eval \
|
| 42 |
+
--tune_speech true \
|
| 43 |
+
--tune_llm false \
|
| 44 |
+
--model_max_length $MODEL_MAX_Length \
|
| 45 |
+
--eval_steps 4000 \
|
| 46 |
+
--output_dir ../checkpoints/minicpmo_sft_asr_new \
|
| 47 |
+
--num_train_epochs 1 \
|
| 48 |
+
--logging_strategy "steps" \
|
| 49 |
+
--per_device_train_batch_size 1 \
|
| 50 |
+
--per_device_eval_batch_size 1 \
|
| 51 |
+
--gradient_accumulation_steps 4 \
|
| 52 |
+
--evaluation_strategy "steps" \
|
| 53 |
+
--save_strategy "steps" \
|
| 54 |
+
--save_steps 10000 \
|
| 55 |
+
--save_total_limit 1 \
|
| 56 |
+
--learning_rate 2e-4 \
|
| 57 |
+
--max_grad_norm 20. \
|
| 58 |
+
--weight_decay 0. \
|
| 59 |
+
--warmup_ratio 0.03 \
|
| 60 |
+
--lr_scheduler_type "cosine" \
|
| 61 |
+
--logging_steps 1 \
|
| 62 |
+
--tf32 True \
|
| 63 |
+
--gradient_checkpointing true
|
scripts/finetune_speech_decoder.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# it currently supports for batch = 1 only.
|
| 4 |
+
|
| 5 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text
|
| 6 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 7 |
+
PROMPT_VERSION=qwen
|
| 8 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/20250114_tgt_unit_preprocessed_combined_mix_text_filtered.jsonl
|
| 9 |
+
# DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250103.jsonl
|
| 10 |
+
CACHE_DIR="../output/cached_sft_speech_decoder_20250114"
|
| 11 |
+
|
| 12 |
+
deepspeed ../omni_speech/train/train_mem.py \
|
| 13 |
+
--deepspeed zero2.json \
|
| 14 |
+
--model_name_or_path $MODEL_PATH \
|
| 15 |
+
--version $PROMPT_VERSION \
|
| 16 |
+
--data_path $DATA_PATH \
|
| 17 |
+
--cache_dir $CACHE_DIR \
|
| 18 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 19 |
+
--mel_size 80 \
|
| 20 |
+
--speech_encoder_hidden_size 1024 \
|
| 21 |
+
--speech_encoder_type whisper \
|
| 22 |
+
--tune_speech_generator_only True \
|
| 23 |
+
--bf16 True \
|
| 24 |
+
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc-mixed-vfva-text_speech-decoder \
|
| 25 |
+
--num_train_epochs 16 \
|
| 26 |
+
--per_device_train_batch_size 1 \
|
| 27 |
+
--per_device_eval_batch_size 1 \
|
| 28 |
+
--gradient_accumulation_steps 4 \
|
| 29 |
+
--evaluation_strategy "no" \
|
| 30 |
+
--save_strategy "no" \
|
| 31 |
+
--save_steps 3000 \
|
| 32 |
+
--save_total_limit 1 \
|
| 33 |
+
--learning_rate 2e-4 \
|
| 34 |
+
--max_grad_norm 200. \
|
| 35 |
+
--weight_decay 0. \
|
| 36 |
+
--warmup_ratio 0.03 \
|
| 37 |
+
--logging_steps 1 \
|
| 38 |
+
--tf32 True \
|
| 39 |
+
--model_max_length 4096 \
|
| 40 |
+
--gradient_checkpointing True \
|
| 41 |
+
--dataloader_num_workers 8 \
|
| 42 |
+
--has_tgt_units True
|
scripts/minicpmp_config.json
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"batch_vision_input": true,
|
| 3 |
+
"drop_vision_last_layer": false,
|
| 4 |
+
"image_size": 448,
|
| 5 |
+
"audio_chunk_length": 1.0,
|
| 6 |
+
"audio_config": {
|
| 7 |
+
"_name_or_path": "openai/whisper-medium",
|
| 8 |
+
"architectures": [
|
| 9 |
+
"MiniCPMWhisperEncoder"
|
| 10 |
+
],
|
| 11 |
+
"begin_suppress_tokens": [
|
| 12 |
+
220,
|
| 13 |
+
50257
|
| 14 |
+
],
|
| 15 |
+
"bos_token_id": 50257,
|
| 16 |
+
"d_model": 1024,
|
| 17 |
+
"decoder_attention_heads": 16,
|
| 18 |
+
"decoder_ffn_dim": 4096,
|
| 19 |
+
"decoder_layers": 24,
|
| 20 |
+
"decoder_start_token_id": 50258,
|
| 21 |
+
"encoder_attention_heads": 16,
|
| 22 |
+
"encoder_ffn_dim": 4096,
|
| 23 |
+
"encoder_layers": 24,
|
| 24 |
+
"eos_token_id": 50257,
|
| 25 |
+
"forced_decoder_ids": [
|
| 26 |
+
[
|
| 27 |
+
1,
|
| 28 |
+
50259
|
| 29 |
+
],
|
| 30 |
+
[
|
| 31 |
+
2,
|
| 32 |
+
50359
|
| 33 |
+
],
|
| 34 |
+
[
|
| 35 |
+
3,
|
| 36 |
+
50363
|
| 37 |
+
]
|
| 38 |
+
],
|
| 39 |
+
"max_length": 448,
|
| 40 |
+
"model_type": "whisper",
|
| 41 |
+
"num_hidden_layers": 24,
|
| 42 |
+
"pad_token_id": 50257,
|
| 43 |
+
"suppress_tokens": [
|
| 44 |
+
1,
|
| 45 |
+
2,
|
| 46 |
+
7,
|
| 47 |
+
8,
|
| 48 |
+
9,
|
| 49 |
+
10,
|
| 50 |
+
14,
|
| 51 |
+
25,
|
| 52 |
+
26,
|
| 53 |
+
27,
|
| 54 |
+
28,
|
| 55 |
+
29,
|
| 56 |
+
31,
|
| 57 |
+
58,
|
| 58 |
+
59,
|
| 59 |
+
60,
|
| 60 |
+
61,
|
| 61 |
+
62,
|
| 62 |
+
63,
|
| 63 |
+
90,
|
| 64 |
+
91,
|
| 65 |
+
92,
|
| 66 |
+
93,
|
| 67 |
+
359,
|
| 68 |
+
503,
|
| 69 |
+
522,
|
| 70 |
+
542,
|
| 71 |
+
873,
|
| 72 |
+
893,
|
| 73 |
+
902,
|
| 74 |
+
918,
|
| 75 |
+
922,
|
| 76 |
+
931,
|
| 77 |
+
1350,
|
| 78 |
+
1853,
|
| 79 |
+
1982,
|
| 80 |
+
2460,
|
| 81 |
+
2627,
|
| 82 |
+
3246,
|
| 83 |
+
3253,
|
| 84 |
+
3268,
|
| 85 |
+
3536,
|
| 86 |
+
3846,
|
| 87 |
+
3961,
|
| 88 |
+
4183,
|
| 89 |
+
4667,
|
| 90 |
+
6585,
|
| 91 |
+
6647,
|
| 92 |
+
7273,
|
| 93 |
+
9061,
|
| 94 |
+
9383,
|
| 95 |
+
10428,
|
| 96 |
+
10929,
|
| 97 |
+
11938,
|
| 98 |
+
12033,
|
| 99 |
+
12331,
|
| 100 |
+
12562,
|
| 101 |
+
13793,
|
| 102 |
+
14157,
|
| 103 |
+
14635,
|
| 104 |
+
15265,
|
| 105 |
+
15618,
|
| 106 |
+
16553,
|
| 107 |
+
16604,
|
| 108 |
+
18362,
|
| 109 |
+
18956,
|
| 110 |
+
20075,
|
| 111 |
+
21675,
|
| 112 |
+
22520,
|
| 113 |
+
26130,
|
| 114 |
+
26161,
|
| 115 |
+
26435,
|
| 116 |
+
28279,
|
| 117 |
+
29464,
|
| 118 |
+
31650,
|
| 119 |
+
32302,
|
| 120 |
+
32470,
|
| 121 |
+
36865,
|
| 122 |
+
42863,
|
| 123 |
+
47425,
|
| 124 |
+
49870,
|
| 125 |
+
50254,
|
| 126 |
+
50258,
|
| 127 |
+
50358,
|
| 128 |
+
50359,
|
| 129 |
+
50360,
|
| 130 |
+
50361,
|
| 131 |
+
50362
|
| 132 |
+
],
|
| 133 |
+
"torch_dtype": "float32"
|
| 134 |
+
},
|
| 135 |
+
"audio_pool_step": 2,
|
| 136 |
+
"chunk_input": true,
|
| 137 |
+
"model_type": "minicpmo",
|
| 138 |
+
"patch_size": 14,
|
| 139 |
+
"query_num": 64,
|
| 140 |
+
"slice_config": {
|
| 141 |
+
"max_slice_nums": 9,
|
| 142 |
+
"model_type": "minicpmv"
|
| 143 |
+
},
|
| 144 |
+
"slice_mode": true,
|
| 145 |
+
"torch_dtype": "bfloat16",
|
| 146 |
+
"transformers_version": "4.44.2",
|
| 147 |
+
"tts_config": {
|
| 148 |
+
"model_type": "conditional_chattts",
|
| 149 |
+
"llm_dim": 3584
|
| 150 |
+
},
|
| 151 |
+
"use_cache": false,
|
| 152 |
+
"use_image_id": true,
|
| 153 |
+
"vision_batch_size": 16,
|
| 154 |
+
"vision_config": {
|
| 155 |
+
"hidden_size": 1152,
|
| 156 |
+
"image_size": 980,
|
| 157 |
+
"intermediate_size": 4304,
|
| 158 |
+
"model_type": "siglip_vision_model",
|
| 159 |
+
"num_attention_heads": 16,
|
| 160 |
+
"num_hidden_layers": 27,
|
| 161 |
+
"patch_size": 14
|
| 162 |
+
}
|
| 163 |
+
}
|
scripts/pretrain_minicpmo_test.sh
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPUS_PER_NODE=8
|
| 4 |
+
# NNODES=1
|
| 5 |
+
# NODE_RANK=0
|
| 6 |
+
# MASTER_ADDR=localhost
|
| 7 |
+
# MASTER_PORT=6001
|
| 8 |
+
|
| 9 |
+
# MODEL="/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct"
|
| 10 |
+
PRETRAINED_LLM="/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct"
|
| 11 |
+
MODEL="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
|
| 12 |
+
# PRETRAINED_LLM="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
|
| 13 |
+
TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 14 |
+
AUDIO_ENCODER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 15 |
+
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
| 16 |
+
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
| 17 |
+
# See the section for finetuning in README for more information.
|
| 18 |
+
# DATA="/data1/speech/anhnmt2/cuongnm/datasets/asr/train_asr_mixed_balanced_1M5_train.json "
|
| 19 |
+
# EVAL_DATA="/data1/speech/anhnmt2/cuongnm/datasets/asr/train_asr_mixed_balanced_1M5_dev.json "
|
| 20 |
+
# DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/train_asr_eng_100000_new_dataloader.jsonl"
|
| 21 |
+
# EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/dev_asr_eng_1000_new_dataloader.jsonl"
|
| 22 |
+
DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl"
|
| 23 |
+
EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl"
|
| 24 |
+
CONFIG_PATH="minicpmp_config.json"
|
| 25 |
+
AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
|
| 26 |
+
|
| 27 |
+
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
| 28 |
+
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
| 29 |
+
LLM_TYPE="qwen"
|
| 30 |
+
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
| 31 |
+
CACHE_DIR="../output/cached_sft_20252502"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# DISTRIBUTED_ARGS="
|
| 35 |
+
# --nproc_per_node $GPUS_PER_NODE \
|
| 36 |
+
# --nnodes $NNODES \
|
| 37 |
+
# --node_rank $NODE_RANK \
|
| 38 |
+
# --master_addr $MASTER_ADDR \
|
| 39 |
+
# --master_port $MASTER_PORT
|
| 40 |
+
# "
|
| 41 |
+
DEEPSPEED_CMD="/home/anhnmt2/.local/bin/deepspeed"
|
| 42 |
+
|
| 43 |
+
# Kiểm tra file thực thi DeepSpeed
|
| 44 |
+
if [ ! -x "$DEEPSPEED_CMD" ]; then
|
| 45 |
+
echo "Error: DeepSpeed executable not found at $DEEPSPEED_CMD."
|
| 46 |
+
echo "Try reinstalling with: pip install deepspeed"
|
| 47 |
+
exit 1
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
CUDA_LAUNCH_BLOCKING=1 "$DEEPSPEED_CMD" --master_port 29501 ../omni_speech/train/train_minicpmo_test.py \
|
| 52 |
+
--deepspeed zero2.json \
|
| 53 |
+
--model_name_or_path $MODEL \
|
| 54 |
+
--pretrained_llm_path $PRETRAINED_LLM \
|
| 55 |
+
--tokenizer_path $TOKENIZER_PATH \
|
| 56 |
+
--cache_dir $CACHE_DIR \
|
| 57 |
+
--audio_encoder_path $AUDIO_ENCODER_PATH \
|
| 58 |
+
--llm_type $LLM_TYPE \
|
| 59 |
+
--data_path $DATA \
|
| 60 |
+
--eval_data_path $EVAL_DATA \
|
| 61 |
+
--config_path $CONFIG_PATH \
|
| 62 |
+
--remove_unused_columns false \
|
| 63 |
+
--prediction_loss_only false \
|
| 64 |
+
--bf16 true \
|
| 65 |
+
--do_train \
|
| 66 |
+
--do_eval \
|
| 67 |
+
--tune_speech false \
|
| 68 |
+
--tune_llm false \
|
| 69 |
+
--model_max_length $MODEL_MAX_Length \
|
| 70 |
+
--eval_steps 3000 \
|
| 71 |
+
--output_dir ../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector \
|
| 72 |
+
--num_train_epochs 3 \
|
| 73 |
+
--logging_strategy "steps" \
|
| 74 |
+
--per_device_train_batch_size 8 \
|
| 75 |
+
--per_device_eval_batch_size 8 \
|
| 76 |
+
--gradient_accumulation_steps 4 \
|
| 77 |
+
--evaluation_strategy "steps" \
|
| 78 |
+
--save_strategy "steps" \
|
| 79 |
+
--save_steps 5000 \
|
| 80 |
+
--save_total_limit 1 \
|
| 81 |
+
--learning_rate 5e-5 \
|
| 82 |
+
--weight_decay 0. \
|
| 83 |
+
--warmup_ratio 0.03 \
|
| 84 |
+
--lr_scheduler_type "cosine" \
|
| 85 |
+
--logging_steps 1 \
|
| 86 |
+
--tf32 true \
|
| 87 |
+
--gradient_checkpointing true
|
| 88 |
+
# --augment_prob 0.2 \
|
| 89 |
+
# --augment_path $AUGMENT_PATH
|
scripts/pretrained.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
PROMPT_VERSION=qwen
|
| 6 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/train_asr_eng_5M.jsonl
|
| 7 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/dev_asr_libri_spgi.jsonl
|
| 8 |
+
CACHE_DIR="../output/cached_asr_full"
|
| 9 |
+
AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
|
| 10 |
+
|
| 11 |
+
deepspeed ../omni_speech/train/train_mem.py \
|
| 12 |
+
--deepspeed zero2.json \
|
| 13 |
+
--model_name_or_path $MODEL_PATH \
|
| 14 |
+
--version $PROMPT_VERSION \
|
| 15 |
+
--data_path $DATA_PATH \
|
| 16 |
+
--dev_path $DEV_PATH \
|
| 17 |
+
--cache_dir $CACHE_DIR \
|
| 18 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 19 |
+
--mel_size 80 \
|
| 20 |
+
--speech_encoder_hidden_size 1024 \
|
| 21 |
+
--speech_encoder_type whisper \
|
| 22 |
+
--bf16 True \
|
| 23 |
+
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr-5M \
|
| 24 |
+
--num_train_epochs 4 \
|
| 25 |
+
--tune_speech_projector True \
|
| 26 |
+
--per_device_train_batch_size 16 \
|
| 27 |
+
--per_device_eval_batch_size 4 \
|
| 28 |
+
--gradient_accumulation_steps 2 \
|
| 29 |
+
--evaluation_strategy "steps" \
|
| 30 |
+
--save_strategy "steps" \
|
| 31 |
+
--eval_steps 2000 \
|
| 32 |
+
--save_steps 2000 \
|
| 33 |
+
--save_total_limit 1 \
|
| 34 |
+
--learning_rate 1e-3 \
|
| 35 |
+
--weight_decay 0. \
|
| 36 |
+
--warmup_ratio 0.03 \
|
| 37 |
+
--lr_scheduler_type "cosine" \
|
| 38 |
+
--logging_steps 1 \
|
| 39 |
+
--tf32 True \
|
| 40 |
+
--model_max_length 4096 \
|
| 41 |
+
--gradient_checkpointing True \
|
| 42 |
+
--dataloader_num_workers 8
|
| 43 |
+
# --augment_prob 0.2 \
|
| 44 |
+
# --augment_path $AUGMENT_PATH \
|
scripts/pretrained_minicpmo.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPUS_PER_NODE=8
|
| 4 |
+
# NNODES=1
|
| 5 |
+
# NODE_RANK=0
|
| 6 |
+
# MASTER_ADDR=localhost
|
| 7 |
+
# MASTER_PORT=6001
|
| 8 |
+
|
| 9 |
+
MODEL="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
|
| 10 |
+
PRETRAINED_LLM="/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct"
|
| 11 |
+
TOKENIZER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 12 |
+
AUDIO_ENCODER_PATH="/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6"
|
| 13 |
+
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
| 14 |
+
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
| 15 |
+
# See the section for finetuning in README for more information.
|
| 16 |
+
DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/train_asr_eng_100000_new_dataloader.jsonl"
|
| 17 |
+
EVAL_DATA="/data1/speech/anhnmt2/dataset/s2s/english/minicpmo/dev_asr_eng_1000_new_dataloader.jsonl"
|
| 18 |
+
CONFIG_PATH="minicpmp_config.json"
|
| 19 |
+
AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt"
|
| 20 |
+
|
| 21 |
+
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
| 22 |
+
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
| 23 |
+
LLM_TYPE="qwen"
|
| 24 |
+
MODEL_MAX_Length=4096 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
| 25 |
+
CACHE_DIR="../output/cached_sft_20252502"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# DISTRIBUTED_ARGS="
|
| 29 |
+
# --nproc_per_node $GPUS_PER_NODE \
|
| 30 |
+
# --nnodes $NNODES \
|
| 31 |
+
# --node_rank $NODE_RANK \
|
| 32 |
+
# --master_addr $MASTER_ADDR \
|
| 33 |
+
# --master_port $MASTER_PORT
|
| 34 |
+
# "
|
| 35 |
+
|
| 36 |
+
deepspeed --master_port 29501 ../omni_speech/train/train_minicpmo.py \
|
| 37 |
+
--deepspeed zero2.json \
|
| 38 |
+
--model_name_or_path $MODEL \
|
| 39 |
+
--pretrained_llm_path $PRETRAINED_LLM \
|
| 40 |
+
--tokenizer_path $TOKENIZER_PATH \
|
| 41 |
+
--cache_dir $CACHE_DIR \
|
| 42 |
+
--audio_encoder_path $AUDIO_ENCODER_PATH \
|
| 43 |
+
--llm_type $LLM_TYPE \
|
| 44 |
+
--data_path $DATA \
|
| 45 |
+
--eval_data_path $EVAL_DATA \
|
| 46 |
+
--config_path $CONFIG_PATH \
|
| 47 |
+
--remove_unused_columns false \
|
| 48 |
+
--prediction_loss_only false \
|
| 49 |
+
--bf16 true \
|
| 50 |
+
--do_train \
|
| 51 |
+
--do_eval \
|
| 52 |
+
--tune_speech true \
|
| 53 |
+
--tune_llm false \
|
| 54 |
+
--model_max_length $MODEL_MAX_Length \
|
| 55 |
+
--eval_steps 1000 \
|
| 56 |
+
--output_dir ../checkpoints/minicpmo_whisper-medium_Qwen2.5-3B_pretrained-asr \
|
| 57 |
+
--num_train_epochs 1 \
|
| 58 |
+
--logging_strategy "steps" \
|
| 59 |
+
--per_device_train_batch_size 1 \
|
| 60 |
+
--per_device_eval_batch_size 1 \
|
| 61 |
+
--gradient_accumulation_steps 4 \
|
| 62 |
+
--evaluation_strategy "steps" \
|
| 63 |
+
--save_strategy "no" \
|
| 64 |
+
--save_steps 2000 \
|
| 65 |
+
--save_total_limit 1 \
|
| 66 |
+
--learning_rate 2e-4 \
|
| 67 |
+
--weight_decay 0. \
|
| 68 |
+
--warmup_ratio 0.03 \
|
| 69 |
+
--lr_scheduler_type "cosine" \
|
| 70 |
+
--logging_steps 1 \
|
| 71 |
+
--tf32 true \
|
| 72 |
+
--gradient_checkpointing true
|
| 73 |
+
# --augment_prob 0.2 \
|
| 74 |
+
# --augment_path $AUGMENT_PATH
|
scripts/test_llama.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Llama-3.1-8B-Instruct
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
PROMPT_VERSION=llama_3
|
| 6 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/new/train_asr_eng_50000.jsonl
|
| 7 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000.jsonl
|
| 8 |
+
CACHE_DIR="../output/cached_asr"
|
| 9 |
+
|
| 10 |
+
deepspeed ../omni_speech/train/train.py \
|
| 11 |
+
--deepspeed zero2.json \
|
| 12 |
+
--model_name_or_path $MODEL_PATH \
|
| 13 |
+
--version $PROMPT_VERSION \
|
| 14 |
+
--data_path $DATA_PATH \
|
| 15 |
+
--dev_path $DEV_PATH \
|
| 16 |
+
--cache_dir $CACHE_DIR \
|
| 17 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 18 |
+
--mel_size 80 \
|
| 19 |
+
--speech_encoder_hidden_size 1024 \
|
| 20 |
+
--speech_encoder_type whisper \
|
| 21 |
+
--bf16 True \
|
| 22 |
+
--output_dir ../checkpoints/llama-omni-pretrained-asr-test \
|
| 23 |
+
--num_train_epochs 10 \
|
| 24 |
+
--tune_speech_projector True \
|
| 25 |
+
--per_device_train_batch_size 4 \
|
| 26 |
+
--per_device_eval_batch_size 2 \
|
| 27 |
+
--gradient_accumulation_steps 4 \
|
| 28 |
+
--evaluation_strategy "steps" \
|
| 29 |
+
--save_strategy "steps" \
|
| 30 |
+
--eval_steps 2000 \
|
| 31 |
+
--save_steps 2000 \
|
| 32 |
+
--save_total_limit 1 \
|
| 33 |
+
--learning_rate 1e-3 \
|
| 34 |
+
--optim adamw_torch \
|
| 35 |
+
--weight_decay 0. \
|
| 36 |
+
--warmup_ratio 0.03 \
|
| 37 |
+
--logging_steps 1 \
|
| 38 |
+
--tf32 True \
|
| 39 |
+
--model_max_length 2048 \
|
| 40 |
+
--gradient_checkpointing True \
|
| 41 |
+
--dataloader_num_workers 8
|
scripts/test_qwen.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-1.5B-Instruct
|
| 4 |
+
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
|
| 5 |
+
PROMPT_VERSION=qwen
|
| 6 |
+
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000_multiturn.jsonl
|
| 7 |
+
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000_multiturn.jsonl
|
| 8 |
+
CACHE_DIR="../output/cached_asr"
|
| 9 |
+
|
| 10 |
+
deepspeed ../omni_speech/train/train_multiturn.py \
|
| 11 |
+
--deepspeed zero2.json \
|
| 12 |
+
--model_name_or_path $MODEL_PATH \
|
| 13 |
+
--version $PROMPT_VERSION \
|
| 14 |
+
--data_path $DATA_PATH \
|
| 15 |
+
--dev_path $DEV_PATH \
|
| 16 |
+
--cache_dir $CACHE_DIR \
|
| 17 |
+
--speech_encoder $SPEECH_ENCODER \
|
| 18 |
+
--mel_size 80 \
|
| 19 |
+
--speech_encoder_hidden_size 1024 \
|
| 20 |
+
--speech_encoder_type whisper \
|
| 21 |
+
--bf16 True \
|
| 22 |
+
--output_dir ../checkpoints/llama-omni-pretrained-asr-qwen \
|
| 23 |
+
--num_train_epochs 10 \
|
| 24 |
+
--tune_speech_projector True \
|
| 25 |
+
--per_device_train_batch_size 4 \
|
| 26 |
+
--per_device_eval_batch_size 2 \
|
| 27 |
+
--gradient_accumulation_steps 4 \
|
| 28 |
+
--evaluation_strategy "steps" \
|
| 29 |
+
--save_strategy "steps" \
|
| 30 |
+
--eval_steps 2000 \
|
| 31 |
+
--save_steps 2000 \
|
| 32 |
+
--save_total_limit 1 \
|
| 33 |
+
--learning_rate 1e-3 \
|
| 34 |
+
--optim adamw_torch \
|
| 35 |
+
--weight_decay 0. \
|
| 36 |
+
--warmup_ratio 0.03 \
|
| 37 |
+
--logging_steps 1 \
|
| 38 |
+
--tf32 True \
|
| 39 |
+
--model_max_length 2048 \
|
| 40 |
+
--gradient_checkpointing True \
|
| 41 |
+
--dataloader_num_workers 8
|
scripts/wandb/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-04-10T17:19:28.842729448+07:00","level":"INFO","msg":"stream: starting","core version":"0.19.8","symlink path":"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-04-10T17:19:28.960322418+07:00","level":"INFO","msg":"created new stream","id":"pfaibe0c"}
|
| 3 |
+
{"time":"2025-04-10T17:19:28.960351593+07:00","level":"INFO","msg":"stream: started","id":"pfaibe0c"}
|
| 4 |
+
{"time":"2025-04-10T17:19:28.960375959+07:00","level":"INFO","msg":"writer: Do: started","stream_id":"pfaibe0c"}
|
| 5 |
+
{"time":"2025-04-10T17:19:28.960456552+07:00","level":"INFO","msg":"handler: started","stream_id":"pfaibe0c"}
|
| 6 |
+
{"time":"2025-04-10T17:19:28.961574927+07:00","level":"INFO","msg":"sender: started","stream_id":"pfaibe0c"}
|
| 7 |
+
{"time":"2025-04-10T17:19:29.497777718+07:00","level":"INFO","msg":"Starting system monitor"}
|
scripts/wandb/debug.log
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Current SDK version is 0.19.8
|
| 2 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Configure stats pid to 1734298
|
| 3 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from /home/anhnmt2/.config/wandb/settings
|
| 4 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/settings
|
| 5 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_setup.py:_flush():67] Loading settings from environment variables
|
| 6 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:setup_run_log_directory():647] Logging user logs to /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug.log
|
| 7 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:setup_run_log_directory():648] Logging internal logs to /data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-internal.log
|
| 8 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():761] calling init triggers
|
| 9 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():766] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():784] starting backend
|
| 12 |
+
2025-04-10 17:19:28,830 INFO MainThread:1734298 [wandb_init.py:init():788] sending inform_init request
|
| 13 |
+
2025-04-10 17:19:28,834 INFO MainThread:1734298 [backend.py:_multiprocessing_setup():101] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
| 14 |
+
2025-04-10 17:19:28,834 INFO MainThread:1734298 [wandb_init.py:init():798] backend started and connected
|
| 15 |
+
2025-04-10 17:19:28,836 INFO MainThread:1734298 [wandb_init.py:init():891] updated telemetry
|
| 16 |
+
2025-04-10 17:19:28,852 INFO MainThread:1734298 [wandb_init.py:init():915] communicating run to backend with 90.0 second timeout
|
| 17 |
+
2025-04-10 17:19:29,493 INFO MainThread:1734298 [wandb_init.py:init():990] starting run threads in backend
|
| 18 |
+
2025-04-10 17:19:29,890 INFO MainThread:1734298 [wandb_run.py:_console_start():2375] atexit reg
|
| 19 |
+
2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2227] redirect: wrap_raw
|
| 20 |
+
2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2292] Wrapping output streams.
|
| 21 |
+
2025-04-10 17:19:29,891 INFO MainThread:1734298 [wandb_run.py:_redirect():2315] Redirects installed.
|
| 22 |
+
2025-04-10 17:19:29,895 INFO MainThread:1734298 [wandb_init.py:init():1032] run started, returning control to user process
|
| 23 |
+
2025-04-10 17:19:29,898 INFO MainThread:1734298 [wandb_run.py:_config_callback():1261] config_cb None None {'use_cache': False, 'query_num': 64, 'image_size': 448, 'drop_vision_last_layer': False, 'batch_vision_input': True, 'use_image_id': True, 'vision_batch_size': 16, 'audio_pool_step': 2, 'audio_chunk_length': 1.0, 'stream_input': False, 'init_vision': False, 'init_audio': True, 'init_tts': False, 'processor_path': '/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6', 'pretrained_encoder_path': '/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6', 'pretrained_llm_path': '/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct', 'chunk_input': True, 'slice_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'minicpmv', 'patch_size': 14, 'max_slice_nums': 9, 'scale_resolution': 448}, 'slice_mode': True, 'vision_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'siglip_vision_model', 'hidden_size': 1152, 'intermediate_size': 4304, 'num_hidden_layers': 27, 'num_attention_heads': 16, 'num_channels': 3, 'patch_size': 14, 'image_size': 980, 'attention_dropout': 0.0, 'layer_norm_eps': 1e-06, 'hidden_act': 'gelu_pytorch_tanh'}, 'audio_config': {'vocab_size': 51865, 'num_mel_bins': 80, 'd_model': 1024, 'encoder_layers': 24, 'encoder_attention_heads': 16, 'decoder_layers': 24, 'decoder_attention_heads': 16, 'decoder_ffn_dim': 4096, 'encoder_ffn_dim': 4096, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.0, 'activation_function': 'gelu', 'init_std': 0.02, 'encoder_layerdrop': 0.0, 'decoder_layerdrop': 0.0, 'use_cache': True, 'num_hidden_layers': 24, 'scale_embedding': False, 'max_source_positions': 1500, 'max_target_positions': 448, 'classifier_proj_size': 256, 'use_weighted_layer_sum': False, 'apply_spec_augment': False, 'mask_time_prob': 0.05, 'mask_time_length': 10, 'mask_time_min_masks': 2, 'mask_feature_prob': 0.0, 'mask_feature_length': 10, 'mask_feature_min_masks': 0, 'median_filter_width': 7, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': True, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 448, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257], 'architectures': ['MiniCPMWhisperEncoder'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 50257, 'pad_token_id': 50257, 'eos_token_id': 50257, 'sep_token_id': None, 'decoder_start_token_id': 50258, 'task_specific_params': None, 'problem_type': None, '_name_or_path': 'openai/whisper-medium', 'forced_decoder_ids': [[1, 50259], [2, 50359], [3, 50363]], 'model_type': 'whisper'}, 'tts_config': {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': True, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 20, 'top_p': 0.7, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': None, 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'model_type': 'conditional_chattts', 'llm_dim': 3584, 'hidden_size': 768, 'intermediate_size': 3072, 'num_attention_heads': 12, 'num_hidden_layers': 20, 'max_position_embeddings': 4096, 'num_audio_tokens': 626, 'num_text_tokens': 21178, 'num_mel_bins': 100, 'num_vq': 4, 'use_speaker_embedding': True, 'use_llm_hidden_state': False, 'spk_emb_token_id': 21143, 'num_spk_embs': 1, 'audio_bos_token_id': 21132, 'text_eos_token_id': 21133, 'use_text': True, 'streaming': True, 'streaming_text_chunk_size': 10, 'streaming_text_reserved_len': 300, 'streaming_audio_chunk_size': 50, 'attn_implementation': 'sdpa', 'use_mlp': True, 'aug_loss_weight': True}, 'patch_size': 14, 'vocab_size': 152064, 'max_position_embeddings': 32768, 'hidden_size': 3584, 'intermediate_size': 18944, 'num_hidden_layers': 28, 'num_attention_heads': 28, 'use_sliding_window': False, 'sliding_window': None, 'max_window_layers': 28, 'num_key_value_heads': 4, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'rope_theta': 1000000.0, 'rope_scaling': None, 'attention_dropout': 0.0, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 2048, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['Qwen2ForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 151643, 'pad_token_id': None, 'eos_token_id': 151645, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct', 'transformers_version': '4.45.0', 'model_type': 'minicpmo', 'output_dir': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector', 'overwrite_output_dir': False, 'do_train': True, 'do_eval': True, 'do_predict': False, 'eval_strategy': 'steps', 'prediction_loss_only': False, 'per_device_train_batch_size': 8, 'per_device_eval_batch_size': 8, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 4, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 3.0, 'max_steps': -1, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.03, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector/runs/Apr10_17-18-52_dgx-a100-5', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 1.0, 'logging_nan_inf_filter': True, 'save_strategy': 'steps', 'save_steps': 5000, 'save_total_limit': 1, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': True, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': True, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 3000, 'dataloader_num_workers': 0, 'dataloader_prefetch_factor': None, 'past_index': -1, 'run_name': '../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': 'zero2.json', 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'ddp_bucket_cap_mb': None, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': True, 'dataloader_persistent_workers': False, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': False, 'hub_always_push': False, 'gradient_checkpointing': True, 'gradient_checkpointing_kwargs': {'use_reentrant': False}, 'include_inputs_for_metrics': False, 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'evaluation_strategy': 'steps', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'dispatch_batches': None, 'split_batches': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'cache_dir': '../output/cached_sft_20252502', 'model_max_length': 2048, 'tune_vision': True, 'tune_speech': False, 'tune_llm': False, 'llm_type': 'qwen', 'use_lora': False, 'max_slice_nums': 9, 'config_path': 'minicpmp_config.json', 'init_speech': True}
|
| 24 |
+
2025-04-10 17:19:29,901 INFO MainThread:1734298 [wandb_config.py:__setitem__():154] config set model/num_parameters = 802971264 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x1553c5eda240>>
|
| 25 |
+
2025-04-10 17:19:29,901 INFO MainThread:1734298 [wandb_run.py:_config_callback():1261] config_cb model/num_parameters 802971264 None
|
scripts/wandb/latest-run/files/output.log
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0%| | 0/43233 [00:00<?, ?it/s]DEBUG:numba.core.byteflow:bytecode dump:
|
| 2 |
+
> 0 NOP(arg=None, lineno=1141)
|
| 3 |
+
2 RESUME(arg=0, lineno=1141)
|
| 4 |
+
4 LOAD_FAST(arg=0, lineno=1144)
|
| 5 |
+
6 LOAD_CONST(arg=1, lineno=1144)
|
| 6 |
+
8 BINARY_SUBSCR(arg=None, lineno=1144)
|
| 7 |
+
12 STORE_FAST(arg=3, lineno=1144)
|
| 8 |
+
14 LOAD_FAST(arg=1, lineno=1145)
|
| 9 |
+
16 UNARY_NEGATIVE(arg=None, lineno=1145)
|
| 10 |
+
18 LOAD_FAST(arg=3, lineno=1145)
|
| 11 |
+
20 SWAP(arg=2, lineno=1145)
|
| 12 |
+
22 COPY(arg=2, lineno=1145)
|
| 13 |
+
24 COMPARE_OP(arg=26, lineno=1145)
|
| 14 |
+
28 POP_JUMP_IF_FALSE(arg=5, lineno=1145)
|
| 15 |
+
30 LOAD_FAST(arg=1, lineno=1145)
|
| 16 |
+
32 COMPARE_OP(arg=26, lineno=1145)
|
| 17 |
+
36 POP_JUMP_IF_FALSE(arg=5, lineno=1145)
|
| 18 |
+
38 JUMP_FORWARD(arg=2, lineno=1145)
|
| 19 |
+
> 40 POP_TOP(arg=None, lineno=1145)
|
| 20 |
+
42 JUMP_FORWARD(arg=2, lineno=1145)
|
| 21 |
+
> 44 LOAD_CONST(arg=1, lineno=1146)
|
| 22 |
+
46 STORE_FAST(arg=3, lineno=1146)
|
| 23 |
+
> 48 LOAD_FAST(arg=0, lineno=1148)
|
| 24 |
+
50 LOAD_CONST(arg=2, lineno=1148)
|
| 25 |
+
52 BINARY_SUBSCR(arg=None, lineno=1148)
|
| 26 |
+
56 STORE_FAST(arg=4, lineno=1148)
|
| 27 |
+
58 LOAD_FAST(arg=1, lineno=1149)
|
| 28 |
+
60 UNARY_NEGATIVE(arg=None, lineno=1149)
|
| 29 |
+
62 LOAD_FAST(arg=4, lineno=1149)
|
| 30 |
+
64 SWAP(arg=2, lineno=1149)
|
| 31 |
+
66 COPY(arg=2, lineno=1149)
|
| 32 |
+
68 COMPARE_OP(arg=26, lineno=1149)
|
| 33 |
+
72 POP_JUMP_IF_FALSE(arg=5, lineno=1149)
|
| 34 |
+
74 LOAD_FAST(arg=1, lineno=1149)
|
| 35 |
+
76 COMPARE_OP(arg=26, lineno=1149)
|
| 36 |
+
80 POP_JUMP_IF_FALSE(arg=5, lineno=1149)
|
| 37 |
+
82 JUMP_FORWARD(arg=2, lineno=1149)
|
| 38 |
+
> 84 POP_TOP(arg=None, lineno=1149)
|
| 39 |
+
86 JUMP_FORWARD(arg=2, lineno=1149)
|
| 40 |
+
> 88 LOAD_CONST(arg=1, lineno=1150)
|
| 41 |
+
90 STORE_FAST(arg=4, lineno=1150)
|
| 42 |
+
> 92 LOAD_FAST(arg=2, lineno=1152)
|
| 43 |
+
94 POP_JUMP_IF_FALSE(arg=43, lineno=1152)
|
| 44 |
+
96 LOAD_GLOBAL(arg=1, lineno=1153)
|
| 45 |
+
106 LOAD_ATTR(arg=2, lineno=1153)
|
| 46 |
+
126 LOAD_FAST(arg=3, lineno=1153)
|
| 47 |
+
128 CALL(arg=1, lineno=1153)
|
| 48 |
+
136 LOAD_GLOBAL(arg=1, lineno=1153)
|
| 49 |
+
146 LOAD_ATTR(arg=2, lineno=1153)
|
| 50 |
+
166 LOAD_FAST(arg=4, lineno=1153)
|
| 51 |
+
168 CALL(arg=1, lineno=1153)
|
| 52 |
+
176 COMPARE_OP(arg=55, lineno=1153)
|
| 53 |
+
180 RETURN_VALUE(arg=None, lineno=1153)
|
| 54 |
+
> 182 LOAD_GLOBAL(arg=1, lineno=1155)
|
| 55 |
+
192 LOAD_ATTR(arg=4, lineno=1155)
|
| 56 |
+
212 LOAD_FAST(arg=3, lineno=1155)
|
| 57 |
+
214 CALL(arg=1, lineno=1155)
|
| 58 |
+
222 LOAD_GLOBAL(arg=1, lineno=1155)
|
| 59 |
+
232 LOAD_ATTR(arg=4, lineno=1155)
|
| 60 |
+
252 LOAD_FAST(arg=4, lineno=1155)
|
| 61 |
+
254 CALL(arg=1, lineno=1155)
|
| 62 |
+
262 COMPARE_OP(arg=55, lineno=1155)
|
| 63 |
+
266 RETURN_VALUE(arg=None, lineno=1155)
|
| 64 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
|
| 65 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 66 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
|
| 67 |
+
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1141)
|
| 68 |
+
DEBUG:numba.core.byteflow:stack []
|
| 69 |
+
DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1141)
|
| 70 |
+
DEBUG:numba.core.byteflow:stack []
|
| 71 |
+
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1144)
|
| 72 |
+
DEBUG:numba.core.byteflow:stack []
|
| 73 |
+
DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1144)
|
| 74 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0']
|
| 75 |
+
DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1144)
|
| 76 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
|
| 77 |
+
DEBUG:numba.core.byteflow:dispatch pc=12, inst=STORE_FAST(arg=3, lineno=1144)
|
| 78 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
|
| 79 |
+
DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_FAST(arg=1, lineno=1145)
|
| 80 |
+
DEBUG:numba.core.byteflow:stack []
|
| 81 |
+
DEBUG:numba.core.byteflow:dispatch pc=16, inst=UNARY_NEGATIVE(arg=None, lineno=1145)
|
| 82 |
+
DEBUG:numba.core.byteflow:stack ['$threshold14.3']
|
| 83 |
+
DEBUG:numba.core.byteflow:dispatch pc=18, inst=LOAD_FAST(arg=3, lineno=1145)
|
| 84 |
+
DEBUG:numba.core.byteflow:stack ['$16unary_negative.4']
|
| 85 |
+
DEBUG:numba.core.byteflow:dispatch pc=20, inst=SWAP(arg=2, lineno=1145)
|
| 86 |
+
DEBUG:numba.core.byteflow:stack ['$16unary_negative.4', '$x018.5']
|
| 87 |
+
DEBUG:numba.core.byteflow:dispatch pc=22, inst=COPY(arg=2, lineno=1145)
|
| 88 |
+
DEBUG:numba.core.byteflow:stack ['$x018.5', '$16unary_negative.4']
|
| 89 |
+
DEBUG:numba.core.byteflow:dispatch pc=24, inst=COMPARE_OP(arg=26, lineno=1145)
|
| 90 |
+
DEBUG:numba.core.byteflow:stack ['$x018.5', '$16unary_negative.4', '$x018.5']
|
| 91 |
+
DEBUG:numba.core.byteflow:dispatch pc=28, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1145)
|
| 92 |
+
DEBUG:numba.core.byteflow:stack ['$x018.5', '$24compare_op.6']
|
| 93 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=30, stack=('$x018.5',), blockstack=(), npush=0), Edge(pc=40, stack=('$x018.5',), blockstack=(), npush=0)]
|
| 94 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=30 nstack_initial=1), State(pc_initial=40 nstack_initial=1)])
|
| 95 |
+
DEBUG:numba.core.byteflow:stack: ['$phi30.0']
|
| 96 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=30 nstack_initial=1)
|
| 97 |
+
DEBUG:numba.core.byteflow:dispatch pc=30, inst=LOAD_FAST(arg=1, lineno=1145)
|
| 98 |
+
DEBUG:numba.core.byteflow:stack ['$phi30.0']
|
| 99 |
+
DEBUG:numba.core.byteflow:dispatch pc=32, inst=COMPARE_OP(arg=26, lineno=1145)
|
| 100 |
+
DEBUG:numba.core.byteflow:stack ['$phi30.0', '$threshold30.1']
|
| 101 |
+
DEBUG:numba.core.byteflow:dispatch pc=36, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1145)
|
| 102 |
+
DEBUG:numba.core.byteflow:stack ['$32compare_op.2']
|
| 103 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=38, stack=(), blockstack=(), npush=0), Edge(pc=48, stack=(), blockstack=(), npush=0)]
|
| 104 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=40 nstack_initial=1), State(pc_initial=38 nstack_initial=0), State(pc_initial=48 nstack_initial=0)])
|
| 105 |
+
DEBUG:numba.core.byteflow:stack: ['$phi40.0']
|
| 106 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=40 nstack_initial=1)
|
| 107 |
+
DEBUG:numba.core.byteflow:dispatch pc=40, inst=POP_TOP(arg=None, lineno=1145)
|
| 108 |
+
DEBUG:numba.core.byteflow:stack ['$phi40.0']
|
| 109 |
+
DEBUG:numba.core.byteflow:dispatch pc=42, inst=JUMP_FORWARD(arg=2, lineno=1145)
|
| 110 |
+
DEBUG:numba.core.byteflow:stack []
|
| 111 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=48, stack=(), blockstack=(), npush=0)]
|
| 112 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=38 nstack_initial=0), State(pc_initial=48 nstack_initial=0), State(pc_initial=48 nstack_initial=0)])
|
| 113 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 114 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=38 nstack_initial=0)
|
| 115 |
+
DEBUG:numba.core.byteflow:dispatch pc=38, inst=JUMP_FORWARD(arg=2, lineno=1145)
|
| 116 |
+
DEBUG:numba.core.byteflow:stack []
|
| 117 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=44, stack=(), blockstack=(), npush=0)]
|
| 118 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=48 nstack_initial=0), State(pc_initial=44 nstack_initial=0)])
|
| 119 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 120 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=48 nstack_initial=0)
|
| 121 |
+
DEBUG:numba.core.byteflow:dispatch pc=48, inst=LOAD_FAST(arg=0, lineno=1148)
|
| 122 |
+
DEBUG:numba.core.byteflow:stack []
|
| 123 |
+
DEBUG:numba.core.byteflow:dispatch pc=50, inst=LOAD_CONST(arg=2, lineno=1148)
|
| 124 |
+
DEBUG:numba.core.byteflow:stack ['$x48.0']
|
| 125 |
+
DEBUG:numba.core.byteflow:dispatch pc=52, inst=BINARY_SUBSCR(arg=None, lineno=1148)
|
| 126 |
+
DEBUG:numba.core.byteflow:stack ['$x48.0', '$const50.1']
|
| 127 |
+
DEBUG:numba.core.byteflow:dispatch pc=56, inst=STORE_FAST(arg=4, lineno=1148)
|
| 128 |
+
DEBUG:numba.core.byteflow:stack ['$52binary_subscr.2']
|
| 129 |
+
DEBUG:numba.core.byteflow:dispatch pc=58, inst=LOAD_FAST(arg=1, lineno=1149)
|
| 130 |
+
DEBUG:numba.core.byteflow:stack []
|
| 131 |
+
DEBUG:numba.core.byteflow:dispatch pc=60, inst=UNARY_NEGATIVE(arg=None, lineno=1149)
|
| 132 |
+
DEBUG:numba.core.byteflow:stack ['$threshold58.3']
|
| 133 |
+
DEBUG:numba.core.byteflow:dispatch pc=62, inst=LOAD_FAST(arg=4, lineno=1149)
|
| 134 |
+
DEBUG:numba.core.byteflow:stack ['$60unary_negative.4']
|
| 135 |
+
DEBUG:numba.core.byteflow:dispatch pc=64, inst=SWAP(arg=2, lineno=1149)
|
| 136 |
+
DEBUG:numba.core.byteflow:stack ['$60unary_negative.4', '$x162.5']
|
| 137 |
+
DEBUG:numba.core.byteflow:dispatch pc=66, inst=COPY(arg=2, lineno=1149)
|
| 138 |
+
DEBUG:numba.core.byteflow:stack ['$x162.5', '$60unary_negative.4']
|
| 139 |
+
DEBUG:numba.core.byteflow:dispatch pc=68, inst=COMPARE_OP(arg=26, lineno=1149)
|
| 140 |
+
DEBUG:numba.core.byteflow:stack ['$x162.5', '$60unary_negative.4', '$x162.5']
|
| 141 |
+
DEBUG:numba.core.byteflow:dispatch pc=72, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1149)
|
| 142 |
+
DEBUG:numba.core.byteflow:stack ['$x162.5', '$68compare_op.6']
|
| 143 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=74, stack=('$x162.5',), blockstack=(), npush=0), Edge(pc=84, stack=('$x162.5',), blockstack=(), npush=0)]
|
| 144 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=44 nstack_initial=0), State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1)])
|
| 145 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=44 nstack_initial=0), State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1)])
|
| 146 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 147 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=44 nstack_initial=0)
|
| 148 |
+
DEBUG:numba.core.byteflow:dispatch pc=44, inst=LOAD_CONST(arg=1, lineno=1146)
|
| 149 |
+
DEBUG:numba.core.byteflow:stack []
|
| 150 |
+
DEBUG:numba.core.byteflow:dispatch pc=46, inst=STORE_FAST(arg=3, lineno=1146)
|
| 151 |
+
DEBUG:numba.core.byteflow:stack ['$const44.0']
|
| 152 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=48, stack=(), blockstack=(), npush=0)]
|
| 153 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=74 nstack_initial=1), State(pc_initial=84 nstack_initial=1), State(pc_initial=48 nstack_initial=0)])
|
| 154 |
+
DEBUG:numba.core.byteflow:stack: ['$phi74.0']
|
| 155 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=74 nstack_initial=1)
|
| 156 |
+
DEBUG:numba.core.byteflow:dispatch pc=74, inst=LOAD_FAST(arg=1, lineno=1149)
|
| 157 |
+
DEBUG:numba.core.byteflow:stack ['$phi74.0']
|
| 158 |
+
DEBUG:numba.core.byteflow:dispatch pc=76, inst=COMPARE_OP(arg=26, lineno=1149)
|
| 159 |
+
DEBUG:numba.core.byteflow:stack ['$phi74.0', '$threshold74.1']
|
| 160 |
+
DEBUG:numba.core.byteflow:dispatch pc=80, inst=POP_JUMP_IF_FALSE(arg=5, lineno=1149)
|
| 161 |
+
DEBUG:numba.core.byteflow:stack ['$76compare_op.2']
|
| 162 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=82, stack=(), blockstack=(), npush=0), Edge(pc=92, stack=(), blockstack=(), npush=0)]
|
| 163 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=84 nstack_initial=1), State(pc_initial=48 nstack_initial=0), State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
|
| 164 |
+
DEBUG:numba.core.byteflow:stack: ['$phi84.0']
|
| 165 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=84 nstack_initial=1)
|
| 166 |
+
DEBUG:numba.core.byteflow:dispatch pc=84, inst=POP_TOP(arg=None, lineno=1149)
|
| 167 |
+
DEBUG:numba.core.byteflow:stack ['$phi84.0']
|
| 168 |
+
DEBUG:numba.core.byteflow:dispatch pc=86, inst=JUMP_FORWARD(arg=2, lineno=1149)
|
| 169 |
+
DEBUG:numba.core.byteflow:stack []
|
| 170 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=92, stack=(), blockstack=(), npush=0)]
|
| 171 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=48 nstack_initial=0), State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
|
| 172 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=82 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
|
| 173 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 174 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=82 nstack_initial=0)
|
| 175 |
+
DEBUG:numba.core.byteflow:dispatch pc=82, inst=JUMP_FORWARD(arg=2, lineno=1149)
|
| 176 |
+
DEBUG:numba.core.byteflow:stack []
|
| 177 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=88, stack=(), blockstack=(), npush=0)]
|
| 178 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0), State(pc_initial=92 nstack_initial=0), State(pc_initial=88 nstack_initial=0)])
|
| 179 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 180 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=92 nstack_initial=0)
|
| 181 |
+
DEBUG:numba.core.byteflow:dispatch pc=92, inst=LOAD_FAST(arg=2, lineno=1152)
|
| 182 |
+
DEBUG:numba.core.byteflow:stack []
|
| 183 |
+
DEBUG:numba.core.byteflow:dispatch pc=94, inst=POP_JUMP_IF_FALSE(arg=43, lineno=1152)
|
| 184 |
+
DEBUG:numba.core.byteflow:stack ['$zero_pos92.0']
|
| 185 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=96, stack=(), blockstack=(), npush=0), Edge(pc=182, stack=(), blockstack=(), npush=0)]
|
| 186 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0), State(pc_initial=88 nstack_initial=0), State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0)])
|
| 187 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=88 nstack_initial=0), State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0)])
|
| 188 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 189 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=88 nstack_initial=0)
|
| 190 |
+
DEBUG:numba.core.byteflow:dispatch pc=88, inst=LOAD_CONST(arg=1, lineno=1150)
|
| 191 |
+
DEBUG:numba.core.byteflow:stack []
|
| 192 |
+
DEBUG:numba.core.byteflow:dispatch pc=90, inst=STORE_FAST(arg=4, lineno=1150)
|
| 193 |
+
DEBUG:numba.core.byteflow:stack ['$const88.0']
|
| 194 |
+
DEBUG:numba.core.byteflow:end state. edges=[Edge(pc=92, stack=(), blockstack=(), npush=0)]
|
| 195 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=96 nstack_initial=0), State(pc_initial=182 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
|
| 196 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 197 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=96 nstack_initial=0)
|
| 198 |
+
DEBUG:numba.core.byteflow:dispatch pc=96, inst=LOAD_GLOBAL(arg=1, lineno=1153)
|
| 199 |
+
DEBUG:numba.core.byteflow:stack []
|
| 200 |
+
DEBUG:numba.core.byteflow:dispatch pc=106, inst=LOAD_ATTR(arg=2, lineno=1153)
|
| 201 |
+
DEBUG:numba.core.byteflow:stack ['$null$96.1', '$96load_global.0']
|
| 202 |
+
DEBUG:numba.core.byteflow:dispatch pc=126, inst=LOAD_FAST(arg=3, lineno=1153)
|
| 203 |
+
DEBUG:numba.core.byteflow:stack ['$null$96.1', '$106load_attr.2']
|
| 204 |
+
DEBUG:numba.core.byteflow:dispatch pc=128, inst=CALL(arg=1, lineno=1153)
|
| 205 |
+
DEBUG:numba.core.byteflow:stack ['$null$96.1', '$106load_attr.2', '$x0126.3']
|
| 206 |
+
DEBUG:numba.core.byteflow:dispatch pc=136, inst=LOAD_GLOBAL(arg=1, lineno=1153)
|
| 207 |
+
DEBUG:numba.core.byteflow:stack ['$128call.4']
|
| 208 |
+
DEBUG:numba.core.byteflow:dispatch pc=146, inst=LOAD_ATTR(arg=2, lineno=1153)
|
| 209 |
+
DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$136load_global.5']
|
| 210 |
+
DEBUG:numba.core.byteflow:dispatch pc=166, inst=LOAD_FAST(arg=4, lineno=1153)
|
| 211 |
+
DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$146load_attr.7']
|
| 212 |
+
DEBUG:numba.core.byteflow:dispatch pc=168, inst=CALL(arg=1, lineno=1153)
|
| 213 |
+
DEBUG:numba.core.byteflow:stack ['$128call.4', '$null$136.6', '$146load_attr.7', '$x1166.8']
|
| 214 |
+
DEBUG:numba.core.byteflow:dispatch pc=176, inst=COMPARE_OP(arg=55, lineno=1153)
|
| 215 |
+
DEBUG:numba.core.byteflow:stack ['$128call.4', '$168call.9']
|
| 216 |
+
DEBUG:numba.core.byteflow:dispatch pc=180, inst=RETURN_VALUE(arg=None, lineno=1153)
|
| 217 |
+
DEBUG:numba.core.byteflow:stack ['$176compare_op.10']
|
| 218 |
+
DEBUG:numba.core.byteflow:end state. edges=[]
|
| 219 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=182 nstack_initial=0), State(pc_initial=92 nstack_initial=0)])
|
| 220 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 221 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=182 nstack_initial=0)
|
| 222 |
+
DEBUG:numba.core.byteflow:dispatch pc=182, inst=LOAD_GLOBAL(arg=1, lineno=1155)
|
| 223 |
+
DEBUG:numba.core.byteflow:stack []
|
| 224 |
+
DEBUG:numba.core.byteflow:dispatch pc=192, inst=LOAD_ATTR(arg=4, lineno=1155)
|
| 225 |
+
DEBUG:numba.core.byteflow:stack ['$null$182.1', '$182load_global.0']
|
| 226 |
+
DEBUG:numba.core.byteflow:dispatch pc=212, inst=LOAD_FAST(arg=3, lineno=1155)
|
| 227 |
+
DEBUG:numba.core.byteflow:stack ['$null$182.1', '$192load_attr.2']
|
| 228 |
+
DEBUG:numba.core.byteflow:dispatch pc=214, inst=CALL(arg=1, lineno=1155)
|
| 229 |
+
DEBUG:numba.core.byteflow:stack ['$null$182.1', '$192load_attr.2', '$x0212.3']
|
| 230 |
+
DEBUG:numba.core.byteflow:dispatch pc=222, inst=LOAD_GLOBAL(arg=1, lineno=1155)
|
| 231 |
+
DEBUG:numba.core.byteflow:stack ['$214call.4']
|
| 232 |
+
DEBUG:numba.core.byteflow:dispatch pc=232, inst=LOAD_ATTR(arg=4, lineno=1155)
|
| 233 |
+
DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$222load_global.5']
|
| 234 |
+
DEBUG:numba.core.byteflow:dispatch pc=252, inst=LOAD_FAST(arg=4, lineno=1155)
|
| 235 |
+
DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$232load_attr.7']
|
| 236 |
+
DEBUG:numba.core.byteflow:dispatch pc=254, inst=CALL(arg=1, lineno=1155)
|
| 237 |
+
DEBUG:numba.core.byteflow:stack ['$214call.4', '$null$222.6', '$232load_attr.7', '$x1252.8']
|
| 238 |
+
DEBUG:numba.core.byteflow:dispatch pc=262, inst=COMPARE_OP(arg=55, lineno=1155)
|
| 239 |
+
DEBUG:numba.core.byteflow:stack ['$214call.4', '$254call.9']
|
| 240 |
+
DEBUG:numba.core.byteflow:dispatch pc=266, inst=RETURN_VALUE(arg=None, lineno=1155)
|
| 241 |
+
DEBUG:numba.core.byteflow:stack ['$262compare_op.10']
|
| 242 |
+
DEBUG:numba.core.byteflow:end state. edges=[]
|
| 243 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=92 nstack_initial=0)])
|
| 244 |
+
DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
|
| 245 |
+
DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>,
|
| 246 |
+
{State(pc_initial=0 nstack_initial=0): set(),
|
| 247 |
+
State(pc_initial=30 nstack_initial=1): {'$phi30.0'},
|
| 248 |
+
State(pc_initial=38 nstack_initial=0): set(),
|
| 249 |
+
State(pc_initial=40 nstack_initial=1): set(),
|
| 250 |
+
State(pc_initial=44 nstack_initial=0): set(),
|
| 251 |
+
State(pc_initial=48 nstack_initial=0): set(),
|
| 252 |
+
State(pc_initial=74 nstack_initial=1): {'$phi74.0'},
|
| 253 |
+
State(pc_initial=82 nstack_initial=0): set(),
|
| 254 |
+
State(pc_initial=84 nstack_initial=1): set(),
|
| 255 |
+
State(pc_initial=88 nstack_initial=0): set(),
|
| 256 |
+
State(pc_initial=92 nstack_initial=0): set(),
|
| 257 |
+
State(pc_initial=96 nstack_initial=0): set(),
|
| 258 |
+
State(pc_initial=182 nstack_initial=0): set()})
|
| 259 |
+
DEBUG:numba.core.byteflow:defmap: {'$phi30.0': State(pc_initial=0 nstack_initial=0),
|
| 260 |
+
'$phi40.0': State(pc_initial=0 nstack_initial=0),
|
| 261 |
+
'$phi74.0': State(pc_initial=48 nstack_initial=0),
|
| 262 |
+
'$phi84.0': State(pc_initial=48 nstack_initial=0)}
|
| 263 |
+
DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>,
|
| 264 |
+
{'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
|
| 265 |
+
'$phi40.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
|
| 266 |
+
'$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))},
|
| 267 |
+
'$phi84.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}})
|
| 268 |
+
DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>,
|
| 269 |
+
{'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
|
| 270 |
+
'$phi40.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
|
| 271 |
+
'$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))},
|
| 272 |
+
'$phi84.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}})
|
| 273 |
+
DEBUG:numba.core.byteflow:keep phismap: {'$phi30.0': {('$x018.5', State(pc_initial=0 nstack_initial=0))},
|
| 274 |
+
'$phi74.0': {('$x162.5', State(pc_initial=48 nstack_initial=0))}}
|
| 275 |
+
DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>,
|
| 276 |
+
{State(pc_initial=0 nstack_initial=0): {'$phi30.0': '$x018.5'},
|
| 277 |
+
State(pc_initial=48 nstack_initial=0): {'$phi74.0': '$x162.5'}})
|
| 278 |
+
DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
|
| 279 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
|
| 280 |
+
AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'value': '$8binary_subscr.2'}), (14, {'res': '$threshold14.3'}), (16, {'value': '$threshold14.3', 'res': '$16unary_negative.4'}), (18, {'res': '$x018.5'}), (24, {'lhs': '$16unary_negative.4', 'rhs': '$x018.5', 'res': '$24compare_op.6'}), (28, {'pred': '$24compare_op.6'})), outgoing_phis={'$phi30.0': '$x018.5'}, blockstack=(), active_try_block=None, outgoing_edgepushed={30: ('$x018.5',), 40: ('$x018.5',)})
|
| 281 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=30 nstack_initial=1):
|
| 282 |
+
AdaptBlockInfo(insts=((30, {'res': '$threshold30.1'}), (32, {'lhs': '$phi30.0', 'rhs': '$threshold30.1', 'res': '$32compare_op.2'}), (36, {'pred': '$32compare_op.2'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={38: (), 48: ()})
|
| 283 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=38 nstack_initial=0):
|
| 284 |
+
AdaptBlockInfo(insts=((38, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={44: ()})
|
| 285 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=40 nstack_initial=1):
|
| 286 |
+
AdaptBlockInfo(insts=((42, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={48: ()})
|
| 287 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=44 nstack_initial=0):
|
| 288 |
+
AdaptBlockInfo(insts=((44, {'res': '$const44.0'}), (46, {'value': '$const44.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={48: ()})
|
| 289 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=48 nstack_initial=0):
|
| 290 |
+
AdaptBlockInfo(insts=((48, {'res': '$x48.0'}), (50, {'res': '$const50.1'}), (52, {'index': '$const50.1', 'target': '$x48.0', 'res': '$52binary_subscr.2'}), (56, {'value': '$52binary_subscr.2'}), (58, {'res': '$threshold58.3'}), (60, {'value': '$threshold58.3', 'res': '$60unary_negative.4'}), (62, {'res': '$x162.5'}), (68, {'lhs': '$60unary_negative.4', 'rhs': '$x162.5', 'res': '$68compare_op.6'}), (72, {'pred': '$68compare_op.6'})), outgoing_phis={'$phi74.0': '$x162.5'}, blockstack=(), active_try_block=None, outgoing_edgepushed={74: ('$x162.5',), 84: ('$x162.5',)})
|
| 291 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=74 nstack_initial=1):
|
| 292 |
+
AdaptBlockInfo(insts=((74, {'res': '$threshold74.1'}), (76, {'lhs': '$phi74.0', 'rhs': '$threshold74.1', 'res': '$76compare_op.2'}), (80, {'pred': '$76compare_op.2'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={82: (), 92: ()})
|
| 293 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=82 nstack_initial=0):
|
| 294 |
+
AdaptBlockInfo(insts=((82, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={88: ()})
|
| 295 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=84 nstack_initial=1):
|
| 296 |
+
AdaptBlockInfo(insts=((86, {}),), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={92: ()})
|
| 297 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=88 nstack_initial=0):
|
| 298 |
+
AdaptBlockInfo(insts=((88, {'res': '$const88.0'}), (90, {'value': '$const88.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={92: ()})
|
| 299 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=92 nstack_initial=0):
|
| 300 |
+
AdaptBlockInfo(insts=((92, {'res': '$zero_pos92.0'}), (94, {'pred': '$zero_pos92.0'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={96: (), 182: ()})
|
| 301 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=96 nstack_initial=0):
|
| 302 |
+
AdaptBlockInfo(insts=((96, {'idx': 0, 'res': '$96load_global.0'}), (106, {'item': '$96load_global.0', 'res': '$106load_attr.2'}), (126, {'res': '$x0126.3'}), (128, {'func': '$106load_attr.2', 'args': ['$x0126.3'], 'kw_names': None, 'res': '$128call.4'}), (136, {'idx': 0, 'res': '$136load_global.5'}), (146, {'item': '$136load_global.5', 'res': '$146load_attr.7'}), (166, {'res': '$x1166.8'}), (168, {'func': '$146load_attr.7', 'args': ['$x1166.8'], 'kw_names': None, 'res': '$168call.9'}), (176, {'lhs': '$128call.4', 'rhs': '$168call.9', 'res': '$176compare_op.10'}), (180, {'retval': '$176compare_op.10', 'castval': '$180return_value.11'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
|
| 303 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=182 nstack_initial=0):
|
| 304 |
+
AdaptBlockInfo(insts=((182, {'idx': 0, 'res': '$182load_global.0'}), (192, {'item': '$182load_global.0', 'res': '$192load_attr.2'}), (212, {'res': '$x0212.3'}), (214, {'func': '$192load_attr.2', 'args': ['$x0212.3'], 'kw_names': None, 'res': '$214call.4'}), (222, {'idx': 0, 'res': '$222load_global.5'}), (232, {'item': '$222load_global.5', 'res': '$232load_attr.7'}), (252, {'res': '$x1252.8'}), (254, {'func': '$232load_attr.7', 'args': ['$x1252.8'], 'kw_names': None, 'res': '$254call.9'}), (262, {'lhs': '$214call.4', 'rhs': '$254call.9', 'res': '$262compare_op.10'}), (266, {'retval': '$262compare_op.10', 'castval': '$266return_value.11'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
|
| 305 |
+
DEBUG:numba.core.interpreter:label 0:
|
| 306 |
+
x = arg(0, name=x) ['x']
|
| 307 |
+
threshold = arg(1, name=threshold) ['threshold']
|
| 308 |
+
zero_pos = arg(2, name=zero_pos) ['zero_pos']
|
| 309 |
+
$const6.1 = const(int, 0) ['$const6.1']
|
| 310 |
+
x0 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$const6.1', 'x', 'x0']
|
| 311 |
+
$16unary_negative.4 = unary(fn=<built-in function neg>, value=threshold) ['$16unary_negative.4', 'threshold']
|
| 312 |
+
$24compare_op.6 = $16unary_negative.4 <= x0 ['$16unary_negative.4', '$24compare_op.6', 'x0']
|
| 313 |
+
bool28 = global(bool: <class 'bool'>) ['bool28']
|
| 314 |
+
$28pred = call bool28($24compare_op.6, func=bool28, args=(Var($24compare_op.6, audio.py:1145),), kws=(), vararg=None, varkwarg=None, target=None) ['$24compare_op.6', '$28pred', 'bool28']
|
| 315 |
+
$phi30.0 = x0 ['$phi30.0', 'x0']
|
| 316 |
+
branch $28pred, 30, 40 ['$28pred']
|
| 317 |
+
label 30:
|
| 318 |
+
$32compare_op.2 = $phi30.0 <= threshold ['$32compare_op.2', '$phi30.0', 'threshold']
|
| 319 |
+
bool36 = global(bool: <class 'bool'>) ['bool36']
|
| 320 |
+
$36pred = call bool36($32compare_op.2, func=bool36, args=(Var($32compare_op.2, audio.py:1145),), kws=(), vararg=None, varkwarg=None, target=None) ['$32compare_op.2', '$36pred', 'bool36']
|
| 321 |
+
branch $36pred, 38, 48 ['$36pred']
|
| 322 |
+
label 38:
|
| 323 |
+
jump 44 []
|
| 324 |
+
label 40:
|
| 325 |
+
jump 48 []
|
| 326 |
+
label 44:
|
| 327 |
+
x0 = const(int, 0) ['x0']
|
| 328 |
+
jump 48 []
|
| 329 |
+
label 48:
|
| 330 |
+
$const50.1 = const(int, -1) ['$const50.1']
|
| 331 |
+
x1 = getitem(value=x, index=$const50.1, fn=<built-in function getitem>) ['$const50.1', 'x', 'x1']
|
| 332 |
+
$60unary_negative.4 = unary(fn=<built-in function neg>, value=threshold) ['$60unary_negative.4', 'threshold']
|
| 333 |
+
$68compare_op.6 = $60unary_negative.4 <= x1 ['$60unary_negative.4', '$68compare_op.6', 'x1']
|
| 334 |
+
bool72 = global(bool: <class 'bool'>) ['bool72']
|
| 335 |
+
$72pred = call bool72($68compare_op.6, func=bool72, args=(Var($68compare_op.6, audio.py:1149),), kws=(), vararg=None, varkwarg=None, target=None) ['$68compare_op.6', '$72pred', 'bool72']
|
| 336 |
+
$phi74.0 = x1 ['$phi74.0', 'x1']
|
| 337 |
+
branch $72pred, 74, 84 ['$72pred']
|
| 338 |
+
label 74:
|
| 339 |
+
$76compare_op.2 = $phi74.0 <= threshold ['$76compare_op.2', '$phi74.0', 'threshold']
|
| 340 |
+
bool80 = global(bool: <class 'bool'>) ['bool80']
|
| 341 |
+
$80pred = call bool80($76compare_op.2, func=bool80, args=(Var($76compare_op.2, audio.py:1149),), kws=(), vararg=None, varkwarg=None, target=None) ['$76compare_op.2', '$80pred', 'bool80']
|
| 342 |
+
branch $80pred, 82, 92 ['$80pred']
|
| 343 |
+
label 82:
|
| 344 |
+
jump 88 []
|
| 345 |
+
label 84:
|
| 346 |
+
jump 92 []
|
| 347 |
+
label 88:
|
| 348 |
+
x1 = const(int, 0) ['x1']
|
| 349 |
+
jump 92 []
|
| 350 |
+
label 92:
|
| 351 |
+
bool94 = global(bool: <class 'bool'>) ['bool94']
|
| 352 |
+
$94pred = call bool94(zero_pos, func=bool94, args=(Var(zero_pos, audio.py:1141),), kws=(), vararg=None, varkwarg=None, target=None) ['$94pred', 'bool94', 'zero_pos']
|
| 353 |
+
branch $94pred, 96, 182 ['$94pred']
|
| 354 |
+
label 96:
|
| 355 |
+
$96load_global.0 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$96load_global.0']
|
| 356 |
+
$106load_attr.2 = getattr(value=$96load_global.0, attr=signbit) ['$106load_attr.2', '$96load_global.0']
|
| 357 |
+
$128call.4 = call $106load_attr.2(x0, func=$106load_attr.2, args=[Var(x0, audio.py:1144)], kws=(), vararg=None, varkwarg=None, target=None) ['$106load_attr.2', '$128call.4', 'x0']
|
| 358 |
+
$136load_global.5 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$136load_global.5']
|
| 359 |
+
$146load_attr.7 = getattr(value=$136load_global.5, attr=signbit) ['$136load_global.5', '$146load_attr.7']
|
| 360 |
+
$168call.9 = call $146load_attr.7(x1, func=$146load_attr.7, args=[Var(x1, audio.py:1148)], kws=(), vararg=None, varkwarg=None, target=None) ['$146load_attr.7', '$168call.9', 'x1']
|
| 361 |
+
$176compare_op.10 = $128call.4 != $168call.9 ['$128call.4', '$168call.9', '$176compare_op.10']
|
| 362 |
+
$180return_value.11 = cast(value=$176compare_op.10) ['$176compare_op.10', '$180return_value.11']
|
| 363 |
+
return $180return_value.11 ['$180return_value.11']
|
| 364 |
+
label 182:
|
| 365 |
+
$182load_global.0 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$182load_global.0']
|
| 366 |
+
$192load_attr.2 = getattr(value=$182load_global.0, attr=sign) ['$182load_global.0', '$192load_attr.2']
|
| 367 |
+
$214call.4 = call $192load_attr.2(x0, func=$192load_attr.2, args=[Var(x0, audio.py:1144)], kws=(), vararg=None, varkwarg=None, target=None) ['$192load_attr.2', '$214call.4', 'x0']
|
| 368 |
+
$222load_global.5 = global(np: <module 'numpy' from '/home/anhnmt2/.local/lib/python3.12/site-packages/numpy/__init__.py'>) ['$222load_global.5']
|
| 369 |
+
$232load_attr.7 = getattr(value=$222load_global.5, attr=sign) ['$222load_global.5', '$232load_attr.7']
|
| 370 |
+
$254call.9 = call $232load_attr.7(x1, func=$232load_attr.7, args=[Var(x1, audio.py:1148)], kws=(), vararg=None, varkwarg=None, target=None) ['$232load_attr.7', '$254call.9', 'x1']
|
| 371 |
+
$262compare_op.10 = $214call.4 != $254call.9 ['$214call.4', '$254call.9', '$262compare_op.10']
|
| 372 |
+
$266return_value.11 = cast(value=$262compare_op.10) ['$262compare_op.10', '$266return_value.11']
|
| 373 |
+
return $266return_value.11 ['$266return_value.11']
|
| 374 |
+
|
| 375 |
+
DEBUG:numba.core.byteflow:bytecode dump:
|
| 376 |
+
> 0 NOP(arg=None, lineno=1039)
|
| 377 |
+
2 RESUME(arg=0, lineno=1039)
|
| 378 |
+
4 LOAD_FAST(arg=0, lineno=1042)
|
| 379 |
+
6 LOAD_CONST(arg=1, lineno=1042)
|
| 380 |
+
8 BINARY_SUBSCR(arg=None, lineno=1042)
|
| 381 |
+
12 LOAD_FAST(arg=0, lineno=1042)
|
| 382 |
+
14 LOAD_CONST(arg=2, lineno=1042)
|
| 383 |
+
16 BINARY_SUBSCR(arg=None, lineno=1042)
|
| 384 |
+
20 COMPARE_OP(arg=68, lineno=1042)
|
| 385 |
+
24 LOAD_FAST(arg=0, lineno=1042)
|
| 386 |
+
26 LOAD_CONST(arg=1, lineno=1042)
|
| 387 |
+
28 BINARY_SUBSCR(arg=None, lineno=1042)
|
| 388 |
+
32 LOAD_FAST(arg=0, lineno=1042)
|
| 389 |
+
34 LOAD_CONST(arg=3, lineno=1042)
|
| 390 |
+
36 BINARY_SUBSCR(arg=None, lineno=1042)
|
| 391 |
+
40 COMPARE_OP(arg=92, lineno=1042)
|
| 392 |
+
44 BINARY_OP(arg=1, lineno=1042)
|
| 393 |
+
48 RETURN_VALUE(arg=None, lineno=1042)
|
| 394 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
|
| 395 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 396 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
|
| 397 |
+
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1039)
|
| 398 |
+
DEBUG:numba.core.byteflow:stack []
|
| 399 |
+
DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1039)
|
| 400 |
+
DEBUG:numba.core.byteflow:stack []
|
| 401 |
+
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1042)
|
| 402 |
+
DEBUG:numba.core.byteflow:stack []
|
| 403 |
+
DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1042)
|
| 404 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0']
|
| 405 |
+
DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1042)
|
| 406 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
|
| 407 |
+
DEBUG:numba.core.byteflow:dispatch pc=12, inst=LOAD_FAST(arg=0, lineno=1042)
|
| 408 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
|
| 409 |
+
DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_CONST(arg=2, lineno=1042)
|
| 410 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3']
|
| 411 |
+
DEBUG:numba.core.byteflow:dispatch pc=16, inst=BINARY_SUBSCR(arg=None, lineno=1042)
|
| 412 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3', '$const14.4']
|
| 413 |
+
DEBUG:numba.core.byteflow:dispatch pc=20, inst=COMPARE_OP(arg=68, lineno=1042)
|
| 414 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$16binary_subscr.5']
|
| 415 |
+
DEBUG:numba.core.byteflow:dispatch pc=24, inst=LOAD_FAST(arg=0, lineno=1042)
|
| 416 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6']
|
| 417 |
+
DEBUG:numba.core.byteflow:dispatch pc=26, inst=LOAD_CONST(arg=1, lineno=1042)
|
| 418 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7']
|
| 419 |
+
DEBUG:numba.core.byteflow:dispatch pc=28, inst=BINARY_SUBSCR(arg=None, lineno=1042)
|
| 420 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7', '$const26.8']
|
| 421 |
+
DEBUG:numba.core.byteflow:dispatch pc=32, inst=LOAD_FAST(arg=0, lineno=1042)
|
| 422 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9']
|
| 423 |
+
DEBUG:numba.core.byteflow:dispatch pc=34, inst=LOAD_CONST(arg=3, lineno=1042)
|
| 424 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10']
|
| 425 |
+
DEBUG:numba.core.byteflow:dispatch pc=36, inst=BINARY_SUBSCR(arg=None, lineno=1042)
|
| 426 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10', '$const34.11']
|
| 427 |
+
DEBUG:numba.core.byteflow:dispatch pc=40, inst=COMPARE_OP(arg=92, lineno=1042)
|
| 428 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$36binary_subscr.12']
|
| 429 |
+
DEBUG:numba.core.byteflow:dispatch pc=44, inst=BINARY_OP(arg=1, lineno=1042)
|
| 430 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$40compare_op.13']
|
| 431 |
+
DEBUG:numba.core.byteflow:dispatch pc=48, inst=RETURN_VALUE(arg=None, lineno=1042)
|
| 432 |
+
DEBUG:numba.core.byteflow:stack ['$binop_and_44.14']
|
| 433 |
+
DEBUG:numba.core.byteflow:end state. edges=[]
|
| 434 |
+
DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
|
| 435 |
+
DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>, {State(pc_initial=0 nstack_initial=0): set()})
|
| 436 |
+
DEBUG:numba.core.byteflow:defmap: {}
|
| 437 |
+
DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>, {})
|
| 438 |
+
DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>, {})
|
| 439 |
+
DEBUG:numba.core.byteflow:keep phismap: {}
|
| 440 |
+
DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>, {})
|
| 441 |
+
DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
|
| 442 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
|
| 443 |
+
AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'res': '$x12.3'}), (14, {'res': '$const14.4'}), (16, {'index': '$const14.4', 'target': '$x12.3', 'res': '$16binary_subscr.5'}), (20, {'lhs': '$8binary_subscr.2', 'rhs': '$16binary_subscr.5', 'res': '$20compare_op.6'}), (24, {'res': '$x24.7'}), (26, {'res': '$const26.8'}), (28, {'index': '$const26.8', 'target': '$x24.7', 'res': '$28binary_subscr.9'}), (32, {'res': '$x32.10'}), (34, {'res': '$const34.11'}), (36, {'index': '$const34.11', 'target': '$x32.10', 'res': '$36binary_subscr.12'}), (40, {'lhs': '$28binary_subscr.9', 'rhs': '$36binary_subscr.12', 'res': '$40compare_op.13'}), (44, {'op': '&', 'lhs': '$20compare_op.6', 'rhs': '$40compare_op.13', 'res': '$binop_and_44.14'}), (48, {'retval': '$binop_and_44.14', 'castval': '$48return_value.15'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
|
| 444 |
+
DEBUG:numba.core.interpreter:label 0:
|
| 445 |
+
x = arg(0, name=x) ['x']
|
| 446 |
+
$const6.1 = const(int, 0) ['$const6.1']
|
| 447 |
+
$8binary_subscr.2 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$8binary_subscr.2', '$const6.1', 'x']
|
| 448 |
+
$const14.4 = const(int, -1) ['$const14.4']
|
| 449 |
+
$16binary_subscr.5 = getitem(value=x, index=$const14.4, fn=<built-in function getitem>) ['$16binary_subscr.5', '$const14.4', 'x']
|
| 450 |
+
$20compare_op.6 = $8binary_subscr.2 > $16binary_subscr.5 ['$16binary_subscr.5', '$20compare_op.6', '$8binary_subscr.2']
|
| 451 |
+
$const26.8 = const(int, 0) ['$const26.8']
|
| 452 |
+
$28binary_subscr.9 = getitem(value=x, index=$const26.8, fn=<built-in function getitem>) ['$28binary_subscr.9', '$const26.8', 'x']
|
| 453 |
+
$const34.11 = const(int, 1) ['$const34.11']
|
| 454 |
+
$36binary_subscr.12 = getitem(value=x, index=$const34.11, fn=<built-in function getitem>) ['$36binary_subscr.12', '$const34.11', 'x']
|
| 455 |
+
$40compare_op.13 = $28binary_subscr.9 >= $36binary_subscr.12 ['$28binary_subscr.9', '$36binary_subscr.12', '$40compare_op.13']
|
| 456 |
+
$binop_and_44.14 = $20compare_op.6 & $40compare_op.13 ['$20compare_op.6', '$40compare_op.13', '$binop_and_44.14']
|
| 457 |
+
$48return_value.15 = cast(value=$binop_and_44.14) ['$48return_value.15', '$binop_and_44.14']
|
| 458 |
+
return $48return_value.15 ['$48return_value.15']
|
| 459 |
+
|
| 460 |
+
DEBUG:numba.core.byteflow:bytecode dump:
|
| 461 |
+
> 0 NOP(arg=None, lineno=1045)
|
| 462 |
+
2 RESUME(arg=0, lineno=1045)
|
| 463 |
+
4 LOAD_FAST(arg=0, lineno=1048)
|
| 464 |
+
6 LOAD_CONST(arg=1, lineno=1048)
|
| 465 |
+
8 BINARY_SUBSCR(arg=None, lineno=1048)
|
| 466 |
+
12 LOAD_FAST(arg=0, lineno=1048)
|
| 467 |
+
14 LOAD_CONST(arg=2, lineno=1048)
|
| 468 |
+
16 BINARY_SUBSCR(arg=None, lineno=1048)
|
| 469 |
+
20 COMPARE_OP(arg=2, lineno=1048)
|
| 470 |
+
24 LOAD_FAST(arg=0, lineno=1048)
|
| 471 |
+
26 LOAD_CONST(arg=1, lineno=1048)
|
| 472 |
+
28 BINARY_SUBSCR(arg=None, lineno=1048)
|
| 473 |
+
32 LOAD_FAST(arg=0, lineno=1048)
|
| 474 |
+
34 LOAD_CONST(arg=3, lineno=1048)
|
| 475 |
+
36 BINARY_SUBSCR(arg=None, lineno=1048)
|
| 476 |
+
40 COMPARE_OP(arg=26, lineno=1048)
|
| 477 |
+
44 BINARY_OP(arg=1, lineno=1048)
|
| 478 |
+
48 RETURN_VALUE(arg=None, lineno=1048)
|
| 479 |
+
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
|
| 480 |
+
DEBUG:numba.core.byteflow:stack: []
|
| 481 |
+
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
|
| 482 |
+
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1045)
|
| 483 |
+
DEBUG:numba.core.byteflow:stack []
|
| 484 |
+
DEBUG:numba.core.byteflow:dispatch pc=2, inst=RESUME(arg=0, lineno=1045)
|
| 485 |
+
DEBUG:numba.core.byteflow:stack []
|
| 486 |
+
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=0, lineno=1048)
|
| 487 |
+
DEBUG:numba.core.byteflow:stack []
|
| 488 |
+
DEBUG:numba.core.byteflow:dispatch pc=6, inst=LOAD_CONST(arg=1, lineno=1048)
|
| 489 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0']
|
| 490 |
+
DEBUG:numba.core.byteflow:dispatch pc=8, inst=BINARY_SUBSCR(arg=None, lineno=1048)
|
| 491 |
+
DEBUG:numba.core.byteflow:stack ['$x4.0', '$const6.1']
|
| 492 |
+
DEBUG:numba.core.byteflow:dispatch pc=12, inst=LOAD_FAST(arg=0, lineno=1048)
|
| 493 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2']
|
| 494 |
+
DEBUG:numba.core.byteflow:dispatch pc=14, inst=LOAD_CONST(arg=2, lineno=1048)
|
| 495 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3']
|
| 496 |
+
DEBUG:numba.core.byteflow:dispatch pc=16, inst=BINARY_SUBSCR(arg=None, lineno=1048)
|
| 497 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$x12.3', '$const14.4']
|
| 498 |
+
DEBUG:numba.core.byteflow:dispatch pc=20, inst=COMPARE_OP(arg=2, lineno=1048)
|
| 499 |
+
DEBUG:numba.core.byteflow:stack ['$8binary_subscr.2', '$16binary_subscr.5']
|
| 500 |
+
DEBUG:numba.core.byteflow:dispatch pc=24, inst=LOAD_FAST(arg=0, lineno=1048)
|
| 501 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6']
|
| 502 |
+
DEBUG:numba.core.byteflow:dispatch pc=26, inst=LOAD_CONST(arg=1, lineno=1048)
|
| 503 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7']
|
| 504 |
+
DEBUG:numba.core.byteflow:dispatch pc=28, inst=BINARY_SUBSCR(arg=None, lineno=1048)
|
| 505 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$x24.7', '$const26.8']
|
| 506 |
+
DEBUG:numba.core.byteflow:dispatch pc=32, inst=LOAD_FAST(arg=0, lineno=1048)
|
| 507 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9']
|
| 508 |
+
DEBUG:numba.core.byteflow:dispatch pc=34, inst=LOAD_CONST(arg=3, lineno=1048)
|
| 509 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10']
|
| 510 |
+
DEBUG:numba.core.byteflow:dispatch pc=36, inst=BINARY_SUBSCR(arg=None, lineno=1048)
|
| 511 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$x32.10', '$const34.11']
|
| 512 |
+
DEBUG:numba.core.byteflow:dispatch pc=40, inst=COMPARE_OP(arg=26, lineno=1048)
|
| 513 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$28binary_subscr.9', '$36binary_subscr.12']
|
| 514 |
+
DEBUG:numba.core.byteflow:dispatch pc=44, inst=BINARY_OP(arg=1, lineno=1048)
|
| 515 |
+
DEBUG:numba.core.byteflow:stack ['$20compare_op.6', '$40compare_op.13']
|
| 516 |
+
DEBUG:numba.core.byteflow:dispatch pc=48, inst=RETURN_VALUE(arg=None, lineno=1048)
|
| 517 |
+
DEBUG:numba.core.byteflow:stack ['$binop_and_44.14']
|
| 518 |
+
DEBUG:numba.core.byteflow:end state. edges=[]
|
| 519 |
+
DEBUG:numba.core.byteflow:-------------------------Prune PHIs-------------------------
|
| 520 |
+
DEBUG:numba.core.byteflow:Used_phis: defaultdict(<class 'set'>, {State(pc_initial=0 nstack_initial=0): set()})
|
| 521 |
+
DEBUG:numba.core.byteflow:defmap: {}
|
| 522 |
+
DEBUG:numba.core.byteflow:phismap: defaultdict(<class 'set'>, {})
|
| 523 |
+
DEBUG:numba.core.byteflow:changing phismap: defaultdict(<class 'set'>, {})
|
| 524 |
+
DEBUG:numba.core.byteflow:keep phismap: {}
|
| 525 |
+
DEBUG:numba.core.byteflow:new_out: defaultdict(<class 'dict'>, {})
|
| 526 |
+
DEBUG:numba.core.byteflow:----------------------DONE Prune PHIs-----------------------
|
| 527 |
+
DEBUG:numba.core.byteflow:block_infos State(pc_initial=0 nstack_initial=0):
|
| 528 |
+
AdaptBlockInfo(insts=((0, {}), (2, {}), (4, {'res': '$x4.0'}), (6, {'res': '$const6.1'}), (8, {'index': '$const6.1', 'target': '$x4.0', 'res': '$8binary_subscr.2'}), (12, {'res': '$x12.3'}), (14, {'res': '$const14.4'}), (16, {'index': '$const14.4', 'target': '$x12.3', 'res': '$16binary_subscr.5'}), (20, {'lhs': '$8binary_subscr.2', 'rhs': '$16binary_subscr.5', 'res': '$20compare_op.6'}), (24, {'res': '$x24.7'}), (26, {'res': '$const26.8'}), (28, {'index': '$const26.8', 'target': '$x24.7', 'res': '$28binary_subscr.9'}), (32, {'res': '$x32.10'}), (34, {'res': '$const34.11'}), (36, {'index': '$const34.11', 'target': '$x32.10', 'res': '$36binary_subscr.12'}), (40, {'lhs': '$28binary_subscr.9', 'rhs': '$36binary_subscr.12', 'res': '$40compare_op.13'}), (44, {'op': '&', 'lhs': '$20compare_op.6', 'rhs': '$40compare_op.13', 'res': '$binop_and_44.14'}), (48, {'retval': '$binop_and_44.14', 'castval': '$48return_value.15'})), outgoing_phis={}, blockstack=(), active_try_block=None, outgoing_edgepushed={})
|
| 529 |
+
DEBUG:numba.core.interpreter:label 0:
|
| 530 |
+
x = arg(0, name=x) ['x']
|
| 531 |
+
$const6.1 = const(int, 0) ['$const6.1']
|
| 532 |
+
$8binary_subscr.2 = getitem(value=x, index=$const6.1, fn=<built-in function getitem>) ['$8binary_subscr.2', '$const6.1', 'x']
|
| 533 |
+
$const14.4 = const(int, -1) ['$const14.4']
|
| 534 |
+
$16binary_subscr.5 = getitem(value=x, index=$const14.4, fn=<built-in function getitem>) ['$16binary_subscr.5', '$const14.4', 'x']
|
| 535 |
+
$20compare_op.6 = $8binary_subscr.2 < $16binary_subscr.5 ['$16binary_subscr.5', '$20compare_op.6', '$8binary_subscr.2']
|
| 536 |
+
$const26.8 = const(int, 0) ['$const26.8']
|
| 537 |
+
$28binary_subscr.9 = getitem(value=x, index=$const26.8, fn=<built-in function getitem>) ['$28binary_subscr.9', '$const26.8', 'x']
|
| 538 |
+
$const34.11 = const(int, 1) ['$const34.11']
|
| 539 |
+
$36binary_subscr.12 = getitem(value=x, index=$const34.11, fn=<built-in function getitem>) ['$36binary_subscr.12', '$const34.11', 'x']
|
| 540 |
+
$40compare_op.13 = $28binary_subscr.9 <= $36binary_subscr.12 ['$28binary_subscr.9', '$36binary_subscr.12', '$40compare_op.13']
|
| 541 |
+
$binop_and_44.14 = $20compare_op.6 & $40compare_op.13 ['$20compare_op.6', '$40compare_op.13', '$binop_and_44.14']
|
| 542 |
+
$48return_value.15 = cast(value=$binop_and_44.14) ['$48return_value.15', '$binop_and_44.14']
|
| 543 |
+
return $48return_value.15 ['$48return_value.15']
|
| 544 |
+
|
| 545 |
+
0%| | 14/43233 [00:59<27:10:42, 2.26s/it]Traceback (most recent call last):
|
| 546 |
+
{'loss': 6.6172, 'grad_norm': 10.827251434326172, 'learning_rate': 3.8550501156515035e-08, 'epoch': 0.0}
|
| 547 |
+
{'loss': 5.794, 'grad_norm': 14.017024040222168, 'learning_rate': 7.710100231303007e-08, 'epoch': 0.0}
|
| 548 |
+
{'loss': 6.8788, 'grad_norm': 13.020977020263672, 'learning_rate': 1.1565150346954511e-07, 'epoch': 0.0}
|
| 549 |
+
{'loss': 6.6162, 'grad_norm': 18.2950439453125, 'learning_rate': 1.5420200462606014e-07, 'epoch': 0.0}
|
| 550 |
+
{'loss': 6.9646, 'grad_norm': 14.263402938842773, 'learning_rate': 1.9275250578257518e-07, 'epoch': 0.0}
|
| 551 |
+
{'loss': 6.631, 'grad_norm': 13.121792793273926, 'learning_rate': 2.3130300693909022e-07, 'epoch': 0.0}
|
| 552 |
+
{'loss': 7.2093, 'grad_norm': 18.358381271362305, 'learning_rate': 2.6985350809560526e-07, 'epoch': 0.0}
|
| 553 |
+
{'loss': 7.7485, 'grad_norm': 24.542631149291992, 'learning_rate': 3.084040092521203e-07, 'epoch': 0.0}
|
| 554 |
+
{'loss': 6.2489, 'grad_norm': 12.371420860290527, 'learning_rate': 3.469545104086353e-07, 'epoch': 0.0}
|
| 555 |
+
{'loss': 6.6981, 'grad_norm': 22.148744583129883, 'learning_rate': 3.8550501156515036e-07, 'epoch': 0.0}
|
| 556 |
+
{'loss': 6.9253, 'grad_norm': 25.32149314880371, 'learning_rate': 4.240555127216654e-07, 'epoch': 0.0}
|
| 557 |
+
{'loss': 6.7832, 'grad_norm': 36.084407806396484, 'learning_rate': 4.6260601387818044e-07, 'epoch': 0.0}
|
| 558 |
+
{'loss': 6.9372, 'grad_norm': 12.946453094482422, 'learning_rate': 5.011565150346955e-07, 'epoch': 0.0}
|
| 559 |
+
{'loss': 6.6468, 'grad_norm': 14.272549629211426, 'learning_rate': 5.397070161912105e-07, 'epoch': 0.0}
|
scripts/wandb/latest-run/files/requirements.txt
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tomlkit==0.12.0
|
| 2 |
+
python-dotenv==1.0.1
|
| 3 |
+
SQLAlchemy==2.0.36
|
| 4 |
+
psutil==6.1.0
|
| 5 |
+
anyio==4.8.0
|
| 6 |
+
onnxruntime==1.20.1
|
| 7 |
+
antlr4-python3-runtime==4.9.3
|
| 8 |
+
httpx-sse==0.4.0
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
tqdm==4.66.5
|
| 11 |
+
simplejson==3.19.3
|
| 12 |
+
csvw==3.5.1
|
| 13 |
+
pooch==1.8.2
|
| 14 |
+
trl==0.9.6
|
| 15 |
+
more-itertools==10.5.0
|
| 16 |
+
jiter==0.6.1
|
| 17 |
+
markdown2==2.5.1
|
| 18 |
+
segments==2.2.1
|
| 19 |
+
opentelemetry-instrumentation-asgi==0.50b0
|
| 20 |
+
Deprecated==1.2.15
|
| 21 |
+
pyasn1_modules==0.4.1
|
| 22 |
+
bcrypt==4.2.1
|
| 23 |
+
opentelemetry-util-http==0.50b0
|
| 24 |
+
intervaltree==3.1.0
|
| 25 |
+
hjson==3.1.0
|
| 26 |
+
modelscope==1.18.1
|
| 27 |
+
fastapi==0.112.4
|
| 28 |
+
pyarrow==17.0.0
|
| 29 |
+
sounddevice==0.5.1
|
| 30 |
+
modelscope_studio==0.4.0.9
|
| 31 |
+
build==1.2.2.post1
|
| 32 |
+
oauthlib==3.2.2
|
| 33 |
+
gunicorn==23.0.0
|
| 34 |
+
pyasn1==0.6.1
|
| 35 |
+
matplotlib==3.9.2
|
| 36 |
+
speechbrain==0.5.16
|
| 37 |
+
joblib==1.4.2
|
| 38 |
+
tyro==0.8.13
|
| 39 |
+
rsa==4.9
|
| 40 |
+
numba==0.60.0
|
| 41 |
+
fastprogress==1.0.3
|
| 42 |
+
wrapt==1.17.0
|
| 43 |
+
PyPika==0.48.9
|
| 44 |
+
dacite==1.8.1
|
| 45 |
+
googleapis-common-protos==1.66.0
|
| 46 |
+
openai==1.68.0
|
| 47 |
+
tabulate==0.9.0
|
| 48 |
+
monotonic==1.6
|
| 49 |
+
lazy_loader==0.4
|
| 50 |
+
google-auth==2.37.0
|
| 51 |
+
fairseq==0.12.3
|
| 52 |
+
opentelemetry-semantic-conventions==0.50b0
|
| 53 |
+
sacrebleu==2.4.3
|
| 54 |
+
requests-toolbelt==1.0.0
|
| 55 |
+
ruff==0.7.0
|
| 56 |
+
bitsandbytes==0.43.1
|
| 57 |
+
tenacity==9.0.0
|
| 58 |
+
uvloop==0.21.0
|
| 59 |
+
Pygments==2.18.0
|
| 60 |
+
langchain==0.3.18
|
| 61 |
+
typer==0.12.5
|
| 62 |
+
uritemplate==4.1.1
|
| 63 |
+
rich==13.9.3
|
| 64 |
+
lion-pytorch==0.2.3
|
| 65 |
+
pydub==0.25.1
|
| 66 |
+
fastcore==1.7.28
|
| 67 |
+
encodec==0.1.1
|
| 68 |
+
cytoolz==1.0.1
|
| 69 |
+
huggingface-hub==0.26.1
|
| 70 |
+
python-dateutil==2.9.0.post0
|
| 71 |
+
duckduckgo_search==7.3.2
|
| 72 |
+
rfc3986==1.5.0
|
| 73 |
+
wavedrom==2.0.3.post3
|
| 74 |
+
sentence-transformers==3.3.1
|
| 75 |
+
httpx==0.28.1
|
| 76 |
+
colorlog==6.9.0
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
termcolor==2.5.0
|
| 79 |
+
importlib_resources==6.4.5
|
| 80 |
+
lilcom==1.8.1
|
| 81 |
+
llamafactory==0.9.2.dev0
|
| 82 |
+
lhotse==1.31.0
|
| 83 |
+
kiwisolver==1.4.7
|
| 84 |
+
watchfiles==1.0.3
|
| 85 |
+
marshmallow==3.23.1
|
| 86 |
+
overrides==7.7.0
|
| 87 |
+
langchain-text-splitters==0.3.6
|
| 88 |
+
lxml==5.3.0
|
| 89 |
+
blinker==1.8.2
|
| 90 |
+
whisper==1.1.10
|
| 91 |
+
triton==3.1.0
|
| 92 |
+
python-multipart==0.0.12
|
| 93 |
+
isodate==0.7.2
|
| 94 |
+
wandb==0.19.8
|
| 95 |
+
nvidia-ml-py==12.560.30
|
| 96 |
+
h11==0.14.0
|
| 97 |
+
zipp==3.20.2
|
| 98 |
+
transformers==4.45.0
|
| 99 |
+
websocket-client==1.8.0
|
| 100 |
+
opentelemetry-instrumentation==0.50b0
|
| 101 |
+
pydantic==2.9.2
|
| 102 |
+
latex2mathml==3.77.0
|
| 103 |
+
numpy-rms==0.4.2
|
| 104 |
+
opentelemetry-exporter-otlp-proto-grpc==1.29.0
|
| 105 |
+
humanfriendly==10.0
|
| 106 |
+
decorator==5.1.1
|
| 107 |
+
fonttools==4.54.1
|
| 108 |
+
fire==0.7.0
|
| 109 |
+
ninja==1.11.1.1
|
| 110 |
+
shortuuid==1.0.13
|
| 111 |
+
tiktoken==0.8.0
|
| 112 |
+
aliyun-python-sdk-kms==2.16.5
|
| 113 |
+
einops==0.8.0
|
| 114 |
+
threadpoolctl==3.5.0
|
| 115 |
+
docker-pycreds==0.4.0
|
| 116 |
+
Flask==3.0.3
|
| 117 |
+
opentelemetry-sdk==1.29.0
|
| 118 |
+
opentelemetry-exporter-otlp-proto-common==1.29.0
|
| 119 |
+
pylatexenc==2.10
|
| 120 |
+
orjson==3.10.10
|
| 121 |
+
durationpy==0.9
|
| 122 |
+
addict==2.4.0
|
| 123 |
+
py-cpuinfo==9.0.0
|
| 124 |
+
contourpy==1.3.0
|
| 125 |
+
crcmod==1.7
|
| 126 |
+
pydantic-settings==2.6.1
|
| 127 |
+
pyproject_hooks==1.2.0
|
| 128 |
+
future==1.0.0
|
| 129 |
+
jsonschema-specifications==2024.10.1
|
| 130 |
+
coloredlogs==15.0.1
|
| 131 |
+
timm==0.6.13
|
| 132 |
+
deepspeed==0.14.5
|
| 133 |
+
referencing==0.35.1
|
| 134 |
+
binpacking==1.5.2
|
| 135 |
+
peft==0.12.0
|
| 136 |
+
language-tags==1.2.0
|
| 137 |
+
speechtokenizer==1.0.1
|
| 138 |
+
shellingham==1.5.4
|
| 139 |
+
primp==0.12.1
|
| 140 |
+
tavily-python==0.5.1
|
| 141 |
+
uvicorn==0.32.0
|
| 142 |
+
opentelemetry-proto==1.29.0
|
| 143 |
+
typing-inspect==0.9.0
|
| 144 |
+
backoff==2.2.1
|
| 145 |
+
sortedcontainers==2.4.0
|
| 146 |
+
gitdb==4.0.12
|
| 147 |
+
aiofiles==23.2.1
|
| 148 |
+
jsonschema==4.23.0
|
| 149 |
+
svgwrite==1.4.3
|
| 150 |
+
protobuf==5.29.1
|
| 151 |
+
starlette==0.38.6
|
| 152 |
+
transformers-stream-generator==0.0.5
|
| 153 |
+
sentry-sdk==2.22.0
|
| 154 |
+
toolz==1.0.0
|
| 155 |
+
einops-exts==0.0.4
|
| 156 |
+
WhisperSpeech==0.8
|
| 157 |
+
hydra-core==1.3.2
|
| 158 |
+
portalocker==2.10.1
|
| 159 |
+
jieba==0.42.1
|
| 160 |
+
pandas==2.2.3
|
| 161 |
+
requests==2.32.3
|
| 162 |
+
flash-attn==2.6.3
|
| 163 |
+
msgpack==1.1.0
|
| 164 |
+
chroma-hnswlib==0.7.6
|
| 165 |
+
librosa==0.10.2.post1
|
| 166 |
+
sniffio==1.3.1
|
| 167 |
+
smmap==5.0.2
|
| 168 |
+
opentelemetry-api==1.29.0
|
| 169 |
+
websockets==14.2
|
| 170 |
+
kubernetes==31.0.0
|
| 171 |
+
audioread==3.0.1
|
| 172 |
+
docstring_parser==0.16
|
| 173 |
+
scipy==1.12.0
|
| 174 |
+
aliyun-python-sdk-core==2.16.0
|
| 175 |
+
accelerate==1.0.0
|
| 176 |
+
dill==0.3.8
|
| 177 |
+
llama-omni==1.0.0
|
| 178 |
+
mdurl==0.1.2
|
| 179 |
+
chromadb==0.5.23
|
| 180 |
+
oss2==2.19.0
|
| 181 |
+
rdflib==7.1.1
|
| 182 |
+
bibtexparser==2.0.0b8
|
| 183 |
+
rpds-py==0.22.3
|
| 184 |
+
soundfile==0.12.1
|
| 185 |
+
langdetect==1.0.9
|
| 186 |
+
duckdb==1.2.0
|
| 187 |
+
numpy==1.26.3
|
| 188 |
+
dataclasses-json==0.6.7
|
| 189 |
+
tokenizers==0.20.3
|
| 190 |
+
cpm-kernels==1.0.11
|
| 191 |
+
einx==0.3.0
|
| 192 |
+
langchain-core==0.3.34
|
| 193 |
+
clldutils==3.24.0
|
| 194 |
+
openai-whisper==20240930
|
| 195 |
+
setuptools==69.5.1
|
| 196 |
+
requests-oauthlib==2.0.0
|
| 197 |
+
langchain-community==0.3.17
|
| 198 |
+
langsmith==0.2.3
|
| 199 |
+
colorama==0.4.6
|
| 200 |
+
omegaconf==2.3.0
|
| 201 |
+
asgiref==3.8.1
|
| 202 |
+
pydantic_core==2.23.4
|
| 203 |
+
ffmpy==0.4.0
|
| 204 |
+
multiprocess==0.70.16
|
| 205 |
+
mmh3==5.0.1
|
| 206 |
+
babel==2.16.0
|
| 207 |
+
phonemizer==3.3.0
|
| 208 |
+
pycryptodome==3.21.0
|
| 209 |
+
gradio==4.44.1
|
| 210 |
+
google-genai==1.5.0
|
| 211 |
+
tzdata==2024.2
|
| 212 |
+
llvmlite==0.43.0
|
| 213 |
+
cachetools==5.5.0
|
| 214 |
+
seaborn==0.13.2
|
| 215 |
+
httptools==0.6.4
|
| 216 |
+
GitPython==3.1.44
|
| 217 |
+
markdown-it-py==3.0.0
|
| 218 |
+
beartype==0.20.2
|
| 219 |
+
whisper_normalizer==0.0.10
|
| 220 |
+
dlinfo==1.2.1
|
| 221 |
+
vocos==0.1.0
|
| 222 |
+
itsdangerous==2.2.0
|
| 223 |
+
bitarray==3.0.0
|
| 224 |
+
opentelemetry-instrumentation-fastapi==0.50b0
|
| 225 |
+
setproctitle==1.3.5
|
| 226 |
+
cycler==0.12.1
|
| 227 |
+
vector-quantize-pytorch==1.18.5
|
| 228 |
+
jmespath==0.10.0
|
| 229 |
+
mypy-extensions==1.0.0
|
| 230 |
+
flatbuffers==24.3.25
|
| 231 |
+
scikit-learn==1.5.2
|
| 232 |
+
pytz==2024.2
|
| 233 |
+
pyparsing==3.2.0
|
| 234 |
+
posthog==3.7.4
|
| 235 |
+
rouge==1.0.1
|
| 236 |
+
semantic-version==2.10.0
|
| 237 |
+
httpcore==1.0.6
|
| 238 |
+
soxr==0.5.0.post1
|
| 239 |
+
importlib_metadata==8.5.0
|
| 240 |
+
audiomentations==0.36.1
|
| 241 |
+
shtab==1.7.1
|
| 242 |
+
Unidecode==1.3.8
|
| 243 |
+
click==8.1.8
|
| 244 |
+
tensorboardX==2.6.2.2
|
| 245 |
+
greenlet==3.1.1
|
| 246 |
+
nltk==3.9.1
|
| 247 |
+
gradio_client==1.3.0
|
| 248 |
+
datasets==2.21.0
|
| 249 |
+
attrdict==2.0.1
|
| 250 |
+
llamafactory==0.9.2.dev0
|
| 251 |
+
ms-swift==2.6.0.dev0
|
| 252 |
+
Brotli==1.0.9
|
| 253 |
+
Cython==3.0.10
|
| 254 |
+
HyperPyYAML==1.2.2
|
| 255 |
+
Markdown==3.6
|
| 256 |
+
MarkupSafe==2.1.3
|
| 257 |
+
PySocks==1.7.1
|
| 258 |
+
PyYAML==6.0.1
|
| 259 |
+
absl-py==2.1.0
|
| 260 |
+
aiohttp==3.9.5
|
| 261 |
+
aiosignal==1.3.1
|
| 262 |
+
anaconda-anon-usage==0.4.4
|
| 263 |
+
archspec==0.2.3
|
| 264 |
+
attrs==23.2.0
|
| 265 |
+
boltons==23.0.0
|
| 266 |
+
certifi==2024.6.2
|
| 267 |
+
cffi==1.16.0
|
| 268 |
+
charset-normalizer==2.0.4
|
| 269 |
+
click==8.1.7
|
| 270 |
+
conda==24.5.0
|
| 271 |
+
conda-content-trust==0.2.0
|
| 272 |
+
conda-libmamba-solver==24.1.0
|
| 273 |
+
conda-package-handling==2.2.0
|
| 274 |
+
conda_package_streaming==0.9.0
|
| 275 |
+
cryptography==42.0.5
|
| 276 |
+
distro==1.9.0
|
| 277 |
+
filelock==3.13.1
|
| 278 |
+
frozendict==2.4.2
|
| 279 |
+
frozenlist==1.4.1
|
| 280 |
+
fsspec==2024.6.0
|
| 281 |
+
grpcio==1.64.1
|
| 282 |
+
huggingface-hub==0.23.3
|
| 283 |
+
idna==3.7
|
| 284 |
+
Jinja2==3.1.4
|
| 285 |
+
jiwer==3.0.4
|
| 286 |
+
jsonargparse==4.29.0
|
| 287 |
+
jsonpatch==1.33
|
| 288 |
+
jsonpointer==2.1
|
| 289 |
+
kaldialign==0.9.1
|
| 290 |
+
libmambapy==1.5.8
|
| 291 |
+
lightning==2.2.5
|
| 292 |
+
lightning-utilities==0.11.2
|
| 293 |
+
llvmlite==0.42.0
|
| 294 |
+
menuinst==2.0.2
|
| 295 |
+
mkl-fft==1.3.8
|
| 296 |
+
mkl-random==1.2.4
|
| 297 |
+
mkl-service==2.4.0
|
| 298 |
+
mpmath==1.3.0
|
| 299 |
+
multidict==6.0.5
|
| 300 |
+
networkx==3.2.1
|
| 301 |
+
numba==0.59.1
|
| 302 |
+
numpy==1.26.4
|
| 303 |
+
packaging==23.2
|
| 304 |
+
pillow==10.3.0
|
| 305 |
+
pip==24.0
|
| 306 |
+
platformdirs==3.10.0
|
| 307 |
+
pluggy==1.0.0
|
| 308 |
+
protobuf==4.25.3
|
| 309 |
+
pycosat==0.6.6
|
| 310 |
+
pycparser==2.21
|
| 311 |
+
pytorch-lightning==2.2.5
|
| 312 |
+
rapidfuzz==3.9.3
|
| 313 |
+
regex==2024.5.15
|
| 314 |
+
requests==2.31.0
|
| 315 |
+
ruamel.yaml==0.18.6
|
| 316 |
+
ruamel.yaml.clib==0.2.8
|
| 317 |
+
safetensors==0.4.3
|
| 318 |
+
scipy==1.13.1
|
| 319 |
+
sentencepiece==0.2.0
|
| 320 |
+
setuptools==69.5.1
|
| 321 |
+
six==1.16.0
|
| 322 |
+
sympy==1.12
|
| 323 |
+
tensorboard==2.17.0
|
| 324 |
+
tensorboard-data-server==0.7.2
|
| 325 |
+
tokenizers==0.19.1
|
| 326 |
+
torch==2.2.1
|
| 327 |
+
torch-complex==0.4.3
|
| 328 |
+
torchaudio==2.2.1
|
| 329 |
+
torchmetrics==1.4.0.post0
|
| 330 |
+
torchvision==0.17.1
|
| 331 |
+
tqdm==4.66.2
|
| 332 |
+
transformers==4.41.2
|
| 333 |
+
truststore==0.8.0
|
| 334 |
+
typeguard==2.13.3
|
| 335 |
+
typing_extensions==4.11.0
|
| 336 |
+
urllib3==2.1.0
|
| 337 |
+
Werkzeug==3.0.3
|
| 338 |
+
wheel==0.43.0
|
| 339 |
+
yarl==1.9.4
|
| 340 |
+
zstandard==0.22.0
|
| 341 |
+
warprnnt_pytorch==0.1
|
scripts/wandb/latest-run/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-1029-nvidia-x86_64-with-glibc2.31",
|
| 3 |
+
"python": "CPython 3.12.3",
|
| 4 |
+
"startedAt": "2025-04-10T10:19:28.834922Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--local_rank=0",
|
| 7 |
+
"--deepspeed",
|
| 8 |
+
"zero2.json",
|
| 9 |
+
"--model_name_or_path",
|
| 10 |
+
"/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-7B-Instruct",
|
| 11 |
+
"--pretrained_llm_path",
|
| 12 |
+
"/data1/speech/anhnmt2/cuongnm/EOT/Qwen2.5-0.5B-Instruct",
|
| 13 |
+
"--tokenizer_path",
|
| 14 |
+
"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6",
|
| 15 |
+
"--cache_dir",
|
| 16 |
+
"../output/cached_sft_20252502",
|
| 17 |
+
"--audio_encoder_path",
|
| 18 |
+
"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/omni_speech/model/minicpmo/MiniCPM-o-2_6",
|
| 19 |
+
"--llm_type",
|
| 20 |
+
"qwen",
|
| 21 |
+
"--data_path",
|
| 22 |
+
"/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/train_asr_mixed_500k.jsonl",
|
| 23 |
+
"--eval_data_path",
|
| 24 |
+
"/data1/speech/anhnmt2/dataset/s2s/minicpmo/asr/dev_asr_mixed.jsonl",
|
| 25 |
+
"--config_path",
|
| 26 |
+
"minicpmp_config.json",
|
| 27 |
+
"--remove_unused_columns",
|
| 28 |
+
"false",
|
| 29 |
+
"--prediction_loss_only",
|
| 30 |
+
"false",
|
| 31 |
+
"--bf16",
|
| 32 |
+
"true",
|
| 33 |
+
"--do_train",
|
| 34 |
+
"--do_eval",
|
| 35 |
+
"--tune_speech",
|
| 36 |
+
"false",
|
| 37 |
+
"--tune_llm",
|
| 38 |
+
"false",
|
| 39 |
+
"--model_max_length",
|
| 40 |
+
"2048",
|
| 41 |
+
"--eval_steps",
|
| 42 |
+
"3000",
|
| 43 |
+
"--output_dir",
|
| 44 |
+
"../checkpoints/minicpmo_whisper-medium_Qwen2.5-0.5B_pretrained-asr-projector",
|
| 45 |
+
"--num_train_epochs",
|
| 46 |
+
"3",
|
| 47 |
+
"--logging_strategy",
|
| 48 |
+
"steps",
|
| 49 |
+
"--per_device_train_batch_size",
|
| 50 |
+
"8",
|
| 51 |
+
"--per_device_eval_batch_size",
|
| 52 |
+
"8",
|
| 53 |
+
"--gradient_accumulation_steps",
|
| 54 |
+
"4",
|
| 55 |
+
"--evaluation_strategy",
|
| 56 |
+
"steps",
|
| 57 |
+
"--save_strategy",
|
| 58 |
+
"steps",
|
| 59 |
+
"--save_steps",
|
| 60 |
+
"5000",
|
| 61 |
+
"--save_total_limit",
|
| 62 |
+
"1",
|
| 63 |
+
"--learning_rate",
|
| 64 |
+
"5e-5",
|
| 65 |
+
"--weight_decay",
|
| 66 |
+
"0.",
|
| 67 |
+
"--warmup_ratio",
|
| 68 |
+
"0.03",
|
| 69 |
+
"--lr_scheduler_type",
|
| 70 |
+
"cosine",
|
| 71 |
+
"--logging_steps",
|
| 72 |
+
"1",
|
| 73 |
+
"--tf32",
|
| 74 |
+
"true",
|
| 75 |
+
"--gradient_checkpointing",
|
| 76 |
+
"true"
|
| 77 |
+
],
|
| 78 |
+
"program": "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/../omni_speech/train/train_minicpmo_test.py",
|
| 79 |
+
"codePath": "omni_speech/train/train_minicpmo_test.py",
|
| 80 |
+
"git": {
|
| 81 |
+
"remote": "https://bitbucket.org/vinbdi-slp/half-streaming-speech-nlp.git",
|
| 82 |
+
"commit": "3876ef3c080c3ca44ad5ea0bd316241f0323ada6"
|
| 83 |
+
},
|
| 84 |
+
"email": "cuong220103@gmail.com",
|
| 85 |
+
"root": "/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts",
|
| 86 |
+
"host": "dgx-a100-5",
|
| 87 |
+
"executable": "/opt/conda/bin/python3",
|
| 88 |
+
"cpu_count": 128,
|
| 89 |
+
"cpu_count_logical": 256,
|
| 90 |
+
"gpu": "NVIDIA A100-SXM4-40GB",
|
| 91 |
+
"gpu_count": 1,
|
| 92 |
+
"disk": {
|
| 93 |
+
"/": {
|
| 94 |
+
"total": "1900954378240",
|
| 95 |
+
"used": "286067507200"
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"memory": {
|
| 99 |
+
"total": "1081975545856"
|
| 100 |
+
},
|
| 101 |
+
"cpu": {
|
| 102 |
+
"count": 128,
|
| 103 |
+
"countLogical": 256
|
| 104 |
+
},
|
| 105 |
+
"gpu_nvidia": [
|
| 106 |
+
{
|
| 107 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
| 108 |
+
"memoryTotal": "42949672960",
|
| 109 |
+
"cudaCores": 6912,
|
| 110 |
+
"architecture": "Ampere"
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"slurm": {
|
| 114 |
+
"cluster_name": "slurm",
|
| 115 |
+
"conf": "/cm/shared/apps/slurm/var/etc/slurm/slurm.conf",
|
| 116 |
+
"cpus_on_node": "24",
|
| 117 |
+
"cpus_per_task": "24",
|
| 118 |
+
"gpus_on_node": "1",
|
| 119 |
+
"gpus_per_node": "1",
|
| 120 |
+
"gtids": "0",
|
| 121 |
+
"job_cpus_per_node": "24",
|
| 122 |
+
"job_end_time": "1775042326",
|
| 123 |
+
"job_gid": "1400",
|
| 124 |
+
"job_group": "speech",
|
| 125 |
+
"job_id": "5154",
|
| 126 |
+
"job_name": "bash",
|
| 127 |
+
"job_nodelist": "dgx-a100-5",
|
| 128 |
+
"job_num_nodes": "1",
|
| 129 |
+
"job_partition": "defq",
|
| 130 |
+
"job_qos": "normal",
|
| 131 |
+
"job_start_time": "1743506326",
|
| 132 |
+
"job_uid": "1407",
|
| 133 |
+
"job_user": "anhnmt2",
|
| 134 |
+
"jobid": "5154",
|
| 135 |
+
"launch_node_ipaddr": "192.168.100.102",
|
| 136 |
+
"localid": "0",
|
| 137 |
+
"mpi_type": "pmix",
|
| 138 |
+
"nnodes": "1",
|
| 139 |
+
"nodeid": "0",
|
| 140 |
+
"nodelist": "dgx-a100-5",
|
| 141 |
+
"nprocs": "1",
|
| 142 |
+
"ntasks": "1",
|
| 143 |
+
"ntasks_per_node": "1",
|
| 144 |
+
"pmix_mapping_serv": "(vector,(0,1,1))",
|
| 145 |
+
"pmixp_abort_agent_port": "37119",
|
| 146 |
+
"prio_process": "0",
|
| 147 |
+
"procid": "0",
|
| 148 |
+
"pty_port": "45373",
|
| 149 |
+
"pty_win_col": "137",
|
| 150 |
+
"pty_win_row": "10",
|
| 151 |
+
"srun_comm_host": "192.168.100.102",
|
| 152 |
+
"srun_comm_port": "43475",
|
| 153 |
+
"step_gpus": "4",
|
| 154 |
+
"step_id": "0",
|
| 155 |
+
"step_launcher_port": "43475",
|
| 156 |
+
"step_nodelist": "dgx-a100-5",
|
| 157 |
+
"step_num_nodes": "1",
|
| 158 |
+
"step_num_tasks": "1",
|
| 159 |
+
"step_tasks_per_node": "1",
|
| 160 |
+
"stepid": "0",
|
| 161 |
+
"submit_dir": "/data1/speech/anhnmt2/ASR/speechgpt/slurm/submit",
|
| 162 |
+
"submit_host": "login-1",
|
| 163 |
+
"task_pid": "268175",
|
| 164 |
+
"tasks_per_node": "1",
|
| 165 |
+
"topology_addr": "dgx-a100-5",
|
| 166 |
+
"topology_addr_pattern": "node",
|
| 167 |
+
"umask": "0022",
|
| 168 |
+
"working_cluster": "slurm:bcm10-headnode:6817:9984:109"
|
| 169 |
+
},
|
| 170 |
+
"cudaVersion": "12.2"
|
| 171 |
+
}
|
scripts/wandb/latest-run/logs/debug-core.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-04-10T17:19:28.173097267+07:00","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpm4_vxj8m/port-1734298.txt","pid":1734298,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false}
|
| 2 |
+
{"time":"2025-04-10T17:19:28.173483898+07:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":44091,"Zone":""}}
|
| 3 |
+
{"time":"2025-04-10T17:19:28.173583196+07:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1734298}
|
| 4 |
+
{"time":"2025-04-10T17:19:28.338675346+07:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:60304"}
|
| 5 |
+
{"time":"2025-04-10T17:19:28.838813222+07:00","level":"INFO","msg":"handleInformInit: received","streamId":"pfaibe0c","id":"127.0.0.1:60304"}
|
| 6 |
+
{"time":"2025-04-10T17:19:28.960357084+07:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"pfaibe0c","id":"127.0.0.1:60304"}
|
| 7 |
+
{"time":"2025-04-10T17:20:36.908864225+07:00","level":"INFO","msg":"received shutdown signal","signal":15}
|
scripts/wandb/latest-run/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-04-10T17:19:28.842729448+07:00","level":"INFO","msg":"stream: starting","core version":"0.19.8","symlink path":"/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/scripts/wandb/run-20250410_171928-pfaibe0c/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-04-10T17:19:28.960322418+07:00","level":"INFO","msg":"created new stream","id":"pfaibe0c"}
|
| 3 |
+
{"time":"2025-04-10T17:19:28.960351593+07:00","level":"INFO","msg":"stream: started","id":"pfaibe0c"}
|
| 4 |
+
{"time":"2025-04-10T17:19:28.960375959+07:00","level":"INFO","msg":"writer: Do: started","stream_id":"pfaibe0c"}
|
| 5 |
+
{"time":"2025-04-10T17:19:28.960456552+07:00","level":"INFO","msg":"handler: started","stream_id":"pfaibe0c"}
|
| 6 |
+
{"time":"2025-04-10T17:19:28.961574927+07:00","level":"INFO","msg":"sender: started","stream_id":"pfaibe0c"}
|
| 7 |
+
{"time":"2025-04-10T17:19:29.497777718+07:00","level":"INFO","msg":"Starting system monitor"}
|