eduard76 commited on
Commit
4e1229a
·
verified ·
1 Parent(s): f971ff0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -73,25 +73,21 @@ class ProfessionalVoiceAgent:
73
  def load_whisper(self):
74
  """Load Whisper model for speech recognition"""
75
  try:
76
- if self.use_large_models:
77
- # Use larger Whisper for better accuracy
78
- model_name = "openai/whisper-small"
79
- logger.info(f"Loading Whisper Small for best accuracy...")
80
- else:
81
- model_name = "openai/whisper-tiny"
82
- logger.info(f"Loading Whisper Tiny...")
83
 
84
  self.whisper_processor = WhisperProcessor.from_pretrained(model_name)
85
  self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
86
  model_name,
87
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
 
88
  ).to(self.device)
89
 
90
- # Enable Flash Attention if available
91
- if hasattr(self.whisper_model.config, "use_flash_attention_2"):
92
- self.whisper_model.config.use_flash_attention_2 = True
93
 
94
- logger.info("✓ Whisper loaded successfully")
95
 
96
  except Exception as e:
97
  logger.error(f"Failed to load Whisper: {e}")
@@ -123,7 +119,10 @@ class ProfessionalVoiceAgent:
123
  # Add padding token
124
  self.chat_tokenizer.pad_token = self.chat_tokenizer.eos_token
125
 
126
- logger.info("✓ Chat model loaded successfully")
 
 
 
127
 
128
  except Exception as e:
129
  logger.error(f"Failed to load chat model: {e}")
@@ -149,6 +148,10 @@ class ProfessionalVoiceAgent:
149
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
150
  ).to(self.device)
151
 
 
 
 
 
152
  # Load speaker embeddings for voice
153
  try:
154
  logger.info("Loading speaker embeddings dataset...")
@@ -212,6 +215,12 @@ class ProfessionalVoiceAgent:
212
  import librosa
213
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
214
 
 
 
 
 
 
 
215
  if self.whisper_processor and hasattr(self.whisper_model, 'generate'):
216
  # Use loaded model
217
  input_features = self.whisper_processor(
@@ -220,14 +229,15 @@ class ProfessionalVoiceAgent:
220
  return_tensors="pt"
221
  ).input_features.to(self.device)
222
 
223
- # Generate token ids
224
  with torch.cuda.amp.autocast(enabled=self.device.type == "cuda"):
225
- predicted_ids = self.whisper_model.generate(
226
- input_features,
227
- max_new_tokens=128,
228
- num_beams=5, # Better accuracy
229
- temperature=0.0 # Deterministic
230
- )
 
231
 
232
  # Decode token ids to text
233
  transcription = self.whisper_processor.batch_decode(
@@ -274,18 +284,18 @@ class ProfessionalVoiceAgent:
274
  max_length=512
275
  ).to(self.device)
276
 
277
- # Generate response
278
  with torch.cuda.amp.autocast(enabled=self.device.type == "cuda"):
279
  with torch.no_grad():
280
  outputs = self.chat_model.generate(
281
  inputs,
282
- max_new_tokens=100,
283
  temperature=temperature,
284
  top_p=0.9,
285
- do_sample=True,
286
  pad_token_id=self.chat_tokenizer.eos_token_id,
287
  eos_token_id=self.chat_tokenizer.eos_token_id,
288
- num_beams=3
289
  )
290
 
291
  # Decode response
 
73
  def load_whisper(self):
74
  """Load Whisper model for speech recognition"""
75
  try:
76
+ # Use tiny model for speed - small is too slow
77
+ model_name = "openai/whisper-tiny"
78
+ logger.info(f"Loading Whisper Tiny for fast processing...")
 
 
 
 
79
 
80
  self.whisper_processor = WhisperProcessor.from_pretrained(model_name)
81
  self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
82
  model_name,
83
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
84
+ low_cpu_mem_usage=True
85
  ).to(self.device)
86
 
87
+ # Set to eval mode for inference
88
+ self.whisper_model.eval()
 
89
 
90
+ logger.info(f"✓ Whisper loaded on {self.device}")
91
 
92
  except Exception as e:
93
  logger.error(f"Failed to load Whisper: {e}")
 
119
  # Add padding token
120
  self.chat_tokenizer.pad_token = self.chat_tokenizer.eos_token
121
 
122
+ # Set to eval mode
123
+ self.chat_model.eval()
124
+
125
+ logger.info(f"✓ Chat model loaded on {self.device}")
126
 
127
  except Exception as e:
128
  logger.error(f"Failed to load chat model: {e}")
 
148
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
149
  ).to(self.device)
150
 
151
+ # Set to eval mode
152
+ self.tts_model.eval()
153
+ self.vocoder.eval()
154
+
155
  # Load speaker embeddings for voice
156
  try:
157
  logger.info("Loading speaker embeddings dataset...")
 
215
  import librosa
216
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
217
 
218
+ # Trim silence and limit audio length for speed (max 30 seconds)
219
+ max_samples = 16000 * 30 # 30 seconds at 16kHz
220
+ if len(audio_data) > max_samples:
221
+ logger.warning(f"Audio trimmed from {len(audio_data)/16000:.1f}s to 30s")
222
+ audio_data = audio_data[:max_samples]
223
+
224
  if self.whisper_processor and hasattr(self.whisper_model, 'generate'):
225
  # Use loaded model
226
  input_features = self.whisper_processor(
 
229
  return_tensors="pt"
230
  ).input_features.to(self.device)
231
 
232
+ # Generate token ids - optimized for speed
233
  with torch.cuda.amp.autocast(enabled=self.device.type == "cuda"):
234
+ with torch.no_grad():
235
+ predicted_ids = self.whisper_model.generate(
236
+ input_features,
237
+ max_new_tokens=64, # Reduced for faster processing
238
+ num_beams=1, # Greedy decoding for speed
239
+ do_sample=False # Deterministic
240
+ )
241
 
242
  # Decode token ids to text
243
  transcription = self.whisper_processor.batch_decode(
 
284
  max_length=512
285
  ).to(self.device)
286
 
287
+ # Generate response - optimized for speed
288
  with torch.cuda.amp.autocast(enabled=self.device.type == "cuda"):
289
  with torch.no_grad():
290
  outputs = self.chat_model.generate(
291
  inputs,
292
+ max_new_tokens=50, # Shorter for faster response
293
  temperature=temperature,
294
  top_p=0.9,
295
+ do_sample=True if temperature > 0 else False,
296
  pad_token_id=self.chat_tokenizer.eos_token_id,
297
  eos_token_id=self.chat_tokenizer.eos_token_id,
298
+ num_beams=1 # Greedy for speed
299
  )
300
 
301
  # Decode response