|
|
""" |
|
|
Marine1 Underwater Acoustic Classifier - Inference Script |
|
|
Simple example for using the model with Hugging Face |
|
|
Supports both .pth (pickle) and .safetensors formats |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import librosa |
|
|
import numpy as np |
|
|
from typing import Dict, Tuple |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
SAFETENSORS_AVAILABLE = True |
|
|
except ImportError: |
|
|
SAFETENSORS_AVAILABLE = False |
|
|
print("Warning: safetensors not installed. Install with: pip install safetensors") |
|
|
|
|
|
|
|
|
class Marine1Classifier: |
|
|
"""Underwater acoustic classifier using Marine1 model""" |
|
|
|
|
|
def __init__(self, model_path: str, device: str = None): |
|
|
""" |
|
|
Initialize the classifier |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model file (.pth or .safetensors) |
|
|
device: Device to run on ('cuda', 'cpu', or 'mps'). Auto-detected if None. |
|
|
""" |
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
device = 'cuda' |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = 'mps' |
|
|
else: |
|
|
device = 'cpu' |
|
|
|
|
|
self.device = torch.device(device) |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
is_safetensors = model_path.endswith('.safetensors') |
|
|
|
|
|
if is_safetensors: |
|
|
if not SAFETENSORS_AVAILABLE: |
|
|
raise ImportError("safetensors not installed. Install with: pip install safetensors") |
|
|
|
|
|
print(f"Loading safetensors model (secure format)...") |
|
|
|
|
|
state_dict = load_file(model_path, device=str(self.device)) |
|
|
|
|
|
|
|
|
from safetensors import safe_open |
|
|
with safe_open(model_path, framework="pt", device=str(self.device)) as f: |
|
|
metadata = f.metadata() |
|
|
|
|
|
|
|
|
import ast |
|
|
self.class_to_id = ast.literal_eval(metadata.get('class_to_id', "{}")) |
|
|
if not self.class_to_id: |
|
|
|
|
|
self.class_to_id = { |
|
|
'vessel': 0, 'marine_animal': 1, |
|
|
'natural_sound': 2, 'other_anthropogenic': 3 |
|
|
} |
|
|
else: |
|
|
print(f"Loading PyTorch model (.pth format)...") |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
|
|
|
self.class_to_id = checkpoint['class_to_id'] |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
|
|
self.id_to_class = {v: k for k, v in self.class_to_id.items()} |
|
|
self.class_names = [self.id_to_class[i] for i in range(len(self.id_to_class))] |
|
|
|
|
|
|
|
|
self.model = self._create_model_architecture(len(self.class_names)) |
|
|
|
|
|
|
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
format_type = "safetensors (secure)" if is_safetensors else "PyTorch (.pth)" |
|
|
print(f"✅ Model loaded successfully ({format_type})") |
|
|
print(f" Classes: {len(self.class_names)}") |
|
|
|
|
|
def _create_model_architecture(self, num_classes: int): |
|
|
"""Create the model architecture matching the trained model""" |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
|
|
|
class LightweightFineTuned(nn.Module): |
|
|
def __init__(self, num_classes=4): |
|
|
super(LightweightFineTuned, self).__init__() |
|
|
|
|
|
resnet = models.resnet18(weights=None) |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
self.bn1 = resnet.bn1 |
|
|
self.relu = resnet.relu |
|
|
self.maxpool = resnet.maxpool |
|
|
|
|
|
self.layer1 = resnet.layer1 |
|
|
self.layer2 = resnet.layer2 |
|
|
self.layer3 = resnet.layer3 |
|
|
self.layer4 = resnet.layer4 |
|
|
self.avgpool = resnet.avgpool |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.25), |
|
|
nn.Linear(256, num_classes) |
|
|
) |
|
|
|
|
|
self.confidence_head = nn.Sequential( |
|
|
nn.Linear(512, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x, return_confidence=False): |
|
|
if len(x.shape) == 3: |
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = self.relu(x) |
|
|
x = self.maxpool(x) |
|
|
|
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
|
|
|
x = self.avgpool(x) |
|
|
features = torch.flatten(x, 1) |
|
|
logits = self.classifier(features) |
|
|
|
|
|
if return_confidence: |
|
|
confidence = self.confidence_head(features) |
|
|
return logits, confidence |
|
|
|
|
|
return logits |
|
|
|
|
|
return LightweightFineTuned(num_classes=num_classes) |
|
|
|
|
|
def process_audio(self, audio_path: str, sr: int = 16000, duration: float = 10.0) -> np.ndarray: |
|
|
""" |
|
|
Process audio file into mel spectrogram |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
sr: Target sample rate |
|
|
duration: Maximum duration in seconds |
|
|
|
|
|
Returns: |
|
|
Log mel spectrogram as numpy array |
|
|
""" |
|
|
|
|
|
y, _ = librosa.load(audio_path, sr=sr, duration=duration) |
|
|
|
|
|
|
|
|
target_length = int(sr * duration) |
|
|
if len(y) < target_length: |
|
|
y = np.pad(y, (0, target_length - len(y)), mode='constant') |
|
|
|
|
|
|
|
|
mel_spec = librosa.feature.melspectrogram( |
|
|
y=y, |
|
|
sr=sr, |
|
|
n_mels=128, |
|
|
n_fft=2048, |
|
|
hop_length=512, |
|
|
fmax=8000 |
|
|
) |
|
|
|
|
|
|
|
|
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) |
|
|
|
|
|
return log_mel_spec |
|
|
|
|
|
def predict(self, audio_path: str) -> Dict[str, any]: |
|
|
""" |
|
|
Predict the class of an audio file |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
|
|
|
Returns: |
|
|
Dictionary with prediction results |
|
|
""" |
|
|
|
|
|
log_mel_spec = self.process_audio(audio_path) |
|
|
|
|
|
|
|
|
input_tensor = torch.FloatTensor(log_mel_spec).unsqueeze(0).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
|
|
|
|
|
|
|
predicted_idx = probabilities.argmax().item() |
|
|
predicted_class = self.class_names[predicted_idx] |
|
|
confidence = probabilities[predicted_idx].item() |
|
|
|
|
|
|
|
|
all_probs = { |
|
|
self.class_names[i]: probabilities[i].item() |
|
|
for i in range(len(self.class_names)) |
|
|
} |
|
|
|
|
|
return { |
|
|
'predicted_class': predicted_class, |
|
|
'confidence': confidence, |
|
|
'all_probabilities': all_probs, |
|
|
'predicted_class_id': predicted_idx |
|
|
} |
|
|
|
|
|
def predict_batch(self, audio_paths: list) -> list: |
|
|
""" |
|
|
Predict classes for multiple audio files |
|
|
|
|
|
Args: |
|
|
audio_paths: List of paths to audio files |
|
|
|
|
|
Returns: |
|
|
List of prediction dictionaries |
|
|
""" |
|
|
results = [] |
|
|
for audio_path in audio_paths: |
|
|
try: |
|
|
result = self.predict(audio_path) |
|
|
result['audio_path'] = audio_path |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
print(f"Error processing {audio_path}: {e}") |
|
|
results.append({ |
|
|
'audio_path': audio_path, |
|
|
'error': str(e) |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) < 3: |
|
|
print("Usage: python inference.py <model_path> <audio_path>") |
|
|
print("Example: python inference.py best_model_finetuned.pth underwater_sound.wav") |
|
|
sys.exit(1) |
|
|
|
|
|
model_path = sys.argv[1] |
|
|
audio_path = sys.argv[2] |
|
|
|
|
|
|
|
|
classifier = Marine1Classifier(model_path) |
|
|
|
|
|
|
|
|
print(f"\nProcessing: {audio_path}") |
|
|
result = classifier.predict(audio_path) |
|
|
|
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"Prediction: {result['predicted_class'].replace('_', ' ').title()}") |
|
|
print(f"Confidence: {result['confidence']*100:.2f}%") |
|
|
print(f"\nAll Probabilities:") |
|
|
for class_name, prob in sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True): |
|
|
print(f" {class_name.replace('_', ' ').title():25s}: {prob*100:6.2f}%") |
|
|
print(f"{'='*50}\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|