legolasyiu commited on
Commit
cf972e4
·
verified ·
1 Parent(s): 505b4c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -20
app.py CHANGED
@@ -8,34 +8,63 @@ from transformers import AutoProcessor, TextIteratorStreamer
8
  from threading import Thread
9
 
10
  TARGET_SAMPLING_RATE = 16000
11
- print("Loading model and processor...")
 
 
 
12
  processor = AutoProcessor.from_pretrained('EpistemeAI/Audiogemma-3N-finetune')
 
 
 
13
  model, _ = FastModel.from_pretrained(
14
  model_name='EpistemeAI/Audiogemma-3N-finetune',
15
  max_seq_length=512,
16
  load_in_4bit=True,
17
  dtype=torch.bfloat16,
18
  )
19
- print("Model and processor loaded successfully.")
 
 
 
 
 
 
 
20
 
21
  def transcribe_and_translate(audio_input):
22
  """
23
- This function takes audio data from the Gradio component, processes it,
24
- and then streams the model's transcription and translation back to the UI.
25
  """
26
  if audio_input is None:
27
  yield "Error: Please upload or record a German audio file first."
28
  return
29
- sample_rate, audio_array = audio_input
 
 
 
 
 
 
 
 
 
 
 
 
30
  if audio_array.ndim > 1:
31
  audio_array = audio_array.mean(axis=1)
 
32
  audio_array = audio_array.astype(np.float32)
 
 
33
  if sample_rate != TARGET_SAMPLING_RATE:
34
  audio_array = librosa.resample(
35
  y=audio_array,
36
  orig_sr=sample_rate,
37
  target_sr=TARGET_SAMPLING_RATE
38
  )
 
39
  messages = [
40
  {
41
  'role': 'system',
@@ -50,34 +79,77 @@ def transcribe_and_translate(audio_input):
50
  'role': 'user',
51
  'content': [
52
  {'type': 'audio', 'audio': audio_array},
53
- {'type': 'text', 'text': 'Please transcribe this audio and translate it to English. Give both, the transcription and the translation.'}
54
  ]
55
  }
56
  ]
 
 
57
  inputs = processor.apply_chat_template(
58
  messages,
59
  add_generation_prompt=True,
60
  tokenize=True,
61
  return_dict=True,
62
  return_tensors='pt'
63
- ).to('cuda', dtype=torch.bfloat16)
64
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
65
- generation_kwargs = dict(
66
- **inputs,
67
- streamer=streamer,
68
- max_new_tokens=1024,
69
- do_sample=False,
70
  )
71
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  thread.start()
 
 
73
  output_text = ""
74
- for new_text in streamer:
75
- output_text += new_text
76
- yield output_text
 
 
 
 
 
 
 
77
 
78
- # Grab all wav files in the directory
79
  example_audios = glob.glob('test_wav_files/*.wav')
80
- example_list = [ for audio in example_audios]
 
81
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
82
  gr.Markdown(
83
  """
@@ -90,11 +162,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
  audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="German Audio")
91
  text_output = gr.Textbox(label="Transcription and Translation", lines=10, interactive=False)
92
  submit_btn = gr.Button("Transcribe and Translate", variant="primary")
 
 
93
  submit_btn.click(
94
  fn=transcribe_and_translate,
95
  inputs=audio_input,
96
  outputs=text_output
97
  )
 
98
  gr.Examples(
99
  examples=example_list,
100
  inputs=audio_input,
@@ -102,5 +177,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
102
  fn=transcribe_and_translate,
103
  cache_examples=False
104
  )
 
105
  if __name__ == "__main__":
106
- demo.launch(share=True)
 
8
  from threading import Thread
9
 
10
  TARGET_SAMPLING_RATE = 16000
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ print("Loading processor and model...")
15
  processor = AutoProcessor.from_pretrained('EpistemeAI/Audiogemma-3N-finetune')
16
+
17
+ # FastModel.from_pretrained may return (model, something). keep as you had it.
18
+ # Note: load_in_4bit and dtype handling depend on your environment and FastModel implementation.
19
  model, _ = FastModel.from_pretrained(
20
  model_name='EpistemeAI/Audiogemma-3N-finetune',
21
  max_seq_length=512,
22
  load_in_4bit=True,
23
  dtype=torch.bfloat16,
24
  )
25
+ # Move model to device if needed (FastModel might already handle device_map)
26
+ try:
27
+ model.to(device)
28
+ except Exception:
29
+ # some FastModel wrappers manage device automatically; ignore if .to is unsupported
30
+ pass
31
+
32
+ print("Model and processor loaded successfully. Device:", device)
33
 
34
  def transcribe_and_translate(audio_input):
35
  """
36
+ Generator function for Gradio streaming. Yields progressive output text.
37
+ audio_input from gr.Audio(type="numpy") is (sample_rate, np_array)
38
  """
39
  if audio_input is None:
40
  yield "Error: Please upload or record a German audio file first."
41
  return
42
+
43
+ # Unpack Gradio audio tuple
44
+ try:
45
+ sample_rate, audio_array = audio_input
46
+ except Exception:
47
+ # If Gradio returns just the numpy array sometimes, handle that
48
+ audio_array = audio_input
49
+ sample_rate = TARGET_SAMPLING_RATE
50
+
51
+ # Mono conversion
52
+ if audio_array is None:
53
+ yield "Error: audio data is empty."
54
+ return
55
  if audio_array.ndim > 1:
56
  audio_array = audio_array.mean(axis=1)
57
+
58
  audio_array = audio_array.astype(np.float32)
59
+
60
+ # Resample if needed
61
  if sample_rate != TARGET_SAMPLING_RATE:
62
  audio_array = librosa.resample(
63
  y=audio_array,
64
  orig_sr=sample_rate,
65
  target_sr=TARGET_SAMPLING_RATE
66
  )
67
+
68
  messages = [
69
  {
70
  'role': 'system',
 
79
  'role': 'user',
80
  'content': [
81
  {'type': 'audio', 'audio': audio_array},
82
+ {'type': 'text', 'text': 'Please transcribe this audio and translate it to English. Give both the transcription and the translation.'}
83
  ]
84
  }
85
  ]
86
+
87
+ # Build model inputs. apply_chat_template returns tensors when return_tensors='pt'.
88
  inputs = processor.apply_chat_template(
89
  messages,
90
  add_generation_prompt=True,
91
  tokenize=True,
92
  return_dict=True,
93
  return_tensors='pt'
 
 
 
 
 
 
 
94
  )
95
+
96
+ # Move any tensors to the device (do NOT force dtype changes on integer tensors)
97
+ def _move_to_device(obj):
98
+ if isinstance(obj, torch.Tensor):
99
+ return obj.to(device)
100
+ if isinstance(obj, dict):
101
+ return {k: _move_to_device(v) for k, v in obj.items()}
102
+ if isinstance(obj, (list, tuple)):
103
+ return type(obj)(_move_to_device(x) for x in obj)
104
+ return obj
105
+
106
+ inputs = _move_to_device(inputs)
107
+
108
+ # Prepare the tokenizer-based streamer (TextIteratorStreamer expects a tokenizer)
109
+ tokenizer = getattr(processor, "tokenizer", None)
110
+ if tokenizer is None:
111
+ # fallback: try attribute name used by some processors
112
+ tokenizer = getattr(processor, "tokenizer_fast", None)
113
+
114
+ if tokenizer is None:
115
+ yield "Error: tokenizer not found on processor (needed for streaming)."
116
+ return
117
+
118
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
119
+
120
+ # Prepare generation args - only include tensor keys model.generate expects (e.g., input_ids, attention_mask)
121
+ gen_inputs = {}
122
+ for k, v in inputs.items():
123
+ # typical keys: input_ids, attention_mask, etc. pass tensors only.
124
+ if isinstance(v, torch.Tensor):
125
+ gen_inputs[k] = v
126
+ gen_inputs.update({
127
+ "streamer": streamer,
128
+ "max_new_tokens": 1024,
129
+ "do_sample": False,
130
+ })
131
+
132
+ # Run generation in background thread so we can stream results
133
+ thread = Thread(target=model.generate, kwargs=gen_inputs, daemon=True)
134
  thread.start()
135
+
136
+ # Collect and yield streaming text
137
  output_text = ""
138
+ try:
139
+ for new_text in streamer:
140
+ output_text += new_text
141
+ yield output_text
142
+ except GeneratorExit:
143
+ # Gradio closed the generator early
144
+ return
145
+ finally:
146
+ # ensure thread finishes (optional)
147
+ thread.join(timeout=1)
148
 
149
+ # Grab all wav files in the directory and format examples as lists for one input component
150
  example_audios = glob.glob('test_wav_files/*.wav')
151
+ example_list = [[audio] for audio in example_audios] # gr.Examples expects each example to match inputs
152
+
153
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
154
  gr.Markdown(
155
  """
 
162
  audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="German Audio")
163
  text_output = gr.Textbox(label="Transcription and Translation", lines=10, interactive=False)
164
  submit_btn = gr.Button("Transcribe and Translate", variant="primary")
165
+
166
+ # NOTE: For Gradio streaming to the Textbox, Gradio supports generator-returning functions mapped to an output component.
167
  submit_btn.click(
168
  fn=transcribe_and_translate,
169
  inputs=audio_input,
170
  outputs=text_output
171
  )
172
+
173
  gr.Examples(
174
  examples=example_list,
175
  inputs=audio_input,
 
177
  fn=transcribe_and_translate,
178
  cache_examples=False
179
  )
180
+
181
  if __name__ == "__main__":
182
+ demo.launch(share=True)