Ryanus commited on
Commit
d224b23
·
verified ·
1 Parent(s): 4a42dd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -239
app.py CHANGED
@@ -1,240 +1,25 @@
1
- import os
2
- import torch
3
- import torchaudio
4
  import gradio as gr
5
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
6
- from diffusers import StableAudioPipeline
7
- import numpy as np
8
- from scipy.io import wavfile
9
-
10
- # --- IMPLEMENT MODEL LOADING AND FUNCTIONS ACCORDING TO GPA REPO ---
11
-
12
- # Global variable to store the model and processor after initial load
13
- model_instance = None
14
- processor_instance = None
15
- tts_pipeline_instance = None
16
-
17
- def load_gpa_model():
18
- global model_instance, processor_instance, tts_pipeline_instance
19
-
20
- # Use environment variables for cache directory if available (useful for Spaces)
21
- cache_dir = os.getenv('HF_HOME', './hf_cache')
22
-
23
- # --- ASR Model Loading ---
24
- print("Loading ASR Model...")
25
- asr_model_id = "AutoArk-AI/GPA-0.9B-preview-ASR"
26
- try:
27
- model_instance = AutoModelForSpeechSeq2Seq.from_pretrained(
28
- asr_model_id,
29
- torch_dtype=torch.float16, # Use float16 for efficiency if supported
30
- low_cpu_mem_usage=True,
31
- use_safetensors=True,
32
- cache_dir=cache_dir
33
- ).to("cuda" if torch.cuda.is_available() else "cpu") # Move to GPU if available
34
-
35
- processor_instance = AutoProcessor.from_pretrained(asr_model_id, cache_dir=cache_dir)
36
- print("ASR Model loaded successfully.")
37
-
38
- except Exception as e:
39
- print(f"Error loading ASR model: {e}")
40
- raise gr.Error(f"Failed to load ASR model: {e}")
41
-
42
- # --- TTS Pipeline Loading ---
43
- print("Loading TTS Pipeline...")
44
- tts_model_id = "AutoArk-AI/GPA-0.9B-preview-TTS"
45
- try:
46
- # The TTS model appears to be based on Stable Audio Open Repo
47
- tts_pipeline_instance = StableAudioPipeline.from_pretrained(
48
- tts_model_id,
49
- torch_dtype=torch.float16,
50
- cache_dir=cache_dir
51
- ).to("cuda" if torch.cuda.is_available() else "cpu")
52
- print("TTS Pipeline loaded successfully.")
53
-
54
- except Exception as e:
55
- print(f"Error loading TTS pipeline: {e}")
56
- raise gr.Error(f"Failed to load TTS pipeline: {e}")
57
-
58
- print("All models loaded successfully!")
59
- return model_instance, processor_instance, tts_pipeline_instance
60
-
61
-
62
- def run_tts(text, pipe, device):
63
- """Run TTS using the StableAudioPipeline."""
64
- if not text.strip():
65
- raise gr.Error("Text input cannot be empty.")
66
-
67
- try:
68
- # Generate audio using the pipeline
69
- # The exact parameters might need fine-tuning based on the model's expected prompt format
70
- output = pipe(
71
- prompt=text,
72
- negative_prompt="", # You might want to adjust this
73
- num_inference_steps=100, # Adjust steps as needed
74
- audio_end_size=1024 * 48000 // 32, # Example: ~10 seconds at 48kHz, adjust as needed
75
- generator=torch.Generator().manual_seed(42), # For reproducibility
76
- )
77
-
78
- # Extract audio tensor
79
- audio_tensor = output.audios[0] # Shape: [channels, time_steps]
80
-
81
- # Convert to numpy array and then to the expected format for Gradio (float32 [-1, 1])
82
- audio_np = audio_tensor.cpu().numpy()
83
- # Ensure shape is (time_steps,) for mono or (time_steps, channels) for stereo
84
- if audio_np.ndim > 1 and audio_np.shape[0] == 1:
85
- audio_np = audio_np[0] # Flatten if it's (1, time_steps)
86
- elif audio_np.ndim > 1 and audio_np.shape[0] == 2:
87
- audio_np = audio_np.T # Transpose if it's (2, time_steps) -> (time_steps, 2)
88
-
89
- # Normalize if values are outside [-1, 1] range (depends on model output scale)
90
- if np.max(np.abs(audio_np)) > 1.0:
91
- audio_np = audio_np / np.max(np.abs(audio_np))
92
-
93
- # Create a temporary file to save the audio
94
- temp_filename = "temp_tts_output.wav"
95
- # Gradio expects int16 wav files for filepath mode, but accepts float32 for numpy arrays.
96
- # Saving as int16 wav for compatibility.
97
- scaled_audio = np.int16(audio_np * 32767)
98
- wavfile.write(temp_filename, 48000, scaled_audio) # Assuming 48kHz sample rate
99
- print(f"TTS completed, saved to {temp_filename}")
100
- return temp_filename
101
-
102
- except Exception as e:
103
- print(f"TTS Error: {e}")
104
- raise gr.Error(f"TTS generation failed: {e}")
105
-
106
-
107
- def run_asr(audio_path, model, processor, device):
108
- """Run ASR using the Whisper-based model."""
109
- if not audio_path:
110
- raise gr.Error("Audio input is required for ASR.")
111
-
112
- try:
113
- # Load and preprocess audio
114
- audio_input, sr = torchaudio.load(audio_path)
115
- # Resample to 16kHz if needed (Whisper typically uses 16kHz)
116
- if sr != 16000:
117
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
118
- audio_input = resampler(audio_input)
119
-
120
- # Take the mean along the channel axis if stereo
121
- audio_array = audio_input.mean(dim=0).numpy()
122
-
123
- # Create the pipeline using the loaded model and processor
124
- pipe = pipeline(
125
- "automatic-speech-recognition",
126
- model=model,
127
- tokenizer=processor.tokenizer,
128
- feature_extractor=processor.feature_extractor,
129
- max_new_tokens=128,
130
- chunk_length_s=15,
131
- batch_size=16,
132
- torch_dtype=torch.float16,
133
- device=device,
134
- )
135
-
136
- # Perform transcription
137
- result = pipe(audio_array)
138
- print(f"ASR completed: {result['text']}")
139
- return result["text"]
140
-
141
- except Exception as e:
142
- print(f"ASR Error: {e}")
143
- raise gr.Error(f"ASR transcription failed: {e}")
144
-
145
- # Attempt to load the model when the app starts
146
- print("Starting model loading process...")
147
- try:
148
- model_instance, processor_instance, tts_pipeline_instance = load_gpa_model()
149
- device = "cuda" if torch.cuda.is_available() else "cpu"
150
- print(f"Models loaded successfully on {device}.")
151
- except Exception as e:
152
- print(f"Critical Error during startup: {e}")
153
- model_instance = None
154
- processor_instance = None
155
- tts_pipeline_instance = None
156
- device = None
157
-
158
-
159
- def tts_interface(text):
160
- if tts_pipeline_instance is None:
161
- raise gr.Error("TTS model not loaded. Cannot perform TTS.")
162
- try:
163
- output_path = run_tts(text, tts_pipeline_instance, device)
164
- return output_path
165
- except Exception as e:
166
- print(f"TTS Interface Error: {e}")
167
- raise gr.Error(f"TTS failed: {e}")
168
-
169
- def asr_interface(audio):
170
- if model_instance is None or processor_instance is None:
171
- raise gr.Error("ASR model not loaded. Cannot perform ASR.")
172
- try:
173
- transcription = run_asr(audio, model_instance, processor_instance, device)
174
- return transcription
175
- except Exception as e:
176
- print(f"ASR Interface Error: {e}")
177
- raise gr.Error(f"ASR failed: {e}")
178
-
179
- # VC is not explicitly detailed as a separate model in the latest info found, so it's omitted for now
180
- # If a specific VC model exists later, it can be added similarly.
181
-
182
-
183
- with gr.Blocks(title="GPA Model Demo") as demo:
184
- gr.Markdown(
185
- """
186
- # GPA Model Demo (0.9B Preview)
187
- Unified TTS and ASR powered by AutoArk-AI's GPA model.
188
- """
189
- )
190
-
191
- with gr.Tab("Text-to-Speech (TTS)"):
192
- with gr.Row():
193
- with gr.Column():
194
- text_input_tts = gr.Textbox(
195
- label="Input Text",
196
- placeholder="Enter text to convert to speech...",
197
- lines=5
198
- )
199
- tts_button = gr.Button("Generate Speech", variant="primary")
200
- with gr.Column():
201
- audio_output_tts = gr.Audio(
202
- label="Generated Audio",
203
- type="filepath"
204
- )
205
- tts_button.click(
206
- fn=tts_interface,
207
- inputs=text_input_tts,
208
- outputs=audio_output_tts
209
- )
210
-
211
- with gr.Tab("Automatic Speech Recognition (ASR)"):
212
- with gr.Row():
213
- with gr.Column():
214
- audio_input_asr = gr.Audio(
215
- label="Upload Audio File",
216
- type="filepath",
217
- sources=["upload"],
218
- )
219
- asr_button = gr.Button("Transcribe Speech", variant="primary")
220
- with gr.Column():
221
- text_output_asr = gr.Textbox(
222
- label="Transcribed Text",
223
- placeholder="Transcription will appear here...",
224
- interactive=False
225
- )
226
- asr_button.click(
227
- fn=asr_interface,
228
- inputs=audio_input_asr,
229
- outputs=text_output_asr
230
- )
231
-
232
- gr.Markdown(
233
- """
234
- ---
235
- *Powered by [AutoArk-AI/GPA](https://huggingface.co/AutoArk-AI/GPA). Deployed on Hugging Face Spaces.*
236
- """
237
- )
238
-
239
- if __name__ == "__main__":
240
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ # 加載模型和分詞器
5
+ model_name = "AutoArk-AI/GPA"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda") # 如果使用 GPU
8
+
9
+ def generate_text(input_text):
10
+ # 將輸入文本進行分詞並生成輸出
11
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda") # 如果使用 GPU
12
+ outputs = model.generate(**inputs, max_length=50)
13
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
14
+
15
+ # 創建 Gradio 界面
16
+ interface = gr.Interface(
17
+ fn=generate_text,
18
+ inputs=gr.Textbox(lines=5, placeholder="輸入你的文本..."),
19
+ outputs="text",
20
+ title="AutoArk-AI/GPA 模型演示",
21
+ description="輸入文本,模型將生成回覆。"
22
+ )
23
+
24
+ # 啟動界面
25
+ interface.launch()