BissakaAI commited on
Commit
0e51547
·
verified ·
1 Parent(s): c97244e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -12
app.py CHANGED
@@ -2,36 +2,71 @@ import os
2
  import torch
3
  import gradio as gr
4
  import librosa
 
5
  from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText
6
 
 
 
 
7
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- print("Loading ASR processor...")
 
 
 
12
  processor = AutoProcessor.from_pretrained(
13
  ASR_MODEL_ID,
14
  token=HF_TOKEN
15
  )
16
 
17
- print("Loading ASR model...")
18
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
19
  ASR_MODEL_ID,
20
  token=HF_TOKEN
21
  ).to(DEVICE)
22
 
23
  asr_model.eval()
24
- print("ASR model loaded successfully")
25
 
26
- def transcribe_audio(audio):
 
 
 
 
 
 
27
  if audio is None:
28
- return "No audio provided."
 
 
29
 
30
- speech, sr = audio
 
 
31
 
32
- # Ensure 16kHz
 
 
 
33
  if sr != 16000:
34
- speech = librosa.resample(speech, orig_sr=sr, target_sr=16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  inputs = processor(
37
  audios=speech,
@@ -39,9 +74,15 @@ def transcribe_audio(audio):
39
  return_tensors="pt"
40
  ).to(DEVICE)
41
 
 
 
 
 
 
42
  with torch.no_grad():
43
  generated_ids = asr_model.generate(
44
  inputs["input_features"],
 
45
  max_new_tokens=256
46
  )
47
 
@@ -52,13 +93,24 @@ def transcribe_audio(audio):
52
 
53
  return transcription.strip()
54
 
 
 
 
55
  demo = gr.Interface(
56
  fn=transcribe_audio,
57
- inputs=gr.Audio(type="numpy", label="Upload Speech"),
58
- outputs=gr.Textbox(label="Transcription"),
59
- title="HealthAtlas ASR Service (Auto Language)",
60
- description="Speech → Text with automatic language detection"
 
 
 
 
 
61
  )
62
 
 
 
 
63
  if __name__ == "__main__":
64
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  import librosa
5
+ import numpy as np
6
  from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText
7
 
8
+ # ----------------------------
9
+ # Config
10
+ # ----------------------------
11
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # ----------------------------
16
+ # Load Processor & Model
17
+ # ----------------------------
18
+ print("🔹 Loading ASR processor...")
19
  processor = AutoProcessor.from_pretrained(
20
  ASR_MODEL_ID,
21
  token=HF_TOKEN
22
  )
23
 
24
+ print("🔹 Loading ASR model...")
25
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
26
  ASR_MODEL_ID,
27
  token=HF_TOKEN
28
  ).to(DEVICE)
29
 
30
  asr_model.eval()
31
+ print("✅ ASR model loaded successfully")
32
 
33
+ # ----------------------------
34
+ # Audio Preprocessing
35
+ # ----------------------------
36
+ def preprocess_audio(audio):
37
+ """
38
+ Ensures mono audio at 16kHz (required by SeamlessM4T)
39
+ """
40
  if audio is None:
41
+ return None
42
+
43
+ speech, sr = audio # (numpy array, sample rate)
44
 
45
+ # Convert stereo to mono
46
+ if speech.ndim > 1:
47
+ speech = np.mean(speech, axis=1)
48
 
49
+ # Force float32
50
+ speech = speech.astype("float32")
51
+
52
+ # Resample to 16kHz if needed
53
  if sr != 16000:
54
+ speech = librosa.resample(
55
+ speech,
56
+ orig_sr=sr,
57
+ target_sr=16000
58
+ )
59
+
60
+ return speech
61
+
62
+ # ----------------------------
63
+ # ASR Function
64
+ # ----------------------------
65
+ def transcribe_audio(audio):
66
+ speech = preprocess_audio(audio)
67
+
68
+ if speech is None or len(speech) == 0:
69
+ return "No audio provided."
70
 
71
  inputs = processor(
72
  audios=speech,
 
74
  return_tensors="pt"
75
  ).to(DEVICE)
76
 
77
+ # Auto language detection (no hardcoding)
78
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
79
+ task="transcribe"
80
+ )
81
+
82
  with torch.no_grad():
83
  generated_ids = asr_model.generate(
84
  inputs["input_features"],
85
+ forced_decoder_ids=forced_decoder_ids,
86
  max_new_tokens=256
87
  )
88
 
 
93
 
94
  return transcription.strip()
95
 
96
+ # ----------------------------
97
+ # Gradio Interface
98
+ # ----------------------------
99
  demo = gr.Interface(
100
  fn=transcribe_audio,
101
+ inputs=gr.Audio(
102
+ type="numpy",
103
+ label="Upload or Record Speech"
104
+ ),
105
+ outputs=gr.Textbox(
106
+ label="Transcription"
107
+ ),
108
+ title="HealthAtlas ASR Service",
109
+ description="Speech → Text | Automatic language detection | Emergency-safe"
110
  )
111
 
112
+ # ----------------------------
113
+ # Launch (REQUIRED)
114
+ # ----------------------------
115
  if __name__ == "__main__":
116
  demo.launch()