Marine1 / inference.py
shiv207's picture
Add safetensors models (secure format) and update documentation
0195699
"""
Marine1 Underwater Acoustic Classifier - Inference Script
Simple example for using the model with Hugging Face
Supports both .pth (pickle) and .safetensors formats
"""
import torch
import librosa
import numpy as np
from typing import Dict, Tuple
import warnings
warnings.filterwarnings('ignore')
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not installed. Install with: pip install safetensors")
class Marine1Classifier:
"""Underwater acoustic classifier using Marine1 model"""
def __init__(self, model_path: str, device: str = None):
"""
Initialize the classifier
Args:
model_path: Path to the model file (.pth or .safetensors)
device: Device to run on ('cuda', 'cpu', or 'mps'). Auto-detected if None.
"""
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Determine file format
is_safetensors = model_path.endswith('.safetensors')
if is_safetensors:
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors not installed. Install with: pip install safetensors")
print(f"Loading safetensors model (secure format)...")
# Load safetensors
state_dict = load_file(model_path, device=str(self.device))
# Parse metadata
from safetensors import safe_open
with safe_open(model_path, framework="pt", device=str(self.device)) as f:
metadata = f.metadata()
# Get class mapping from metadata
import ast
self.class_to_id = ast.literal_eval(metadata.get('class_to_id', "{}"))
if not self.class_to_id:
# Default mapping
self.class_to_id = {
'vessel': 0, 'marine_animal': 1,
'natural_sound': 2, 'other_anthropogenic': 3
}
else:
print(f"Loading PyTorch model (.pth format)...")
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
# Get class mapping
self.class_to_id = checkpoint['class_to_id']
state_dict = checkpoint['model_state_dict']
self.id_to_class = {v: k for k, v in self.class_to_id.items()}
self.class_names = [self.id_to_class[i] for i in range(len(self.id_to_class))]
# Load model architecture (custom fine-tuned ResNet18)
self.model = self._create_model_architecture(len(self.class_names))
# Load weights
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
format_type = "safetensors (secure)" if is_safetensors else "PyTorch (.pth)"
print(f"✅ Model loaded successfully ({format_type})")
print(f" Classes: {len(self.class_names)}")
def _create_model_architecture(self, num_classes: int):
"""Create the model architecture matching the trained model"""
import torch.nn as nn
from torchvision import models
class LightweightFineTuned(nn.Module):
def __init__(self, num_classes=4):
super(LightweightFineTuned, self).__init__()
resnet = models.resnet18(weights=None)
# Adapt first layer for grayscale spectrograms
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, num_classes)
)
self.confidence_head = nn.Sequential(
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, x, return_confidence=False):
if len(x.shape) == 3:
x = x.unsqueeze(1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
features = torch.flatten(x, 1)
logits = self.classifier(features)
if return_confidence:
confidence = self.confidence_head(features)
return logits, confidence
return logits
return LightweightFineTuned(num_classes=num_classes)
def process_audio(self, audio_path: str, sr: int = 16000, duration: float = 10.0) -> np.ndarray:
"""
Process audio file into mel spectrogram
Args:
audio_path: Path to audio file
sr: Target sample rate
duration: Maximum duration in seconds
Returns:
Log mel spectrogram as numpy array
"""
# Load audio
y, _ = librosa.load(audio_path, sr=sr, duration=duration)
# Pad if too short
target_length = int(sr * duration)
if len(y) < target_length:
y = np.pad(y, (0, target_length - len(y)), mode='constant')
# Create mel spectrogram
mel_spec = librosa.feature.melspectrogram(
y=y,
sr=sr,
n_mels=128,
n_fft=2048,
hop_length=512,
fmax=8000
)
# Convert to log scale
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
return log_mel_spec
def predict(self, audio_path: str) -> Dict[str, any]:
"""
Predict the class of an audio file
Args:
audio_path: Path to audio file
Returns:
Dictionary with prediction results
"""
# Process audio
log_mel_spec = self.process_audio(audio_path)
# Prepare input tensor
input_tensor = torch.FloatTensor(log_mel_spec).unsqueeze(0).unsqueeze(0).to(self.device)
# Predict
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
# Get results
predicted_idx = probabilities.argmax().item()
predicted_class = self.class_names[predicted_idx]
confidence = probabilities[predicted_idx].item()
# Get all class probabilities
all_probs = {
self.class_names[i]: probabilities[i].item()
for i in range(len(self.class_names))
}
return {
'predicted_class': predicted_class,
'confidence': confidence,
'all_probabilities': all_probs,
'predicted_class_id': predicted_idx
}
def predict_batch(self, audio_paths: list) -> list:
"""
Predict classes for multiple audio files
Args:
audio_paths: List of paths to audio files
Returns:
List of prediction dictionaries
"""
results = []
for audio_path in audio_paths:
try:
result = self.predict(audio_path)
result['audio_path'] = audio_path
results.append(result)
except Exception as e:
print(f"Error processing {audio_path}: {e}")
results.append({
'audio_path': audio_path,
'error': str(e)
})
return results
def main():
"""Example usage"""
import sys
if len(sys.argv) < 3:
print("Usage: python inference.py <model_path> <audio_path>")
print("Example: python inference.py best_model_finetuned.pth underwater_sound.wav")
sys.exit(1)
model_path = sys.argv[1]
audio_path = sys.argv[2]
# Initialize classifier
classifier = Marine1Classifier(model_path)
# Make prediction
print(f"\nProcessing: {audio_path}")
result = classifier.predict(audio_path)
# Display results
print(f"\n{'='*50}")
print(f"Prediction: {result['predicted_class'].replace('_', ' ').title()}")
print(f"Confidence: {result['confidence']*100:.2f}%")
print(f"\nAll Probabilities:")
for class_name, prob in sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True):
print(f" {class_name.replace('_', ' ').title():25s}: {prob*100:6.2f}%")
print(f"{'='*50}\n")
if __name__ == "__main__":
main()