|
|
import argparse |
|
|
import os |
|
|
import wave |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import snapshot_download |
|
|
from snac import SNAC |
|
|
|
|
|
|
|
|
def load_models(model_path: str, device: str = "cuda"): |
|
|
|
|
|
print("Loading SNAC model...") |
|
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
|
snac_model = snac_model.to(device) |
|
|
|
|
|
|
|
|
print(f"Loading Orpheus model from: {model_path}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path,fix_mistral_regex=True) |
|
|
|
|
|
print(f"Models loaded on {device}") |
|
|
return model, tokenizer, snac_model |
|
|
|
|
|
|
|
|
def process_prompt(prompt: str, voice: str, tokenizer, device: str): |
|
|
""" |
|
|
1:1 die Logik aus app.py: |
|
|
- voice + ": " + text |
|
|
- SOH (128259) vornedran |
|
|
- EOT (128009) + EOH (128260) hinten dran |
|
|
""" |
|
|
prompt = f"— {prompt}" |
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
|
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
|
|
|
|
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) |
|
|
attention_mask = torch.ones_like(modified_input_ids) |
|
|
|
|
|
return modified_input_ids.to(device), attention_mask.to(device) |
|
|
|
|
|
|
|
|
def parse_output(generated_ids: torch.Tensor): |
|
|
""" |
|
|
1:1 aus app.py: |
|
|
- nach Token 128257 schneiden |
|
|
- 128258 entfernen |
|
|
- Codes in 7er-Gruppen trimmen |
|
|
- 128266 abziehen |
|
|
""" |
|
|
token_to_find = 128257 |
|
|
token_to_remove = 128258 |
|
|
|
|
|
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) |
|
|
|
|
|
if len(token_indices[1]) > 0: |
|
|
last_occurrence_idx = token_indices[1][-1].item() |
|
|
cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :] |
|
|
else: |
|
|
cropped_tensor = generated_ids |
|
|
|
|
|
processed_rows = [] |
|
|
for row in cropped_tensor: |
|
|
masked_row = row[row != token_to_remove] |
|
|
processed_rows.append(masked_row) |
|
|
|
|
|
code_lists = [] |
|
|
for row in processed_rows: |
|
|
row_length = row.size(0) |
|
|
new_length = (row_length // 7) * 7 |
|
|
trimmed_row = row[:new_length] |
|
|
trimmed_row = [t - 128266 for t in trimmed_row] |
|
|
code_lists.append(trimmed_row) |
|
|
|
|
|
return code_lists[0] |
|
|
|
|
|
|
|
|
def redistribute_codes(code_list, snac_model: SNAC): |
|
|
""" |
|
|
Ebenfalls 1:1 aus app.py – SNAC-Code in Ebenen splitten und dekodieren. |
|
|
""" |
|
|
device = next(snac_model.parameters()).device |
|
|
|
|
|
layer_1 = [] |
|
|
layer_2 = [] |
|
|
layer_3 = [] |
|
|
for i in range((len(code_list) + 1) // 7): |
|
|
layer_1.append(code_list[7 * i]) |
|
|
layer_2.append(code_list[7 * i + 1] - 4096) |
|
|
layer_3.append(code_list[7 * i + 2] - (2 * 4096)) |
|
|
layer_3.append(code_list[7 * i + 3] - (3 * 4096)) |
|
|
layer_2.append(code_list[7 * i + 4] - (4 * 4096)) |
|
|
layer_3.append(code_list[7 * i + 5] - (5 * 4096)) |
|
|
layer_3.append(code_list[7 * i + 6] - (6 * 4096)) |
|
|
|
|
|
codes = [ |
|
|
torch.tensor(layer_1, device=device).unsqueeze(0), |
|
|
torch.tensor(layer_2, device=device).unsqueeze(0), |
|
|
torch.tensor(layer_3, device=device).unsqueeze(0), |
|
|
] |
|
|
|
|
|
audio_hat = snac_model.decode(codes) |
|
|
return audio_hat.detach().squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
def generate_speech_once( |
|
|
text: str, |
|
|
voice: str, |
|
|
model, |
|
|
tokenizer, |
|
|
snac_model, |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.05, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_new_tokens: int = 7500, |
|
|
): |
|
|
""" |
|
|
Exakt wie in app.py: 1 Durchlauf, 1 Audio. |
|
|
""" |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
if not text.strip(): |
|
|
return None |
|
|
|
|
|
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
num_return_sequences=1, |
|
|
eos_token_id=128258, |
|
|
) |
|
|
|
|
|
code_list = parse_output(generated_ids) |
|
|
audio_samples = redistribute_codes(code_list, snac_model) |
|
|
|
|
|
sr = 24000 |
|
|
return sr, audio_samples |
|
|
|
|
|
|
|
|
def save_wav(path: str, sr: int, audio: np.ndarray): |
|
|
|
|
|
audio_clipped = np.clip(audio, -1.0, 1.0) |
|
|
audio_int16 = (audio_clipped * 32767).astype(np.int16) |
|
|
|
|
|
with wave.open(path, "wb") as wf: |
|
|
wf.setnchannels(1) |
|
|
wf.setsampwidth(2) |
|
|
wf.setframerate(sr) |
|
|
wf.writeframes(audio_int16.tobytes()) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Pfad zum gemergten Modell (z.B. checkpoints/merged)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Text, der gesprochen werden soll", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--voice", |
|
|
type=str, |
|
|
default="leo", |
|
|
help="", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--outfile", |
|
|
type=str, |
|
|
default="output.wav", |
|
|
help="Ausgabedatei (WAV)", |
|
|
) |
|
|
|
|
|
parser.add_argument("--temperature", type=float, default=0.6) |
|
|
parser.add_argument("--top_p", type=float, default=0.95) |
|
|
parser.add_argument("--repetition_penalty", type=float, default=1.1) |
|
|
parser.add_argument("--max_new_tokens", type=int, default=1200) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, tokenizer, snac_model = load_models(args.model_path, device=device) |
|
|
|
|
|
print("Generating speech...") |
|
|
sr, audio = generate_speech_once( |
|
|
text=args.text, |
|
|
voice=args.voice, |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
snac_model=snac_model, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
) |
|
|
|
|
|
print(f"Saving to {args.outfile}") |
|
|
save_wav(args.outfile, sr, audio) |
|
|
print("Done.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|