hyungjoochae commited on
Commit
cf76f12
·
verified ·
1 Parent(s): 7e6df1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -10,14 +10,14 @@ import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
 
13
- # Install flash-attn without building CUDA part
14
  subprocess.run(
15
  "pip install flash-attn --no-build-isolation",
16
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
  shell=True,
18
  )
19
 
20
- # Available models
21
  MODEL_OPTIONS = [
22
  "openai/whisper-tiny",
23
  "openai/whisper-base",
@@ -26,17 +26,19 @@ MODEL_OPTIONS = [
26
  "openai/whisper-large-v3-turbo"
27
  ]
28
 
29
- # Set device and dtype
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  torch_dtype = torch.float16
32
 
33
- # Default model
34
  current_model_name = MODEL_OPTIONS[-1]
35
 
36
- # Load pipeline for selected model
37
  def load_pipeline(model_name):
38
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
39
- model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True,
 
 
 
40
  attn_implementation="flash_attention_2"
41
  ).to(device)
42
 
@@ -53,19 +55,19 @@ def load_pipeline(model_name):
53
  device=device,
54
  )
55
 
56
- # Initialize pipeline
57
  pipe = load_pipeline(current_model_name)
58
 
59
- # Function to update model
60
- def update_model_and_return_status(model_name):
61
- global pipe, current_model_name
62
  current_model_name = model_name
63
  pipe = load_pipeline(model_name)
64
- return f"✅ Loaded model: {model_name}"
65
 
66
  @spaces.GPU
67
  def stream_transcribe(stream, new_chunk):
68
- start_time = time.time()
69
  try:
70
  sr, y = new_chunk
71
  if y.ndim > 1:
@@ -81,7 +83,7 @@ def stream_transcribe(stream, new_chunk):
81
 
82
  @spaces.GPU
83
  def transcribe(inputs, previous_transcription):
84
- start_time = time.time()
85
  try:
86
  filename = f"{uuid.uuid4().hex}.wav"
87
  sample_rate, audio_data = inputs
@@ -96,14 +98,15 @@ def transcribe(inputs, previous_transcription):
96
  def clear(): return ""
97
  def clear_state(): return None
98
 
99
- # Microphone Interface
100
  with gr.Blocks() as microphone:
101
  with gr.Column():
 
102
  model_dropdown = gr.Dropdown(label="Select Whisper Model", choices=MODEL_OPTIONS, value=current_model_name)
 
103
  model_status = gr.Textbox(label="Model Load Status", value=f"✅ Loaded model: {current_model_name}")
104
- model_dropdown.change(fn=update_model_and_return_status, inputs=model_dropdown, outputs=model_status)
105
 
106
- gr.Markdown("# 🎤 Realtime Whisper ASR (Streaming)")
107
  with gr.Row():
108
  input_audio_microphone = gr.Audio(streaming=True)
109
  output = gr.Textbox(label="Transcription", value="")
@@ -112,22 +115,23 @@ with gr.Blocks() as microphone:
112
  clear_button = gr.Button("Clear Output")
113
  state = gr.State()
114
  input_audio_microphone.stream(
115
- stream_transcribe,
116
- [state, input_audio_microphone],
117
- [state, output, latency_textbox],
118
- time_limit=30,
119
  stream_every=2
120
  )
121
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
122
 
123
- # File Upload Interface
124
  with gr.Blocks() as file:
125
  with gr.Column():
 
126
  model_dropdown_file = gr.Dropdown(label="Select Whisper Model", choices=MODEL_OPTIONS, value=current_model_name)
 
127
  model_status_file = gr.Textbox(label="Model Load Status", value=f"✅ Loaded model: {current_model_name}")
128
- model_dropdown_file.change(fn=update_model_and_return_status, inputs=model_dropdown_file, outputs=model_status_file)
129
 
130
- gr.Markdown("# 📁 Upload Audio File for Transcription")
131
  with gr.Row():
132
  input_audio_file = gr.Audio(sources="upload", type="numpy")
133
  output = gr.Textbox(label="Transcription", value="")
@@ -138,7 +142,7 @@ with gr.Blocks() as file:
138
  submit_button.click(transcribe, [input_audio_file, output], [output, latency_textbox])
139
  clear_button.click(clear, outputs=[output])
140
 
141
- # Combine into demo
142
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
143
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
144
 
 
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
 
13
+ # Install flash-attn
14
  subprocess.run(
15
  "pip install flash-attn --no-build-isolation",
16
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
  shell=True,
18
  )
19
 
20
+ # Whisper 모델 리스트
21
  MODEL_OPTIONS = [
22
  "openai/whisper-tiny",
23
  "openai/whisper-base",
 
26
  "openai/whisper-large-v3-turbo"
27
  ]
28
 
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  torch_dtype = torch.float16
31
 
32
+ # 초기 모델 설정
33
  current_model_name = MODEL_OPTIONS[-1]
34
 
35
+ # 모델 불러오기 함수
36
  def load_pipeline(model_name):
37
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch_dtype,
40
+ low_cpu_mem_usage=True,
41
+ use_safetensors=True,
42
  attn_implementation="flash_attention_2"
43
  ).to(device)
44
 
 
55
  device=device,
56
  )
57
 
58
+ # 전역 상태
59
  pipe = load_pipeline(current_model_name)
60
 
61
+ # 모델 로딩 버튼 함수
62
+ def update_model_with_button(model_name):
63
+ global current_model_name, pipe
64
  current_model_name = model_name
65
  pipe = load_pipeline(model_name)
66
+ return f"✅ Model loaded: {model_name}"
67
 
68
  @spaces.GPU
69
  def stream_transcribe(stream, new_chunk):
70
+ start_time = time.time()
71
  try:
72
  sr, y = new_chunk
73
  if y.ndim > 1:
 
83
 
84
  @spaces.GPU
85
  def transcribe(inputs, previous_transcription):
86
+ start_time = time.time()
87
  try:
88
  filename = f"{uuid.uuid4().hex}.wav"
89
  sample_rate, audio_data = inputs
 
98
  def clear(): return ""
99
  def clear_state(): return None
100
 
101
+ # 마이크 입력 탭
102
  with gr.Blocks() as microphone:
103
  with gr.Column():
104
+ gr.Markdown("### 🎙️ Realtime Whisper Transcription")
105
  model_dropdown = gr.Dropdown(label="Select Whisper Model", choices=MODEL_OPTIONS, value=current_model_name)
106
+ model_load_button = gr.Button("Load Model")
107
  model_status = gr.Textbox(label="Model Load Status", value=f"✅ Loaded model: {current_model_name}")
108
+ model_load_button.click(fn=update_model_with_button, inputs=[model_dropdown], outputs=[model_status])
109
 
 
110
  with gr.Row():
111
  input_audio_microphone = gr.Audio(streaming=True)
112
  output = gr.Textbox(label="Transcription", value="")
 
115
  clear_button = gr.Button("Clear Output")
116
  state = gr.State()
117
  input_audio_microphone.stream(
118
+ stream_transcribe,
119
+ [state, input_audio_microphone],
120
+ [state, output, latency_textbox],
121
+ time_limit=30,
122
  stream_every=2
123
  )
124
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
125
 
126
+ # 파일 업로드
127
  with gr.Blocks() as file:
128
  with gr.Column():
129
+ gr.Markdown("### 📁 Upload Audio File for Transcription")
130
  model_dropdown_file = gr.Dropdown(label="Select Whisper Model", choices=MODEL_OPTIONS, value=current_model_name)
131
+ model_load_button_file = gr.Button("Load Model")
132
  model_status_file = gr.Textbox(label="Model Load Status", value=f"✅ Loaded model: {current_model_name}")
133
+ model_load_button_file.click(fn=update_model_with_button, inputs=[model_dropdown_file], outputs=[model_status_file])
134
 
 
135
  with gr.Row():
136
  input_audio_file = gr.Audio(sources="upload", type="numpy")
137
  output = gr.Textbox(label="Transcription", value="")
 
142
  submit_button.click(transcribe, [input_audio_file, output], [output, latency_textbox])
143
  clear_button.click(clear, outputs=[output])
144
 
145
+ # 통합된 데모 UI
146
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
147
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
148