import torch import torchaudio from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import gradio as gr MODEL_PATH = "nambn0321/ASR_models" processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH) model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def transcribe(audio): try: if audio is None: return "No audio provided" sr, data = audio print(f"Sample rate: {sr}, Audio shape: {len(data)}") waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0) waveform = waveform / 32768.0 if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) input_values = inputs.input_values.to(device) with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription.strip() except Exception as e: print("Error during transcription:", str(e)) return f"Error: {str(e)}" gr.Interface( fn=transcribe, inputs=gr.Audio(type="numpy", label="Upload WAV/MP3 file"), outputs=gr.Textbox(label="Transcription"), title=" ASR Demo oMGMGGOMGOMGOGMOG", description="Upload an audio file (WAV or MP3) and get the transcription.", ).launch()