jblast94 commited on
Commit
24c936f
·
verified ·
1 Parent(s): 6fd45d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -26
app.py CHANGED
@@ -5,8 +5,40 @@ import os
5
  # You must use the exact same model name as your repo
6
  MODEL_ID = "nineninesix/Kani-TTS-370m"
7
 
 
 
 
8
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def generate_speech(text: str, model_choice: str, speaker_display: str):
 
10
  if not text.strip():
11
  return "Please enter text for speech generation.", None
12
 
@@ -14,13 +46,18 @@ def generate_speech(text: str, model_choice: str, speaker_display: str):
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"Using device: {device}")
16
 
17
- # --- This is the key part to load a specific model ---
 
 
 
 
18
  if model_choice not in MODELS:
19
  return f"Model '{model_choice}' not found.", None
20
 
21
  selected_model = MODELS[model_choice]
22
 
23
- # --- This part handles speakers ---
 
24
  cfg = selected_model[1] # Model config
25
  speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
26
  if speaker_display and speaker_map:
@@ -31,7 +68,6 @@ def generate_speech(text: str, model_choice: str, speaker_display: str):
31
  print(f"Generating speech with {model_choice}...")
32
 
33
  # --- Use the specific part of the model for generation ---
34
- model_to_generate = selected_model[0]
35
  audio, _, time_report = model_to_generate.run_model(
36
  text=text,
37
  speaker_id=speaker_id,
@@ -45,25 +81,7 @@ def generate_speech(text: str, model_choice: str, speaker_display: str):
45
 
46
  return (sample_rate, audio), time_report
47
 
48
- def load_models():
49
- global MODELS
50
- if not MODELS:
51
- print("Loading models into GPU memory...")
52
- from transformers import AutoModel
53
- model_path = MODEL_ID
54
-
55
- # Load both the main model and its config
56
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
57
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
58
-
59
- MODELS = {
60
- "Kani TTS 370M": (model, config)
61
- }
62
-
63
- print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
64
- return MODELS
65
-
66
- # --- Gradio interface setup ---
67
  MODELS = load_models()
68
 
69
  with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
@@ -76,7 +94,10 @@ with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
76
  )
77
 
78
  # --- Speaker selector (populated on model load) ---
79
- all_speakers = list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()) if MODELS and MODELS[list(MODELS.keys())[0]][1] and MODELS[list(MODELS.keys())[0]][1].speaker_id else []
 
 
 
80
  speaker_dropdown = gr.Dropdown(
81
  choices=all_speakers,
82
  value=None,
@@ -91,18 +112,19 @@ with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
91
 
92
  audio_output = gr.Audio(label="Generated Audio", type="numpy")
93
 
94
- # --- Event handlers ---
95
  model_dropdown.change(
96
  fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
97
  inputs=[model_dropdown],
98
  outputs=[speaker_dropdown]
99
  )
100
 
 
101
  generate_btn.click(
102
  fn=generate_speech,
103
  inputs=[text_input, model_dropdown, speaker_dropdown],
104
  outputs=[audio_output]
105
  )
106
 
107
- # --- This is the API enabling line ---
108
- demo.queue().launch(show_api=True)
 
5
  # You must use the exact same model name as your repo
6
  MODEL_ID = "nineninesix/Kani-TTS-370m"
7
 
8
+ # --- Global variable to store loaded models ---
9
+ MODELS = {}
10
+
11
  @spaces.GPU
12
+ def load_models():
13
+ """Load models into GPU memory and store in a global variable."""
14
+ global MODELS
15
+ if not MODELS:
16
+ print("Loading models into GPU memory...")
17
+ from transformers import AutoModel, AutoConfig
18
+
19
+ model_path = MODEL_ID
20
+
21
+ # Load both the main model and its configuration
22
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
23
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
24
+
25
+ # Store the loaded model and its configuration in the global variable
26
+ MODELS = {
27
+ "Kani TTS 370M": (model, config)
28
+ }
29
+
30
+ print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
31
+ return MODELS
32
+
33
+ # --- Define a separate function for updating the stats display ---
34
+ def update_stats_display():
35
+ """This function gets the agent's stats and returns a formatted string for Gradio."""
36
+ # This assumes 'agent' is a global instance of your ConversationalAgent class
37
+ stats_text = agent.get_memory_stats()
38
+ return gr.Markdown(f"### 📊 Memory Stats\n{stats_text}")
39
+
40
  def generate_speech(text: str, model_choice: str, speaker_display: str):
41
+ """Generate speech using the selected model."""
42
  if not text.strip():
43
  return "Please enter text for speech generation.", None
44
 
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  print(f"Using device: {device}")
48
 
49
+ # Ensure models are loaded
50
+ if not MODELS:
51
+ load_models()
52
+
53
+ # Get the selected model from the global variable
54
  if model_choice not in MODELS:
55
  return f"Model '{model_choice}' not found.", None
56
 
57
  selected_model = MODELS[model_choice]
58
 
59
+ # --- This is the key part to load a specific model ---
60
+ model_to_generate = selected_model[0]
61
  cfg = selected_model[1] # Model config
62
  speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
63
  if speaker_display and speaker_map:
 
68
  print(f"Generating speech with {model_choice}...")
69
 
70
  # --- Use the specific part of the model for generation ---
 
71
  audio, _, time_report = model_to_generate.run_model(
72
  text=text,
73
  speaker_id=speaker_id,
 
81
 
82
  return (sample_rate, audio), time_report
83
 
84
+ # --- Create and configure the Gradio interface ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  MODELS = load_models()
86
 
87
  with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
 
94
  )
95
 
96
  # --- Speaker selector (populated on model load) ---
97
+ all_speakers = []
98
+ if MODELS and list(MODELS.keys())[0] and MODELS[list(MODELS.keys())[0]][1]:
99
+ all_speakers.extend(list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()))
100
+ all_speakers = sorted(list(set(all_speakers)))
101
  speaker_dropdown = gr.Dropdown(
102
  choices=all_speakers,
103
  value=None,
 
112
 
113
  audio_output = gr.Audio(label="Generated Audio", type="numpy")
114
 
115
+ # --- Define the event to update the speakers when the model changes ---
116
  model_dropdown.change(
117
  fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
118
  inputs=[model_dropdown],
119
  outputs=[speaker_dropdown]
120
  )
121
 
122
+ # --- Wire up the main generation button ---
123
  generate_btn.click(
124
  fn=generate_speech,
125
  inputs=[text_input, model_dropdown, speaker_dropdown],
126
  outputs=[audio_output]
127
  )
128
 
129
+ # --- This is the API-enabling line ---
130
+ demo.queue().launch(show_api=True)