Hameed13 commited on
Commit
109c3b2
·
verified ·
1 Parent(s): 3662c55

Update yarngpt/generate.py

Browse files
Files changed (1) hide show
  1. yarngpt/generate.py +55 -35
yarngpt/generate.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import hf_hub_download
9
  import warnings
10
  import scipy.io.wavfile as wav
11
  from datetime import datetime
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO,
@@ -16,7 +17,7 @@ logging.basicConfig(level=logging.INFO,
16
  logger = logging.getLogger(__name__)
17
 
18
  # Constants
19
- INIT_TIMESTAMP = "2025-05-21 02:08:00"
20
  CURRENT_USER = "Abdulhameed556"
21
 
22
  class TextToSpeech:
@@ -26,40 +27,30 @@ class TextToSpeech:
26
  self.processor_name_or_path = processor_name_or_path or model_name_or_path
27
  self.init_time = INIT_TIMESTAMP
28
  self.user = CURRENT_USER
 
29
 
30
  logger.info(f"Initializing TextToSpeech with model: {model_name_or_path}")
31
 
32
  try:
 
 
 
 
 
 
33
  # Initialize configuration
34
  config = Speech2Text2Config.from_pretrained(
35
  pretrained_model_name_or_path=self.model_name_or_path,
36
- cache_dir="/code/cache",
37
- token=os.getenv('HF_TOKEN'),
38
- trust_remote_code=True
39
  )
40
 
41
- # Download tokenizer files
42
- logger.info("Downloading tokenizer files...")
43
- tokenizer_files = ["tokenizer_config.json", "special_tokens_map.json", "vocab.json"]
44
- for file in tokenizer_files:
45
- try:
46
- hf_hub_download(
47
- repo_id=self.model_name_or_path,
48
- filename=file,
49
- cache_dir="/code/cache",
50
- token=os.getenv('HF_TOKEN')
51
- )
52
- except Exception as e:
53
- logger.warning(f"Could not download {file}: {e}")
54
-
55
  # Initialize tokenizer
56
  logger.info("Loading tokenizer...")
57
  self.tokenizer = AutoTokenizer.from_pretrained(
58
- self.model_name_or_path,
59
  config=config,
60
- cache_dir="/code/cache",
61
- token=os.getenv('HF_TOKEN'),
62
- trust_remote_code=True
63
  )
64
 
65
  # Initialize model
@@ -70,26 +61,55 @@ class TextToSpeech:
70
  self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
71
  self.model_name_or_path,
72
  config=config,
73
- cache_dir="/code/cache",
74
- token=os.getenv('HF_TOKEN'),
75
- trust_remote_code=True
76
  ).to(self.device)
77
 
78
- # Load processor
79
- logger.info("Loading processor...")
80
- self.processor = AutoProcessor.from_pretrained(
81
- self.model_name_or_path,
82
- cache_dir="/code/cache",
83
- token=os.getenv('HF_TOKEN'),
84
- trust_remote_code=True
85
- )
86
-
87
  logger.info("Model initialization complete")
88
 
89
  except Exception as e:
90
  logger.error(f"Error initializing TextToSpeech: {e}")
91
  raise
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def tts(self, text, speed=1.0):
94
  """Generate speech from text."""
95
  try:
@@ -114,7 +134,7 @@ class TextToSpeech:
114
  )
115
 
116
  # Convert to audio
117
- audio = self.processor.batch_decode(output, skip_special_tokens=True)[0]
118
 
119
  # Apply speed adjustment if needed
120
  if speed != 1.0:
 
9
  import warnings
10
  import scipy.io.wavfile as wav
11
  from datetime import datetime
12
+ import json
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO,
 
17
  logger = logging.getLogger(__name__)
18
 
19
  # Constants
20
+ INIT_TIMESTAMP = "2025-05-21 02:21:23"
21
  CURRENT_USER = "Abdulhameed556"
22
 
23
  class TextToSpeech:
 
27
  self.processor_name_or_path = processor_name_or_path or model_name_or_path
28
  self.init_time = INIT_TIMESTAMP
29
  self.user = CURRENT_USER
30
+ self.cache_dir = "/code/cache"
31
 
32
  logger.info(f"Initializing TextToSpeech with model: {model_name_or_path}")
33
 
34
  try:
35
+ # Create cache directory if it doesn't exist
36
+ os.makedirs(self.cache_dir, exist_ok=True)
37
+
38
+ # Create tokenizer files locally if they don't exist
39
+ self._create_tokenizer_files()
40
+
41
  # Initialize configuration
42
  config = Speech2Text2Config.from_pretrained(
43
  pretrained_model_name_or_path=self.model_name_or_path,
44
+ cache_dir=self.cache_dir,
45
+ token=os.getenv('HF_TOKEN')
 
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Initialize tokenizer
49
  logger.info("Loading tokenizer...")
50
  self.tokenizer = AutoTokenizer.from_pretrained(
51
+ self.cache_dir, # Use local cache directory
52
  config=config,
53
+ token=os.getenv('HF_TOKEN')
 
 
54
  )
55
 
56
  # Initialize model
 
61
  self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
62
  self.model_name_or_path,
63
  config=config,
64
+ cache_dir=self.cache_dir,
65
+ token=os.getenv('HF_TOKEN')
 
66
  ).to(self.device)
67
 
 
 
 
 
 
 
 
 
 
68
  logger.info("Model initialization complete")
69
 
70
  except Exception as e:
71
  logger.error(f"Error initializing TextToSpeech: {e}")
72
  raise
73
 
74
+ def _create_tokenizer_files(self):
75
+ """Create necessary tokenizer files in cache directory."""
76
+ tokenizer_files = {
77
+ "tokenizer_config.json": {
78
+ "name_or_path": self.model_name_or_path,
79
+ "padding_side": "right",
80
+ "truncation_side": "right",
81
+ "model_max_length": 1024,
82
+ "bos_token": "<s>",
83
+ "eos_token": "</s>",
84
+ "unk_token": "<unk>",
85
+ "pad_token": "<pad>",
86
+ "mask_token": "<mask>",
87
+ "special_tokens_map_file": "special_tokens_map.json",
88
+ "tokenizer_class": "Speech2Text2Tokenizer"
89
+ },
90
+ "special_tokens_map.json": {
91
+ "bos_token": "<s>",
92
+ "eos_token": "</s>",
93
+ "pad_token": "<pad>",
94
+ "unk_token": "<unk>",
95
+ "mask_token": "<mask>"
96
+ },
97
+ "vocab.json": {
98
+ "<s>": 0,
99
+ "<pad>": 1,
100
+ "</s>": 2,
101
+ "<unk>": 3,
102
+ "<mask>": 4
103
+ }
104
+ }
105
+
106
+ logger.info("Creating tokenizer files in cache directory...")
107
+ for filename, content in tokenizer_files.items():
108
+ filepath = os.path.join(self.cache_dir, filename)
109
+ with open(filepath, 'w', encoding='utf-8') as f:
110
+ json.dump(content, f, indent=2)
111
+ logger.info(f"Created {filename}")
112
+
113
  def tts(self, text, speed=1.0):
114
  """Generate speech from text."""
115
  try:
 
134
  )
135
 
136
  # Convert to audio
137
+ audio = output[0].cpu().numpy()
138
 
139
  # Apply speed adjustment if needed
140
  if speed != 1.0: