David-Chew-HL commited on
Commit
fa1c2c0
·
verified ·
1 Parent(s): c8884b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -20
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import tempfile
3
  from pathlib import Path
4
 
@@ -11,7 +13,7 @@ MODEL_NAME = "Qwen/Qwen3-ASR-1.7B"
11
  LANG_MAP = {
12
  "English": "English",
13
  "Chinese": "Chinese",
14
- "Bilingual": None, # auto-detect mixed English + Mandarin
15
  }
16
 
17
  device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -21,40 +23,94 @@ model = Qwen3ASRModel.from_pretrained(
21
  MODEL_NAME,
22
  dtype=dtype,
23
  device_map=device_map,
24
- max_inference_batch_size=1,
25
- max_new_tokens=1024,
26
  )
27
 
28
- def transcribe(audio_path: str, mode: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if not audio_path:
30
  raise gr.Error("Please upload an audio file.")
31
 
32
  if mode not in LANG_MAP:
33
  raise gr.Error("Invalid mode selected.")
34
 
35
- language = LANG_MAP[mode]
 
36
 
37
- result = model.transcribe(
38
- audio=audio_path,
39
- language=language,
40
- )[0]
41
 
42
- text = result.text.strip()
 
43
 
44
- if not text:
45
- text = ""
 
 
46
 
47
- out_dir = Path(tempfile.mkdtemp())
48
- txt_path = out_dir / "transcript.txt"
49
- txt_path.write_text(text, encoding="utf-8")
 
 
 
 
50
 
51
- detected_language = getattr(result, "language", None)
 
52
 
53
- meta = f"Mode: {mode}"
54
- if detected_language:
55
- meta += f"\nDetected language: {detected_language}"
 
 
 
56
 
57
- return text, str(txt_path), meta
58
 
59
  with gr.Blocks(title="Qwen3 ASR Transcriber") as demo:
60
  gr.Markdown("# Qwen3 ASR Transcriber")
 
1
  import os
2
+ import shutil
3
+ import subprocess
4
  import tempfile
5
  from pathlib import Path
6
 
 
13
  LANG_MAP = {
14
  "English": "English",
15
  "Chinese": "Chinese",
16
+ "Bilingual": None, # let Qwen auto-detect
17
  }
18
 
19
  device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
 
23
  MODEL_NAME,
24
  dtype=dtype,
25
  device_map=device_map,
26
+ max_inference_batch_size=1
 
27
  )
28
 
29
+
30
+ def normalize_audio(input_path: str, progress: gr.Progress | None = None) -> str:
31
+ """
32
+ Convert uploaded audio to mono 16k WAV.
33
+ No silence trimming. No noise reduction.
34
+ """
35
+ if progress:
36
+ progress(0.15, desc="Preparing audio...")
37
+
38
+ if shutil.which("ffmpeg") is None:
39
+ raise gr.Error("ffmpeg is not installed in this environment.")
40
+
41
+ out_dir = Path(tempfile.mkdtemp())
42
+ out_path = out_dir / "normalized.wav"
43
+
44
+ cmd = [
45
+ "ffmpeg",
46
+ "-y",
47
+ "-i", input_path,
48
+ "-ac", "1", # mono
49
+ "-ar", "16000", # 16 kHz
50
+ "-vn",
51
+ str(out_path),
52
+ ]
53
+
54
+ try:
55
+ subprocess.run(
56
+ cmd,
57
+ check=True,
58
+ stdout=subprocess.DEVNULL,
59
+ stderr=subprocess.DEVNULL,
60
+ )
61
+ except subprocess.CalledProcessError:
62
+ raise gr.Error("Failed to process the uploaded audio file.")
63
+
64
+ return str(out_path)
65
+
66
+
67
+ def make_output_txt(text: str, original_audio_path: str) -> str:
68
+ out_dir = Path(tempfile.mkdtemp())
69
+ stem = Path(original_audio_path).stem or "transcript"
70
+ txt_path = out_dir / f"{stem}.txt"
71
+ txt_path.write_text(text, encoding="utf-8")
72
+ return str(txt_path)
73
+
74
+
75
+ def transcribe(audio_path: str, mode: str, progress=gr.Progress()):
76
  if not audio_path:
77
  raise gr.Error("Please upload an audio file.")
78
 
79
  if mode not in LANG_MAP:
80
  raise gr.Error("Invalid mode selected.")
81
 
82
+ progress(0.05, desc="Starting...")
83
+ normalized_path = None
84
 
85
+ try:
86
+ normalized_path = normalize_audio(audio_path, progress=progress)
 
 
87
 
88
+ progress(0.45, desc="Running transcription...")
89
+ language = LANG_MAP[mode]
90
 
91
+ result = model.transcribe(
92
+ audio=normalized_path,
93
+ language=language,
94
+ )[0]
95
 
96
+ text = (result.text or "").strip()
97
+ txt_path = make_output_txt(text, audio_path)
98
+
99
+ detected_language = getattr(result, "language", None)
100
+ info = f"Mode: {mode}"
101
+ if detected_language:
102
+ info += f"\nDetected language: {detected_language}"
103
 
104
+ progress(1.0, desc="Done")
105
+ return text, txt_path, info
106
 
107
+ finally:
108
+ if normalized_path and os.path.exists(normalized_path):
109
+ try:
110
+ os.remove(normalized_path)
111
+ except OSError:
112
+ pass
113
 
 
114
 
115
  with gr.Blocks(title="Qwen3 ASR Transcriber") as demo:
116
  gr.Markdown("# Qwen3 ASR Transcriber")