Cheeky Sparrow commited on
Commit
59b01a4
Β·
1 Parent(s): d2e4f4f

better app.p

Browse files
Files changed (1) hide show
  1. app.py +91 -100
app.py CHANGED
@@ -2,7 +2,7 @@ import re
2
  import tempfile
3
  from collections import Counter
4
  from pathlib import Path
5
- from typing import Literal
6
 
7
  import gradio as gr
8
  import torch
@@ -12,95 +12,92 @@ from NatureLM.models.NatureLM import NatureLM
12
  from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
13
  import spaces
14
 
15
- CONFIG: Config = None
16
- MODEL: NatureLM = None
17
- MODEL_LOADED = False
18
- MODEL_LOADING = False
19
- MODEL_LOAD_FAILED = False
20
 
21
-
22
- def check_model_availability():
23
- """Check if the model is available for download"""
24
- try:
25
- from huggingface_hub import model_info
26
- info = model_info("EarthSpeciesProject/NatureLM-audio")
27
- return True, "Model is available"
28
- except Exception as e:
29
- return False, f"Model not available: {str(e)}"
30
-
31
-
32
- def reset_model_state():
33
- """Reset the model loading state to allow retrying after a failure"""
34
- global MODEL, MODEL_LOADED, MODEL_LOADING, MODEL_LOAD_FAILED
35
- MODEL = None
36
- MODEL_LOADED = False
37
- MODEL_LOADING = False
38
- MODEL_LOAD_FAILED = False
39
- return get_model_status()
40
-
41
-
42
- def get_model_status():
43
- """Get the current model loading status"""
44
- if MODEL_LOADED:
45
- return "βœ… Model loaded and ready"
46
- elif MODEL_LOADING:
47
- return "πŸ”„ Loading model... Please wait"
48
- elif MODEL_LOAD_FAILED:
49
- return "❌ Model failed to load. Please check the configuration."
50
- else:
51
- return "⏳ Ready to load model on first use"
52
-
53
-
54
- def load_model_if_needed():
55
- """Lazy load the model when first needed"""
56
- global MODEL, MODEL_LOADED, MODEL_LOADING, MODEL_LOAD_FAILED
57
 
58
- if MODEL_LOADED:
59
- return MODEL
 
 
 
 
60
 
61
- if MODEL_LOADING:
62
- # Model is currently loading, return a message to try again
63
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- if MODEL_LOAD_FAILED:
66
- # Model has already failed to load, don't try again
67
- return None
 
 
 
 
 
 
 
68
 
69
- if MODEL is None:
 
 
 
 
 
 
 
70
  try:
71
- MODEL_LOADING = True
72
  print("Loading model...")
73
 
74
  # Check if model is available first
75
- available, message = check_model_availability()
76
  if not available:
77
  raise Exception(f"Model not available: {message}")
78
 
79
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
80
- model.to("cuda") # Use CPU for HuggingFace Spaces
81
  model.eval()
82
- MODEL = model
83
- MODEL_LOADED = True
84
- MODEL_LOADING = False
 
85
  print("Model loaded successfully!")
86
- return MODEL
 
87
  except Exception as e:
88
  print(f"Error loading model: {e}")
89
- MODEL_LOADING = False
90
- MODEL_LOAD_FAILED = True
91
  return None
92
-
93
- return MODEL
 
 
 
94
 
95
  @spaces.GPU
96
- def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
97
- # Always try to load the model if needed
98
- model = load_model_if_needed()
99
 
100
  if model is None:
101
- if MODEL_LOADING:
102
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment."
103
- elif MODEL_LOAD_FAILED:
104
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again using the retry button."
105
  else:
106
  return "Demo mode: Model not loaded. Please check the model configuration."
@@ -115,12 +112,12 @@ def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
115
  r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
116
  "",
117
  prompt_text,
118
- ) # exclude the system header from the prompt
119
- prompt_text = re.sub("\\n", r"\\n", prompt_text) # FIXME this is a hack to fix the issue #34
120
 
121
  print(f"{prompt_text=}")
122
  with torch.cuda.amp.autocast(dtype=torch.float16):
123
- llm_answer = model.generate(samples, CONFIG.generate, prompts=[prompt_text])
124
  return llm_answer[0]
125
 
126
 
@@ -159,8 +156,9 @@ def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]:
159
  files.append(path)
160
  case _:
161
  messages.append(msg)
 
 
162
  joined_messages = []
163
- # join consecutive messages from the same role
164
  for msg in messages:
165
  if joined_messages and joined_messages[-1]["role"] == msg["role"]:
166
  joined_messages[-1]["content"] += msg["content"]
@@ -175,20 +173,19 @@ def bot_response(history: list):
175
  combined_inputs = combine_model_inputs(history)
176
  response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
177
  history.append({"role": "assistant", "content": response})
178
-
179
  return history
180
 
181
 
182
  def _chat_tab(examples):
183
- # Add status indicator
184
  status_text = gr.Textbox(
185
- value=get_model_status(),
186
  label="Model Status",
187
  interactive=False,
188
  visible=True
189
  )
190
 
191
- # Add retry button that only shows when model failed to load
192
  retry_button = gr.Button(
193
  "πŸ”„ Retry Loading Model",
194
  visible=False,
@@ -201,7 +198,6 @@ def _chat_tab(examples):
201
  bubble_full_width=False,
202
  type="messages",
203
  render_markdown=False,
204
- # editable="user", # disable because of https://github.com/gradio-app/gradio/issues/10320
205
  resizeable=True,
206
  )
207
 
@@ -218,20 +214,20 @@ def _chat_tab(examples):
218
  )
219
 
220
  # Update status after bot response
221
- bot_msg.then(lambda: get_model_status(), None, [status_text])
222
  bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
223
  clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
224
 
225
  # Handle retry button
226
  retry_button.click(
227
- reset_model_state,
228
  None,
229
  [status_text]
230
  )
231
 
232
  # Show/hide retry button based on model status
233
  def update_retry_button_visibility():
234
- return gr.Button(visible=MODEL_LOAD_FAILED)
235
 
236
  # Update retry button visibility when status changes
237
  bot_msg.then(update_retry_button_visibility, None, [retry_button])
@@ -253,11 +249,11 @@ def summarize_batch_results(results):
253
 
254
 
255
  def run_batch_inference(files, task, progress=gr.Progress()) -> str:
256
- model = load_model_if_needed()
257
  if model is None:
258
- if MODEL_LOADING:
259
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment."
260
- elif MODEL_LOAD_FAILED:
261
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again."
262
  else:
263
  return "Demo mode: Model not loaded. Please check the model configuration."
@@ -310,10 +306,6 @@ def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
310
  last_label = ""
311
  row = 1
312
 
313
- # The "Selection" column is just the row number.
314
- # The "view" column will always say "Spectrogram 1".
315
- # Channel can always be "1".
316
- # For the frequency bounds we can just use 0 and 1/2 the sample rate
317
  for offset, label in sorted(outputs.items()):
318
  if label != last_label and last_label:
319
  raven_output.append(get_line(row, current_offset, offset, last_label))
@@ -332,11 +324,11 @@ def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
332
 
333
 
334
  def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()):
335
- model = load_model_if_needed()
336
  if model is None:
337
- if MODEL_LOADING:
338
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment.", None
339
- elif MODEL_LOAD_FAILED:
340
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again.", None
341
  else:
342
  return "Demo mode: Model not loaded. Please check the model configuration.", None
@@ -346,12 +338,12 @@ def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int
346
  offset = 0
347
 
348
  prompt = f"<Audio><AudioHere></Audio> {task}"
349
- prompt = CONFIG.model.prompt_template.format(prompt)
350
 
351
  for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
352
  prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
353
  with torch.cuda.amp.autocast(dtype=torch.float16):
354
- llm_answers = model.generate(batch, CONFIG.generate, prompts=prompt_strs)
355
  for answer in llm_answers:
356
  outputs[offset] = answer
357
  offset += hop_len
@@ -400,23 +392,22 @@ def _long_recording_tab():
400
  [output, download_raven],
401
  )
402
 
403
- @spaces.GPU
404
  def main(
405
  assets_dir: Path,
406
  cfg_path: str | Path,
407
  options: list[str] = [],
408
  device: str = "cuda",
409
  ):
410
- global CONFIG
411
-
412
  try:
413
  cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
414
- CONFIG = cfg
415
  print("Configuration loaded successfully")
416
  except Exception as e:
417
  print(f"Warning: Could not load config: {e}")
418
  print("Running in demo mode")
419
- CONFIG = None
420
 
421
  # Check if assets directory exists, if not create a placeholder
422
  if not assets_dir.exists():
@@ -466,15 +457,15 @@ def main(
466
  _long_recording_tab()
467
 
468
  return app
469
-
470
- # At the bottom of the file:
 
471
  app = main(
472
  assets_dir=Path("assets"),
473
  cfg_path=Path("configs/inference.yml"),
474
  options=[],
475
- device="cuda", # TODO: from config depending on zerogpu! (to change)
476
  )
477
 
478
- # Launch the app
479
  if __name__ == "__main__":
480
  app.launch()
 
2
  import tempfile
3
  from collections import Counter
4
  from pathlib import Path
5
+ from typing import Literal, Optional
6
 
7
  import gradio as gr
8
  import torch
 
12
  from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
13
  import spaces
14
 
 
 
 
 
 
15
 
16
+ class ModelManager:
17
+ """Manages model loading and state"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def __init__(self):
20
+ self.model: Optional[NatureLM] = None
21
+ self.config: Optional[Config] = None
22
+ self.is_loaded = False
23
+ self.is_loading = False
24
+ self.load_failed = False
25
 
26
+ def check_availability(self) -> tuple[bool, str]:
27
+ """Check if the model is available for download"""
28
+ try:
29
+ from huggingface_hub import model_info
30
+ info = model_info("EarthSpeciesProject/NatureLM-audio")
31
+ return True, "Model is available"
32
+ except Exception as e:
33
+ return False, f"Model not available: {str(e)}"
34
+
35
+ def reset_state(self):
36
+ """Reset the model loading state to allow retrying after a failure"""
37
+ self.model = None
38
+ self.is_loaded = False
39
+ self.is_loading = False
40
+ self.load_failed = False
41
+ return self.get_status()
42
 
43
+ def get_status(self) -> str:
44
+ """Get the current model loading status"""
45
+ if self.is_loaded:
46
+ return "βœ… Model loaded and ready"
47
+ elif self.is_loading:
48
+ return "πŸ”„ Loading model... Please wait"
49
+ elif self.load_failed:
50
+ return "❌ Model failed to load. Please check the configuration."
51
+ else:
52
+ return "⏳ Ready to load model on first use"
53
 
54
+ def load_model(self) -> Optional[NatureLM]:
55
+ """Load the model if needed"""
56
+ if self.is_loaded:
57
+ return self.model
58
+
59
+ if self.is_loading or self.load_failed:
60
+ return None
61
+
62
  try:
63
+ self.is_loading = True
64
  print("Loading model...")
65
 
66
  # Check if model is available first
67
+ available, message = self.check_availability()
68
  if not available:
69
  raise Exception(f"Model not available: {message}")
70
 
71
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
72
+ model.to("cuda")
73
  model.eval()
74
+
75
+ self.model = model
76
+ self.is_loaded = True
77
+ self.is_loading = False
78
  print("Model loaded successfully!")
79
+ return model
80
+
81
  except Exception as e:
82
  print(f"Error loading model: {e}")
83
+ self.is_loading = False
84
+ self.load_failed = True
85
  return None
86
+
87
+
88
+ # Global model manager instance
89
+ model_manager = ModelManager()
90
+
91
 
92
  @spaces.GPU
93
+ def prompt_lm(audios: list[str], messages: list[dict[str, str]]) -> str:
94
+ """Generate response using the model"""
95
+ model = model_manager.load_model()
96
 
97
  if model is None:
98
+ if model_manager.is_loading:
99
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment."
100
+ elif model_manager.load_failed:
101
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again using the retry button."
102
  else:
103
  return "Demo mode: Model not loaded. Please check the model configuration."
 
112
  r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
113
  "",
114
  prompt_text,
115
+ )
116
+ prompt_text = re.sub("\\n", r"\\n", prompt_text)
117
 
118
  print(f"{prompt_text=}")
119
  with torch.cuda.amp.autocast(dtype=torch.float16):
120
+ llm_answer = model.generate(samples, model_manager.config.generate, prompts=[prompt_text])
121
  return llm_answer[0]
122
 
123
 
 
156
  files.append(path)
157
  case _:
158
  messages.append(msg)
159
+
160
+ # Join consecutive messages from the same role
161
  joined_messages = []
 
162
  for msg in messages:
163
  if joined_messages and joined_messages[-1]["role"] == msg["role"]:
164
  joined_messages[-1]["content"] += msg["content"]
 
173
  combined_inputs = combine_model_inputs(history)
174
  response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
175
  history.append({"role": "assistant", "content": response})
 
176
  return history
177
 
178
 
179
  def _chat_tab(examples):
180
+ # Status indicator
181
  status_text = gr.Textbox(
182
+ value=model_manager.get_status(),
183
  label="Model Status",
184
  interactive=False,
185
  visible=True
186
  )
187
 
188
+ # Retry button that only shows when model failed to load
189
  retry_button = gr.Button(
190
  "πŸ”„ Retry Loading Model",
191
  visible=False,
 
198
  bubble_full_width=False,
199
  type="messages",
200
  render_markdown=False,
 
201
  resizeable=True,
202
  )
203
 
 
214
  )
215
 
216
  # Update status after bot response
217
+ bot_msg.then(lambda: model_manager.get_status(), None, [status_text])
218
  bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
219
  clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
220
 
221
  # Handle retry button
222
  retry_button.click(
223
+ model_manager.reset_state,
224
  None,
225
  [status_text]
226
  )
227
 
228
  # Show/hide retry button based on model status
229
  def update_retry_button_visibility():
230
+ return gr.Button(visible=model_manager.load_failed)
231
 
232
  # Update retry button visibility when status changes
233
  bot_msg.then(update_retry_button_visibility, None, [retry_button])
 
249
 
250
 
251
  def run_batch_inference(files, task, progress=gr.Progress()) -> str:
252
+ model = model_manager.load_model()
253
  if model is None:
254
+ if model_manager.is_loading:
255
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment."
256
+ elif model_manager.load_failed:
257
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again."
258
  else:
259
  return "Demo mode: Model not loaded. Please check the model configuration."
 
306
  last_label = ""
307
  row = 1
308
 
 
 
 
 
309
  for offset, label in sorted(outputs.items()):
310
  if label != last_label and last_label:
311
  raven_output.append(get_line(row, current_offset, offset, last_label))
 
324
 
325
 
326
  def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()):
327
+ model = model_manager.load_model()
328
  if model is None:
329
+ if model_manager.is_loading:
330
  return "πŸ”„ Loading model... This may take a few minutes on first use. Please try again in a moment.", None
331
+ elif model_manager.load_failed:
332
  return "❌ Model failed to load. This could be due to:\nβ€’ No internet connection\nβ€’ Insufficient disk space\nβ€’ Model repository access issues\n\nPlease check your connection and try again.", None
333
  else:
334
  return "Demo mode: Model not loaded. Please check the model configuration.", None
 
338
  offset = 0
339
 
340
  prompt = f"<Audio><AudioHere></Audio> {task}"
341
+ prompt = model_manager.config.model.prompt_template.format(prompt)
342
 
343
  for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
344
  prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
345
  with torch.cuda.amp.autocast(dtype=torch.float16):
346
+ llm_answers = model.generate(batch, model_manager.config.generate, prompts=prompt_strs)
347
  for answer in llm_answers:
348
  outputs[offset] = answer
349
  offset += hop_len
 
392
  [output, download_raven],
393
  )
394
 
395
+
396
  def main(
397
  assets_dir: Path,
398
  cfg_path: str | Path,
399
  options: list[str] = [],
400
  device: str = "cuda",
401
  ):
402
+ # Load configuration
 
403
  try:
404
  cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
405
+ model_manager.config = cfg
406
  print("Configuration loaded successfully")
407
  except Exception as e:
408
  print(f"Warning: Could not load config: {e}")
409
  print("Running in demo mode")
410
+ model_manager.config = None
411
 
412
  # Check if assets directory exists, if not create a placeholder
413
  if not assets_dir.exists():
 
457
  _long_recording_tab()
458
 
459
  return app
460
+
461
+
462
+ # Create and launch the app
463
  app = main(
464
  assets_dir=Path("assets"),
465
  cfg_path=Path("configs/inference.yml"),
466
  options=[],
467
+ device="cuda",
468
  )
469
 
 
470
  if __name__ == "__main__":
471
  app.launch()