Spaces:
Sleeping
Sleeping
| import spaces | |
| from transkun.ModelTransformer import TransKun, writeMidi | |
| import numpy as np | |
| import torch | |
| from functools import lru_cache | |
| import librosa | |
| import tempfile | |
| import os | |
| import pkg_resources | |
| import moduleconf | |
| from pyharp.core import ModelCard, build_endpoint | |
| from pyharp.media.audio import load_audio | |
| import gradio as gr | |
| model_card = ModelCard( | |
| name="Transkun Piano Transcription", | |
| description=("Transcribes solo piano performance into MIDI notation"), | |
| author="Yujia Yan, Zhiyao Duan", | |
| tags=["transcription"] | |
| ) | |
| def load_model(device: str): | |
| defaultWeight = (pkg_resources.resource_filename("transkun", "pretrained/2.0.pt")) | |
| defaultConf = (pkg_resources.resource_filename("transkun", "pretrained/2.0.conf")) | |
| confManager = moduleconf.parseFromFile(defaultConf) | |
| conf = confManager["Model"].config | |
| checkpoint = torch.load(defaultWeight, map_location=device) | |
| model = TransKun(conf=conf).to(device) | |
| # Mirrors your checkpoint loading logic :contentReference[oaicite:3]{index=3} | |
| if "best_state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["best_state_dict"], strict=False) | |
| else: | |
| model.load_state_dict(checkpoint["state_dict"], strict=False) | |
| model.eval() | |
| return model | |
| def transcribe(input_file): | |
| device = "cuda" | |
| model = load_model(device) | |
| signal = load_audio(input_file) | |
| waveform = np.asarray(signal.audio_data) | |
| sr = int(signal.sample_rate) | |
| waveform = np.squeeze(waveform) | |
| # If 2D, assume (channels, samples). Make mono. | |
| if waveform.ndim == 2: | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(axis=0) | |
| else: | |
| waveform = waveform.reshape(-1) | |
| if sr != model.fs: | |
| waveform = librosa.resample(waveform.astype(np.float32), orig_sr=sr, target_sr=model.fs) | |
| sr = model.fs | |
| x = torch.from_numpy(waveform.reshape(-1, 1)).to(device) | |
| notesEst = model.transcribe(x, discardSecondHalf=False) | |
| outputMidi = writeMidi(notesEst) | |
| out_fd, out_path = tempfile.mkstemp(suffix=".mid") | |
| os.close(out_fd) | |
| outputMidi.write(out_path) | |
| return out_path | |
| def process_fn(input_audio_path: str) -> str: | |
| midi_path = transcribe(input_audio_path) | |
| return midi_path | |
| with gr.Blocks() as demo: | |
| input_audio = gr.Audio(label="Upload Solo Piano Audio", type="filepath").harp_required(True) | |
| #output_midi = gr.File(label="Output MIDI File", file_types=[".mid"]).harp_required(True) | |
| # ensure this is serialized as midi_track | |
| output_midi = gr.File( | |
| label="Output MIDI File", | |
| file_types=[".mid", ".midi"], | |
| type="filepath" | |
| ).harp_required(True) | |
| app = build_endpoint( | |
| model_card=model_card, | |
| input_components=[input_audio], | |
| output_components=[output_midi], | |
| process_fn=process_fn | |
| ) | |
| demo.queue().launch(share=True, show_error=True) | |