|
|
""" |
|
|
Example Inference Script for LWM-Spectro Model |
|
|
|
|
|
This script demonstrates how to: |
|
|
1. Load the pre-trained MoE model |
|
|
2. Load and preprocess a spectrogram |
|
|
3. Perform inference |
|
|
4. Interpret results |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
|
|
from pretraining.pretrained_model import PretrainedLWM |
|
|
|
|
|
|
|
|
class SpectrogramClassifier: |
|
|
"""Wrapper class for easy inference with LWM-Spectro model""" |
|
|
|
|
|
def __init__(self, model_path, device='cuda'): |
|
|
""" |
|
|
Initialize the classifier |
|
|
|
|
|
Args: |
|
|
model_path: Path to the trained model checkpoint (.pth file) |
|
|
device: 'cuda' or 'cpu' |
|
|
""" |
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.model = self._load_model(model_path) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.classes = ['LTE', 'WiFi', '5G'] |
|
|
|
|
|
def _load_model(self, model_path): |
|
|
"""Load the trained model from checkpoint""" |
|
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'model_state_dict' in checkpoint: |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
elif 'state_dict' in checkpoint: |
|
|
state_dict = checkpoint['state_dict'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
model = PretrainedLWM() |
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.to(self.device) |
|
|
|
|
|
return model |
|
|
|
|
|
def load_spectrogram(self, image_path, target_size=(128, 128)): |
|
|
""" |
|
|
Load and preprocess a spectrogram image |
|
|
|
|
|
Args: |
|
|
image_path: Path to spectrogram image file |
|
|
target_size: Target size for resizing (height, width) |
|
|
|
|
|
Returns: |
|
|
Preprocessed tensor ready for inference |
|
|
""" |
|
|
|
|
|
img = Image.open(image_path).convert('L') |
|
|
|
|
|
|
|
|
img = img.resize((target_size[1], target_size[0]), Image.BILINEAR) |
|
|
|
|
|
|
|
|
img_array = np.array(img, dtype=np.float32) / 255.0 |
|
|
|
|
|
|
|
|
tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
return tensor.to(self.device) |
|
|
|
|
|
def predict(self, spectrogram, return_probs=False): |
|
|
""" |
|
|
Perform inference on a spectrogram |
|
|
|
|
|
Args: |
|
|
spectrogram: Preprocessed spectrogram tensor or path to image file |
|
|
return_probs: If True, return class probabilities along with prediction |
|
|
|
|
|
Returns: |
|
|
If return_probs=False: predicted class name |
|
|
If return_probs=True: (predicted class name, probability dict) |
|
|
""" |
|
|
|
|
|
if isinstance(spectrogram, (str, Path)): |
|
|
spectrogram = self.load_spectrogram(spectrogram) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(spectrogram) |
|
|
probabilities = F.softmax(output, dim=1) |
|
|
predicted_idx = torch.argmax(probabilities, dim=1).item() |
|
|
|
|
|
predicted_class = self.classes[predicted_idx] |
|
|
|
|
|
if return_probs: |
|
|
prob_dict = { |
|
|
cls: probabilities[0, i].item() |
|
|
for i, cls in enumerate(self.classes) |
|
|
} |
|
|
return predicted_class, prob_dict |
|
|
|
|
|
return predicted_class |
|
|
|
|
|
def predict_batch(self, spectrogram_paths): |
|
|
""" |
|
|
Perform batch inference on multiple spectrograms |
|
|
|
|
|
Args: |
|
|
spectrogram_paths: List of paths to spectrogram images |
|
|
|
|
|
Returns: |
|
|
List of predictions |
|
|
""" |
|
|
predictions = [] |
|
|
for path in spectrogram_paths: |
|
|
pred = self.predict(path) |
|
|
predictions.append(pred) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def visualize_prediction(self, image_path, save_path=None): |
|
|
""" |
|
|
Visualize spectrogram with prediction |
|
|
|
|
|
Args: |
|
|
image_path: Path to spectrogram image |
|
|
save_path: Optional path to save the visualization |
|
|
""" |
|
|
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
|
|
|
pred_class, probs = self.predict(image_path, return_probs=True) |
|
|
|
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
|
|
|
ax1.imshow(img, cmap='viridis') |
|
|
ax1.set_title(f'Input Spectrogram\nPredicted: {pred_class}', fontsize=14, fontweight='bold') |
|
|
ax1.axis('off') |
|
|
|
|
|
|
|
|
classes = list(probs.keys()) |
|
|
probabilities = list(probs.values()) |
|
|
colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] |
|
|
|
|
|
bars = ax2.barh(classes, probabilities, color=colors) |
|
|
ax2.set_xlabel('Probability', fontsize=12) |
|
|
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold') |
|
|
ax2.set_xlim(0, 1) |
|
|
|
|
|
|
|
|
for bar, prob in zip(bars, probabilities): |
|
|
width = bar.get_width() |
|
|
ax2.text(width, bar.get_y() + bar.get_height()/2, |
|
|
f'{prob:.3f}', ha='left', va='center', fontsize=11) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
print(f"Visualization saved to: {save_path}") |
|
|
|
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def example_single_inference(): |
|
|
"""Example: Single spectrogram inference""" |
|
|
print("=" * 60) |
|
|
print("Example 1: Single Spectrogram Inference") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" |
|
|
classifier = SpectrogramClassifier(model_path, device='cuda') |
|
|
|
|
|
|
|
|
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png" |
|
|
prediction = classifier.predict(image_path) |
|
|
print(f"\nPrediction: {prediction}") |
|
|
|
|
|
|
|
|
pred_class, probs = classifier.predict(image_path, return_probs=True) |
|
|
print(f"\nPredicted Class: {pred_class}") |
|
|
print("\nClass Probabilities:") |
|
|
for cls, prob in probs.items(): |
|
|
print(f" {cls}: {prob:.4f}") |
|
|
|
|
|
|
|
|
def example_batch_inference(): |
|
|
"""Example: Batch inference on multiple spectrograms""" |
|
|
print("\n" + "=" * 60) |
|
|
print("Example 2: Batch Inference") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" |
|
|
classifier = SpectrogramClassifier(model_path, device='cuda') |
|
|
|
|
|
|
|
|
image_paths = [ |
|
|
"spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png", |
|
|
"spectrograms/LTE/QAM16/rate1-2/SNR10dB/sample_0001.png", |
|
|
"spectrograms/WiFi/QAM64/rate3-4/sample_0001.png", |
|
|
] |
|
|
|
|
|
|
|
|
predictions = classifier.predict_batch(image_paths) |
|
|
|
|
|
print("\nBatch Predictions:") |
|
|
for path, pred in zip(image_paths, predictions): |
|
|
print(f" {Path(path).name}: {pred}") |
|
|
|
|
|
|
|
|
def example_visualization(): |
|
|
"""Example: Visualize prediction with probabilities""" |
|
|
print("\n" + "=" * 60) |
|
|
print("Example 3: Prediction Visualization") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" |
|
|
classifier = SpectrogramClassifier(model_path, device='cuda') |
|
|
|
|
|
|
|
|
image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png" |
|
|
classifier.visualize_prediction(image_path, save_path="prediction_result.png") |
|
|
|
|
|
|
|
|
def example_custom_preprocessing(): |
|
|
"""Example: Custom preprocessing and inference""" |
|
|
print("\n" + "=" * 60) |
|
|
print("Example 4: Custom Preprocessing") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" |
|
|
classifier = SpectrogramClassifier(model_path, device='cuda') |
|
|
|
|
|
|
|
|
img = Image.open("spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png") |
|
|
img_array = np.array(img.convert('L'), dtype=np.float32) / 255.0 |
|
|
|
|
|
|
|
|
noise = np.random.normal(0, 0.01, img_array.shape) |
|
|
img_array_noisy = np.clip(img_array + noise, 0, 1) |
|
|
|
|
|
|
|
|
tensor = torch.from_numpy(img_array_noisy).unsqueeze(0).unsqueeze(0) |
|
|
tensor = tensor.to(classifier.device) |
|
|
|
|
|
|
|
|
prediction = classifier.predict(tensor) |
|
|
print(f"\nPrediction on noisy image: {prediction}") |
|
|
|
|
|
|
|
|
def example_error_analysis(): |
|
|
"""Example: Analyze predictions across different SNR levels""" |
|
|
print("\n" + "=" * 60) |
|
|
print("Example 5: SNR-based Error Analysis") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model_path = "mixture/runs/embedding_router/moe_checkpoint.pth" |
|
|
classifier = SpectrogramClassifier(model_path, device='cuda') |
|
|
|
|
|
|
|
|
snr_levels = ['SNR-5dB', 'SNR0dB', 'SNR5dB', 'SNR10dB', 'SNR15dB', 'SNR20dB', 'SNR25dB'] |
|
|
base_path = Path("spectrograms/5G/QPSK/rate1-2") |
|
|
|
|
|
print("\nPredictions across SNR levels:") |
|
|
for snr in snr_levels: |
|
|
snr_path = base_path / snr / "sample_0001.png" |
|
|
if snr_path.exists(): |
|
|
pred_class, probs = classifier.predict(str(snr_path), return_probs=True) |
|
|
confidence = max(probs.values()) |
|
|
print(f" {snr}: {pred_class} (confidence: {confidence:.3f})") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "=" * 60) |
|
|
print("LWM-Spectro Inference Examples") |
|
|
print("=" * 60) |
|
|
|
|
|
try: |
|
|
|
|
|
example_single_inference() |
|
|
example_batch_inference() |
|
|
example_visualization() |
|
|
example_custom_preprocessing() |
|
|
example_error_analysis() |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("All examples completed successfully!") |
|
|
print("=" * 60) |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
print(f"\nError: {e}") |
|
|
print("\nNote: Update the file paths in the examples to match your directory structure.") |
|
|
except Exception as e: |
|
|
print(f"\nError: {e}") |
|
|
print("\nPlease ensure:") |
|
|
print(" 1. Model checkpoint exists at specified path") |
|
|
print(" 2. Spectrogram images are available") |
|
|
print(" 3. All dependencies are installed") |
|
|
|