""" Marine1 Underwater Acoustic Classifier - Inference Script Simple example for using the model with Hugging Face Supports both .pth (pickle) and .safetensors formats """ import torch import librosa import numpy as np from typing import Dict, Tuple import warnings warnings.filterwarnings('ignore') try: from safetensors.torch import load_file SAFETENSORS_AVAILABLE = True except ImportError: SAFETENSORS_AVAILABLE = False print("Warning: safetensors not installed. Install with: pip install safetensors") class Marine1Classifier: """Underwater acoustic classifier using Marine1 model""" def __init__(self, model_path: str, device: str = None): """ Initialize the classifier Args: model_path: Path to the model file (.pth or .safetensors) device: Device to run on ('cuda', 'cpu', or 'mps'). Auto-detected if None. """ if device is None: if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' self.device = torch.device(device) print(f"Using device: {self.device}") # Determine file format is_safetensors = model_path.endswith('.safetensors') if is_safetensors: if not SAFETENSORS_AVAILABLE: raise ImportError("safetensors not installed. Install with: pip install safetensors") print(f"Loading safetensors model (secure format)...") # Load safetensors state_dict = load_file(model_path, device=str(self.device)) # Parse metadata from safetensors import safe_open with safe_open(model_path, framework="pt", device=str(self.device)) as f: metadata = f.metadata() # Get class mapping from metadata import ast self.class_to_id = ast.literal_eval(metadata.get('class_to_id', "{}")) if not self.class_to_id: # Default mapping self.class_to_id = { 'vessel': 0, 'marine_animal': 1, 'natural_sound': 2, 'other_anthropogenic': 3 } else: print(f"Loading PyTorch model (.pth format)...") # Load checkpoint checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) # Get class mapping self.class_to_id = checkpoint['class_to_id'] state_dict = checkpoint['model_state_dict'] self.id_to_class = {v: k for k, v in self.class_to_id.items()} self.class_names = [self.id_to_class[i] for i in range(len(self.id_to_class))] # Load model architecture (custom fine-tuned ResNet18) self.model = self._create_model_architecture(len(self.class_names)) # Load weights self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() format_type = "safetensors (secure)" if is_safetensors else "PyTorch (.pth)" print(f"✅ Model loaded successfully ({format_type})") print(f" Classes: {len(self.class_names)}") def _create_model_architecture(self, num_classes: int): """Create the model architecture matching the trained model""" import torch.nn as nn from torchvision import models class LightweightFineTuned(nn.Module): def __init__(self, num_classes=4): super(LightweightFineTuned, self).__init__() resnet = models.resnet18(weights=None) # Adapt first layer for grayscale spectrograms self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 self.avgpool = resnet.avgpool self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, num_classes) ) self.confidence_head = nn.Sequential( nn.Linear(512, 1), nn.Sigmoid() ) def forward(self, x, return_confidence=False): if len(x.shape) == 3: x = x.unsqueeze(1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) features = torch.flatten(x, 1) logits = self.classifier(features) if return_confidence: confidence = self.confidence_head(features) return logits, confidence return logits return LightweightFineTuned(num_classes=num_classes) def process_audio(self, audio_path: str, sr: int = 16000, duration: float = 10.0) -> np.ndarray: """ Process audio file into mel spectrogram Args: audio_path: Path to audio file sr: Target sample rate duration: Maximum duration in seconds Returns: Log mel spectrogram as numpy array """ # Load audio y, _ = librosa.load(audio_path, sr=sr, duration=duration) # Pad if too short target_length = int(sr * duration) if len(y) < target_length: y = np.pad(y, (0, target_length - len(y)), mode='constant') # Create mel spectrogram mel_spec = librosa.feature.melspectrogram( y=y, sr=sr, n_mels=128, n_fft=2048, hop_length=512, fmax=8000 ) # Convert to log scale log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) return log_mel_spec def predict(self, audio_path: str) -> Dict[str, any]: """ Predict the class of an audio file Args: audio_path: Path to audio file Returns: Dictionary with prediction results """ # Process audio log_mel_spec = self.process_audio(audio_path) # Prepare input tensor input_tensor = torch.FloatTensor(log_mel_spec).unsqueeze(0).unsqueeze(0).to(self.device) # Predict with torch.no_grad(): outputs = self.model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] # Get results predicted_idx = probabilities.argmax().item() predicted_class = self.class_names[predicted_idx] confidence = probabilities[predicted_idx].item() # Get all class probabilities all_probs = { self.class_names[i]: probabilities[i].item() for i in range(len(self.class_names)) } return { 'predicted_class': predicted_class, 'confidence': confidence, 'all_probabilities': all_probs, 'predicted_class_id': predicted_idx } def predict_batch(self, audio_paths: list) -> list: """ Predict classes for multiple audio files Args: audio_paths: List of paths to audio files Returns: List of prediction dictionaries """ results = [] for audio_path in audio_paths: try: result = self.predict(audio_path) result['audio_path'] = audio_path results.append(result) except Exception as e: print(f"Error processing {audio_path}: {e}") results.append({ 'audio_path': audio_path, 'error': str(e) }) return results def main(): """Example usage""" import sys if len(sys.argv) < 3: print("Usage: python inference.py ") print("Example: python inference.py best_model_finetuned.pth underwater_sound.wav") sys.exit(1) model_path = sys.argv[1] audio_path = sys.argv[2] # Initialize classifier classifier = Marine1Classifier(model_path) # Make prediction print(f"\nProcessing: {audio_path}") result = classifier.predict(audio_path) # Display results print(f"\n{'='*50}") print(f"Prediction: {result['predicted_class'].replace('_', ' ').title()}") print(f"Confidence: {result['confidence']*100:.2f}%") print(f"\nAll Probabilities:") for class_name, prob in sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True): print(f" {class_name.replace('_', ' ').title():25s}: {prob*100:6.2f}%") print(f"{'='*50}\n") if __name__ == "__main__": main()