BissakaAI commited on
Commit
46b214b
·
verified ·
1 Parent(s): 761e783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -25
app.py CHANGED
@@ -5,16 +5,12 @@ import librosa
5
  import numpy as np
6
  from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText
7
 
8
- # ----------------------------
9
- # Config
10
- # ----------------------------
11
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- # ----------------------------
16
- # Load Model
17
- # ----------------------------
18
  processor = AutoProcessor.from_pretrained(
19
  ASR_MODEL_ID,
20
  token=HF_TOKEN
@@ -27,14 +23,12 @@ asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
27
 
28
  asr_model.eval()
29
 
30
- # ----------------------------
31
- # Audio Handling (FIXED)
32
- # ----------------------------
33
  def preprocess_audio(audio):
34
  if audio is None:
35
  return None
36
 
37
- # Handle all Gradio formats safely
38
  if isinstance(audio, tuple):
39
  if isinstance(audio[0], np.ndarray):
40
  speech = audio[0]
@@ -45,11 +39,10 @@ def preprocess_audio(audio):
45
  else:
46
  return None
47
 
48
- # Convert stereo → mono
49
  if speech.ndim > 1:
50
  speech = np.mean(speech, axis=1)
51
 
52
- # Ensure float32
53
  speech = speech.astype(np.float32)
54
 
55
  # Force 16kHz
@@ -62,9 +55,8 @@ def preprocess_audio(audio):
62
 
63
  return speech
64
 
65
- # ----------------------------
66
- # ASR Function
67
- # ----------------------------
68
  def transcribe_audio(audio):
69
  speech = preprocess_audio(audio)
70
 
@@ -72,19 +64,14 @@ def transcribe_audio(audio):
72
  return "No audio provided."
73
 
74
  inputs = processor(
75
- audios=speech,
76
  sampling_rate=16000,
77
  return_tensors="pt"
78
  ).to(DEVICE)
79
 
80
- forced_decoder_ids = processor.get_decoder_prompt_ids(
81
- task="transcribe"
82
- )
83
-
84
  with torch.no_grad():
85
  generated_ids = asr_model.generate(
86
  inputs["input_features"],
87
- forced_decoder_ids=forced_decoder_ids,
88
  max_new_tokens=256
89
  )
90
 
@@ -95,15 +82,13 @@ def transcribe_audio(audio):
95
 
96
  return transcription.strip()
97
 
98
- # ----------------------------
99
- # Gradio UI
100
- # ----------------------------
101
  demo = gr.Interface(
102
  fn=transcribe_audio,
103
  inputs=gr.Audio(type="numpy", label="Upload or Record Speech"),
104
  outputs=gr.Textbox(label="Transcription"),
105
  title="HealthAtlas ASR Service",
106
- description="Automatic language detection | Emergency-safe"
107
  )
108
 
109
  if __name__ == "__main__":
 
5
  import numpy as np
6
  from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText
7
 
8
+
 
 
9
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+
 
 
14
  processor = AutoProcessor.from_pretrained(
15
  ASR_MODEL_ID,
16
  token=HF_TOKEN
 
23
 
24
  asr_model.eval()
25
 
26
+ # Audio preprocessing
 
 
27
  def preprocess_audio(audio):
28
  if audio is None:
29
  return None
30
 
31
+ # Gradio returns (sr, np.ndarray) OR (np.ndarray, sr)
32
  if isinstance(audio, tuple):
33
  if isinstance(audio[0], np.ndarray):
34
  speech = audio[0]
 
39
  else:
40
  return None
41
 
42
+ # Stereo → mono
43
  if speech.ndim > 1:
44
  speech = np.mean(speech, axis=1)
45
 
 
46
  speech = speech.astype(np.float32)
47
 
48
  # Force 16kHz
 
55
 
56
  return speech
57
 
58
+
59
+ #ASR
 
60
  def transcribe_audio(audio):
61
  speech = preprocess_audio(audio)
62
 
 
64
  return "No audio provided."
65
 
66
  inputs = processor(
67
+ audio=speech,
68
  sampling_rate=16000,
69
  return_tensors="pt"
70
  ).to(DEVICE)
71
 
 
 
 
 
72
  with torch.no_grad():
73
  generated_ids = asr_model.generate(
74
  inputs["input_features"],
 
75
  max_new_tokens=256
76
  )
77
 
 
82
 
83
  return transcription.strip()
84
 
85
+
 
 
86
  demo = gr.Interface(
87
  fn=transcribe_audio,
88
  inputs=gr.Audio(type="numpy", label="Upload or Record Speech"),
89
  outputs=gr.Textbox(label="Transcription"),
90
  title="HealthAtlas ASR Service",
91
+ description="Automatic language detection (Seamless-M4T v2)"
92
  )
93
 
94
  if __name__ == "__main__":