wi-lab commited on
Commit
652afc6
·
verified ·
1 Parent(s): 063084c

Upload example_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_inference.py +334 -0
example_inference.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example Inference Script for LWM-Spectro Model
3
+
4
+ This script demonstrates how to:
5
+ 1. Load the pre-trained MoE model
6
+ 2. Load and preprocess a spectrogram
7
+ 3. Perform inference
8
+ 4. Interpret results
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ from pathlib import Path
17
+ import sys
18
+
19
+ # Add project root to path
20
+ sys.path.append(str(Path(__file__).parent))
21
+
22
+ from pretraining.pretrained_model import PretrainedLWM
23
+
24
+
25
+ class SpectrogramClassifier:
26
+ """Wrapper class for easy inference with LWM-Spectro model"""
27
+
28
+ def __init__(self, model_path, device='cuda'):
29
+ """
30
+ Initialize the classifier
31
+
32
+ Args:
33
+ model_path: Path to the trained model checkpoint (.pth file)
34
+ device: 'cuda' or 'cpu'
35
+ """
36
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
37
+ print(f"Using device: {self.device}")
38
+
39
+ # Load model
40
+ self.model = self._load_model(model_path)
41
+ self.model.eval()
42
+
43
+ # Class mapping
44
+ self.classes = ['LTE', 'WiFi', '5G']
45
+
46
+ def _load_model(self, model_path):
47
+ """Load the trained model from checkpoint"""
48
+ checkpoint = torch.load(model_path, map_location=self.device)
49
+
50
+ # Handle different checkpoint formats
51
+ if isinstance(checkpoint, dict):
52
+ if 'model_state_dict' in checkpoint:
53
+ state_dict = checkpoint['model_state_dict']
54
+ elif 'state_dict' in checkpoint:
55
+ state_dict = checkpoint['state_dict']
56
+ else:
57
+ state_dict = checkpoint
58
+ else:
59
+ state_dict = checkpoint
60
+
61
+ # Initialize model (adjust architecture as needed)
62
+ model = PretrainedLWM() # or your specific model class
63
+
64
+ # Load state dict
65
+ model.load_state_dict(state_dict, strict=False)
66
+ model.to(self.device)
67
+
68
+ return model
69
+
70
+ def load_spectrogram(self, image_path, target_size=(128, 128)):
71
+ """
72
+ Load and preprocess a spectrogram image
73
+
74
+ Args:
75
+ image_path: Path to spectrogram image file
76
+ target_size: Target size for resizing (height, width)
77
+
78
+ Returns:
79
+ Preprocessed tensor ready for inference
80
+ """
81
+ # Load image
82
+ img = Image.open(image_path).convert('L') # Convert to grayscale
83
+
84
+ # Resize
85
+ img = img.resize((target_size[1], target_size[0]), Image.BILINEAR)
86
+
87
+ # Convert to numpy array and normalize
88
+ img_array = np.array(img, dtype=np.float32) / 255.0
89
+
90
+ # Convert to tensor [1, 1, H, W]
91
+ tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
92
+
93
+ return tensor.to(self.device)
94
+
95
+ def predict(self, spectrogram, return_probs=False):
96
+ """
97
+ Perform inference on a spectrogram
98
+
99
+ Args:
100
+ spectrogram: Preprocessed spectrogram tensor or path to image file
101
+ return_probs: If True, return class probabilities along with prediction
102
+
103
+ Returns:
104
+ If return_probs=False: predicted class name
105
+ If return_probs=True: (predicted class name, probability dict)
106
+ """
107
+ # Load spectrogram if path is provided
108
+ if isinstance(spectrogram, (str, Path)):
109
+ spectrogram = self.load_spectrogram(spectrogram)
110
+
111
+ # Inference
112
+ with torch.no_grad():
113
+ output = self.model(spectrogram)
114
+ probabilities = F.softmax(output, dim=1)
115
+ predicted_idx = torch.argmax(probabilities, dim=1).item()
116
+
117
+ predicted_class = self.classes[predicted_idx]
118
+
119
+ if return_probs:
120
+ prob_dict = {
121
+ cls: probabilities[0, i].item()
122
+ for i, cls in enumerate(self.classes)
123
+ }
124
+ return predicted_class, prob_dict
125
+
126
+ return predicted_class
127
+
128
+ def predict_batch(self, spectrogram_paths):
129
+ """
130
+ Perform batch inference on multiple spectrograms
131
+
132
+ Args:
133
+ spectrogram_paths: List of paths to spectrogram images
134
+
135
+ Returns:
136
+ List of predictions
137
+ """
138
+ predictions = []
139
+ for path in spectrogram_paths:
140
+ pred = self.predict(path)
141
+ predictions.append(pred)
142
+
143
+ return predictions
144
+
145
+ def visualize_prediction(self, image_path, save_path=None):
146
+ """
147
+ Visualize spectrogram with prediction
148
+
149
+ Args:
150
+ image_path: Path to spectrogram image
151
+ save_path: Optional path to save the visualization
152
+ """
153
+ # Load original image for display
154
+ img = Image.open(image_path)
155
+
156
+ # Get prediction with probabilities
157
+ pred_class, probs = self.predict(image_path, return_probs=True)
158
+
159
+ # Create visualization
160
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
161
+
162
+ # Display spectrogram
163
+ ax1.imshow(img, cmap='viridis')
164
+ ax1.set_title(f'Input Spectrogram\nPredicted: {pred_class}', fontsize=14, fontweight='bold')
165
+ ax1.axis('off')
166
+
167
+ # Display probability distribution
168
+ classes = list(probs.keys())
169
+ probabilities = list(probs.values())
170
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
171
+
172
+ bars = ax2.barh(classes, probabilities, color=colors)
173
+ ax2.set_xlabel('Probability', fontsize=12)
174
+ ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
175
+ ax2.set_xlim(0, 1)
176
+
177
+ # Add probability values on bars
178
+ for bar, prob in zip(bars, probabilities):
179
+ width = bar.get_width()
180
+ ax2.text(width, bar.get_y() + bar.get_height()/2,
181
+ f'{prob:.3f}', ha='left', va='center', fontsize=11)
182
+
183
+ plt.tight_layout()
184
+
185
+ if save_path:
186
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
187
+ print(f"Visualization saved to: {save_path}")
188
+
189
+ plt.show()
190
+
191
+
192
+ # ============================================================================
193
+ # Example Usage
194
+ # ============================================================================
195
+
196
+ def example_single_inference():
197
+ """Example: Single spectrogram inference"""
198
+ print("=" * 60)
199
+ print("Example 1: Single Spectrogram Inference")
200
+ print("=" * 60)
201
+
202
+ # Initialize classifier
203
+ model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
204
+ classifier = SpectrogramClassifier(model_path, device='cuda')
205
+
206
+ # Single inference
207
+ image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
208
+ prediction = classifier.predict(image_path)
209
+ print(f"\nPrediction: {prediction}")
210
+
211
+ # With probabilities
212
+ pred_class, probs = classifier.predict(image_path, return_probs=True)
213
+ print(f"\nPredicted Class: {pred_class}")
214
+ print("\nClass Probabilities:")
215
+ for cls, prob in probs.items():
216
+ print(f" {cls}: {prob:.4f}")
217
+
218
+
219
+ def example_batch_inference():
220
+ """Example: Batch inference on multiple spectrograms"""
221
+ print("\n" + "=" * 60)
222
+ print("Example 2: Batch Inference")
223
+ print("=" * 60)
224
+
225
+ # Initialize classifier
226
+ model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
227
+ classifier = SpectrogramClassifier(model_path, device='cuda')
228
+
229
+ # Multiple images
230
+ image_paths = [
231
+ "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png",
232
+ "spectrograms/LTE/QAM16/rate1-2/SNR10dB/sample_0001.png",
233
+ "spectrograms/WiFi/QAM64/rate3-4/sample_0001.png",
234
+ ]
235
+
236
+ # Batch prediction
237
+ predictions = classifier.predict_batch(image_paths)
238
+
239
+ print("\nBatch Predictions:")
240
+ for path, pred in zip(image_paths, predictions):
241
+ print(f" {Path(path).name}: {pred}")
242
+
243
+
244
+ def example_visualization():
245
+ """Example: Visualize prediction with probabilities"""
246
+ print("\n" + "=" * 60)
247
+ print("Example 3: Prediction Visualization")
248
+ print("=" * 60)
249
+
250
+ # Initialize classifier
251
+ model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
252
+ classifier = SpectrogramClassifier(model_path, device='cuda')
253
+
254
+ # Visualize prediction
255
+ image_path = "spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png"
256
+ classifier.visualize_prediction(image_path, save_path="prediction_result.png")
257
+
258
+
259
+ def example_custom_preprocessing():
260
+ """Example: Custom preprocessing and inference"""
261
+ print("\n" + "=" * 60)
262
+ print("Example 4: Custom Preprocessing")
263
+ print("=" * 60)
264
+
265
+ # Initialize classifier
266
+ model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
267
+ classifier = SpectrogramClassifier(model_path, device='cuda')
268
+
269
+ # Load and custom preprocess
270
+ img = Image.open("spectrograms/5G/QPSK/rate1-2/SNR10dB/sample_0001.png")
271
+ img_array = np.array(img.convert('L'), dtype=np.float32) / 255.0
272
+
273
+ # Apply custom transformations (example: add noise)
274
+ noise = np.random.normal(0, 0.01, img_array.shape)
275
+ img_array_noisy = np.clip(img_array + noise, 0, 1)
276
+
277
+ # Convert to tensor
278
+ tensor = torch.from_numpy(img_array_noisy).unsqueeze(0).unsqueeze(0)
279
+ tensor = tensor.to(classifier.device)
280
+
281
+ # Predict
282
+ prediction = classifier.predict(tensor)
283
+ print(f"\nPrediction on noisy image: {prediction}")
284
+
285
+
286
+ def example_error_analysis():
287
+ """Example: Analyze predictions across different SNR levels"""
288
+ print("\n" + "=" * 60)
289
+ print("Example 5: SNR-based Error Analysis")
290
+ print("=" * 60)
291
+
292
+ # Initialize classifier
293
+ model_path = "mixture/runs/embedding_router/moe_checkpoint.pth"
294
+ classifier = SpectrogramClassifier(model_path, device='cuda')
295
+
296
+ # Test across different SNR levels
297
+ snr_levels = ['SNR-5dB', 'SNR0dB', 'SNR5dB', 'SNR10dB', 'SNR15dB', 'SNR20dB', 'SNR25dB']
298
+ base_path = Path("spectrograms/5G/QPSK/rate1-2")
299
+
300
+ print("\nPredictions across SNR levels:")
301
+ for snr in snr_levels:
302
+ snr_path = base_path / snr / "sample_0001.png"
303
+ if snr_path.exists():
304
+ pred_class, probs = classifier.predict(str(snr_path), return_probs=True)
305
+ confidence = max(probs.values())
306
+ print(f" {snr}: {pred_class} (confidence: {confidence:.3f})")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ print("\n" + "=" * 60)
311
+ print("LWM-Spectro Inference Examples")
312
+ print("=" * 60)
313
+
314
+ try:
315
+ # Run examples
316
+ example_single_inference()
317
+ example_batch_inference()
318
+ example_visualization()
319
+ example_custom_preprocessing()
320
+ example_error_analysis()
321
+
322
+ print("\n" + "=" * 60)
323
+ print("All examples completed successfully!")
324
+ print("=" * 60)
325
+
326
+ except FileNotFoundError as e:
327
+ print(f"\nError: {e}")
328
+ print("\nNote: Update the file paths in the examples to match your directory structure.")
329
+ except Exception as e:
330
+ print(f"\nError: {e}")
331
+ print("\nPlease ensure:")
332
+ print(" 1. Model checkpoint exists at specified path")
333
+ print(" 2. Spectrogram images are available")
334
+ print(" 3. All dependencies are installed")