import spaces from transkun.ModelTransformer import TransKun, writeMidi import numpy as np import torch from functools import lru_cache import librosa import tempfile import os import pkg_resources import moduleconf from pyharp.core import ModelCard, build_endpoint from pyharp.media.audio import load_audio import gradio as gr model_card = ModelCard( name="Transkun Piano Transcription", description=("Transcribes solo piano performance into MIDI notation"), author="Yujia Yan, Zhiyao Duan", tags=["transcription"] ) @lru_cache(maxsize=2) def load_model(device: str): defaultWeight = (pkg_resources.resource_filename("transkun", "pretrained/2.0.pt")) defaultConf = (pkg_resources.resource_filename("transkun", "pretrained/2.0.conf")) confManager = moduleconf.parseFromFile(defaultConf) conf = confManager["Model"].config checkpoint = torch.load(defaultWeight, map_location=device) model = TransKun(conf=conf).to(device) # Mirrors your checkpoint loading logic :contentReference[oaicite:3]{index=3} if "best_state_dict" in checkpoint: model.load_state_dict(checkpoint["best_state_dict"], strict=False) else: model.load_state_dict(checkpoint["state_dict"], strict=False) model.eval() return model @spaces.GPU def transcribe(input_file): device = "cuda" model = load_model(device) signal = load_audio(input_file) waveform = np.asarray(signal.audio_data) sr = int(signal.sample_rate) waveform = np.squeeze(waveform) # If 2D, assume (channels, samples). Make mono. if waveform.ndim == 2: if waveform.shape[0] > 1: waveform = waveform.mean(axis=0) else: waveform = waveform.reshape(-1) if sr != model.fs: waveform = librosa.resample(waveform.astype(np.float32), orig_sr=sr, target_sr=model.fs) sr = model.fs x = torch.from_numpy(waveform.reshape(-1, 1)).to(device) notesEst = model.transcribe(x, discardSecondHalf=False) outputMidi = writeMidi(notesEst) out_fd, out_path = tempfile.mkstemp(suffix=".mid") os.close(out_fd) outputMidi.write(out_path) return out_path def process_fn(input_audio_path: str) -> str: midi_path = transcribe(input_audio_path) return midi_path with gr.Blocks() as demo: input_audio = gr.Audio(label="Upload Solo Piano Audio", type="filepath").harp_required(True) #output_midi = gr.File(label="Output MIDI File", file_types=[".mid"]).harp_required(True) # ensure this is serialized as midi_track output_midi = gr.File( label="Output MIDI File", file_types=[".mid", ".midi"], type="filepath" ).harp_required(True) app = build_endpoint( model_card=model_card, input_components=[input_audio], output_components=[output_midi], process_fn=process_fn ) demo.queue().launch(share=True, show_error=True)