transkun / app.py
ellagranger's picture
Standard temp file
1fae0ab
import spaces
from transkun.ModelTransformer import TransKun, writeMidi
import numpy as np
import torch
from functools import lru_cache
import librosa
import tempfile
import os
import pkg_resources
import moduleconf
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio
import gradio as gr
model_card = ModelCard(
name="Transkun Piano Transcription",
description=("Transcribes solo piano performance into MIDI notation"),
author="Yujia Yan, Zhiyao Duan",
tags=["transcription"]
)
@lru_cache(maxsize=2)
def load_model(device: str):
defaultWeight = (pkg_resources.resource_filename("transkun", "pretrained/2.0.pt"))
defaultConf = (pkg_resources.resource_filename("transkun", "pretrained/2.0.conf"))
confManager = moduleconf.parseFromFile(defaultConf)
conf = confManager["Model"].config
checkpoint = torch.load(defaultWeight, map_location=device)
model = TransKun(conf=conf).to(device)
# Mirrors your checkpoint loading logic :contentReference[oaicite:3]{index=3}
if "best_state_dict" in checkpoint:
model.load_state_dict(checkpoint["best_state_dict"], strict=False)
else:
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.eval()
return model
@spaces.GPU
def transcribe(input_file):
device = "cuda"
model = load_model(device)
signal = load_audio(input_file)
waveform = np.asarray(signal.audio_data)
sr = int(signal.sample_rate)
waveform = np.squeeze(waveform)
# If 2D, assume (channels, samples). Make mono.
if waveform.ndim == 2:
if waveform.shape[0] > 1:
waveform = waveform.mean(axis=0)
else:
waveform = waveform.reshape(-1)
if sr != model.fs:
waveform = librosa.resample(waveform.astype(np.float32), orig_sr=sr, target_sr=model.fs)
sr = model.fs
x = torch.from_numpy(waveform.reshape(-1, 1)).to(device)
notesEst = model.transcribe(x, discardSecondHalf=False)
outputMidi = writeMidi(notesEst)
out_fd, out_path = tempfile.mkstemp(suffix=".mid")
os.close(out_fd)
outputMidi.write(out_path)
return out_path
def process_fn(input_audio_path: str) -> str:
midi_path = transcribe(input_audio_path)
return midi_path
with gr.Blocks() as demo:
input_audio = gr.Audio(label="Upload Solo Piano Audio", type="filepath").harp_required(True)
#output_midi = gr.File(label="Output MIDI File", file_types=[".mid"]).harp_required(True)
# ensure this is serialized as midi_track
output_midi = gr.File(
label="Output MIDI File",
file_types=[".mid", ".midi"],
type="filepath"
).harp_required(True)
app = build_endpoint(
model_card=model_card,
input_components=[input_audio],
output_components=[output_midi],
process_fn=process_fn
)
demo.queue().launch(share=True, show_error=True)