BissakaAI commited on
Commit
c97244e
·
verified ·
1 Parent(s): ba7ca0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -8,17 +8,20 @@ ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
11
  processor = AutoProcessor.from_pretrained(
12
  ASR_MODEL_ID,
13
  token=HF_TOKEN
14
  )
15
 
 
16
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
17
  ASR_MODEL_ID,
18
  token=HF_TOKEN
19
  ).to(DEVICE)
20
 
21
  asr_model.eval()
 
22
 
23
  def transcribe_audio(audio):
24
  if audio is None:
@@ -26,8 +29,9 @@ def transcribe_audio(audio):
26
 
27
  speech, sr = audio
28
 
 
29
  if sr != 16000:
30
- speech = librosa.resample(speech, sr, 16000)
31
 
32
  inputs = processor(
33
  audios=speech,
@@ -35,15 +39,9 @@ def transcribe_audio(audio):
35
  return_tensors="pt"
36
  ).to(DEVICE)
37
 
38
- forced_decoder_ids = processor.get_decoder_prompt_ids(
39
- task="transcribe",
40
- language="eng"
41
- )
42
-
43
  with torch.no_grad():
44
  generated_ids = asr_model.generate(
45
  inputs["input_features"],
46
- forced_decoder_ids=forced_decoder_ids,
47
  max_new_tokens=256
48
  )
49
 
@@ -58,7 +56,8 @@ demo = gr.Interface(
58
  fn=transcribe_audio,
59
  inputs=gr.Audio(type="numpy", label="Upload Speech"),
60
  outputs=gr.Textbox(label="Transcription"),
61
- title="HealthAtlas ASR Service",
 
62
  )
63
 
64
  if __name__ == "__main__":
 
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ print("Loading ASR processor...")
12
  processor = AutoProcessor.from_pretrained(
13
  ASR_MODEL_ID,
14
  token=HF_TOKEN
15
  )
16
 
17
+ print("Loading ASR model...")
18
  asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
19
  ASR_MODEL_ID,
20
  token=HF_TOKEN
21
  ).to(DEVICE)
22
 
23
  asr_model.eval()
24
+ print("ASR model loaded successfully")
25
 
26
  def transcribe_audio(audio):
27
  if audio is None:
 
29
 
30
  speech, sr = audio
31
 
32
+ # Ensure 16kHz
33
  if sr != 16000:
34
+ speech = librosa.resample(speech, orig_sr=sr, target_sr=16000)
35
 
36
  inputs = processor(
37
  audios=speech,
 
39
  return_tensors="pt"
40
  ).to(DEVICE)
41
 
 
 
 
 
 
42
  with torch.no_grad():
43
  generated_ids = asr_model.generate(
44
  inputs["input_features"],
 
45
  max_new_tokens=256
46
  )
47
 
 
56
  fn=transcribe_audio,
57
  inputs=gr.Audio(type="numpy", label="Upload Speech"),
58
  outputs=gr.Textbox(label="Transcription"),
59
+ title="HealthAtlas ASR Service (Auto Language)",
60
+ description="Speech → Text with automatic language detection"
61
  )
62
 
63
  if __name__ == "__main__":