File size: 3,180 Bytes
b538a96
 
 
 
 
 
40a0ff4
 
b538a96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a0ff4
b538a96
 
 
 
 
40a0ff4
 
b538a96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import os
import shutil
import mimetypes

import gradio as gr

from model_helper import load_model_checkpoint, transcribe
from prepare_media import prepare_media

MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
PROJECT = '2024'

MODELS = {
    "YMT3+": {
        "checkpoint": "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt",
        "args": ["notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt", '-p', PROJECT, '-pr', PRECISION]
    },
    "YPTF+Single (noPS)": {
        "checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt",
        "args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec',
            '-hop', '300', '-atc', '1', '-pr', PRECISION]
    },
    "YPTF+Multi (PS)": {
        "checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt",
        "args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256',
                '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT]
    },
    "YPTF.MoE+Multi (noPS)": {
        "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",
        "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",  '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
    },
    "YPTF.MoE+Multi (PS)": {
        "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt",
        "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
    }
}

model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
#model.to("cuda")


def handle_audio(file_path):
    # Guess extension from MIME
    mime_type, _ = mimetypes.guess_type(file_path)
    ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"

    output_path = f"received_audio{ext}"
    shutil.copy(file_path, output_path)
    return output_path

demo = gr.Interface(
    fn=handle_audio,
    inputs=gr.Audio(type="filepath"),
    outputs=gr.File()
)

if __name__ == "__main__":
    demo.launch()