lwm-spectro / example_inference.py
wi-lab's picture
Upload example_inference.py with huggingface_hub
652afc6 verified
"""
Example Inference Script for LWM-Spectro Model
This script demonstrates how to:
1. Load the pre-trained MoE model
2. Load and preprocess a spectrogram
3. Perform inference
4. Interpret results
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import sys
# Add project root to path
sys.path.append(str(Path(__file__).parent))
from pretraining.pretrained_model import PretrainedLWM
class SpectrogramClassifier:
"""Wrapper class for easy inference with LWM-Spectro model"""
def __init__(self, model_path, device='cuda'):
"""
Initialize the classifier
Args:
model_path: Path to the trained model checkpoint (.pth file)
device: 'cuda' or 'cpu'
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load model
self.model = self._load_model(model_path)
self.model.eval()
# Class mapping
self.classes = ['LTE', 'WiFi', '5G']
def _load_model(self, model_path):
"""Load the trained model from checkpoint"""
checkpoint = torch.load(model_path, map_location=self.device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Initialize model (adjust architecture as needed)
model = PretrainedLWM() # or your specific model class
# Load state dict
model.load_state_dict(state_dict, strict=False)
model.to(self.device)
return model
def load_spectrogram(self, image_path, target_size=(128, 128)):
"""
Load and preprocess a spectrogram image
Args:
image_path: Path to spectrogram image file
target_size: Target size for resizing (height, width)
Returns:
Preprocessed tensor ready for inference
"""
# Load image
img = Image.open(image_path).convert('L') # Convert to grayscale
# Resize
img = img.resize((target_size[1], target_size[0]), Image.BILINEAR)
# Convert to numpy array and normalize
img_array = np.array(img, dtype=np.float32) / 255.0
# Convert to tensor [1, 1, H, W]
tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
return tensor.to(self.device)
def predict(self, spectrogram, return_probs=False):
"""
Perform inference on a spectrogram
Args:
spectrogram: Preprocessed spectrogram tensor or path to image file
return_probs: If True, return class probabilities along with prediction
Returns:
If return_probs=False: predicted class name
If return_probs=True: (predicted class name, probability dict)
"""
# Load spectrogram if path is provided
if isinstance(spectrogram, (str, Path)):
spectrogram = self.load_spectrogram(spectrogram)
# Inference
with torch.no_grad():
output = self.model(spectrogram)
probabilities = F.softmax(output, dim=1)
predicted_idx = torch.argmax(probabilities, dim=1).item()
predicted_class = self.classes[predicted_idx]
if return_probs:
prob_dict = {
cls: probabilities[0, i].item()
for i, cls in enumerate(self.classes)
}
return predicted_class, prob_dict
return predicted_class
def predict_batch(self, spectrogram_paths):
"""
Perform batch inference on multiple spectrograms
Args:
spectrogram_paths: List of paths to spectrogram images
Returns:
List of predictions
"""
predictions = []
for path in spectrogram_paths:
pred = self.predict(path)
predictions.append(pred)
return predictions
def visualize_prediction(self, image_path, save_path=None):
"""
Visualize spectrogram with prediction
Args:
image_path: Path to spectrogram image
save_path: Optional path to save the visualization
"""
# Load original image for display
img = Image.open(image_path)
# Get prediction with probabilities
pred_class, probs = self.predict(image_path, return_probs=True)
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Display spectrogram
ax1.imshow(img, cmap='viridis')
ax1.set_title(f'Input Spectrogram\nPredicted: {pred_class}', fontsize=14, fontweight='bold')
ax1.axis('off')
# Display probability distribution
classes = list(probs.keys())
probabilities = list(probs.values())
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
bars = ax2.barh(classes, probabilities, color=colors)
ax2.set_xlabel('Probability', fontsize=12)
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
ax2.set_xlim(0, 1)
# Add probability values on bars
for bar, prob in zip(bars, probabilities):
width = bar.get_width()
ax2.text(width, bar.get_y() + bar.get_height()/2,
f'{prob:.3f}', ha='left', va='center', fontsize=11)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Visualization saved to: {save_path}")
plt.show()
# ============================================================================
# Example Usage
# ============================================================================
def example_single_inference():
"""Example: Single spectrogram inference"""
print("=" * 60)
print("Example 1: Single Spectrogram Inference")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Single inference
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
prediction = classifier.predict(image_path)
print(f"\nPrediction: {prediction}")
# With probabilities
pred_class, probs = classifier.predict(image_path, return_probs=True)
print(f"\nPredicted Class: {pred_class}")
print("\nClass Probabilities:")
for cls, prob in probs.items():
print(f" {cls}: {prob:.4f}")
def example_batch_inference():
"""Example: Batch inference on multiple spectrograms"""
print("\n" + "=" * 60)
print("Example 2: Batch Inference")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Multiple images
image_paths = [
"spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png",
"spectrograms/LTE/QAM16/rate1-2/SNR10dB/sample_0001.png",
"spectrograms/WiFi/QAM64/rate3-4/sample_0001.png",
]
# Batch prediction
predictions = classifier.predict_batch(image_paths)
print("\nBatch Predictions:")
for path, pred in zip(image_paths, predictions):
print(f" {Path(path).name}: {pred}")
def example_visualization():
"""Example: Visualize prediction with probabilities"""
print("\n" + "=" * 60)
print("Example 3: Prediction Visualization")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Visualize prediction
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
classifier.visualize_prediction(image_path, save_path="prediction_result.png")
def example_custom_preprocessing():
"""Example: Custom preprocessing and inference"""
print("\n" + "=" * 60)
print("Example 4: Custom Preprocessing")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Load and custom preprocess
img = Image.open("spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png")
img_array = np.array(img.convert('L'), dtype=np.float32) / 255.0
# Apply custom transformations (example: add noise)
noise = np.random.normal(0, 0.01, img_array.shape)
img_array_noisy = np.clip(img_array + noise, 0, 1)
# Convert to tensor
tensor = torch.from_numpy(img_array_noisy).unsqueeze(0).unsqueeze(0)
tensor = tensor.to(classifier.device)
# Predict
prediction = classifier.predict(tensor)
print(f"\nPrediction on noisy image: {prediction}")
def example_error_analysis():
"""Example: Analyze predictions across different SNR levels"""
print("\n" + "=" * 60)
print("Example 5: SNR-based Error Analysis")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Test across different SNR levels
snr_levels = ['SNR-5dB', 'SNR0dB', 'SNR5dB', 'SNR10dB', 'SNR15dB', 'SNR20dB', 'SNR25dB']
base_path = Path("spectrograms/5G/QPSK/rate1-2")
print("\nPredictions across SNR levels:")
for snr in snr_levels:
snr_path = base_path / snr / "sample_0001.png"
if snr_path.exists():
pred_class, probs = classifier.predict(str(snr_path), return_probs=True)
confidence = max(probs.values())
print(f" {snr}: {pred_class} (confidence: {confidence:.3f})")
if __name__ == "__main__":
print("\n" + "=" * 60)
print("LWM-Spectro Inference Examples")
print("=" * 60)
try:
# Run examples
example_single_inference()
example_batch_inference()
example_visualization()
example_custom_preprocessing()
example_error_analysis()
print("\n" + "=" * 60)
print("All examples completed successfully!")
print("=" * 60)
except FileNotFoundError as e:
print(f"\nError: {e}")
print("\nNote: Update the file paths in the examples to match your directory structure.")
except Exception as e:
print(f"\nError: {e}")
print("\nPlease ensure:")
print(" 1. Model checkpoint exists at specified path")
print(" 2. Spectrogram images are available")
print(" 3. All dependencies are installed")