.gitattributes CHANGED
@@ -50,4 +50,3 @@ assets/Humpback[[:space:]]Whale[[:space:]]-[[:space:]]Megaptera[[:space:]]novaea
50
  assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.m4a filter=lfs diff=lfs merge=lfs -text
51
  assets/Walrus[[:space:]]-[[:space:]]Odobenus[[:space:]]rosmarus.wav filter=lfs diff=lfs merge=lfs -text
52
  assets/ESP_logo_white.png filter=lfs diff=lfs merge=lfs -text
53
- assets/American[[:space:]]Crow[[:space:]]-[[:space:]]Corvus[[:space:]]brachyrhynchos.mp3 filter=lfs diff=lfs merge=lfs -text
 
50
  assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.m4a filter=lfs diff=lfs merge=lfs -text
51
  assets/Walrus[[:space:]]-[[:space:]]Odobenus[[:space:]]rosmarus.wav filter=lfs diff=lfs merge=lfs -text
52
  assets/ESP_logo_white.png filter=lfs diff=lfs merge=lfs -text
 
NatureLM/config.py CHANGED
@@ -136,7 +136,6 @@ class GenerateConfig(BaseModel, extra="forbid", validate_assignment=True):
136
  temperature: float
137
  repetition_penalty: float
138
  length_penalty: float
139
- merging_alpha: float = 1.0
140
 
141
 
142
  class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
 
136
  temperature: float
137
  repetition_penalty: float
138
  length_penalty: float
 
139
 
140
 
141
  class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
NatureLM/models/NatureLM.py CHANGED
@@ -12,10 +12,8 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import hashlib
16
  import logging
17
  import os
18
- from collections import OrderedDict
19
  from pathlib import Path
20
  from typing import Literal, Union
21
 
@@ -37,98 +35,8 @@ from .Qformer import BertConfig, BertLMHeadModel
37
  from .utils import StoppingCriteriaSub
38
 
39
  torch.backends.cuda.matmul.allow_tf32 = True
40
- auth_token = os.getenv("llama", None)
41
-
42
-
43
- class AudioEncodingCache:
44
- """LRU cache for audio encoding with content-based hashing."""
45
-
46
- def __init__(self, capacity: int = 100):
47
- self.capacity = capacity
48
- self.cache = OrderedDict()
49
- self.hits = 0
50
- self.misses = 0
51
-
52
- def _compute_hash(
53
- self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor | None = None
54
- ) -> str:
55
- """Compute a hash key from the audio tensor and padding mask."""
56
- # Use a sample of the tensor for efficiency (first, middle, last portions)
57
- B, L = raw_wav.shape
58
- sample_size = min(1000, L) # Sample 1000 points or entire length if smaller
59
-
60
- # Sample from beginning, middle, and end
61
- indices = torch.cat(
62
- [
63
- torch.arange(min(sample_size // 3, L)),
64
- torch.arange(L // 2, min(L // 2 + sample_size // 3, L)),
65
- torch.arange(max(0, L - sample_size // 3), L),
66
- ]
67
- )
68
-
69
- sampled_wav = raw_wav[:, indices].cpu().numpy().tobytes()
70
-
71
- # Create hash from audio data, shape, and padding mask presence
72
- hash_obj = hashlib.sha256(sampled_wav)
73
- hash_obj.update(str(raw_wav.shape).encode())
74
- hash_obj.update(str(raw_wav.dtype).encode())
75
-
76
- if audio_padding_mask is not None:
77
- mask_sample = audio_padding_mask[:, indices].cpu().numpy().tobytes()
78
- hash_obj.update(mask_sample)
79
- hash_obj.update(str(audio_padding_mask.shape).encode())
80
- else:
81
- hash_obj.update(b"no_mask")
82
-
83
- return hash_obj.hexdigest()
84
-
85
- def get(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor = None):
86
- """Retrieve cached encoding if available."""
87
- key = self._compute_hash(raw_wav, audio_padding_mask)
88
-
89
- if key in self.cache:
90
- self.hits += 1
91
- # Move to end (most recently used)
92
- self.cache.move_to_end(key)
93
- return self.cache[key]
94
-
95
- self.misses += 1
96
- return None
97
-
98
- def put(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor, value: tuple):
99
- """Store encoding in cache (on CPU to save GPU memory)."""
100
- key = self._compute_hash(raw_wav, audio_padding_mask)
101
-
102
- # Move tensors to CPU for storage
103
- audio_embeds, audio_atts = value
104
- cached_value = (audio_embeds.cpu(), audio_atts.cpu())
105
-
106
- # Add to cache
107
- self.cache[key] = cached_value
108
- self.cache.move_to_end(key)
109
-
110
- # Evict oldest if over capacity
111
- if len(self.cache) > self.capacity:
112
- self.cache.popitem(last=False)
113
-
114
- def clear(self):
115
- """Clear the cache."""
116
- self.cache.clear()
117
- self.hits = 0
118
- self.misses = 0
119
-
120
- def get_stats(self):
121
- """Get cache statistics."""
122
- total = self.hits + self.misses
123
- hit_rate = self.hits / total if total > 0 else 0
124
- return {
125
- "hits": self.hits,
126
- "misses": self.misses,
127
- "hit_rate": hit_rate,
128
- "size": len(self.cache),
129
- "capacity": self.capacity,
130
- }
131
 
 
132
 
133
  class NatureLM(nn.Module, PyTorchModelHubMixin):
134
  def __init__(
@@ -157,16 +65,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
157
  max_txt_len: int = 128,
158
  end_sym: str = "</s>",
159
  device: str = "cuda",
160
- audio_encoding_cache_size: int = 100,
161
  ):
162
  super().__init__()
163
 
164
- self.audio_encoding_cache = (
165
- AudioEncodingCache(capacity=audio_encoding_cache_size)
166
- if audio_encoding_cache_size > 0
167
- else None
168
- )
169
-
170
  self.beats_path = beats_path
171
  self.beats_cfg = beats_cfg
172
  self.use_audio_Qformer = use_audio_Qformer
@@ -183,9 +84,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
183
 
184
  logging.info(f"Llama path: {llama_path}")
185
  logging.info("Loading Llama Tokenizer")
186
- self.llama_tokenizer = AutoTokenizer.from_pretrained(
187
- llama_path, use_fast=False, use_auth_token=auth_token
188
- )
189
  self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
190
  self.llama_tokenizer.padding_side = "right"
191
 
@@ -196,6 +95,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
196
  torch_dtype=torch.float32,
197
  attn_implementation="eager",
198
  device_map="cpu",
 
199
  )
200
  # An issue with tiny-llama is that pad_token_id was set to -1, but
201
  # model.save_pretrained checks generation configs and does not allow -1 as
@@ -206,6 +106,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
206
  llama_path,
207
  torch_dtype=torch.bfloat16,
208
  attn_implementation=flash_attn,
 
209
  )
210
 
211
  self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
@@ -234,9 +135,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
234
  self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
235
 
236
  if self.beats_path:
237
- beats_ckpt = universal_torch_load(
238
- self.beats_path, cache_mode="none", map_location="cpu"
239
- )
240
  self.beats.load_state_dict(beats_ckpt["model"])
241
 
242
  self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
@@ -437,15 +336,11 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
437
  audio_embeds = self.ln_audio(audio_embeds)
438
 
439
  # Generate attention mask
440
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
441
- audio_embeds.device
442
- )
443
 
444
  if self.window_level_Qformer:
445
  B, T, C = audio_embeds.shape # batch, T, Channels
446
- kernel = round(
447
- 1500 * self.second_per_window / 30.0
448
- ) # 160 ms patches; calculate kernel size
449
  stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
450
  kernel = (1, kernel)
451
  stride = (1, stride)
@@ -465,9 +360,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
465
  audio_embeds_overlap, [0, 3, 2, 1]
466
  ) # (B, num_windows, kernel_size, C)
467
  audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
468
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
469
- audio_embeds.device
470
- )
471
 
472
  # Q-Former mechanism
473
  query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
@@ -483,19 +376,13 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
483
  if self.window_level_Qformer:
484
  audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
485
 
486
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
487
- audio_embeds.device
488
- )
489
 
490
  elif self.htsat:
491
  # HTSAT processing
492
  audio_embeds = self.ln_audio(audio_embeds)
493
- audio_embeds = self.audio_llama_proj(audio_embeds).reshape(
494
- -1, 30, self.llama_model.config.hidden_size
495
- )
496
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
497
- audio_embeds.device
498
- )
499
 
500
  else:
501
  raise NotImplementedError("no audio qformer or max pooling")
@@ -503,32 +390,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
503
  return audio_embeds, audio_atts
504
 
505
  def encode_audio(self, raw_wav, audio_padding_mask=None):
506
- # Only use cache during inference (not training)
507
- if self.audio_encoding_cache is not None and not self.training:
508
- cached_result = self.audio_encoding_cache.get(raw_wav, audio_padding_mask)
509
- if cached_result is not None:
510
- print("#### Audio encoding cache hit ####")
511
- # Move cached tensors back to the model's device
512
- audio_embeds, audio_atts = cached_result
513
- return audio_embeds.to(self.device), audio_atts.to(self.device)
514
-
515
- # Compute encoding if not cached
516
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
517
  audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
518
- result = self._encode_auditory_feature(
519
- audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask
520
- )
521
-
522
- # Store in cache if enabled and in inference mode
523
- if self.audio_encoding_cache is not None and not self.training:
524
- self.audio_encoding_cache.put(raw_wav, audio_padding_mask, result)
525
-
526
- return result
527
-
528
- def clear_audio_embed_cache(self):
529
- """Clear the audio encoding cache."""
530
- if self.audio_encoding_cache is not None:
531
- self.audio_encoding_cache.clear()
532
 
533
  def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
534
  """Merge audio embeddings with embeddings of the tokens in the prompt.
@@ -576,9 +440,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
576
  wrapped_atts = []
577
 
578
  for part in prompt_parts:
579
- tokens = self.llama_tokenizer(
580
- part, return_tensors="pt", add_special_tokens=False
581
- ).to(device)
582
  part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
583
  part_atts = tokens.attention_mask.squeeze(0)
584
  wrapped_embeds.append(part_embeds)
@@ -644,9 +506,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
644
 
645
  # BOS token embeddings
646
  bos_token_id = self.llama_tokenizer.bos_token_id
647
- bos = torch.full(
648
- (batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device
649
- )
650
  bos_embeds = self.llama_embed_tokens(bos)
651
 
652
  # Prepare lists to collect per-sample embeddings, attention masks, and targets
@@ -661,9 +521,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
661
 
662
  # Extract non-padded text embeddings and attention mask
663
  text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
664
- text_att = to_regress_tokens.attention_mask[i][
665
- to_regress_tokens.attention_mask[i].bool()
666
- ]
667
 
668
  # Extract corresponding targets for the text tokens
669
  target = targets[i][to_regress_tokens.attention_mask[i].bool()]
@@ -723,9 +581,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
723
  shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
724
  shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
725
  )
726
- loss_per_token = loss_per_token.view(
727
- shift_labels.size()
728
- ) # Reshape back to [batch_size, seq_len-1]
729
 
730
  # Create mask
731
  mask = shift_labels != -100 # [batch_size, seq_len-1]
@@ -741,9 +597,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
741
  predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
742
 
743
  # Compute per-example correct counts
744
- correct_per_sample = (
745
- ((predicted_tokens == shift_labels) & mask).sum(dim=1).float()
746
- ) # [batch_size]
747
  total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
748
 
749
  # Total correct and total tokens across the batch
@@ -761,37 +615,8 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
761
 
762
  return {"loss": loss, "per_example_loss": loss_per_example}
763
 
764
- def model_merging_scaling(self, merging_alpha, adapter_name="default"):
765
- """
766
- Performs model merging with the base model by adjusting the scaling of the LoRA adapters as described in
767
- "Model Merging Improves Zero-Shot Generalization in Bioacoustic Foundation Models"
768
- (https://arxiv.org/abs/2511.05171).
769
-
770
- The best value for alpha is task- and dataset-specific, but the paper found alpha values between
771
- 0.4 and 0.6 to perform generally well.
772
-
773
- Args:
774
- merging_alpha: The merging_alpha used for interpolation.
775
- adapter_name (str): The name of the adapter to rescale when merging.
776
- """
777
-
778
- # Store original scaling on first call, then always scale relative to original
779
- if not hasattr(self, "_original_lora_scaling"):
780
- self._original_lora_scaling = {}
781
- for name, module in self.llama_model.named_modules():
782
- if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r:
783
- self._original_lora_scaling[name] = module.scaling[adapter_name]
784
-
785
- for name, module in self.llama_model.named_modules():
786
- if name in self._original_lora_scaling:
787
- module.scaling[adapter_name] = merging_alpha * self._original_lora_scaling[name]
788
-
789
  @torch.inference_mode()
790
- def generate(self, samples, generate_cfg, prompts) -> list[str]:
791
- merging_alpha = getattr(generate_cfg, "merging_alpha", 1.0)
792
- if merging_alpha != 1.0:
793
- self.model_merging_scaling(merging_alpha)
794
-
795
  batch_size = len(prompts)
796
 
797
  raw_wav = samples["raw_wav"]
@@ -820,7 +645,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
820
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
821
 
822
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
823
- outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/
824
  inputs_embeds=embeds.bfloat16(),
825
  max_new_tokens=generate_cfg.max_new_tokens,
826
  stopping_criteria=stopping_criteria,
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import logging
16
  import os
 
17
  from pathlib import Path
18
  from typing import Literal, Union
19
 
 
35
  from .utils import StoppingCriteriaSub
36
 
37
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ auth_token = os.getenv('llama')
40
 
41
  class NatureLM(nn.Module, PyTorchModelHubMixin):
42
  def __init__(
 
65
  max_txt_len: int = 128,
66
  end_sym: str = "</s>",
67
  device: str = "cuda",
 
68
  ):
69
  super().__init__()
70
 
 
 
 
 
 
 
71
  self.beats_path = beats_path
72
  self.beats_cfg = beats_cfg
73
  self.use_audio_Qformer = use_audio_Qformer
 
84
 
85
  logging.info(f"Llama path: {llama_path}")
86
  logging.info("Loading Llama Tokenizer")
87
+ self.llama_tokenizer = AutoTokenizer.from_pretrained(llama_path, use_fast=False, use_auth_token=auth_token)
 
 
88
  self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
89
  self.llama_tokenizer.padding_side = "right"
90
 
 
95
  torch_dtype=torch.float32,
96
  attn_implementation="eager",
97
  device_map="cpu",
98
+ use_auth_token=auth_token
99
  )
100
  # An issue with tiny-llama is that pad_token_id was set to -1, but
101
  # model.save_pretrained checks generation configs and does not allow -1 as
 
106
  llama_path,
107
  torch_dtype=torch.bfloat16,
108
  attn_implementation=flash_attn,
109
+ use_auth_token=auth_token
110
  )
111
 
112
  self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
 
135
  self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
136
 
137
  if self.beats_path:
138
+ beats_ckpt = universal_torch_load(self.beats_path, cache_mode="none", map_location="cpu")
 
 
139
  self.beats.load_state_dict(beats_ckpt["model"])
140
 
141
  self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
 
336
  audio_embeds = self.ln_audio(audio_embeds)
337
 
338
  # Generate attention mask
339
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
340
 
341
  if self.window_level_Qformer:
342
  B, T, C = audio_embeds.shape # batch, T, Channels
343
+ kernel = round(1500 * self.second_per_window / 30.0) # 160 ms patches; calculate kernel size
 
 
344
  stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
345
  kernel = (1, kernel)
346
  stride = (1, stride)
 
360
  audio_embeds_overlap, [0, 3, 2, 1]
361
  ) # (B, num_windows, kernel_size, C)
362
  audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
363
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
364
 
365
  # Q-Former mechanism
366
  query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
 
376
  if self.window_level_Qformer:
377
  audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
378
 
379
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
380
 
381
  elif self.htsat:
382
  # HTSAT processing
383
  audio_embeds = self.ln_audio(audio_embeds)
384
+ audio_embeds = self.audio_llama_proj(audio_embeds).reshape(-1, 30, self.llama_model.config.hidden_size)
385
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
 
 
386
 
387
  else:
388
  raise NotImplementedError("no audio qformer or max pooling")
 
390
  return audio_embeds, audio_atts
391
 
392
  def encode_audio(self, raw_wav, audio_padding_mask=None):
 
 
 
 
 
 
 
 
 
 
393
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
394
  audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
395
+ return self._encode_auditory_feature(audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
398
  """Merge audio embeddings with embeddings of the tokens in the prompt.
 
440
  wrapped_atts = []
441
 
442
  for part in prompt_parts:
443
+ tokens = self.llama_tokenizer(part, return_tensors="pt", add_special_tokens=False).to(device)
 
 
444
  part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
445
  part_atts = tokens.attention_mask.squeeze(0)
446
  wrapped_embeds.append(part_embeds)
 
506
 
507
  # BOS token embeddings
508
  bos_token_id = self.llama_tokenizer.bos_token_id
509
+ bos = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device)
 
 
510
  bos_embeds = self.llama_embed_tokens(bos)
511
 
512
  # Prepare lists to collect per-sample embeddings, attention masks, and targets
 
521
 
522
  # Extract non-padded text embeddings and attention mask
523
  text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
524
+ text_att = to_regress_tokens.attention_mask[i][to_regress_tokens.attention_mask[i].bool()]
 
 
525
 
526
  # Extract corresponding targets for the text tokens
527
  target = targets[i][to_regress_tokens.attention_mask[i].bool()]
 
581
  shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
582
  shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
583
  )
584
+ loss_per_token = loss_per_token.view(shift_labels.size()) # Reshape back to [batch_size, seq_len-1]
 
 
585
 
586
  # Create mask
587
  mask = shift_labels != -100 # [batch_size, seq_len-1]
 
597
  predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
598
 
599
  # Compute per-example correct counts
600
+ correct_per_sample = ((predicted_tokens == shift_labels) & mask).sum(dim=1).float() # [batch_size]
 
 
601
  total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
602
 
603
  # Total correct and total tokens across the batch
 
615
 
616
  return {"loss": loss, "per_example_loss": loss_per_example}
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  @torch.inference_mode()
619
+ def generate(self, samples, generate_cfg, prompts):
 
 
 
 
620
  batch_size = len(prompts)
621
 
622
  raw_wav = samples["raw_wav"]
 
645
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
646
 
647
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
648
+ outputs = self.llama_model.generate(
649
  inputs_embeds=embeds.bfloat16(),
650
  max_new_tokens=generate_cfg.max_new_tokens,
651
  stopping_criteria=stopping_criteria,
README.md CHANGED
@@ -1,21 +1,25 @@
1
  ---
2
- title: NatureLM-audio Demo
3
- emoji: 🔈
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.38.2
8
  app_file: app.py
9
- pinned: true
10
  license: apache-2.0
11
- short_description: Analyze your bioacoustic data with NatureLM-audio
12
- thumbnail: >-
13
- https://cdn-uploads.huggingface.co/production/uploads/67e0630403121d657d96b0a4/VwZf6xhy8xz-AIr8rykvB.png
14
  ---
15
 
16
- # NatureLM-audio Demo
17
 
18
- This is a demo of the NatureLM-audio model. Users can upload an audio file containing animal vocalizations and ask questions about them in a chat interface.
 
 
 
 
 
 
19
 
20
  ## Usage
21
 
@@ -27,4 +31,4 @@ This is a demo of the NatureLM-audio model. Users can upload an audio file conta
27
 
28
  The app uses lazy loading to start quickly. The model is only loaded when you first interact with it, not during app initialization. This prevents timeout issues on HuggingFace Spaces.
29
 
30
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NatureLM Audio Demo
3
+ emoji: 🎵
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.38.2
8
  app_file: app.py
9
+ pinned: false
10
  license: apache-2.0
11
+ short_description: Audio analysis with NatureLM model
 
 
12
  ---
13
 
14
+ # NatureLM Audio Demo
15
 
16
+ This is a demo of the NatureLM audio analysis model. The app provides three main features:
17
+
18
+ ## Features
19
+
20
+ 1. **Chat Interface**: Upload audio files and ask questions about them
21
+ 2. **Batch Processing**: Process multiple audio files with the same task
22
+ 3. **Long Recording Analysis**: Analyze long audio recordings by chunking them
23
 
24
  ## Usage
25
 
 
31
 
32
  The app uses lazy loading to start quickly. The model is only loaded when you first interact with it, not during app initialization. This prevents timeout issues on HuggingFace Spaces.
33
 
34
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import spaces
2
- import uuid
3
  import warnings
4
  import traceback
5
  import numpy as np
@@ -10,37 +9,16 @@ from collections import Counter
10
  import gradio as gr
11
  import torch
12
  import torchaudio
13
- import soundfile as sf
14
  import matplotlib.pyplot as plt
15
 
16
  from NatureLM.config import Config
17
  from NatureLM.models.NatureLM import NatureLM
18
  from NatureLM.infer import Pipeline
19
 
20
- from data_store import upload_data
21
-
22
-
23
  warnings.filterwarnings("ignore")
24
  SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
25
- DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
26
- MIN_AUDIO_DURATION: float = 0.5 # seconds
27
- MAX_HISTORY_TURNS = (
28
- 3 # Maximum number of conversation turns to include in context (user + assistant pairs)
29
- )
30
-
31
- # Load model at startup if CUDA is available
32
- print(f"Device: {DEVICE}")
33
- model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
34
- model = model.eval().to(DEVICE)
35
- model = Pipeline(model)
36
-
37
-
38
- def check_audio_duration_greater(audio_path: str) -> bool:
39
- """Check the duration of the audio file."""
40
- info = sf.info(audio_path)
41
- duration = info.duration # info.num_frames / info.sample_rate
42
- if not duration >= MIN_AUDIO_DURATION:
43
- raise gr.Error(f"Audio duration must be at least {MIN_AUDIO_DURATION} seconds.")
44
 
45
 
46
  def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
@@ -86,6 +64,85 @@ def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
86
  return fig
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def take_majority_vote(results: list[list[dict]]) -> list[str]:
90
  """For each audio file, take the majority vote of the labels across all windows"""
91
  outputs = []
@@ -110,19 +167,35 @@ def prompt_lm(
110
  hop_length_seconds: float = 10.0,
111
  ) -> list[str]:
112
  """Generate response using the model
 
113
  Args:
114
  audios (list[str]): List of audio file paths
115
  queries (list[str] | str): Query or list of queries to process
116
  window_length_seconds (float): Length of the window for processing audio
117
  hop_length_seconds (float): Hop length for processing audio
 
118
  Returns:
119
  list[str]: List of generated responses for each audio-query pair
120
  """
121
- if model is None:
122
- return "❌ Model not loaded. Please check the model configuration."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
125
- results: list[list[dict]] = model(
126
  audios,
127
  queries,
128
  window_length_seconds=window_length_seconds,
@@ -164,65 +237,33 @@ def add_user_query(chatbot_history: list[dict], chat_input: str) -> list[dict]:
164
  return chatbot_history
165
 
166
  chatbot_history.append({"role": "user", "content": chat_input.strip()})
 
 
 
 
 
 
 
 
 
 
 
167
  return chatbot_history
168
 
169
 
170
- def send_data_to_hub(chatbot_history: list[dict], audio: str, session_id: str):
171
- """Upload data to hub"""
172
- if not chatbot_history or len(chatbot_history) < 2:
173
- return
174
- user_text = chatbot_history[-2]["content"]
175
- model_response = chatbot_history[-1]["content"]
176
- upload_data(audio, user_text, model_response, session_id)
177
-
178
-
179
  def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
180
- """Generate response from the model based on user input and audio file with conversation history"""
181
  try:
182
- # Warn if conversation is getting long
183
- num_turns = len(chatbot_history)
184
- if num_turns > MAX_HISTORY_TURNS * 2: # Each turn = user + assistant message
185
- gr.Warning(
186
- "⚠️ Long conversations may affect response quality. Consider starting a new conversation with the Clear button."
187
- )
188
-
189
- # Build conversation context from history
190
- conversation_context = []
191
- for message in chatbot_history:
192
- if message["role"] == "user":
193
- conversation_context.append(f"User: {message['content']}")
194
- elif message["role"] == "assistant":
195
- conversation_context.append(f"Assistant: {message['content']}")
196
-
197
- # Get the last user message
198
  last_user_message = ""
199
  for message in reversed(chatbot_history):
200
  if message["role"] == "user":
201
  last_user_message = message["content"]
202
  break
203
-
204
- # Format the full prompt with conversation history
205
- if len(conversation_context) > 2: # More than just the current query
206
- # Include previous turns (limit to last MAX_HISTORY_TURNS exchanges)
207
- # recent_context = conversation_context[
208
- # -(MAX_HISTORY_TURNS + 1) : -1
209
- # ] # Exclude current message
210
- recent_context = conversation_context
211
-
212
- full_prompt = (
213
- "Previous conversation:\n"
214
- + "\n".join(recent_context)
215
- + "\n\nCurrent question: "
216
- + last_user_message
217
- )
218
- else:
219
- full_prompt = last_user_message
220
-
221
- print("\nFull prompt with history:", full_prompt)
222
-
223
  response = prompt_lm(
224
  audios=[audio_input],
225
- queries=[full_prompt.strip()],
226
  window_length_seconds=100_000,
227
  hop_length_seconds=100_000,
228
  )
@@ -236,7 +277,7 @@ def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
236
  print(f"Error generating response: {e}")
237
  traceback.print_exc()
238
  response = "Error generating response. Please try again."
239
-
240
  # Add model response to chat history
241
  chatbot_history.append({"role": "assistant", "content": response})
242
 
@@ -245,7 +286,19 @@ def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
245
 
246
  def main(
247
  assets_dir: Path,
 
 
248
  ):
 
 
 
 
 
 
 
 
 
 
249
  # Check if assets directory exists, if not create a placeholder
250
  if not assets_dir.exists():
251
  print(f"Warning: Assets directory {assets_dir} does not exist")
@@ -255,8 +308,7 @@ def main(
255
  laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
256
  frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
257
  robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3"
258
- whale_audio = assets_dir / "Humpback Whale - Megaptera novaeangliae.wav"
259
- crow_audio = assets_dir / "American Crow - Corvus brachyrhynchos.mp3"
260
 
261
  examples = {
262
  "Identifying Focal Species (Lazuli Bunting)": [
@@ -271,30 +323,35 @@ def main(
271
  str(robin_audio),
272
  "Caption the audio, using the scientific name for any animal species.",
273
  ],
274
- "Identifying Focal Species (Megaptera novaeangliae)": [
275
- str(whale_audio),
276
- "What is the scientific name for the focal species in the audio?",
277
- ],
278
- "Speaker Count (American Crow)": [
279
- str(crow_audio),
280
  "How many individuals are vocalizing in this audio?",
281
  ],
282
- "Caption the audio (Humpback Whale)": [str(whale_audio), "Caption the audio."],
 
 
 
 
 
 
 
 
283
  }
284
 
285
- gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
286
-
287
  with gr.Blocks(
288
  title="NatureLM-audio",
289
- theme=gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]),
 
 
 
290
  ) as app:
291
- with gr.Row():
292
- gr.HTML("""
293
  <div style="display: flex; align-items: center; gap: 12px;">
294
  <picture>
295
- <source srcset="/gradio_api/file=assets/ESP_logo_white.png" media="(prefers-color-scheme: dark)">
296
- <source srcset="/gradio_api/file=assets/esp_logo.png" media="(prefers-color-scheme: light)">
297
- <img src="/gradio_api/file=assets/esp_logo.png"
298
  alt="ESP Logo"
299
  style="height: 40px; width: auto;">
300
  </picture>
@@ -304,8 +361,7 @@ def main(
304
 
305
  with gr.Tabs():
306
  with gr.Tab("Analyze Audio"):
307
- session_id = gr.State(str(uuid.uuid4()))
308
- # uploaded_audio = gr.State()
309
  # Status indicator
310
  # status_text = gr.Textbox(
311
  # value=model_manager.get_status(),
@@ -325,7 +381,7 @@ def main(
325
  <div class="banner-text">Upload your first audio file below or select a pre-loaded example below.</div>
326
  </div>
327
  </div>
328
- <a href="https://huggingface.co/blog/EarthSpeciesProject/nature-lm-audio-ui-demo/" target="_blank" class="link-btn">View Tutorial</a>
329
  </div>
330
  """,
331
  padding=False,
@@ -338,14 +394,6 @@ def main(
338
  interactive=True,
339
  sources=["upload"],
340
  )
341
- # check that audio duration is greater than MIN_AUDIO_DURATION
342
- # raise
343
- audio_input.change(
344
- fn=check_audio_duration_greater,
345
- inputs=[audio_input],
346
- outputs=[],
347
- )
348
-
349
  with gr.Accordion(
350
  label="Toggle Spectrogram", open=False, visible=False
351
  ) as spectrogram:
@@ -398,7 +446,7 @@ def main(
398
  lines=1,
399
  show_label=False,
400
  submit_btn="Send",
401
- container=True,
402
  autofocus=False,
403
  elem_id="chat-input",
404
  )
@@ -420,6 +468,11 @@ def main(
420
  updated_history = add_user_query(chatbot_history, chat_input)
421
  return updated_history, ""
422
 
 
 
 
 
 
423
  clear_button = gr.ClearButton(
424
  components=[chatbot, chat_input, audio_input, plotter],
425
  visible=False,
@@ -459,11 +512,19 @@ def main(
459
  chat,
460
  plotter,
461
  ],
 
 
 
 
462
  ).then(
463
  fn=make_spectrogram_figure,
464
  inputs=[audio_input],
465
  outputs=[plotter],
466
- )
 
 
 
 
467
 
468
  # When submit clicked first:
469
  # 1. Validate and add user query to chat history
@@ -482,20 +543,18 @@ def main(
482
  lambda: gr.update(visible=True), # Show clear button
483
  None,
484
  [clear_button],
485
- ).then(
486
- send_data_to_hub,
487
- [chatbot, audio_input, session_id],
488
- None,
489
  )
490
 
491
- clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
 
 
492
 
493
  with gr.Tab("Sample Library"):
494
  with gr.Row():
495
  with gr.Column():
496
  gr.Markdown("### Download Sample Audio")
497
  gr.Markdown(
498
- """Feel free to explore these sample audio files. To download, click the button in the top-right corner of each audio file. You can also find a large collection of publicly available animal sounds on
499
  [Xenocanto](https://xeno-canto.org/explore/taxonomy) and [Watkins Marine Mammal Sound Database](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm)."""
500
  )
501
  samples = [
@@ -508,8 +567,8 @@ def main(
508
  "Green Tree Frog",
509
  ),
510
  (
511
- "assets/American Crow - Corvus brachyrhynchos.mp3",
512
- "American Crow",
513
  ),
514
  (
515
  "assets/Gray Wolf - Canis lupus italicus.m4a",
@@ -531,46 +590,33 @@ def main(
531
  type="filepath",
532
  show_download_button=True,
533
  )
 
 
 
 
534
 
535
  with gr.Tab("💡 Help"):
536
- gr.HTML("""
537
- <div class="banner">
538
- <div style="display: flex; padding: 0px; align-items: center; flex: 1;">
539
- <div style="font-size: 20px; margin-right: 12px;"></div>
540
- <div style="flex: 1;">
541
- <div class="banner-header">Help us improve the model!</div>
542
- <div class="banner-text">Found an issue or have suggestions? Join us on Discourse to share feedback and questions.</div>
543
- </div>
544
- </div>
545
- <a href="https://earthspeciesproject.discourse.group/t/feedback-for-naturelm-audio-ui-hugging-face-spaces-demo/17" target="_blank" class="link-btn">Share Feedback</a>
546
- </div>
547
  <div class="guide-section">
548
- <h3>Getting Started</h3>
549
  <ol style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
550
- <li style="margin-bottom: 8px;"><strong>Upload your audio</strong> or click on a pre-loaded example. Drag and drop your audio file containing animal vocalizations, or click on an example.</li>
551
- <li style="margin-bottom: 8px;"><strong>Trim your audio (if needed)</strong> by clicking the scissors icon on the bottom right of the audio panel. Try to keep your audio to 10 seconds or less.</li>
552
- <li style="margin-bottom: 8px;"><strong>View the Spectrogram (optional)</strong>. You can easily view/hide the spectrogram of your audio for closer analysis.</li>
553
- <li style="margin-bottom: 8px;"><strong>Select a task or write your own</strong>. Select an option from pre-loaded tasks. This will auto-fill the text box with a prompt, so all you have to do is hit Send. Or, type a custom prompt directly into the chat.</li>
554
- <li style="margin-bottom: 0;"><strong>Send and Analyze Audio</strong>. Press "Send" or type Enter to begin processing your audio. Ask follow-up questions or press "Clear" to start a new conversation.</li>
555
  </ol>
556
  <p></p>
557
  </div>
 
558
  <div class="guide-section">
559
- <h3>Tips</h3>
560
  <b>Prompting Best Practices</b>
561
- <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
562
- <li>When possible, use scientific or taxonomic names and mention the context if known (geographic area/location, time of day or year, habitat type)</li>
563
- <li>Ask one question at a time, and be specific about what you want to know</li>
564
- <ul> Don't ask: <i>"Analyze this audio and tell me all you know about it."</i></ul>
565
- <ul>✅ Do ask: <i>"What species made this sound?"</i></ul>
566
- <li>Keep prompts more open-ended and avoid asking Yes/No or very targeted questions</li>
567
- <ul>❌ Don't ask: <i>"Is there a bottlenose dolphin vocalizing in the audio? Yes or No."</i></ul>
568
- <ul>✅ Do ask: <i>"What focal species, if any, are heard in the audio?"</i></em></ul>
569
- <li>Giving the model options to choose works well for broader categories (less so for specific species)</li>
570
- <ul>❌ Don't ask: <i>"Classify the audio into one of the following species: Bottlenose Dolphin, Orca, Great Gray Owl"‍</i></ul>
571
- <ul>✅ Do ask: <i>"Classify the audio into one of the following categories: Cetaceans, Aves, or None."</i></ul>
572
  </ul>
573
- <br>
574
  <b>Audio Files</b>
575
  <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
576
  <li>Supported formats: .wav, .mp3, .aac, .flac, .ogg, .webm, .midi, .aiff, .wma, .opus, .amr</li>
@@ -582,20 +628,32 @@ def main(
582
  <div class="guide-section">
583
  <h3>Learn More</h3>
584
  <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
585
- <li>Read our <a href="https://huggingface.co/blog/EarthSpeciesProject/nature-lm-audio-ui-demo/" target="_blank">recent blog post</a> with a step-by-step tutorial</li>
586
  <li>Check out the <a href="https://arxiv.org/abs/2411.07186" target="_blank">published paper</a> for a deeper technical dive on NatureLM-audio.</li>
587
  <li>Visit the <a href="https://earthspecies.github.io/naturelm-audio-demo/" target="_blank">NatureLM-audio Demo Page</a> for additional context, a demo video, and more examples of the model in action.</li>
588
  <li>Sign up for our <a href="https://forms.gle/WjrbmFhKkzmEgwvY7" target="_blank">closed beta waitlist</a>, if you’re interested in testing upcoming features like longer audio files and batch processing.</li>
589
  </ul>
 
 
 
 
590
  </div>
591
  </div>
592
  """)
593
 
594
  app.css = """
 
 
 
 
 
 
 
595
  #chat-input textarea {
596
  background: white;
597
  flex: 1;
598
  }
 
599
  #chat-input .submit-button {
600
  padding: 10px;
601
  margin: 2px 6px;
@@ -624,6 +682,7 @@ def main(
624
  color: #374151;
625
  margin-bottom: 4px;
626
  }
 
627
  .banner .banner-text {
628
  style="font-size: 14px;
629
  color: #6b7280;
@@ -642,10 +701,30 @@ def main(
642
  display: inline-block;
643
  transition: background 0.2s ease;
644
  }
 
645
  .link-btn:hover {
646
  background: #2563eb;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  }
648
-
649
  .guide-section {
650
  margin-bottom: 32px;
651
  border-radius: 8px;
@@ -667,10 +746,12 @@ def main(
667
  #chat-input {
668
  background: #1e1e1e;
669
  }
 
670
  #chat-input textarea {
671
  background: #1e1e1e;
672
  color: white;
673
  }
 
674
  .banner {
675
  background: #1e1e1e;
676
  color: white;
@@ -687,6 +768,8 @@ def main(
687
  # Create and launch the app
688
  app = main(
689
  assets_dir=Path("assets"),
 
 
690
  )
691
 
692
  if __name__ == "__main__":
 
1
  import spaces
 
2
  import warnings
3
  import traceback
4
  import numpy as np
 
9
  import gradio as gr
10
  import torch
11
  import torchaudio
 
12
  import matplotlib.pyplot as plt
13
 
14
  from NatureLM.config import Config
15
  from NatureLM.models.NatureLM import NatureLM
16
  from NatureLM.infer import Pipeline
17
 
 
 
 
18
  warnings.filterwarnings("ignore")
19
  SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
20
+ CURRENT_AUDIO = "" # Placeholder for current audio file
21
+ FIRST_QUERY: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
 
64
  return fig
65
 
66
 
67
+ class ModelManager:
68
+ """Manages model loading and state"""
69
+
70
+ def __init__(self):
71
+ self.model: Optional[NatureLM] = None
72
+ self.config: Optional[Config] = None
73
+ self.is_loaded = False
74
+ self.is_loading = False
75
+ self.load_failed = False
76
+
77
+ def check_availability(self) -> tuple[bool, str]:
78
+ """Check if the model is available for download"""
79
+ try:
80
+ from huggingface_hub import model_info
81
+
82
+ model_info("EarthSpeciesProject/NatureLM-audio")
83
+ return True, "Model is available"
84
+ except Exception as e:
85
+ return False, f"Model not available: {str(e)}"
86
+
87
+ def reset_state(self):
88
+ """Reset the model loading state to allow retrying after a failure"""
89
+ self.model = None
90
+ self.is_loaded = False
91
+ self.is_loading = False
92
+ self.load_failed = False
93
+ return self.get_status()
94
+
95
+ def get_status(self) -> str:
96
+ """Get the current model loading status"""
97
+ if self.is_loaded:
98
+ return "✅ Model loaded and ready"
99
+ elif self.is_loading:
100
+ return "🔄 Loading model... Please wait"
101
+ elif self.load_failed:
102
+ return "❌ Model failed to load. Please check the configuration."
103
+ else:
104
+ return "⏳ Ready to load model on first use"
105
+
106
+ def load_model(self) -> Optional[NatureLM]:
107
+ """Load the model if needed"""
108
+ if self.is_loaded:
109
+ return self.model
110
+
111
+ if self.is_loading or self.load_failed:
112
+ return None
113
+
114
+ try:
115
+ self.is_loading = True
116
+ print("Loading model...")
117
+
118
+ # Check if model is available first
119
+ available, message = self.check_availability()
120
+ if not available:
121
+ raise Exception(f"Model not available: {message}")
122
+
123
+ model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
124
+ model = model.eval().to("cuda")
125
+
126
+ self.model = Pipeline(model)
127
+ self.is_loaded = True
128
+ self.is_loading = False
129
+ print("Model loaded successfully!")
130
+
131
+ except Exception as e:
132
+ print(f"Error loading model: {e}")
133
+ self.is_loading = False
134
+ self.load_failed = True
135
+ return None
136
+
137
+
138
+ # Global model manager instance
139
+ model_manager = ModelManager()
140
+
141
+ # @spaces.GPU
142
+ # def load_model():
143
+ # model_manager.load_model()
144
+
145
+
146
  def take_majority_vote(results: list[list[dict]]) -> list[str]:
147
  """For each audio file, take the majority vote of the labels across all windows"""
148
  outputs = []
 
167
  hop_length_seconds: float = 10.0,
168
  ) -> list[str]:
169
  """Generate response using the model
170
+
171
  Args:
172
  audios (list[str]): List of audio file paths
173
  queries (list[str] | str): Query or list of queries to process
174
  window_length_seconds (float): Length of the window for processing audio
175
  hop_length_seconds (float): Hop length for processing audio
176
+
177
  Returns:
178
  list[str]: List of generated responses for each audio-query pair
179
  """
180
+ if model_manager.model is None:
181
+ model_manager.load_model()
182
+
183
+ if model_manager.model is None:
184
+ if model_manager.is_loading:
185
+ return "🔄 Loading model for the first query. This takes 20-30 seconds..👷🏽‍♂️����🪚"
186
+ # while True:
187
+ # if model_manager.is_loaded:
188
+ # model = model_manager.model
189
+ # break
190
+ # elif model_manager.load_failed:
191
+ # return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
192
+ elif model_manager.load_failed:
193
+ return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
194
+ else:
195
+ return "Demo mode: Model not loaded. Please check the model configuration."
196
 
197
  with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
198
+ results: list[list[dict]] = model_manager.model(
199
  audios,
200
  queries,
201
  window_length_seconds=window_length_seconds,
 
237
  return chatbot_history
238
 
239
  chatbot_history.append({"role": "user", "content": chat_input.strip()})
240
+ global FIRST_QUERY
241
+ if FIRST_QUERY:
242
+ # Add an assistant message indicating model is loading
243
+ chatbot_history.append(
244
+ {
245
+ "role": "assistant",
246
+ "content": "🔄 Loading model for the first query. This takes 30-40 seconds..👷🏽‍♂️🔨🪚",
247
+ }
248
+ )
249
+ FIRST_QUERY = False
250
+
251
  return chatbot_history
252
 
253
 
 
 
 
 
 
 
 
 
 
254
  def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
255
+ """Generate response from the model based on user input and audio file"""
256
  try:
257
+ # Get the last user message from chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  last_user_message = ""
259
  for message in reversed(chatbot_history):
260
  if message["role"] == "user":
261
  last_user_message = message["content"]
262
  break
263
+ print("\nUser message:", last_user_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  response = prompt_lm(
265
  audios=[audio_input],
266
+ queries=[last_user_message.strip()],
267
  window_length_seconds=100_000,
268
  hop_length_seconds=100_000,
269
  )
 
277
  print(f"Error generating response: {e}")
278
  traceback.print_exc()
279
  response = "Error generating response. Please try again."
280
+
281
  # Add model response to chat history
282
  chatbot_history.append({"role": "assistant", "content": response})
283
 
 
286
 
287
  def main(
288
  assets_dir: Path,
289
+ cfg_path: str | Path,
290
+ options: list[str] = [],
291
  ):
292
+ # Load configuration
293
+ try:
294
+ cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
295
+ model_manager.config = cfg
296
+ print("Configuration loaded successfully")
297
+ except Exception as e:
298
+ print(f"Warning: Could not load config: {e}")
299
+ print("Running in demo mode")
300
+ model_manager.config = None
301
+
302
  # Check if assets directory exists, if not create a placeholder
303
  if not assets_dir.exists():
304
  print(f"Warning: Assets directory {assets_dir} does not exist")
 
308
  laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
309
  frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
310
  robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3"
311
+ vireo_audio = assets_dir / "yell-YELLWarblingVireoMammoth20150614T29ms.mp3"
 
312
 
313
  examples = {
314
  "Identifying Focal Species (Lazuli Bunting)": [
 
323
  str(robin_audio),
324
  "Caption the audio, using the scientific name for any animal species.",
325
  ],
326
+ "Caption the audio (Warbling Vireo)": [str(vireo_audio), "Caption the audio."],
327
+ "Speaker Count (Lazuli Bunting)": [
328
+ str(laz_audio),
 
 
 
329
  "How many individuals are vocalizing in this audio?",
330
  ],
331
+ "Caption the audio (Green Tree Frog)": [
332
+ str(frog_audio),
333
+ "Caption the audio, using the common name for any animal species.",
334
+ ],
335
+ "Caption the audio (American Robin)": [
336
+ str(robin_audio),
337
+ "Caption the audio, using the scientific name for any animal species.",
338
+ ],
339
+ "Caption the audio (Warbling Vireo)": [str(vireo_audio), "Caption the audio."],
340
  }
341
 
 
 
342
  with gr.Blocks(
343
  title="NatureLM-audio",
344
+ theme=gr.themes.Base(
345
+ primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]
346
+ ),
347
+ css="styles.css",
348
  ) as app:
349
+ header = gr.HTML("""
 
350
  <div style="display: flex; align-items: center; gap: 12px;">
351
  <picture>
352
+ <source srcset="https://huggingface.co/spaces/EarthSpeciesProject/NatureLM-Audio/resolve/main/assets/ESP_logo_white.png" media="(prefers-color-scheme: dark)">
353
+ <source srcset="https://huggingface.co/spaces/EarthSpeciesProject/NatureLM-Audio/resolve/main/assets/esp_logo.png" media="(prefers-color-scheme: light)">
354
+ <img src="https://huggingface.co/spaces/EarthSpeciesProject/NatureLM-Audio/resolve/main/assets/esp_logo.png"
355
  alt="ESP Logo"
356
  style="height: 40px; width: auto;">
357
  </picture>
 
361
 
362
  with gr.Tabs():
363
  with gr.Tab("Analyze Audio"):
364
+ uploaded_audio = gr.State()
 
365
  # Status indicator
366
  # status_text = gr.Textbox(
367
  # value=model_manager.get_status(),
 
381
  <div class="banner-text">Upload your first audio file below or select a pre-loaded example below.</div>
382
  </div>
383
  </div>
384
+ <a href="https://www.earthspecies.org/blog" target="_blank" class="link-btn">View Tutorial</a>
385
  </div>
386
  """,
387
  padding=False,
 
394
  interactive=True,
395
  sources=["upload"],
396
  )
 
 
 
 
 
 
 
 
397
  with gr.Accordion(
398
  label="Toggle Spectrogram", open=False, visible=False
399
  ) as spectrogram:
 
446
  lines=1,
447
  show_label=False,
448
  submit_btn="Send",
449
+ container=False,
450
  autofocus=False,
451
  elem_id="chat-input",
452
  )
 
468
  updated_history = add_user_query(chatbot_history, chat_input)
469
  return updated_history, ""
470
 
471
+ def update_current_audio(audio_input):
472
+ global CURRENT_AUDIO
473
+ if audio_input != CURRENT_AUDIO:
474
+ CURRENT_AUDIO = audio_input
475
+
476
  clear_button = gr.ClearButton(
477
  components=[chatbot, chat_input, audio_input, plotter],
478
  visible=False,
 
512
  chat,
513
  plotter,
514
  ],
515
+ ).then(
516
+ fn=update_current_audio,
517
+ inputs=[audio_input],
518
+ outputs=[],
519
  ).then(
520
  fn=make_spectrogram_figure,
521
  inputs=[audio_input],
522
  outputs=[plotter],
523
+ )# .then(
524
+ # fn=load_model,
525
+ # inputs=[],
526
+ # outputs=[],
527
+ # )
528
 
529
  # When submit clicked first:
530
  # 1. Validate and add user query to chat history
 
543
  lambda: gr.update(visible=True), # Show clear button
544
  None,
545
  [clear_button],
 
 
 
 
546
  )
547
 
548
+ clear_button.click(
549
+ lambda: gr.ClearButton(visible=False), None, [clear_button]
550
+ )
551
 
552
  with gr.Tab("Sample Library"):
553
  with gr.Row():
554
  with gr.Column():
555
  gr.Markdown("### Download Sample Audio")
556
  gr.Markdown(
557
+ """Feel free to explore these sample audio files. To download, click the button in the top-right corner of each audio file, or **Download All**. You can also find a large collection of publicly available animal sounds on
558
  [Xenocanto](https://xeno-canto.org/explore/taxonomy) and [Watkins Marine Mammal Sound Database](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm)."""
559
  )
560
  samples = [
 
567
  "Green Tree Frog",
568
  ),
569
  (
570
+ "assets/Eastern Gray Squirrel - Sciurus carolinensis.wav",
571
+ "Eastern Gray Squirrel",
572
  ),
573
  (
574
  "assets/Gray Wolf - Canis lupus italicus.m4a",
 
590
  type="filepath",
591
  show_download_button=True,
592
  )
593
+ with gr.Row():
594
+ gr.HTML("""<center>
595
+ <a href="https://huggingface.co/spaces/EarthSpeciesProject/NatureLM-Audio/resolve/main/assets/Sample_Audio_Files_NatureLM_audio.zip" download class="download-btn">Download All</a></center>
596
+ """)
597
 
598
  with gr.Tab("💡 Help"):
599
+ gr.HTML("""
 
 
 
 
 
 
 
 
 
 
600
  <div class="guide-section">
601
+ <h3>Getting Started</h3>
602
  <ol style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
603
+ <li style="margin-bottom: 8px;"><strong>Upload your audio</strong> - Click the upload area or drag and drop your audio file containing animal vocalizations.</li>
604
+ <li style="margin-bottom: 8px;"><strong>Trim your audio (if needed)</strong> - Try to keep your audio to 10 seconds or less.</li>
605
+ <li style="margin-bottom: 8px;"><strong>View the Spectrogram (optional)</strong> - You can easily view/hide the spectrogram of your audio for closer analysis.</li>
606
+ <li style="margin-bottom: 8px;"><strong>Select a task or write your own</strong> - Select an option from pre-loaded tasks. This will auto-fill the text box with a prompt, so all you have to do is hit Send. Or, type a custom prompt directly into the chat.</li>
607
+ <li style="margin-bottom: 0;"><strong>Send and Analyze Audio</strong> - Press "Send" or type Enter to begin processing your audio. Ask follow-up questions or press "Clear" to start a new conversation.</li>
608
  </ol>
609
  <p></p>
610
  </div>
611
+
612
  <div class="guide-section">
613
+ <h3>Tips & Tricks</h3>
614
  <b>Prompting Best Practices</b>
615
+ <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
616
+ <li>Be specific about what you want to know (e.g., "What species made this call?" vs "Analyze this audio")</li>
617
+ <li>Mention the context if known (geographic area/location, time of day or year, habitat type)</li>
618
+ <li>[TO ADD: examples of classification prompts that do and don't work well]</li>
 
 
 
 
 
 
 
619
  </ul>
 
620
  <b>Audio Files</b>
621
  <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
622
  <li>Supported formats: .wav, .mp3, .aac, .flac, .ogg, .webm, .midi, .aiff, .wma, .opus, .amr</li>
 
628
  <div class="guide-section">
629
  <h3>Learn More</h3>
630
  <ul style="margin-top: 12px; padding-left: 20px; color: #6b7280; font-size: 14px; line-height: 1.6;">
631
+ <li>Read our <a href="https://earthspecies.org/blog" target="_blank">recent blog post</a> with a step-by-step tutorial</li>
632
  <li>Check out the <a href="https://arxiv.org/abs/2411.07186" target="_blank">published paper</a> for a deeper technical dive on NatureLM-audio.</li>
633
  <li>Visit the <a href="https://earthspecies.github.io/naturelm-audio-demo/" target="_blank">NatureLM-audio Demo Page</a> for additional context, a demo video, and more examples of the model in action.</li>
634
  <li>Sign up for our <a href="https://forms.gle/WjrbmFhKkzmEgwvY7" target="_blank">closed beta waitlist</a>, if you’re interested in testing upcoming features like longer audio files and batch processing.</li>
635
  </ul>
636
+ </div>
637
+ <div class="guide-section">
638
+ <h4>Help us improve the model!</h4>
639
+ <p>Found an issue or have suggestions? Please join us on <a href="https://earthspeciesproject.discourse.group/" target="_blank">Discourse</a> to share any feedback, questions, bug reports, or other ideas. Your input helps make NatureLM-audio better for everyone.</p>
640
  </div>
641
  </div>
642
  """)
643
 
644
  app.css = """
645
+ #chat-input {
646
+ background: white;
647
+ padding: 10px;
648
+ min-height: 44px;
649
+ display: flex;
650
+ align-items: center;
651
+ }
652
  #chat-input textarea {
653
  background: white;
654
  flex: 1;
655
  }
656
+
657
  #chat-input .submit-button {
658
  padding: 10px;
659
  margin: 2px 6px;
 
682
  color: #374151;
683
  margin-bottom: 4px;
684
  }
685
+
686
  .banner .banner-text {
687
  style="font-size: 14px;
688
  color: #6b7280;
 
701
  display: inline-block;
702
  transition: background 0.2s ease;
703
  }
704
+
705
  .link-btn:hover {
706
  background: #2563eb;
707
+ }
708
+ .download-btn {
709
+ padding: 10px 20px;
710
+ border-radius: 6px;
711
+ font-size: 13px;
712
+ font-weight: 500;
713
+ cursor: pointer;
714
+ border: none;
715
+ background: #3b82f6;
716
+ color: white;
717
+ text-decoration: none;
718
+ display: block;
719
+ text-align: center;
720
+ transition: background 0.2s ease;
721
+ width: 200px;
722
+ box-sizing: border-box;
723
+ }
724
+
725
+ .download-btn:hover {
726
+ background: #2563eb;
727
  }
 
728
  .guide-section {
729
  margin-bottom: 32px;
730
  border-radius: 8px;
 
746
  #chat-input {
747
  background: #1e1e1e;
748
  }
749
+
750
  #chat-input textarea {
751
  background: #1e1e1e;
752
  color: white;
753
  }
754
+
755
  .banner {
756
  background: #1e1e1e;
757
  color: white;
 
768
  # Create and launch the app
769
  app = main(
770
  assets_dir=Path("assets"),
771
+ cfg_path=Path("configs/inference.yml"),
772
+ options=[],
773
  )
774
 
775
  if __name__ == "__main__":
assets/American Crow - Corvus brachyrhynchos.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d0f76bff28d3e3021be495754b28ef3924bc32ff0c657b67bd4ee6bb177a1f8e
3
- size 2164626
 
 
 
 
configs/inference.yml CHANGED
@@ -59,4 +59,3 @@ generate:
59
  temperature: 0.1
60
  repetition_penalty: 1.0
61
  length_penalty: 1.0
62
- merging_alpha: 0.5
 
59
  temperature: 0.1
60
  repetition_penalty: 1.0
61
  length_penalty: 1.0
 
data_store.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- import uuid
4
- import json
5
- from huggingface_hub import HfApi, HfFileSystem
6
-
7
- DATASET_REPO = "EarthSpeciesProject/naturelm-audio-space-logs"
8
- SPLIT = "test"
9
- TESTING = os.getenv("TESTING", "0") == "1"
10
- api = HfApi(token=os.getenv("HF_TOKEN", None))
11
- # Upload audio
12
- # check if file exists
13
- hf_fs = HfFileSystem(token=os.getenv("HF_TOKEN", None))
14
-
15
-
16
- def upload_data(audio: str | Path, user_text: str, model_response: str, session_id: str = ""):
17
- data_id = str(uuid.uuid4())
18
-
19
- if TESTING:
20
- data_id = "test-" + data_id
21
- session_id = "test-" + session_id
22
-
23
- # Audio path in repo
24
- suffix = Path(audio).suffix
25
- audio_p = f"{SPLIT}/audio/" + session_id + suffix
26
-
27
- if not hf_fs.exists(f"datasets/{DATASET_REPO}/{audio_p}"):
28
- api.upload_file(
29
- path_or_fileobj=str(audio),
30
- path_in_repo=audio_p,
31
- repo_id=DATASET_REPO,
32
- repo_type="dataset",
33
- )
34
-
35
- text = {
36
- "user_message": user_text,
37
- "model_response": model_response,
38
- "file_name": "audio/" + session_id + suffix, # has to be relative to metadata.jsonl
39
- "original_fn": os.path.basename(audio),
40
- "id": data_id,
41
- "session_id": session_id,
42
- }
43
-
44
- # Append to a jsonl file in the repo
45
- # APPEND DOESNT WORK, have to open first
46
- if hf_fs.exists(f"datasets/{DATASET_REPO}/{SPLIT}/metadata.jsonl"):
47
- with hf_fs.open(f"datasets/{DATASET_REPO}/{SPLIT}/metadata.jsonl", "r") as f:
48
- lines = f.readlines()
49
- lines.append(json.dumps(text) + "\n")
50
- with hf_fs.open(f"datasets/{DATASET_REPO}/{SPLIT}/metadata.jsonl", "w") as f:
51
- f.writelines(lines)
52
- else:
53
- with hf_fs.open(f"datasets/{DATASET_REPO}/{SPLIT}/metadata.jsonl", "w") as f:
54
- f.write(json.dumps(text) + "\n")
55
-
56
- # Write a separate file instead
57
- # with hf_fs.open(f"datasets/{DATASET_REPO}/{data_id}.json", "w") as f:
58
- # json.dump(text, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -14,7 +14,6 @@ soundfile>=0.13.1
14
  spaces>=0.40.0
15
  torch>=2.8.0
16
  torchaudio>=2.8.0
17
- torchcodec>=0.8.0
18
  tqdm>=4.67.1
19
- transformers[sentencepiece]==4.55.3
20
  matplotlib>=3.10.5
 
14
  spaces>=0.40.0
15
  torch>=2.8.0
16
  torchaudio>=2.8.0
 
17
  tqdm>=4.67.1
18
+ transformers[sentencepiece]>=4.55.2
19
  matplotlib>=3.10.5