BissakaAI commited on
Commit
761e783
·
verified ·
1 Parent(s): 0e51547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -26
app.py CHANGED
@@ -13,43 +13,46 @@ HF_TOKEN = os.getenv("HF_TOKEN")
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ----------------------------
16
- # Load Processor & Model
17
  # ----------------------------
18
- print("🔹 Loading ASR processor...")
19
  processor = AutoProcessor.from_pretrained(
20
  ASR_MODEL_ID,
21
  token=HF_TOKEN
22
  )
23
 
24
- print("🔹 Loading ASR model...")
25
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
26
  ASR_MODEL_ID,
27
  token=HF_TOKEN
28
  ).to(DEVICE)
29
 
30
  asr_model.eval()
31
- print("✅ ASR model loaded successfully")
32
 
33
  # ----------------------------
34
- # Audio Preprocessing
35
  # ----------------------------
36
  def preprocess_audio(audio):
37
- """
38
- Ensures mono audio at 16kHz (required by SeamlessM4T)
39
- """
40
  if audio is None:
41
  return None
42
 
43
- speech, sr = audio # (numpy array, sample rate)
 
 
 
 
 
 
 
 
 
44
 
45
- # Convert stereo to mono
46
  if speech.ndim > 1:
47
  speech = np.mean(speech, axis=1)
48
 
49
- # Force float32
50
- speech = speech.astype("float32")
51
 
52
- # Resample to 16kHz if needed
53
  if sr != 16000:
54
  speech = librosa.resample(
55
  speech,
@@ -74,7 +77,6 @@ def transcribe_audio(audio):
74
  return_tensors="pt"
75
  ).to(DEVICE)
76
 
77
- # Auto language detection (no hardcoding)
78
  forced_decoder_ids = processor.get_decoder_prompt_ids(
79
  task="transcribe"
80
  )
@@ -94,23 +96,15 @@ def transcribe_audio(audio):
94
  return transcription.strip()
95
 
96
  # ----------------------------
97
- # Gradio Interface
98
  # ----------------------------
99
  demo = gr.Interface(
100
  fn=transcribe_audio,
101
- inputs=gr.Audio(
102
- type="numpy",
103
- label="Upload or Record Speech"
104
- ),
105
- outputs=gr.Textbox(
106
- label="Transcription"
107
- ),
108
  title="HealthAtlas ASR Service",
109
- description="Speech → Text | Automatic language detection | Emergency-safe"
110
  )
111
 
112
- # ----------------------------
113
- # Launch (REQUIRED)
114
- # ----------------------------
115
  if __name__ == "__main__":
116
  demo.launch()
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ----------------------------
16
+ # Load Model
17
  # ----------------------------
 
18
  processor = AutoProcessor.from_pretrained(
19
  ASR_MODEL_ID,
20
  token=HF_TOKEN
21
  )
22
 
 
23
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
24
  ASR_MODEL_ID,
25
  token=HF_TOKEN
26
  ).to(DEVICE)
27
 
28
  asr_model.eval()
 
29
 
30
  # ----------------------------
31
+ # Audio Handling (FIXED)
32
  # ----------------------------
33
  def preprocess_audio(audio):
 
 
 
34
  if audio is None:
35
  return None
36
 
37
+ # Handle all Gradio formats safely
38
+ if isinstance(audio, tuple):
39
+ if isinstance(audio[0], np.ndarray):
40
+ speech = audio[0]
41
+ sr = audio[1]
42
+ else:
43
+ sr = audio[0]
44
+ speech = audio[1]
45
+ else:
46
+ return None
47
 
48
+ # Convert stereo mono
49
  if speech.ndim > 1:
50
  speech = np.mean(speech, axis=1)
51
 
52
+ # Ensure float32
53
+ speech = speech.astype(np.float32)
54
 
55
+ # Force 16kHz
56
  if sr != 16000:
57
  speech = librosa.resample(
58
  speech,
 
77
  return_tensors="pt"
78
  ).to(DEVICE)
79
 
 
80
  forced_decoder_ids = processor.get_decoder_prompt_ids(
81
  task="transcribe"
82
  )
 
96
  return transcription.strip()
97
 
98
  # ----------------------------
99
+ # Gradio UI
100
  # ----------------------------
101
  demo = gr.Interface(
102
  fn=transcribe_audio,
103
+ inputs=gr.Audio(type="numpy", label="Upload or Record Speech"),
104
+ outputs=gr.Textbox(label="Transcription"),
 
 
 
 
 
105
  title="HealthAtlas ASR Service",
106
+ description="Automatic language detection | Emergency-safe"
107
  )
108
 
 
 
 
109
  if __name__ == "__main__":
110
  demo.launch()