BissakaAI commited on
Commit
624a6c7
·
verified ·
1 Parent(s): d76d28d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -2,18 +2,13 @@ import os
2
  import torch
3
  import gradio as gr
4
  import librosa
5
- from transformers import (
6
- AutoProcessor,
7
- SeamlessM4Tv2ForSpeechToText
8
- )
9
-
10
 
11
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
-
16
- print("Loading ASR processor...")
17
  processor = AutoProcessor.from_pretrained(
18
  ASR_MODEL_ID,
19
  token=HF_TOKEN
@@ -26,13 +21,12 @@ asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
26
  ).to(DEVICE)
27
 
28
  asr_model.eval()
29
- print("ASR model loaded successfully")
30
 
31
  def transcribe_audio(audio_path):
32
  if audio_path is None:
33
  return "No audio provided."
34
 
35
- # Load audio
36
  speech, sr = librosa.load(audio_path, sr=16000)
37
 
38
  inputs = processor(
@@ -41,27 +35,31 @@ def transcribe_audio(audio_path):
41
  return_tensors="pt"
42
  ).to(DEVICE)
43
 
 
 
 
 
 
44
  with torch.no_grad():
45
- predicted_ids = asr_model.generate(
46
- **inputs,
47
- task="transcribe",
48
- max_new_tokens=300
49
  )
50
 
51
  transcription = processor.batch_decode(
52
- predicted_ids,
53
  skip_special_tokens=True
54
  )[0]
55
 
56
  return transcription.strip()
57
 
58
-
59
  demo = gr.Interface(
60
  fn=transcribe_audio,
61
  inputs=gr.Audio(type="filepath", label="Upload Speech"),
62
  outputs=gr.Textbox(label="Transcription"),
63
  title="HealthAtlas ASR Service",
64
- description="Speech → Text using SeamlessM4T v2"
65
  )
66
 
67
  if __name__ == "__main__":
 
2
  import torch
3
  import gradio as gr
4
  import librosa
5
+ from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText
 
 
 
 
6
 
7
  ASR_MODEL_ID = "facebook/seamless-m4t-v2-large"
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ print("🔹 Loading processor...")
 
12
  processor = AutoProcessor.from_pretrained(
13
  ASR_MODEL_ID,
14
  token=HF_TOKEN
 
21
  ).to(DEVICE)
22
 
23
  asr_model.eval()
24
+ print("✅ ASR model loaded")
25
 
26
  def transcribe_audio(audio_path):
27
  if audio_path is None:
28
  return "No audio provided."
29
 
 
30
  speech, sr = librosa.load(audio_path, sr=16000)
31
 
32
  inputs = processor(
 
35
  return_tensors="pt"
36
  ).to(DEVICE)
37
 
38
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
39
+ task="transcribe",
40
+ language="eng"
41
+ )
42
+
43
  with torch.no_grad():
44
+ generated_ids = asr_model.generate(
45
+ inputs["input_features"],
46
+ forced_decoder_ids=forced_decoder_ids,
47
+ max_new_tokens=256
48
  )
49
 
50
  transcription = processor.batch_decode(
51
+ generated_ids,
52
  skip_special_tokens=True
53
  )[0]
54
 
55
  return transcription.strip()
56
 
 
57
  demo = gr.Interface(
58
  fn=transcribe_audio,
59
  inputs=gr.Audio(type="filepath", label="Upload Speech"),
60
  outputs=gr.Textbox(label="Transcription"),
61
  title="HealthAtlas ASR Service",
62
+ description="Speech → Text (SeamlessM4T v2)"
63
  )
64
 
65
  if __name__ == "__main__":