File size: 2,931 Bytes
53094d2
 
 
 
 
 
 
1fae0ab
 
 
53094d2
 
 
 
 
 
 
 
 
7246904
53094d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7246904
53094d2
 
 
 
1fae0ab
 
 
53094d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import spaces
from transkun.ModelTransformer import TransKun, writeMidi
import numpy as np
import torch
from functools import lru_cache
import librosa

import tempfile
import os

import pkg_resources
import moduleconf

from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio
import gradio as gr

model_card = ModelCard(
    name="Transkun Piano Transcription",
    description=("Transcribes solo piano performance into MIDI notation"),
    author="Yujia Yan, Zhiyao Duan",
    tags=["transcription"]
)

@lru_cache(maxsize=2)
def load_model(device: str):
    defaultWeight =  (pkg_resources.resource_filename("transkun", "pretrained/2.0.pt"))
    defaultConf =  (pkg_resources.resource_filename("transkun", "pretrained/2.0.conf"))

    confManager = moduleconf.parseFromFile(defaultConf)
    conf = confManager["Model"].config

    checkpoint = torch.load(defaultWeight, map_location=device)
    model = TransKun(conf=conf).to(device)

    # Mirrors your checkpoint loading logic :contentReference[oaicite:3]{index=3}
    if "best_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["best_state_dict"], strict=False)
    else:
        model.load_state_dict(checkpoint["state_dict"], strict=False)

    model.eval()
    return model

@spaces.GPU
def transcribe(input_file):
    device = "cuda"
    model = load_model(device)

    signal = load_audio(input_file)
    waveform = np.asarray(signal.audio_data)
    sr = int(signal.sample_rate)

    waveform = np.squeeze(waveform)

    # If 2D, assume (channels, samples). Make mono.
    if waveform.ndim == 2:
        if waveform.shape[0] > 1:
            waveform = waveform.mean(axis=0)
        else:
            waveform = waveform.reshape(-1)

    if sr != model.fs:
        waveform = librosa.resample(waveform.astype(np.float32), orig_sr=sr, target_sr=model.fs)
        sr = model.fs

    x = torch.from_numpy(waveform.reshape(-1, 1)).to(device)

    notesEst = model.transcribe(x, discardSecondHalf=False)

    outputMidi = writeMidi(notesEst)

    out_fd, out_path = tempfile.mkstemp(suffix=".mid")
    os.close(out_fd)
    outputMidi.write(out_path)
    return out_path

def process_fn(input_audio_path: str) -> str:
    midi_path = transcribe(input_audio_path)
    return midi_path

with gr.Blocks() as demo:
    input_audio = gr.Audio(label="Upload Solo Piano Audio", type="filepath").harp_required(True)
    #output_midi = gr.File(label="Output MIDI File", file_types=[".mid"]).harp_required(True)

    # ensure this is serialized as midi_track
    output_midi = gr.File(
        label="Output MIDI File",
        file_types=[".mid", ".midi"],  
        type="filepath"
    ).harp_required(True)


    app = build_endpoint(
        model_card=model_card,
        input_components=[input_audio],
        output_components=[output_midi],
        process_fn=process_fn
    )

demo.queue().launch(share=True, show_error=True)