Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +10 -3
orpheus-tts/engine_class.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer
|
|
| 6 |
import threading
|
| 7 |
import queue
|
| 8 |
from decoder import tokens_decoder_sync
|
|
|
|
| 9 |
|
| 10 |
class OrpheusModel:
|
| 11 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
|
@@ -86,7 +87,7 @@ class OrpheusModel:
|
|
| 86 |
if voice not in self.engine.available_voices:
|
| 87 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
| 88 |
|
| 89 |
-
def _format_prompt(self, prompt, voice="
|
| 90 |
# Use Kartoffel model format based on documentation
|
| 91 |
if voice:
|
| 92 |
full_prompt = f"{voice}: {prompt}"
|
|
@@ -166,9 +167,15 @@ class OrpheusModel:
|
|
| 166 |
token_generator = self.generate_tokens_sync(**kwargs)
|
| 167 |
print("DEBUG: Token generator created successfully")
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
|
|
|
| 172 |
return audio_generator
|
| 173 |
except Exception as e:
|
| 174 |
print(f"DEBUG: Error in generate_speech: {e}")
|
|
|
|
| 6 |
import threading
|
| 7 |
import queue
|
| 8 |
from decoder import tokens_decoder_sync
|
| 9 |
+
from kartoffel_decoder import tokens_decoder_kartoffel_sync
|
| 10 |
|
| 11 |
class OrpheusModel:
|
| 12 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
|
|
|
| 87 |
if voice not in self.engine.available_voices:
|
| 88 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
| 89 |
|
| 90 |
+
def _format_prompt(self, prompt, voice="Jakob", model_type="larger"):
|
| 91 |
# Use Kartoffel model format based on documentation
|
| 92 |
if voice:
|
| 93 |
full_prompt = f"{voice}: {prompt}"
|
|
|
|
| 167 |
token_generator = self.generate_tokens_sync(**kwargs)
|
| 168 |
print("DEBUG: Token generator created successfully")
|
| 169 |
|
| 170 |
+
# Verwende Kartoffel-Decoder für deutsche Modelle
|
| 171 |
+
if "german" in self.model_name.lower() or "kartoffel" in self.model_name.lower():
|
| 172 |
+
print("DEBUG: Using Kartoffel decoder for German model")
|
| 173 |
+
audio_generator = tokens_decoder_kartoffel_sync(token_generator, self.tokenizer)
|
| 174 |
+
else:
|
| 175 |
+
print("DEBUG: Using original decoder")
|
| 176 |
+
audio_generator = tokens_decoder_sync(token_generator)
|
| 177 |
|
| 178 |
+
print("DEBUG: Audio decoder called successfully")
|
| 179 |
return audio_generator
|
| 180 |
except Exception as e:
|
| 181 |
print(f"DEBUG: Error in generate_speech: {e}")
|