Hameed13 commited on
Commit
411b835
·
verified ·
1 Parent(s): ab83827

Update yarngpt/generate.py

Browse files
Files changed (1) hide show
  1. yarngpt/generate.py +118 -80
yarngpt/generate.py CHANGED
@@ -1,114 +1,152 @@
1
- import torch
2
- import torchaudio
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from yarngpt.audiotokenizer import AudioTokenizerV2
5
  import os
 
6
  import logging
 
 
 
 
 
 
7
 
8
  # Configure logging
9
- logging.basicConfig(
10
- level=logging.INFO,
11
- format='[%(asctime)s] %(message)s',
12
- datefmt='%Y-%m-%d %H:%M:%S'
13
- )
 
 
14
 
15
  class TextToSpeech:
16
- """Custom TextToSpeech class that mimics the successful Colab implementation"""
17
 
18
- def __init__(self):
19
- """Initialize the TTS components"""
20
- logging.info("Initializing TextToSpeech class...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Set paths
23
  try:
24
- tokenizer_path = "saheedniyi/YarnGPT2"
 
 
 
25
 
26
- # Check if we're running on HF Spaces or local environment
27
- if os.path.exists("/home/user"): # HF Spaces environment
28
- base_path = "/home/user"
29
- else:
30
- base_path = "."
31
-
32
- wav_tokenizer_config_path = os.path.join(base_path, "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
33
- wav_tokenizer_model_path = os.path.join(base_path, "wavtokenizer_large_speech_320_24k.ckpt")
34
 
35
- # Check if files exist
36
- if not os.path.exists(wav_tokenizer_config_path):
37
- logging.warning(f"Config file not found at {wav_tokenizer_config_path}")
38
-
39
- if not os.path.exists(wav_tokenizer_model_path):
40
- logging.warning(f"Model file not found at {wav_tokenizer_model_path}")
41
-
42
- # Initialize audio tokenizer
43
- logging.info("Initializing audio tokenizer...")
44
- self.audio_tokenizer = AudioTokenizerV2(
45
- tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
46
- )
 
 
 
 
 
 
 
47
 
48
- # Load model
49
- logging.info(f"Loading model from {tokenizer_path}...")
50
- self.model = AutoModelForCausalLM.from_pretrained(
51
- tokenizer_path,
52
- torch_dtype="auto"
53
- ).to(self.audio_tokenizer.device)
54
 
55
- logging.info("TextToSpeech initialization complete")
56
 
57
  except Exception as e:
58
- logging.error(f"Error initializing TextToSpeech: {str(e)}")
59
  import traceback
60
  traceback.print_exc()
61
  raise
62
 
63
- def tts(self, text, output_file, accent="nigerian", language="english", speaker="tayo"):
64
- """Generate speech from text and save to file
 
 
65
 
66
  Args:
67
- text: Text to convert to speech
68
- output_file: Path to save the audio file
69
- accent: Type of accent (currently ignored, uses speaker instead)
70
- language: Language ("english", "yoruba", "igbo", "hausa")
71
- speaker: Voice to use (default: "tayo")
72
-
73
  Returns:
74
- Path to generated audio file
75
  """
 
 
76
  try:
77
- # Map accent to speaker if needed
78
- if accent == "nigerian" and speaker == "tayo":
79
- # Use default speaker
80
- pass
81
- elif accent != "nigerian":
82
- # Could map different accents to different speakers
83
- logging.info(f"Accent '{accent}' requested, using speaker '{speaker}'")
84
-
85
- logging.info(f"Generating audio for text: '{text[:50]}...'")
86
- logging.info(f"Using speaker: {speaker}, language: {language}")
87
 
88
- # Create prompt
89
- prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker)
90
- input_ids = self.audio_tokenizer.tokenize_prompt(prompt)
 
 
 
 
91
 
92
- # Generate audio
93
- output = self.model.generate(
94
- input_ids=input_ids,
95
- temperature=0.1,
96
- repetition_penalty=1.1,
97
- max_length=4000,
98
- )
99
 
100
- # Convert to audio
101
- codes = self.audio_tokenizer.get_codes(output)
102
- audio = self.audio_tokenizer.get_audio(codes)
 
 
 
 
 
 
 
103
 
104
- # Save audio file
105
- logging.info(f"Saving audio to {output_file}")
106
- torchaudio.save(output_file, audio, sample_rate=24000)
 
107
 
108
- return output_file
 
 
 
109
 
110
  except Exception as e:
111
- logging.error(f"Error in TTS generation: {str(e)}")
112
  import traceback
113
  traceback.print_exc()
114
  raise
 
 
 
 
 
1
  import os
2
+ import sys
3
  import logging
4
+ import torch
5
+ import numpy as np
6
+ import warnings
7
+ from transformers import AutoTokenizer, AutoProcessor, AutoModel
8
+ import soundfile as sf
9
+ from typing import Optional, Tuple
10
 
11
  # Configure logging
12
+ logging.basicConfig(level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Suppress irrelevant warnings
17
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*The attention mask and the pad token.*")
18
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*torch.nn.utils.weight_norm is deprecated.*")
19
 
20
  class TextToSpeech:
21
+ """Nigerian Text-to-Speech synthesizer using YarnGPT models."""
22
 
23
+ def __init__(self, model_name_or_path, processor_name_or_path=None, disable_playback=True):
24
+ """
25
+ Initialize the TextToSpeech class.
26
+
27
+ Args:
28
+ model_name_or_path (str): Path or name of the YarnGPT model
29
+ processor_name_or_path (str, optional): Path or name of the processor
30
+ disable_playback (bool, optional): Whether to disable audio playback
31
+ """
32
+ self.model_name_or_path = model_name_or_path
33
+ self.processor_name_or_path = processor_name_or_path or model_name_or_path
34
+ self.disable_playback = disable_playback
35
+
36
+ # Set environment variable to disable PortAudio
37
+ if disable_playback:
38
+ os.environ["OUTETTS_NO_PORTAUDIO"] = "1"
39
+
40
+ logger.info(f"Initializing TextToSpeech with model: {model_name_or_path}")
41
 
 
42
  try:
43
+ # Initialize tokenizer
44
+ logger.info("Loading tokenizer...")
45
+ self.tokenizer = AutoTokenizer.from_pretrained(self.processor_name_or_path)
46
+ logger.info("Tokenizer loaded successfully")
47
 
48
+ # Initialize processor
49
+ logger.info("Loading processor...")
50
+ self.processor = AutoProcessor.from_pretrained(self.processor_name_or_path)
51
+ logger.info("Processor loaded successfully")
 
 
 
 
52
 
53
+ # Initialize model with appropriate device
54
+ logger.info("Loading model...")
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ logger.info(f"Using device: {self.device}")
57
+
58
+ # Select appropriate torch dtype based on device and available memory
59
+ if self.device == "cuda":
60
+ # Try to use float16 on CUDA devices for better performance
61
+ try:
62
+ dummy_tensor = torch.zeros(1, device=self.device, dtype=torch.float16)
63
+ self.dtype = torch.float16
64
+ logger.info("Using torch.float16 for better performance")
65
+ except Exception:
66
+ self.dtype = torch.float32
67
+ logger.info("Failed to use torch.float16, falling back to torch.float32")
68
+ else:
69
+ # Use float32 on CPU
70
+ self.dtype = torch.float32
71
+ logger.info("Using torch.float32 on CPU device")
72
 
73
+ # Load model with trust_remote_code=True for custom models
74
+ self.model = AutoModel.from_pretrained(
75
+ self.model_name_or_path,
76
+ torch_dtype=self.dtype,
77
+ trust_remote_code=True
78
+ ).to(self.device)
79
 
80
+ logger.info("Model loaded successfully")
81
 
82
  except Exception as e:
83
+ logger.error(f"Error initializing TextToSpeech: {e}")
84
  import traceback
85
  traceback.print_exc()
86
  raise
87
 
88
+ def tts(self, text: str, accent: str = "nigerian", save_path: Optional[str] = None,
89
+ speed: float = 1.0, get_array: bool = False) -> Optional[Tuple[np.ndarray, int]]:
90
+ """
91
+ Generate speech from text.
92
 
93
  Args:
94
+ text (str): Text to convert to speech
95
+ accent (str, optional): Accent for the speech. Defaults to "nigerian".
96
+ save_path (str, optional): Path to save the audio file. Defaults to None.
97
+ speed (float, optional): Speed factor for speech. Defaults to 1.0.
98
+ get_array (bool, optional): Return audio as numpy array. Defaults to False.
99
+
100
  Returns:
101
+ Tuple[numpy.ndarray, int] or None: Audio data and sample rate if get_array=True
102
  """
103
+ logger.info(f"Generating speech for text: '{text[:50]}...' with accent '{accent}'")
104
+
105
  try:
106
+ # Prepare input
107
+ inputs = self.processor(
108
+ text=text,
109
+ accent=accent,
110
+ return_tensors="pt",
111
+ padding=True,
112
+ )
113
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
114
 
115
+ # Generate speech
116
+ with torch.no_grad():
117
+ outputs = self.model.generate(
118
+ **inputs,
119
+ pad_token_id=self.tokenizer.pad_token_id,
120
+ max_new_tokens=1000
121
+ )
122
 
123
+ # Process outputs
124
+ audio_data = outputs.generated_wavs.cpu().numpy().squeeze()
125
+ sample_rate = self.model.config.sampling_rate
 
 
 
 
126
 
127
+ # Adjust speed if needed
128
+ if speed != 1.0 and speed > 0:
129
+ try:
130
+ import librosa
131
+ audio_data = librosa.effects.time_stretch(audio_data, rate=speed)
132
+ logger.info(f"Adjusted audio speed by factor {speed}")
133
+ except ImportError:
134
+ logger.warning("librosa not available, speed adjustment skipped")
135
+ except Exception as e:
136
+ logger.warning(f"Speed adjustment failed: {e}")
137
 
138
+ # Save if path is provided
139
+ if save_path:
140
+ logger.info(f"Saving audio to {save_path}")
141
+ sf.write(save_path, audio_data, sample_rate)
142
 
143
+ # Return the audio data and sample rate if requested
144
+ if get_array:
145
+ return audio_data, sample_rate
146
+ return None
147
 
148
  except Exception as e:
149
+ logger.error(f"Error generating speech: {e}")
150
  import traceback
151
  traceback.print_exc()
152
  raise