import argparse import os import wave import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download from snac import SNAC def load_models(model_path: str, device: str = "cuda"): # SNAC-Audiodekoder print("Loading SNAC model...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") snac_model = snac_model.to(device) # LLM-TTS-Modell (dein gemergter Orpheus-Checkpoint) print(f"Loading Orpheus model from: {model_path}") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, ).to(device) # Tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path,fix_mistral_regex=True) print(f"Models loaded on {device}") return model, tokenizer, snac_model def process_prompt(prompt: str, voice: str, tokenizer, device: str): """ 1:1 die Logik aus app.py: - voice + ": " + text - SOH (128259) vornedran - EOT (128009) + EOH (128260) hinten dran """ prompt = f"— {prompt}" input_ids = tokenizer(prompt, return_tensors="pt").input_ids start_token = torch.tensor([[128259]], dtype=torch.int64) # SOH end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # EOT, EOH modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones_like(modified_input_ids) return modified_input_ids.to(device), attention_mask.to(device) def parse_output(generated_ids: torch.Tensor): """ 1:1 aus app.py: - nach Token 128257 schneiden - 128258 entfernen - Codes in 7er-Gruppen trimmen - 128266 abziehen """ token_to_find = 128257 token_to_remove = 128258 token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :] else: cropped_tensor = generated_ids processed_rows = [] for row in cropped_tensor: masked_row = row[row != token_to_remove] processed_rows.append(masked_row) code_lists = [] for row in processed_rows: row_length = row.size(0) new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t - 128266 for t in trimmed_row] code_lists.append(trimmed_row) return code_lists[0] def redistribute_codes(code_list, snac_model: SNAC): """ Ebenfalls 1:1 aus app.py – SNAC-Code in Ebenen splitten und dekodieren. """ device = next(snac_model.parameters()).device layer_1 = [] layer_2 = [] layer_3 = [] for i in range((len(code_list) + 1) // 7): layer_1.append(code_list[7 * i]) layer_2.append(code_list[7 * i + 1] - 4096) layer_3.append(code_list[7 * i + 2] - (2 * 4096)) layer_3.append(code_list[7 * i + 3] - (3 * 4096)) layer_2.append(code_list[7 * i + 4] - (4 * 4096)) layer_3.append(code_list[7 * i + 5] - (5 * 4096)) layer_3.append(code_list[7 * i + 6] - (6 * 4096)) codes = [ torch.tensor(layer_1, device=device).unsqueeze(0), torch.tensor(layer_2, device=device).unsqueeze(0), torch.tensor(layer_3, device=device).unsqueeze(0), ] audio_hat = snac_model.decode(codes) return audio_hat.detach().squeeze().cpu().numpy() def generate_speech_once( text: str, voice: str, model, tokenizer, snac_model, temperature: float = 0.8, top_p: float = 0.9, repetition_penalty: float = 1.05, #temperature: float = 0.7, # Some testing for best "Thorsten" experience ;-) #top_p: float = 0.97, #repetition_penalty: float = 1.2, #max_new_tokens: int = 1200, max_new_tokens: int = 7500, ): """ Exakt wie in app.py: 1 Durchlauf, 1 Audio. """ device = next(model.parameters()).device if not text.strip(): return None input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) with torch.no_grad(): generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=1, eos_token_id=128258, # End-of-human token ) code_list = parse_output(generated_ids) audio_samples = redistribute_codes(code_list, snac_model) sr = 24000 return sr, audio_samples def save_wav(path: str, sr: int, audio: np.ndarray): # Normalisieren, falls nötig audio_clipped = np.clip(audio, -1.0, 1.0) audio_int16 = (audio_clipped * 32767).astype(np.int16) with wave.open(path, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) # 16-bit wf.setframerate(sr) wf.writeframes(audio_int16.tobytes()) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, required=True, help="Pfad zum gemergten Modell (z.B. checkpoints/merged)", ) parser.add_argument( "--text", type=str, required=True, help="Text, der gesprochen werden soll", ) parser.add_argument( "--voice", type=str, default="leo", help="", ) parser.add_argument( "--outfile", type=str, default="output.wav", help="Ausgabedatei (WAV)", ) # Defaults wie im HF-Space: parser.add_argument("--temperature", type=float, default=0.6) parser.add_argument("--top_p", type=float, default=0.95) parser.add_argument("--repetition_penalty", type=float, default=1.1) parser.add_argument("--max_new_tokens", type=int, default=1200) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer, snac_model = load_models(args.model_path, device=device) print("Generating speech...") sr, audio = generate_speech_once( text=args.text, voice=args.voice, model=model, tokenizer=tokenizer, snac_model=snac_model, temperature=args.temperature, top_p=args.top_p, repetition_penalty=args.repetition_penalty, max_new_tokens=args.max_new_tokens, ) print(f"Saving to {args.outfile}") save_wav(args.outfile, sr, audio) print("Done.") if __name__ == "__main__": main()