Update README.md
Browse files
README.md
CHANGED
|
@@ -139,3 +139,153 @@ for i, code_list in enumerate(code_lists):
|
|
| 139 |
wavfile.write(filename, 24000, sample_np)
|
| 140 |
print(f"Saved audio to: {filename}")
|
| 141 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
wavfile.write(filename, 24000, sample_np)
|
| 140 |
print(f"Saved audio to: {filename}")
|
| 141 |
```
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## Streaming sample
|
| 145 |
+
|
| 146 |
+
### Sever side command
|
| 147 |
+
```
|
| 148 |
+
python3 -m vllm.entrypoints.openai.api_server --model VoiceCore_gptq --host 0.0.0.0 --port 8000 --max-model-len 9000python3 -m vllm.entrypoints.openai.api_server --model VoiceCore_gptq --host 0.0.0.0 --port 8000 --max-model-len 9000
|
| 149 |
+
```
|
| 150 |
+
### Client side scripyt
|
| 151 |
+
```
|
| 152 |
+
import torch
|
| 153 |
+
from transformers import AutoTokenizer
|
| 154 |
+
from snac import SNAC
|
| 155 |
+
import requests
|
| 156 |
+
import json
|
| 157 |
+
import sounddevice as sd
|
| 158 |
+
import numpy as np
|
| 159 |
+
import queue
|
| 160 |
+
import threading
|
| 161 |
+
|
| 162 |
+
# --- サーバー設定とモデルの準備 (変更なし) ---
|
| 163 |
+
SERVER_URL = "http://192.168.1.16:8000/v1/completions"
|
| 164 |
+
TOKENIZER_PATH = "webbigdata/VoiceCore_gptq"
|
| 165 |
+
MODEL_NAME = "VoiceCore_gptq"
|
| 166 |
+
|
| 167 |
+
prompts = [
|
| 168 |
+
"テストです",
|
| 169 |
+
"ジーピーティーキュー、問題なく動いてますかね?圧縮しすぎると別人の声になっちゃう事があるんですよね、ふふふ"
|
| 170 |
+
]
|
| 171 |
+
chosen_voice = "matsukaze_male[neutral]"
|
| 172 |
+
|
| 173 |
+
print("Loading tokenizer...")
|
| 174 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
|
| 175 |
+
start_token, end_tokens = [128259], [128009, 128260, 128261]
|
| 176 |
+
|
| 177 |
+
print("Loading SNAC decoder to CPU...")
|
| 178 |
+
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
| 179 |
+
snac_model.to("cpu")
|
| 180 |
+
print("SNAC model loaded.")
|
| 181 |
+
audio_start_token = 128257
|
| 182 |
+
|
| 183 |
+
def redistribute_codes(code_list):
|
| 184 |
+
if len(code_list) % 7 != 0: return torch.tensor([])
|
| 185 |
+
layer_1, layer_2, layer_3 = [], [], []
|
| 186 |
+
for i in range(len(code_list) // 7):
|
| 187 |
+
layer_1.append(code_list[7*i])
|
| 188 |
+
layer_2.append(code_list[7*i+1] - 4096)
|
| 189 |
+
layer_3.append(code_list[7*i+2] - (2*4096)); layer_3.append(code_list[7*i+3] - (3*4096))
|
| 190 |
+
layer_2.append(code_list[7*i+4] - (4*4096)); layer_3.append(code_list[7*i+5] - (5*4096))
|
| 191 |
+
layer_3.append(code_list[7*i+6] - (6*4096))
|
| 192 |
+
codes = [torch.tensor(layer).unsqueeze(0) for layer in [layer_1, layer_2, layer_3]]
|
| 193 |
+
return snac_model.decode(codes)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def audio_playback_worker(q, stream):
|
| 197 |
+
while True:
|
| 198 |
+
data = q.get()
|
| 199 |
+
if data is None:
|
| 200 |
+
break
|
| 201 |
+
stream.write(data)
|
| 202 |
+
|
| 203 |
+
for i, prompt in enumerate(prompts):
|
| 204 |
+
print("\n" + "="*50)
|
| 205 |
+
print(f"Processing prompt ({i+1}/{len(prompts)}): '{prompt}'")
|
| 206 |
+
print("="*50)
|
| 207 |
+
|
| 208 |
+
prompt_ = (f"{chosen_voice}: " + prompt) if chosen_voice else prompt
|
| 209 |
+
input_ids = tokenizer.encode(prompt_)
|
| 210 |
+
final_token_ids = start_token + input_ids + end_tokens
|
| 211 |
+
|
| 212 |
+
payload = {
|
| 213 |
+
"model": MODEL_NAME, "prompt": final_token_ids,
|
| 214 |
+
"max_tokens": 8192, "temperature": 0.6, "top_p": 0.90,
|
| 215 |
+
"repetition_penalty": 1.1, "stop_token_ids": [128258],
|
| 216 |
+
"stream": True
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
token_buffer = []
|
| 220 |
+
found_audio_start = False
|
| 221 |
+
CHUNK_SIZE = 28
|
| 222 |
+
|
| 223 |
+
audio_queue = queue.Queue()
|
| 224 |
+
playback_stream = sd.OutputStream(samplerate=24000, channels=1, dtype='float32')
|
| 225 |
+
playback_stream.start()
|
| 226 |
+
|
| 227 |
+
playback_thread = threading.Thread(target=audio_playback_worker, args=(audio_queue, playback_stream))
|
| 228 |
+
playback_thread.start()
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
response = requests.post(SERVER_URL, headers={"Content-Type": "application/json"}, json=payload, stream=True)
|
| 232 |
+
response.raise_for_status()
|
| 233 |
+
|
| 234 |
+
for line in response.iter_lines():
|
| 235 |
+
if line:
|
| 236 |
+
decoded_line = line.decode('utf-8')
|
| 237 |
+
if decoded_line.startswith('data: '):
|
| 238 |
+
content = decoded_line[6:]
|
| 239 |
+
if content == '[DONE]':
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
chunk = json.loads(content)
|
| 244 |
+
text_chunk = chunk['choices'][0]['text']
|
| 245 |
+
if text_chunk:
|
| 246 |
+
token_buffer.extend(tokenizer.encode(text_chunk, add_special_tokens=False))
|
| 247 |
+
|
| 248 |
+
if not found_audio_start:
|
| 249 |
+
try:
|
| 250 |
+
start_index = token_buffer.index(audio_start_token)
|
| 251 |
+
token_buffer = token_buffer[start_index + 1:]
|
| 252 |
+
found_audio_start = True
|
| 253 |
+
print("Audio start token found. Starting playback...")
|
| 254 |
+
except ValueError:
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
while len(token_buffer) >= CHUNK_SIZE:
|
| 258 |
+
tokens_to_process = token_buffer[:CHUNK_SIZE]
|
| 259 |
+
token_buffer = token_buffer[CHUNK_SIZE:]
|
| 260 |
+
|
| 261 |
+
code_list = [t - 128266 for t in tokens_to_process]
|
| 262 |
+
samples = redistribute_codes(code_list)
|
| 263 |
+
|
| 264 |
+
if samples.numel() > 0:
|
| 265 |
+
sample_np = samples.detach().squeeze().numpy()
|
| 266 |
+
audio_queue.put(sample_np)
|
| 267 |
+
|
| 268 |
+
except (json.JSONDecodeError, Exception) as e:
|
| 269 |
+
print(f"処理中にエラー: {e}")
|
| 270 |
+
|
| 271 |
+
if found_audio_start and token_buffer:
|
| 272 |
+
remaining_length = (len(token_buffer) // 7) * 7
|
| 273 |
+
if remaining_length > 0:
|
| 274 |
+
tokens_to_process = token_buffer[:remaining_length]
|
| 275 |
+
code_list = [t - 128266 for t in tokens_to_process]
|
| 276 |
+
samples = redistribute_codes(code_list)
|
| 277 |
+
if samples.numel() > 0:
|
| 278 |
+
sample_np = samples.detach().squeeze().numpy()
|
| 279 |
+
audio_queue.put(sample_np)
|
| 280 |
+
|
| 281 |
+
except requests.exceptions.RequestException as e:
|
| 282 |
+
print(f"サーバーへのリクエストでエラーが発生しました: {e}")
|
| 283 |
+
finally:
|
| 284 |
+
audio_queue.put(None)
|
| 285 |
+
playback_thread.join()
|
| 286 |
+
playback_stream.stop()
|
| 287 |
+
playback_stream.close()
|
| 288 |
+
print("Playback finished for this prompt.")
|
| 289 |
+
|
| 290 |
+
print("\nAll processing complete!")
|
| 291 |
+
```
|