Hameed13 commited on
Commit
9988b14
·
verified ·
1 Parent(s): 3ec7d48

Update yarngpt/generate.py

Browse files
Files changed (1) hide show
  1. yarngpt/generate.py +175 -49
yarngpt/generate.py CHANGED
@@ -1,56 +1,182 @@
1
- def __init__(self, model_name_or_path, processor_name_or_path=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  """
3
- Initialize the TextToSpeech class.
4
 
5
  Args:
6
- model_name_or_path (str): Path or name of the YarnGPT model
7
- processor_name_or_path (str, optional): Path or name of the processor
 
 
 
 
 
 
 
 
8
  """
9
- self.model_name_or_path = model_name_or_path
10
- self.processor_name_or_path = processor_name_or_path or model_name_or_path
11
- self.init_time = INIT_TIMESTAMP
12
- self.user = CURRENT_USER
13
-
14
- logger.info(f"Initializing TextToSpeech with model: {model_name_or_path}")
15
- logger.info(f"Initialization time: {self.init_time}")
16
- logger.info(f"User: {self.user}")
17
-
18
  try:
19
- # Initialize tokenizer using the repository ID
20
- logger.info("Loading tokenizer...")
21
- self.tokenizer = AutoTokenizer.from_pretrained(
22
- self.processor_name_or_path,
23
- token=os.getenv('HF_TOKEN'),
24
- trust_remote_code=True
25
- )
26
- logger.info("Tokenizer loaded successfully")
27
-
28
- # Initialize processor
29
- logger.info("Loading processor...")
30
- self.processor = AutoProcessor.from_pretrained(
31
- self.processor_name_or_path,
32
- token=os.getenv('HF_TOKEN'),
33
- trust_remote_code=True
34
- )
35
- logger.info("Processor loaded successfully")
36
-
37
- # Initialize model
38
- logger.info("Loading model...")
39
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
40
- logger.info(f"Using device: {self.device}")
41
-
42
- self.dtype = torch.float16 if self.device == "cuda" else torch.float32
43
- logger.info(f"Using torch dtype: {self.dtype}")
44
-
45
- self.model = AutoModel.from_pretrained(
46
- self.model_name_or_path,
47
- torch_dtype=self.dtype,
48
- trust_remote_code=True,
49
- token=os.getenv('HF_TOKEN')
50
- ).to(self.device)
51
-
52
- logger.info("Model loaded successfully")
53
-
54
  except Exception as e:
55
- logger.error(f"Error initializing TextToSpeech: {e}")
56
  raise
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ from transformers import AutoTokenizer, AutoProcessor, AutoModel
8
+ from huggingface_hub import hf_hub_download
9
+ import warnings
10
+ import scipy.io.wavfile as wav
11
+ from datetime import datetime
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Suppress irrelevant warnings
19
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*The attention mask and the pad token.*")
20
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*torch.nn.utils.weight_norm is deprecated.*")
21
+
22
+ # Constants
23
+ INIT_TIMESTAMP = "2025-05-21 01:36:55"
24
+ CURRENT_USER = "Abdulhameed556"
25
+
26
+ class TextToSpeech:
27
+ def __init__(self, model_name_or_path, processor_name_or_path=None):
28
+ """
29
+ Initialize the TextToSpeech class.
30
+
31
+ Args:
32
+ model_name_or_path (str): Path or name of the YarnGPT model
33
+ processor_name_or_path (str, optional): Path or name of the processor
34
+ """
35
+ self.model_name_or_path = model_name_or_path
36
+ self.processor_name_or_path = processor_name_or_path or model_name_or_path
37
+ self.init_time = INIT_TIMESTAMP
38
+ self.user = CURRENT_USER
39
+
40
+ logger.info(f"Initializing TextToSpeech with model: {model_name_or_path}")
41
+ logger.info(f"Initialization time: {self.init_time}")
42
+ logger.info(f"User: {self.user}")
43
+
44
+ try:
45
+ # Initialize tokenizer using the repository ID
46
+ logger.info("Loading tokenizer...")
47
+ self.tokenizer = AutoTokenizer.from_pretrained(
48
+ self.processor_name_or_path,
49
+ token=os.getenv('HF_TOKEN'),
50
+ trust_remote_code=True
51
+ )
52
+ logger.info("Tokenizer loaded successfully")
53
+
54
+ # Initialize processor
55
+ logger.info("Loading processor...")
56
+ self.processor = AutoProcessor.from_pretrained(
57
+ self.processor_name_or_path,
58
+ token=os.getenv('HF_TOKEN'),
59
+ trust_remote_code=True
60
+ )
61
+ logger.info("Processor loaded successfully")
62
+
63
+ # Initialize model
64
+ logger.info("Loading model...")
65
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ logger.info(f"Using device: {self.device}")
67
+
68
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
69
+ logger.info(f"Using torch dtype: {self.dtype}")
70
+
71
+ self.model = AutoModel.from_pretrained(
72
+ self.model_name_or_path,
73
+ torch_dtype=self.dtype,
74
+ trust_remote_code=True,
75
+ token=os.getenv('HF_TOKEN')
76
+ ).to(self.device)
77
+
78
+ logger.info("Model loaded successfully")
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error initializing TextToSpeech: {e}")
82
+ raise
83
+
84
+ def get_status(self):
85
+ """Return the current status of the TTS system."""
86
+ return {
87
+ "initialized_at": self.init_time,
88
+ "user": self.user,
89
+ "device": self.device,
90
+ "dtype": str(self.dtype),
91
+ "model_name": self.model_name_or_path,
92
+ "processor_name": self.processor_name_or_path,
93
+ "model_loaded": hasattr(self, 'model'),
94
+ "tokenizer_loaded": hasattr(self, 'tokenizer'),
95
+ "processor_loaded": hasattr(self, 'processor')
96
+ }
97
+
98
+ def tts(self, text, accent="nigerian", save_path=None, speed=1.0):
99
+ """
100
+ Generate speech from text.
101
+
102
+ Args:
103
+ text (str): Text to convert to speech
104
+ accent (str, optional): Accent for the speech. Defaults to "nigerian".
105
+ save_path (str, optional): Path to save the audio file. Defaults to None.
106
+ speed (float, optional): Speed factor for speech. Defaults to 1.0.
107
+
108
+ Returns:
109
+ numpy.ndarray: Audio data as a numpy array
110
+ """
111
+ logger.info(f"Generating speech for text: '{text[:50]}...' with accent '{accent}'")
112
+
113
+ try:
114
+ # Prepare input
115
+ inputs = self.processor(
116
+ text=text,
117
+ accent=accent,
118
+ return_tensors="pt",
119
+ padding=True,
120
+ )
121
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
122
+
123
+ # Generate speech
124
+ with torch.no_grad():
125
+ outputs = self.model.generate(
126
+ **inputs,
127
+ pad_token_id=self.tokenizer.pad_token_id,
128
+ max_new_tokens=1000
129
+ )
130
+
131
+ # Process outputs
132
+ audio_data = outputs.generated_wavs.cpu().numpy().squeeze()
133
+
134
+ # Adjust speed if needed
135
+ if speed != 1.0:
136
+ import librosa
137
+ audio_data = librosa.effects.time_stretch(audio_data, rate=speed)
138
+
139
+ # Save if path is provided
140
+ if save_path:
141
+ logger.info(f"Saving audio to {save_path}")
142
+ sample_rate = self.model.config.sampling_rate
143
+ wav.write(save_path, sample_rate, audio_data.astype(np.float32))
144
+
145
+ return audio_data
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error generating speech: {e}")
149
+ raise
150
+
151
+ def generate_audio(text, checkpoint_path, config_path=None, temperature=0.2, top_p=0.7, top_k=50, speed=1.0):
152
  """
153
+ Convenience function to generate audio from text.
154
 
155
  Args:
156
+ text (str): The text to convert to speech
157
+ checkpoint_path (str): Path to the model checkpoint
158
+ config_path (str, optional): Path to model config
159
+ temperature (float, optional): Temperature for generation. Defaults to 0.2.
160
+ top_p (float, optional): Top-p sampling parameter. Defaults to 0.7.
161
+ top_k (int, optional): Top-k sampling parameter. Defaults to 50.
162
+ speed (float, optional): Speed factor for speech. Defaults to 1.0.
163
+
164
+ Returns:
165
+ numpy.ndarray: Generated audio data
166
  """
 
 
 
 
 
 
 
 
 
167
  try:
168
+ start_time = datetime.utcnow()
169
+ logger.info(f"Starting audio generation at {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
170
+
171
+ tts = TextToSpeech(checkpoint_path)
172
+ audio_data = tts.tts(text, speed=speed)
173
+
174
+ end_time = datetime.utcnow()
175
+ duration = (end_time - start_time).total_seconds()
176
+ logger.info(f"Audio generation completed in {duration:.2f} seconds")
177
+
178
+ return audio_data
179
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  except Exception as e:
181
+ logger.error(f"Error in generate_audio: {e}")
182
  raise