Anjan9320 commited on
Commit
062feb0
·
verified ·
1 Parent(s): 9ef067d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +58 -37
model.py CHANGED
@@ -22,6 +22,8 @@ from huggingface_hub import hf_hub_download
22
  from safetensors.torch import load_file
23
  import os
24
 
 
 
25
  class INF5Config(PretrainedConfig):
26
  model_type = "inf5"
27
 
@@ -64,46 +66,66 @@ class INF5Model(PreTrainedModel):
64
  # # Load state dict into model
65
  self.ema_model.load_state_dict(state_dict, strict=False)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
68
  """
69
  Extract speaker embedding or reference features from audio and text.
70
- Converts audio to WAV if needed. Returns NumPy array for saving/reuse.
71
  """
72
  if not os.path.exists(ref_audio_path):
73
  raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
74
 
75
- ext = os.path.splitext(ref_audio_path)[-1].lower()
 
76
 
77
- # Convert to WAV if input is MP3 or MP4
78
- if ext not in [".wav"]:
79
- audio = AudioSegment.from_file(ref_audio_path)
80
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file:
81
- temp_path = temp_wav_file.name
82
- audio.export(temp_path, format="wav")
83
- ref_audio_path = temp_path # Use converted path
84
 
85
- # Extract embedding
86
- speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
 
87
 
88
- # Clean up if we created a temp file
89
- if ext not in [".wav"] and os.path.exists(ref_audio_path):
90
- os.remove(ref_audio_path)
91
-
92
- # Convert to NumPy for easy saving
93
  if isinstance(speaker_embedding, torch.Tensor):
94
  speaker_embedding = speaker_embedding.detach().cpu().numpy()
95
 
96
  return speaker_embedding
97
 
98
  def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
99
- # Validate input
100
  if speaker_embedding is None:
101
  if not ref_audio_path or not ref_text:
102
  raise ValueError("You must provide either a speaker_embedding or both ref_audio_path and ref_text.")
103
- # Extract speaker embedding from reference audio/text
104
- speaker_embedding, _ = preprocess_ref_audio_text(ref_audio_path, ref_text)
 
105
  else:
106
- # Convert numpy to tensor if needed
107
  if isinstance(speaker_embedding, np.ndarray):
108
  speaker_embedding = torch.tensor(speaker_embedding, dtype=torch.float32)
109
  speaker_embedding = speaker_embedding.to(self.device)
@@ -111,7 +133,6 @@ class INF5Model(PreTrainedModel):
111
  self.ema_model.to(self.device)
112
  self.vocoder.to(self.device)
113
 
114
- # Inference from embedding (no ref_audio/ref_text needed)
115
  audio, final_sample_rate, _ = infer_from_embedding(
116
  speaker_embedding=speaker_embedding,
117
  text=text,
@@ -147,28 +168,30 @@ class INF5Model(PreTrainedModel):
147
 
148
 
149
  if __name__ == '__main__':
150
- model = INF5Model(INF5Config(ckpt_path="checkpoints/model_best.pt", vocab_path="checkpoints/vocab.txt"))
151
- model.save_pretrained("INF5")
152
- model.config.save_pretrained("INF5")
153
-
154
  import numpy as np
155
  import soundfile as sf
156
  from transformers import AutoConfig, AutoModel
157
- from f5_tts.infer.utils_infer import (
158
- preprocess_ref_audio_text,
159
- )
160
-
161
  AutoConfig.register("inf5", INF5Config)
162
  AutoModel.register(INF5Config, INF5Model)
163
 
 
 
 
 
 
 
164
  model = AutoModel.from_pretrained("INF5")
165
 
166
  # Step 1: Extract speaker embedding from reference audio + text
167
- speaker_embedding = extract_speaker_embedding(
168
  "prompts/PAN_F_HAPPY_00001.wav",
169
  "ਭਹੰਪੀ ਵਿੱਚ ਸਮਾਰਕਾਂ ਦੇ ਭਵਨ ਨਿਰਮਾਣ ਕਲਾ ਦੇ ਵੇਰਵੇ ਗੁੰਝਲਦਾਰ ਅਤੇ ਹੈਰਾਨ ਕਰਨ ਵਾਲੇ ਹਨ, ਜੋ ਮੈਨੂੰ ਖੁਸ਼ ਕਰਦੇ ਹਨ।"
170
  )
171
- nnp.save("speaker_embedding.npy", speaker_embedding)
172
 
173
  # Step 2: Load saved embedding (simulate reuse)
174
  loaded_embedding = np.load("speaker_embedding.npy")
@@ -179,15 +202,14 @@ if __name__ == '__main__':
179
  speaker_embedding=loaded_embedding
180
  )
181
 
 
182
  if audio.dtype == np.int16:
183
  audio = audio.astype(np.float32) / 32768.0
184
- sf.write("samples/namaste.wav", np.array(audio, dtype=np.float32), samplerate=24000)
185
 
 
186
  from huggingface_hub import HfApi
187
-
188
  repo_id = "svp19/INF5" # Change to your HF repo
189
-
190
- # Upload model directory to HF
191
  api = HfApi()
192
  api.upload_folder(
193
  folder_path="INF5",
@@ -196,8 +218,7 @@ if __name__ == '__main__':
196
  )
197
  print(f"Model pushed to https://huggingface.co/{repo_id} 🚀")
198
 
199
- print("Verify Upload")
200
- from transformers import AutoModel
201
  model = AutoModel.from_pretrained(repo_id)
202
  print("Success")
203
 
 
22
  from safetensors.torch import load_file
23
  import os
24
 
25
+ import torchaudio
26
+
27
  class INF5Config(PretrainedConfig):
28
  model_type = "inf5"
29
 
 
66
  # # Load state dict into model
67
  self.ema_model.load_state_dict(state_dict, strict=False)
68
 
69
+ def _extract_embedding_from_audio_and_text(self, audio_path: str, text: str) -> torch.Tensor:
70
+
71
+ device = next(self.parameters()).device # model device
72
+
73
+ # Load audio waveform
74
+ waveform, sample_rate = torchaudio.load(audio_path)
75
+ target_sample_rate = 24000
76
+ if sample_rate != target_sample_rate:
77
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate).to(device)
78
+ waveform = resampler(waveform)
79
+ waveform = waveform.to(device)
80
+
81
+ # Forward pass - pass waveform and text directly to ema_model
82
+ with torch.no_grad():
83
+ outputs = self.ema_model(waveform, text)
84
+
85
+ # Extract speaker embedding from outputs
86
+ speaker_embedding = getattr(outputs, "speaker_embedding", None)
87
+ if speaker_embedding is None:
88
+ if isinstance(outputs, dict) and "speaker_embedding" in outputs:
89
+ speaker_embedding = outputs["speaker_embedding"]
90
+ else:
91
+ raise RuntimeError("Speaker embedding not found in model output")
92
+
93
+ return speaker_embedding.squeeze()
94
+
95
+
96
  def extract_speaker_embedding(self, ref_audio_path: str, ref_text: str):
97
  """
98
  Extract speaker embedding or reference features from audio and text.
99
+ Converts audio to WAV if needed. Returns numpy array for saving/reuse.
100
  """
101
  if not os.path.exists(ref_audio_path):
102
  raise FileNotFoundError(f"Reference audio file '{ref_audio_path}' not found.")
103
 
104
+ # Step 1: Preprocess audio + text (clip silence, convert etc)
105
+ processed_audio_path, processed_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
106
 
107
+ # Step 2: Use model’s internal method to extract embedding from processed audio + text
108
+ # IMPORTANT: Replace `self._extract_embedding_from_audio_and_text` with your actual method!
109
+ speaker_embedding = self._extract_embedding_from_audio_and_text(processed_audio_path, processed_text)
 
 
 
 
110
 
111
+ # Clean up temporary processed file if created
112
+ if processed_audio_path != ref_audio_path and os.path.exists(processed_audio_path):
113
+ os.remove(processed_audio_path)
114
 
115
+ # Convert to numpy if it’s a tensor
 
 
 
 
116
  if isinstance(speaker_embedding, torch.Tensor):
117
  speaker_embedding = speaker_embedding.detach().cpu().numpy()
118
 
119
  return speaker_embedding
120
 
121
  def forward(self, text: str, speaker_embedding=None, ref_audio_path=None, ref_text=None):
 
122
  if speaker_embedding is None:
123
  if not ref_audio_path or not ref_text:
124
  raise ValueError("You must provide either a speaker_embedding or both ref_audio_path and ref_text.")
125
+ # Extract speaker embedding correctly
126
+ speaker_embedding = self.extract_speaker_embedding(ref_audio_path, ref_text)
127
+ speaker_embedding = torch.tensor(speaker_embedding, dtype=torch.float32).to(self.device)
128
  else:
 
129
  if isinstance(speaker_embedding, np.ndarray):
130
  speaker_embedding = torch.tensor(speaker_embedding, dtype=torch.float32)
131
  speaker_embedding = speaker_embedding.to(self.device)
 
133
  self.ema_model.to(self.device)
134
  self.vocoder.to(self.device)
135
 
 
136
  audio, final_sample_rate, _ = infer_from_embedding(
137
  speaker_embedding=speaker_embedding,
138
  text=text,
 
168
 
169
 
170
  if __name__ == '__main__':
171
+ import os
 
 
 
172
  import numpy as np
173
  import soundfile as sf
174
  from transformers import AutoConfig, AutoModel
175
+ from f5_tts.infer.utils_infer import preprocess_ref_audio_text
176
+
177
+ # Register your custom config and model
 
178
  AutoConfig.register("inf5", INF5Config)
179
  AutoModel.register(INF5Config, INF5Model)
180
 
181
+ # Instantiate your model with config
182
+ model = INF5Model(INF5Config(ckpt_path="checkpoints/model_best.pt", vocab_path="checkpoints/vocab.txt"))
183
+ model.save_pretrained("INF5")
184
+ model.config.save_pretrained("INF5")
185
+
186
+ # Load model via HF AutoModel interface for proper loading from the saved folder
187
  model = AutoModel.from_pretrained("INF5")
188
 
189
  # Step 1: Extract speaker embedding from reference audio + text
190
+ speaker_embedding = model.extract_speaker_embedding(
191
  "prompts/PAN_F_HAPPY_00001.wav",
192
  "ਭਹੰਪੀ ਵਿੱਚ ਸਮਾਰਕਾਂ ਦੇ ਭਵਨ ਨਿਰਮਾਣ ਕਲਾ ਦੇ ਵੇਰਵੇ ਗੁੰਝਲਦਾਰ ਅਤੇ ਹੈਰਾਨ ਕਰਨ ਵਾਲੇ ਹਨ, ਜੋ ਮੈਨੂੰ ਖੁਸ਼ ਕਰਦੇ ਹਨ।"
193
  )
194
+ np.save("speaker_embedding.npy", speaker_embedding)
195
 
196
  # Step 2: Load saved embedding (simulate reuse)
197
  loaded_embedding = np.load("speaker_embedding.npy")
 
202
  speaker_embedding=loaded_embedding
203
  )
204
 
205
+ # Normalize audio dtype if needed before saving
206
  if audio.dtype == np.int16:
207
  audio = audio.astype(np.float32) / 32768.0
208
+ sf.write("samples/namaste.wav", audio.astype(np.float32), samplerate=24000)
209
 
210
+ # Upload model directory to Hugging Face Hub
211
  from huggingface_hub import HfApi
 
212
  repo_id = "svp19/INF5" # Change to your HF repo
 
 
213
  api = HfApi()
214
  api.upload_folder(
215
  folder_path="INF5",
 
218
  )
219
  print(f"Model pushed to https://huggingface.co/{repo_id} 🚀")
220
 
221
+ # Verify upload by reloading
 
222
  model = AutoModel.from_pretrained(repo_id)
223
  print("Success")
224