File size: 4,649 Bytes
164603c
 
eb18e14
164603c
6fd45d2
 
e9bcb5a
24c936f
 
 
164603c
24c936f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fd45d2
24c936f
164603c
6fd45d2
164603c
 
 
 
 
24c936f
 
 
 
 
6fd45d2
 
 
 
 
24c936f
 
6fd45d2
eb18e14
 
 
 
 
6fd45d2
164603c
 
6fd45d2
 
 
 
 
 
 
 
 
 
164603c
 
6fd45d2
 
24c936f
6fd45d2
 
 
088ca61
164603c
6fd45d2
 
 
 
 
164603c
6fd45d2
24c936f
 
 
 
6fd45d2
 
 
 
 
 
eb18e14
6fd45d2
 
 
 
 
 
 
24c936f
6fd45d2
 
eb18e14
 
 
6fd45d2
24c936f
164603c
6fd45d2
 
 
164603c
 
24c936f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
import os

# You must use the exact same model name as your repo
MODEL_ID = "nineninesix/Kani-TTS-370m"

# --- Global variable to store loaded models ---
MODELS = {}

@spaces.GPU
def load_models():
    """Load models into GPU memory and store in a global variable."""
    global MODELS
    if not MODELS:
        print("Loading models into GPU memory...")
        from transformers import AutoModel, AutoConfig
        
        model_path = MODEL_ID
        
        # Load both the main model and its configuration
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        
        # Store the loaded model and its configuration in the global variable
        MODELS = {
            "Kani TTS 370M": (model, config)
        }
        
        print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
        return MODELS

# --- Define a separate function for updating the stats display ---
def update_stats_display():
    """This function gets the agent's stats and returns a formatted string for Gradio."""
    # This assumes 'agent' is a global instance of your ConversationalAgent class
    stats_text = agent.get_memory_stats()
    return gr.Markdown(f"### 📊 Memory Stats\n{stats_text}")

def generate_speech(text: str, model_choice: str, speaker_display: str):
    """Generate speech using the selected model."""
    if not text.strip():
        return "Please enter text for speech generation.", None
    
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        
        # Ensure models are loaded
        if not MODELS:
            load_models()
        
        # Get the selected model from the global variable
        if model_choice not in MODELS:
            return f"Model '{model_choice}' not found.", None
            
        selected_model = MODELS[model_choice]
        
        # --- This is the key part to load a specific model ---
        model_to_generate = selected_model[0]
        cfg = selected_model[1]  # Model config
        speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
        if speaker_display and speaker_map:
            speaker_id = speaker_map.get(speaker_display)
        else:
            speaker_id = None
            
        print(f"Generating speech with {model_choice}...")
        
        # --- Use the specific part of the model for generation ---
        audio, _, time_report = model_to_generate.run_model(
            text=text, 
            speaker_id=speaker_id,
            temperature=0.7,
            repetition_penalty=1.2,
            max_tokens=1024
        )
        
        sample_rate = 22050
        print("Speech generation completed!")
        
        return (sample_rate, audio), time_report

# --- Create and configure the Gradio interface ---
MODELS = load_models()

with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
    gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
    
    model_dropdown = gr.Dropdown(
        choices=list(MODELS.keys()),
        value=list(MODELS.keys())[0],
        label="Selected Model"
    )
    
    # --- Speaker selector (populated on model load) ---
    all_speakers = []
    if MODELS and list(MODELS.keys())[0] and MODELS[list(MODELS.keys())[0]][1]:
        all_speakers.extend(list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()))
    all_speakers = sorted(list(set(all_speakers)))
    speaker_dropdown = gr.Dropdown(
        choices=all_speakers,
        value=None,
        label="Speaker",
        visible=True,
        allow_custom_value=True
    )
    
    text_input = gr.Textbox(label="Text", lines=5)
    
    generate_btn = gr.Button("Generate Speech", variant="primary")
    
    audio_output = gr.Audio(label="Generated Audio", type="numpy")
    
    # --- Define the event to update the speakers when the model changes ---
    model_dropdown.change(
        fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
        inputs=[model_dropdown],
        outputs=[speaker_dropdown]
    )
    
    # --- Wire up the main generation button ---
    generate_btn.click(
        fn=generate_speech,
        inputs=[text_input, model_dropdown, speaker_dropdown],
        outputs=[audio_output]
    )
    
    # --- This is the API-enabling line ---
    demo.queue().launch(show_api=True)