Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification | |
| # মডেল লোড করো | |
| model_name = "rakib730/finetuned-gtzan" | |
| extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| model = AutoModelForAudioClassification.from_pretrained(model_name) | |
| # মডেলকে eval মোডে নাও | |
| model.eval() | |
| # অডিও ক্লাসিফিকেশন ফাংশন | |
| def classify_music(audio): | |
| # audio: (numpy array, sample_rate) | |
| waveform, sample_rate = audio | |
| # মডেল ট্রেনিংয়ে ব্যবহৃত sample rate ঠিক করো | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(torch.tensor(waveform)) | |
| inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predicted_class_id = torch.argmax(logits, dim=1).item() | |
| predicted_label = model.config.id2label[predicted_class_id] | |
| return predicted_label | |
| # Gradio UI | |
| gr.Interface( | |
| fn=classify_music, | |
| inputs=gr.Audio(type="numpy", label="Upload a Music Clip (WAV/MP3)"), | |
| outputs=gr.Textbox(label="Predicted Genre"), | |
| title="🎵 Music Genre Classifier", | |
| description="Upload a short music clip and get the predicted genre using a fine-tuned GTZAN model.", | |
| live=False | |
| ).launch() |