mds04 commited on
Commit
ff61ee9
·
verified ·
1 Parent(s): 5a3ca1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -46
app.py CHANGED
@@ -1,46 +1,52 @@
1
  import os, io
2
- import gradio as gr
3
  import numpy as np
4
  import soundfile as sf
 
5
  import torch
 
 
 
6
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
7
 
8
- # ---------- EDIT THESE TO YOUR ACTUAL MODEL IDS ----------
9
  MODEL_CATALOG = {
10
  "Iban (ASR)": {
11
- "repo_id": "mds04/iban_transcription", # <- change to exact name
12
  "language": "iban",
13
- "task": "transcribe",
14
  },
15
  "Bukar Sadong (ASR)": {
16
- "repo_id": "mds04/bukar_sadong_transcription", # <- change to exact name
17
  "language": "bukar-sadong",
18
- "task": "transcribe",
19
  },
20
  }
21
- # ---------------------------------------------------------
 
 
 
 
22
 
23
- # Private model(s)? Add Space secret HF_TOKEN (Settings → Secrets)
24
  HF_TOKEN = os.getenv("HF_TOKEN", None)
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.float16 if device == "cuda" else torch.float32
28
 
29
- # Lazy loaders so we don't load both models at startup
30
- _MODEL_CACHE = {}
31
 
32
- def _load_bundle(key):
33
- if key in _MODEL_CACHE:
34
- return _MODEL_CACHE[key]
35
- info = MODEL_CATALOG[key]
36
  proc = AutoProcessor.from_pretrained(info["repo_id"], token=HF_TOKEN)
37
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
38
  info["repo_id"], token=HF_TOKEN, torch_dtype=dtype
39
  ).to(device).eval()
40
- _MODEL_CACHE[key] = (proc, model)
41
- return _MODEL_CACHE[key]
42
 
43
- def _resample_to_16k(x, sr):
 
44
  if sr == 16000:
45
  return x.astype(np.float32)
46
  duration = x.shape[0] / sr
@@ -48,62 +54,103 @@ def _resample_to_16k(x, sr):
48
  t_new = np.linspace(0.0, duration, num=int(duration * 16000), endpoint=False)
49
  return np.interp(t_new, t_old, x).astype(np.float32)
50
 
51
- def _load_audio_16k(path: str) -> np.ndarray:
52
- with open(path, "rb") as f:
53
- data, sr = sf.read(io.BytesIO(f.read()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if data.ndim == 2:
55
- data = data.mean(axis=1) # stereo -> mono
56
  return _resample_to_16k(data, sr)
57
 
58
- def transcribe(model_choice, audio_path, force_lang, task_choice, max_tokens):
59
- if not audio_path:
 
 
 
 
 
 
 
 
 
60
  return "Please upload or record audio."
 
61
  processor, model = _load_bundle(model_choice)
62
- audio = _load_audio_16k(audio_path)
63
  inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
64
  inputs = {k: v.to(device) for k, v in inputs.items()}
65
 
66
  gen_kwargs = dict(max_new_tokens=int(max_tokens), do_sample=False)
67
 
68
- # Try to force language/task if supported (Whisper-style)
69
- if force_lang or task_choice:
70
- if hasattr(processor, "get_decoder_prompt_ids"):
71
- lang = MODEL_CATALOG[model_choice]["language"] if force_lang else None
72
- try:
73
- gen_kwargs["forced_decoder_ids"] = processor.get_decoder_prompt_ids(
74
- language=lang,
75
- task=task_choice
76
- )
77
- except Exception:
78
- pass
79
 
80
  with torch.no_grad():
81
  ids = model.generate(**inputs, **gen_kwargs)
82
 
83
- text = processor.batch_decode(ids, skip_special_tokens=True)[0]
84
- return text
85
 
86
  with gr.Blocks(title="Iban & Bukar Sadong ASR") as demo:
87
  gr.Markdown("## Iban & Bukar Sadong Transcription\nSelect a model, then upload or record audio.")
 
88
  with gr.Row():
89
  model_choice = gr.Dropdown(
90
  choices=list(MODEL_CATALOG.keys()),
91
- value=list(MODEL_CATALOG.keys())[0],
92
  label="Model"
93
  )
 
94
  with gr.Row():
95
- audio_in = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
 
 
96
  with gr.Row():
97
- force_lang = gr.Checkbox(value=True, label="Force model’s language prompt")
98
- task_choice = gr.Dropdown(choices=["transcribe","translate"], value="transcribe", label="Task")
99
- max_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
 
 
 
 
 
 
 
100
 
101
- btn = gr.Button("Transcribe")
102
  out = gr.Textbox(label="Transcription", lines=4)
 
103
 
 
104
  btn.click(
105
- transcribe,
106
- inputs=[model_choice, audio_in, force_lang, task_choice, max_tokens],
107
  outputs=out
108
  )
109
 
 
1
  import os, io
 
2
  import numpy as np
3
  import soundfile as sf
4
+ import requests
5
  import torch
6
+ import gradio as gr
7
+ import spaces # <-- needed for GPU Zero
8
+
9
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
10
 
11
+ # --------------------- CONFIG: EDIT THESE ---------------------
12
  MODEL_CATALOG = {
13
  "Iban (ASR)": {
14
+ "repo_id": "mds04/iban_transcription", # <-- exact repo id
15
  "language": "iban",
 
16
  },
17
  "Bukar Sadong (ASR)": {
18
+ "repo_id": "mds04/bukar_sadong_transcription", # <-- exact repo id
19
  "language": "bukar-sadong",
 
20
  },
21
  }
22
+ DEFAULT_MODEL = "Iban (ASR)"
23
+ DEFAULT_TASK = "transcribe" # or "translate" if your model supports it
24
+ DEFAULT_FORCE_LANG = True
25
+ DEFAULT_MAX_TOKENS = 256
26
+ # --------------------------------------------------------------
27
 
28
+ # Private model(s)? Add Space Secret: HF_TOKEN
29
  HF_TOKEN = os.getenv("HF_TOKEN", None)
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  dtype = torch.float16 if device == "cuda" else torch.float32
33
 
34
+ # Lazy cache to avoid loading both models at startup
35
+ _MODEL_CACHE: dict[str, tuple[AutoProcessor, AutoModelForSpeechSeq2Seq]] = {}
36
 
37
+ def _load_bundle(model_key: str):
38
+ if model_key in _MODEL_CACHE:
39
+ return _MODEL_CACHE[model_key]
40
+ info = MODEL_CATALOG[model_key]
41
  proc = AutoProcessor.from_pretrained(info["repo_id"], token=HF_TOKEN)
42
+ mdl = AutoModelForSpeechSeq2Seq.from_pretrained(
43
  info["repo_id"], token=HF_TOKEN, torch_dtype=dtype
44
  ).to(device).eval()
45
+ _MODEL_CACHE[model_key] = (proc, mdl)
46
+ return _MODEL_CACHE[model_key]
47
 
48
+ def _resample_to_16k(x: np.ndarray, sr: int) -> np.ndarray:
49
+ """Naive linear resampler to 16k (no librosa)."""
50
  if sr == 16000:
51
  return x.astype(np.float32)
52
  duration = x.shape[0] / sr
 
54
  t_new = np.linspace(0.0, duration, num=int(duration * 16000), endpoint=False)
55
  return np.interp(t_new, t_old, x).astype(np.float32)
56
 
57
+ def _read_audio_bytes(path_or_url: str) -> bytes:
58
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
59
+ r = requests.get(path_or_url, timeout=30)
60
+ r.raise_for_status()
61
+ return r.content
62
+ with open(path_or_url, "rb") as f:
63
+ return f.read()
64
+
65
+ def _load_audio_16k(input_obj) -> np.ndarray:
66
+ """
67
+ Accepts:
68
+ - str filepath,
69
+ - dict from Gradio v4 with {'path': <url or filepath>, 'meta': {...}}
70
+ Returns mono float32 @ 16k
71
+ """
72
+ if isinstance(input_obj, dict) and "path" in input_obj:
73
+ path_or_url = input_obj["path"]
74
+ elif isinstance(input_obj, str):
75
+ path_or_url = input_obj
76
+ else:
77
+ raise ValueError("Unsupported audio input format")
78
+
79
+ raw = _read_audio_bytes(path_or_url)
80
+ data, sr = sf.read(io.BytesIO(raw))
81
  if data.ndim == 2:
82
+ data = data.mean(axis=1)
83
  return _resample_to_16k(data, sr)
84
 
85
+ # --------- IMPORTANT FOR GPU ZERO: decorate the main handler ----------
86
+ @spaces.GPU # <- tells Space to allocate GPU for this function
87
+ def transcribe(model_choice, audio_input, task_choice, force_lang, max_tokens):
88
+ """
89
+ model_choice: str (dropdown)
90
+ audio_input: filepath or dict with 'path'
91
+ task_choice: "transcribe" | "translate"
92
+ force_lang: bool
93
+ max_tokens: int (slider)
94
+ """
95
+ if not audio_input:
96
  return "Please upload or record audio."
97
+
98
  processor, model = _load_bundle(model_choice)
99
+ audio = _load_audio_16k(audio_input)
100
  inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
101
  inputs = {k: v.to(device) for k, v in inputs.items()}
102
 
103
  gen_kwargs = dict(max_new_tokens=int(max_tokens), do_sample=False)
104
 
105
+ # Force language (Whisper-style) if available and requested
106
+ if force_lang and hasattr(processor, "get_decoder_prompt_ids"):
107
+ try:
108
+ lang = MODEL_CATALOG[model_choice]["language"]
109
+ gen_kwargs["forced_decoder_ids"] = processor.get_decoder_prompt_ids(
110
+ language=lang, task=task_choice
111
+ )
112
+ except Exception:
113
+ pass
 
 
114
 
115
  with torch.no_grad():
116
  ids = model.generate(**inputs, **gen_kwargs)
117
 
118
+ return processor.batch_decode(ids, skip_special_tokens=True)[0]
119
+ # ---------------------------------------------------------------------
120
 
121
  with gr.Blocks(title="Iban & Bukar Sadong ASR") as demo:
122
  gr.Markdown("## Iban & Bukar Sadong Transcription\nSelect a model, then upload or record audio.")
123
+
124
  with gr.Row():
125
  model_choice = gr.Dropdown(
126
  choices=list(MODEL_CATALOG.keys()),
127
+ value=DEFAULT_MODEL,
128
  label="Model"
129
  )
130
+
131
  with gr.Row():
132
+ # Use type="filepath" so we get a path; code also supports remote URLs
133
+ audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
134
+
135
  with gr.Row():
136
+ task_choice = gr.Dropdown(
137
+ choices=["transcribe", "translate"],
138
+ value=DEFAULT_TASK,
139
+ label="Task"
140
+ )
141
+ force_lang = gr.Checkbox(
142
+ value=DEFAULT_FORCE_LANG,
143
+ label="Force model’s language prompt"
144
+ )
145
+ max_tokens = gr.Slider(64, 512, value=DEFAULT_MAX_TOKENS, step=16, label="Max new tokens")
146
 
 
147
  out = gr.Textbox(label="Transcription", lines=4)
148
+ btn = gr.Button("Transcribe")
149
 
150
+ # IMPORTANT: inputs here must match the function signature order
151
  btn.click(
152
+ fn=transcribe,
153
+ inputs=[model_choice, audio_in, task_choice, force_lang, max_tokens],
154
  outputs=out
155
  )
156