File size: 3,706 Bytes
164603c
 
eb18e14
164603c
6fd45d2
 
e9bcb5a
164603c
6fd45d2
164603c
6fd45d2
164603c
 
 
 
 
6fd45d2
 
 
 
 
 
 
 
eb18e14
 
 
 
 
6fd45d2
164603c
 
6fd45d2
 
 
 
 
 
 
 
 
 
 
164603c
 
6fd45d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164603c
6fd45d2
 
164603c
6fd45d2
 
 
 
088ca61
164603c
6fd45d2
 
 
 
 
164603c
6fd45d2
 
 
 
 
 
 
 
eb18e14
6fd45d2
 
 
 
 
 
 
 
 
 
eb18e14
 
 
6fd45d2
164603c
6fd45d2
 
 
164603c
 
6fd45d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
import os

# You must use the exact same model name as your repo
MODEL_ID = "nineninesix/Kani-TTS-370m"

@spaces.GPU
def generate_speech(text: str, model_choice: str, speaker_display: str):
    if not text.strip():
        return "Please enter text for speech generation.", None
    
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        
        # --- This is the key part to load a specific model ---
        if model_choice not in MODELS:
            return f"Model '{model_choice}' not found.", None
            
        selected_model = MODELS[model_choice]
        
        # --- This part handles speakers ---
        cfg = selected_model[1]  # Model config
        speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
        if speaker_display and speaker_map:
            speaker_id = speaker_map.get(speaker_display)
        else:
            speaker_id = None
            
        print(f"Generating speech with {model_choice}...")
        
        # --- Use the specific part of the model for generation ---
        model_to_generate = selected_model[0]
        audio, _, time_report = model_to_generate.run_model(
            text=text, 
            speaker_id=speaker_id,
            temperature=0.7,
            repetition_penalty=1.2,
            max_tokens=1024
        )
        
        sample_rate = 22050
        print("Speech generation completed!")
        
        return (sample_rate, audio), time_report

def load_models():
    global MODELS
    if not MODELS:
        print("Loading models into GPU memory...")
        from transformers import AutoModel
        model_path = MODEL_ID
        
        # Load both the main model and its config
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        
        MODELS = {
            "Kani TTS 370M": (model, config)
        }
        
        print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
        return MODELS

# --- Gradio interface setup ---
MODELS = load_models()

with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
    gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
    
    model_dropdown = gr.Dropdown(
        choices=list(MODELS.keys()),
        value=list(MODELS.keys())[0],
        label="Selected Model"
    )
    
    # --- Speaker selector (populated on model load) ---
    all_speakers = list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()) if MODELS and MODELS[list(MODELS.keys())[0]][1] and MODELS[list(MODELS.keys())[0]][1].speaker_id else []
    speaker_dropdown = gr.Dropdown(
        choices=all_speakers,
        value=None,
        label="Speaker",
        visible=True,
        allow_custom_value=True
    )
    
    text_input = gr.Textbox(label="Text", lines=5)
    
    generate_btn = gr.Button("Generate Speech", variant="primary")
    
    audio_output = gr.Audio(label="Generated Audio", type="numpy")
    
    # --- Event handlers ---
    model_dropdown.change(
        fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
        inputs=[model_dropdown],
        outputs=[speaker_dropdown]
    )
    
    generate_btn.click(
        fn=generate_speech,
        inputs=[text_input, model_dropdown, speaker_dropdown],
        outputs=[audio_output]
    )
    
    # --- This is the API enabling line ---
    demo.queue().launch(show_api=True)