""" Example Inference Script for LWM-Spectro Model This script demonstrates how to: 1. Load the pre-trained MoE model 2. Load and preprocess a spectrogram 3. Perform inference 4. Interpret results """ import torch import torch.nn.functional as F import numpy as np from PIL import Image import matplotlib.pyplot as plt from pathlib import Path import sys # Add project root to path sys.path.append(str(Path(__file__).parent)) from pretraining.pretrained_model import PretrainedLWM class SpectrogramClassifier: """Wrapper class for easy inference with LWM-Spectro model""" def __init__(self, model_path, device='cuda'): """ Initialize the classifier Args: model_path: Path to the trained model checkpoint (.pth file) device: 'cuda' or 'cpu' """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Load model self.model = self._load_model(model_path) self.model.eval() # Class mapping self.classes = ['LTE', 'WiFi', '5G'] def _load_model(self, model_path): """Load the trained model from checkpoint""" checkpoint = torch.load(model_path, map_location=self.device) # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: state_dict = checkpoint # Initialize model (adjust architecture as needed) model = PretrainedLWM() # or your specific model class # Load state dict model.load_state_dict(state_dict, strict=False) model.to(self.device) return model def load_spectrogram(self, image_path, target_size=(128, 128)): """ Load and preprocess a spectrogram image Args: image_path: Path to spectrogram image file target_size: Target size for resizing (height, width) Returns: Preprocessed tensor ready for inference """ # Load image img = Image.open(image_path).convert('L') # Convert to grayscale # Resize img = img.resize((target_size[1], target_size[0]), Image.BILINEAR) # Convert to numpy array and normalize img_array = np.array(img, dtype=np.float32) / 255.0 # Convert to tensor [1, 1, H, W] tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0) return tensor.to(self.device) def predict(self, spectrogram, return_probs=False): """ Perform inference on a spectrogram Args: spectrogram: Preprocessed spectrogram tensor or path to image file return_probs: If True, return class probabilities along with prediction Returns: If return_probs=False: predicted class name If return_probs=True: (predicted class name, probability dict) """ # Load spectrogram if path is provided if isinstance(spectrogram, (str, Path)): spectrogram = self.load_spectrogram(spectrogram) # Inference with torch.no_grad(): output = self.model(spectrogram) probabilities = F.softmax(output, dim=1) predicted_idx = torch.argmax(probabilities, dim=1).item() predicted_class = self.classes[predicted_idx] if return_probs: prob_dict = { cls: probabilities[0, i].item() for i, cls in enumerate(self.classes) } return predicted_class, prob_dict return predicted_class def predict_batch(self, spectrogram_paths): """ Perform batch inference on multiple spectrograms Args: spectrogram_paths: List of paths to spectrogram images Returns: List of predictions """ predictions = [] for path in spectrogram_paths: pred = self.predict(path) predictions.append(pred) return predictions def visualize_prediction(self, image_path, save_path=None): """ Visualize spectrogram with prediction Args: image_path: Path to spectrogram image save_path: Optional path to save the visualization """ # Load original image for display img = Image.open(image_path) # Get prediction with probabilities pred_class, probs = self.predict(image_path, return_probs=True) # Create visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Display spectrogram ax1.imshow(img, cmap='viridis') ax1.set_title(f'Input Spectrogram\nPredicted: {pred_class}', fontsize=14, fontweight='bold') ax1.axis('off') # Display probability distribution classes = list(probs.keys()) probabilities = list(probs.values()) colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] bars = ax2.barh(classes, probabilities, color=colors) ax2.set_xlabel('Probability', fontsize=12) ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold') ax2.set_xlim(0, 1) # Add probability values on bars for bar, prob in zip(bars, probabilities): width = bar.get_width() ax2.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}', ha='left', va='center', fontsize=11) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Visualization saved to: {save_path}") plt.show() # ============================================================================ # Example Usage # ============================================================================ def example_single_inference(): """Example: Single spectrogram inference""" print("=" * 60) print("Example 1: Single Spectrogram Inference") print("=" * 60) # Initialize classifier model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" classifier = SpectrogramClassifier(model_path, device='cuda') # Single inference image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png" prediction = classifier.predict(image_path) print(f"\nPrediction: {prediction}") # With probabilities pred_class, probs = classifier.predict(image_path, return_probs=True) print(f"\nPredicted Class: {pred_class}") print("\nClass Probabilities:") for cls, prob in probs.items(): print(f" {cls}: {prob:.4f}") def example_batch_inference(): """Example: Batch inference on multiple spectrograms""" print("\n" + "=" * 60) print("Example 2: Batch Inference") print("=" * 60) # Initialize classifier model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" classifier = SpectrogramClassifier(model_path, device='cuda') # Multiple images image_paths = [ "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png", "spectrograms/LTE/QAM16/rate1-2/SNR10dB/sample_0001.png", "spectrograms/WiFi/QAM64/rate3-4/sample_0001.png", ] # Batch prediction predictions = classifier.predict_batch(image_paths) print("\nBatch Predictions:") for path, pred in zip(image_paths, predictions): print(f" {Path(path).name}: {pred}") def example_visualization(): """Example: Visualize prediction with probabilities""" print("\n" + "=" * 60) print("Example 3: Prediction Visualization") print("=" * 60) # Initialize classifier model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" classifier = SpectrogramClassifier(model_path, device='cuda') # Visualize prediction image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png" classifier.visualize_prediction(image_path, save_path="prediction_result.png") def example_custom_preprocessing(): """Example: Custom preprocessing and inference""" print("\n" + "=" * 60) print("Example 4: Custom Preprocessing") print("=" * 60) # Initialize classifier model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" classifier = SpectrogramClassifier(model_path, device='cuda') # Load and custom preprocess img = Image.open("spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png") img_array = np.array(img.convert('L'), dtype=np.float32) / 255.0 # Apply custom transformations (example: add noise) noise = np.random.normal(0, 0.01, img_array.shape) img_array_noisy = np.clip(img_array + noise, 0, 1) # Convert to tensor tensor = torch.from_numpy(img_array_noisy).unsqueeze(0).unsqueeze(0) tensor = tensor.to(classifier.device) # Predict prediction = classifier.predict(tensor) print(f"\nPrediction on noisy image: {prediction}") def example_error_analysis(): """Example: Analyze predictions across different SNR levels""" print("\n" + "=" * 60) print("Example 5: SNR-based Error Analysis") print("=" * 60) # Initialize classifier model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" classifier = SpectrogramClassifier(model_path, device='cuda') # Test across different SNR levels snr_levels = ['SNR-5dB', 'SNR0dB', 'SNR5dB', 'SNR10dB', 'SNR15dB', 'SNR20dB', 'SNR25dB'] base_path = Path("spectrograms/5G/QPSK/rate1-2") print("\nPredictions across SNR levels:") for snr in snr_levels: snr_path = base_path / snr / "sample_0001.png" if snr_path.exists(): pred_class, probs = classifier.predict(str(snr_path), return_probs=True) confidence = max(probs.values()) print(f" {snr}: {pred_class} (confidence: {confidence:.3f})") if __name__ == "__main__": print("\n" + "=" * 60) print("LWM-Spectro Inference Examples") print("=" * 60) try: # Run examples example_single_inference() example_batch_inference() example_visualization() example_custom_preprocessing() example_error_analysis() print("\n" + "=" * 60) print("All examples completed successfully!") print("=" * 60) except FileNotFoundError as e: print(f"\nError: {e}") print("\nNote: Update the file paths in the examples to match your directory structure.") except Exception as e: print(f"\nError: {e}") print("\nPlease ensure:") print(" 1. Model checkpoint exists at specified path") print(" 2. Spectrogram images are available") print(" 3. All dependencies are installed")