Spaces:
Running
Running
| import gradio as gr | |
| from peft import PeftModel, PeftConfig | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| import torch | |
| import torchaudio | |
| import os | |
| # Check if CUDA is available and set the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| peft_model_id = "TArtx/MinD_CH_PEFT_ID" | |
| peft_config = PeftConfig.from_pretrained(peft_model_id) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| "BELLE-2/Belle-whisper-large-v3-zh", | |
| device_map=None | |
| ).to(device) | |
| model = PeftModel.from_pretrained(model, peft_model_id) | |
| return model | |
| def transcribe(audio_path): | |
| if audio_path is None: | |
| return "Please upload an audio file." | |
| try: | |
| # Load and resample audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| waveform = waveform.to(device) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| waveform = resampler(waveform) | |
| # Convert to numpy array | |
| audio_array = waveform.squeeze().cpu().numpy() | |
| # Process audio input | |
| inputs = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).to(device) | |
| # Generate transcription | |
| predicted_ids = model.generate(**inputs) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return transcription | |
| except Exception as e: | |
| return f"Error during transcription: {str(e)}" | |
| # Initialize model and processor | |
| print("Loading model...") | |
| model = load_model() | |
| processor = WhisperProcessor.from_pretrained( | |
| "BELLE-2/Belle-whisper-large-v3-zh", | |
| language="Chinese", | |
| task="transcribe" | |
| ) | |
| print("Model loaded!") | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=transcribe, | |
| inputs=gr.Audio(type="filepath"), | |
| outputs="text", | |
| title="Chinese-Mindong Speech Recognition", | |
| description="Upload an audio file for transcription. Model optimized for Eastern Min dialect." | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch() |