File size: 11,285 Bytes
652afc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
"""
Example Inference Script for LWM-Spectro Model
This script demonstrates how to:
1. Load the pre-trained MoE model
2. Load and preprocess a spectrogram
3. Perform inference
4. Interpret results
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import sys
# Add project root to path
sys.path.append(str(Path(__file__).parent))
from pretraining.pretrained_model import PretrainedLWM
class SpectrogramClassifier:
"""Wrapper class for easy inference with LWM-Spectro model"""
def __init__(self, model_path, device='cuda'):
"""
Initialize the classifier
Args:
model_path: Path to the trained model checkpoint (.pth file)
device: 'cuda' or 'cpu'
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load model
self.model = self._load_model(model_path)
self.model.eval()
# Class mapping
self.classes = ['LTE', 'WiFi', '5G']
def _load_model(self, model_path):
"""Load the trained model from checkpoint"""
checkpoint = torch.load(model_path, map_location=self.device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Initialize model (adjust architecture as needed)
model = PretrainedLWM() # or your specific model class
# Load state dict
model.load_state_dict(state_dict, strict=False)
model.to(self.device)
return model
def load_spectrogram(self, image_path, target_size=(128, 128)):
"""
Load and preprocess a spectrogram image
Args:
image_path: Path to spectrogram image file
target_size: Target size for resizing (height, width)
Returns:
Preprocessed tensor ready for inference
"""
# Load image
img = Image.open(image_path).convert('L') # Convert to grayscale
# Resize
img = img.resize((target_size[1], target_size[0]), Image.BILINEAR)
# Convert to numpy array and normalize
img_array = np.array(img, dtype=np.float32) / 255.0
# Convert to tensor [1, 1, H, W]
tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
return tensor.to(self.device)
def predict(self, spectrogram, return_probs=False):
"""
Perform inference on a spectrogram
Args:
spectrogram: Preprocessed spectrogram tensor or path to image file
return_probs: If True, return class probabilities along with prediction
Returns:
If return_probs=False: predicted class name
If return_probs=True: (predicted class name, probability dict)
"""
# Load spectrogram if path is provided
if isinstance(spectrogram, (str, Path)):
spectrogram = self.load_spectrogram(spectrogram)
# Inference
with torch.no_grad():
output = self.model(spectrogram)
probabilities = F.softmax(output, dim=1)
predicted_idx = torch.argmax(probabilities, dim=1).item()
predicted_class = self.classes[predicted_idx]
if return_probs:
prob_dict = {
cls: probabilities[0, i].item()
for i, cls in enumerate(self.classes)
}
return predicted_class, prob_dict
return predicted_class
def predict_batch(self, spectrogram_paths):
"""
Perform batch inference on multiple spectrograms
Args:
spectrogram_paths: List of paths to spectrogram images
Returns:
List of predictions
"""
predictions = []
for path in spectrogram_paths:
pred = self.predict(path)
predictions.append(pred)
return predictions
def visualize_prediction(self, image_path, save_path=None):
"""
Visualize spectrogram with prediction
Args:
image_path: Path to spectrogram image
save_path: Optional path to save the visualization
"""
# Load original image for display
img = Image.open(image_path)
# Get prediction with probabilities
pred_class, probs = self.predict(image_path, return_probs=True)
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Display spectrogram
ax1.imshow(img, cmap='viridis')
ax1.set_title(f'Input Spectrogram\nPredicted: {pred_class}', fontsize=14, fontweight='bold')
ax1.axis('off')
# Display probability distribution
classes = list(probs.keys())
probabilities = list(probs.values())
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
bars = ax2.barh(classes, probabilities, color=colors)
ax2.set_xlabel('Probability', fontsize=12)
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
ax2.set_xlim(0, 1)
# Add probability values on bars
for bar, prob in zip(bars, probabilities):
width = bar.get_width()
ax2.text(width, bar.get_y() + bar.get_height()/2,
f'{prob:.3f}', ha='left', va='center', fontsize=11)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Visualization saved to: {save_path}")
plt.show()
# ============================================================================
# Example Usage
# ============================================================================
def example_single_inference():
"""Example: Single spectrogram inference"""
print("=" * 60)
print("Example 1: Single Spectrogram Inference")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Single inference
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
prediction = classifier.predict(image_path)
print(f"\nPrediction: {prediction}")
# With probabilities
pred_class, probs = classifier.predict(image_path, return_probs=True)
print(f"\nPredicted Class: {pred_class}")
print("\nClass Probabilities:")
for cls, prob in probs.items():
print(f" {cls}: {prob:.4f}")
def example_batch_inference():
"""Example: Batch inference on multiple spectrograms"""
print("\n" + "=" * 60)
print("Example 2: Batch Inference")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Multiple images
image_paths = [
"spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png",
"spectrograms/LTE/QAM16/rate1-2/SNR10dB/sample_0001.png",
"spectrograms/WiFi/QAM64/rate3-4/sample_0001.png",
]
# Batch prediction
predictions = classifier.predict_batch(image_paths)
print("\nBatch Predictions:")
for path, pred in zip(image_paths, predictions):
print(f" {Path(path).name}: {pred}")
def example_visualization():
"""Example: Visualize prediction with probabilities"""
print("\n" + "=" * 60)
print("Example 3: Prediction Visualization")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Visualize prediction
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
classifier.visualize_prediction(image_path, save_path="prediction_result.png")
def example_custom_preprocessing():
"""Example: Custom preprocessing and inference"""
print("\n" + "=" * 60)
print("Example 4: Custom Preprocessing")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Load and custom preprocess
img = Image.open("spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png")
img_array = np.array(img.convert('L'), dtype=np.float32) / 255.0
# Apply custom transformations (example: add noise)
noise = np.random.normal(0, 0.01, img_array.shape)
img_array_noisy = np.clip(img_array + noise, 0, 1)
# Convert to tensor
tensor = torch.from_numpy(img_array_noisy).unsqueeze(0).unsqueeze(0)
tensor = tensor.to(classifier.device)
# Predict
prediction = classifier.predict(tensor)
print(f"\nPrediction on noisy image: {prediction}")
def example_error_analysis():
"""Example: Analyze predictions across different SNR levels"""
print("\n" + "=" * 60)
print("Example 5: SNR-based Error Analysis")
print("=" * 60)
# Initialize classifier
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
classifier = SpectrogramClassifier(model_path, device='cuda')
# Test across different SNR levels
snr_levels = ['SNR-5dB', 'SNR0dB', 'SNR5dB', 'SNR10dB', 'SNR15dB', 'SNR20dB', 'SNR25dB']
base_path = Path("spectrograms/5G/QPSK/rate1-2")
print("\nPredictions across SNR levels:")
for snr in snr_levels:
snr_path = base_path / snr / "sample_0001.png"
if snr_path.exists():
pred_class, probs = classifier.predict(str(snr_path), return_probs=True)
confidence = max(probs.values())
print(f" {snr}: {pred_class} (confidence: {confidence:.3f})")
if __name__ == "__main__":
print("\n" + "=" * 60)
print("LWM-Spectro Inference Examples")
print("=" * 60)
try:
# Run examples
example_single_inference()
example_batch_inference()
example_visualization()
example_custom_preprocessing()
example_error_analysis()
print("\n" + "=" * 60)
print("All examples completed successfully!")
print("=" * 60)
except FileNotFoundError as e:
print(f"\nError: {e}")
print("\nNote: Update the file paths in the examples to match your directory structure.")
except Exception as e:
print(f"\nError: {e}")
print("\nPlease ensure:")
print(" 1. Model checkpoint exists at specified path")
print(" 2. Spectrogram images are available")
print(" 3. All dependencies are installed")
|