PatienceIzere commited on
Commit
a6e99c9
·
verified ·
1 Parent(s): 535de7a

Update hf_transcriber.py

Browse files
Files changed (1) hide show
  1. hf_transcriber.py +25 -20
hf_transcriber.py CHANGED
@@ -59,24 +59,26 @@ class HFTranscriber:
59
  except Exception as e:
60
  raise Exception(f"Failed to load model {self.model_name}: {str(e)}")
61
 
62
- def transcribe_audio(self, audio_path: str) -> Tuple[List[int], int]:
63
  """
64
- Transcribe audio file to notes using the loaded Hugging Face model.
65
 
66
  Args:
67
- audio_path (str): Path to the audio file
 
68
 
69
  Returns:
70
- tuple: (notes, sample_rate) where notes is a list of MIDI note numbers
71
  """
72
  try:
73
- # Load and preprocess audio
74
- waveform, sample_rate = self._load_audio(audio_path)
75
-
 
76
  if self.is_speecht5:
77
  # Process the audio input for SpeechT5
78
  inputs = self.processor(
79
- audio=waveform,
80
  sampling_rate=sample_rate,
81
  return_tensors="pt"
82
  ).to(self.device)
@@ -91,10 +93,12 @@ class HFTranscriber:
91
 
92
  # Decode the generated ids to text
93
  transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
94
  else:
95
- # Process the audio input for wav2vec2
96
  inputs = self.processor(
97
- waveform,
98
  sampling_rate=sample_rate,
99
  return_tensors="pt",
100
  padding=True
@@ -105,18 +109,19 @@ class HFTranscriber:
105
  logits = self.model(inputs).logits
106
 
107
  # Get predicted token ids
108
- predicted_ids = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
 
 
109
 
110
- # Decode the predicted ids to text
111
- transcription = self.processor.batch_decode(predicted_ids)[0]
112
-
113
- # Convert text to MIDI notes (simplified example)
114
- notes = self._text_to_midi_notes(transcription)
115
-
116
- return notes, sample_rate
117
-
118
  except Exception as e:
119
- raise Exception(f"Transcription failed: {str(e)}")
120
 
121
  def _text_to_midi_notes(self, text: str) -> List[int]:
122
  """Convert transcribed text to MIDI notes (simplified example)."""
 
59
  except Exception as e:
60
  raise Exception(f"Failed to load model {self.model_name}: {str(e)}")
61
 
62
+ def transcribe_audio(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]:
63
  """
64
+ Transcribe audio data to text using the loaded Hugging Face model.
65
 
66
  Args:
67
+ audio_array (np.ndarray): Audio data as a numpy array
68
+ sample_rate (int): Sample rate of the audio data
69
 
70
  Returns:
71
+ dict: Dictionary containing 'text' and optionally 'word_timestamps'
72
  """
73
  try:
74
+ # Convert to mono if needed
75
+ if len(audio_array.shape) > 1:
76
+ audio_array = librosa.to_mono(audio_array)
77
+
78
  if self.is_speecht5:
79
  # Process the audio input for SpeechT5
80
  inputs = self.processor(
81
+ audio=audio_array,
82
  sampling_rate=sample_rate,
83
  return_tensors="pt"
84
  ).to(self.device)
 
93
 
94
  # Decode the generated ids to text
95
  transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
96
+ return {'text': transcription}
97
+
98
  else:
99
+ # Process the audio input for wav2vec2/whisper
100
  inputs = self.processor(
101
+ audio_array,
102
  sampling_rate=sample_rate,
103
  return_tensors="pt",
104
  padding=True
 
109
  logits = self.model(inputs).logits
110
 
111
  # Get predicted token ids
112
+ pred_ids = torch.argmax(logits, dim=-1)
113
+
114
+ # Convert to text
115
+ transcription = self.processor.batch_decode(pred_ids)[0]
116
+
117
+ # Return the transcription text
118
+ return {
119
+ 'text': transcription,
120
+ 'word_timestamps': [] # Word-level timestamps not available in this basic implementation
121
+ }
122
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ raise Exception(f"Error during transcription: {str(e)}")
125
 
126
  def _text_to_midi_notes(self, text: str) -> List[int]:
127
  """Convert transcribed text to MIDI notes (simplified example)."""