wi-lab commited on
Commit
6e07ee1
·
verified ·
1 Parent(s): 9602105

Upload mixture/run_moe_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mixture/run_moe_inference.py +296 -0
mixture/run_moe_inference.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run inference using trained Task 1 (Modulation) and Task 2 (SNR/Mobility) MoE models.
3
+
4
+ This script loads two separate MoE checkpoints and performs predictions on input spectrograms:
5
+ - Task 1 MoE: Predicts modulation scheme (QPSK, 16QAM, 64QAM, etc.)
6
+ - Task 2 MoE: Predicts joint SNR and mobility class
7
+
8
+ Usage:
9
+ python -m mixture.run_moe_inference \\
10
+ --task1-checkpoint mixture/runs/task1_moe/moe_checkpoint.pth \\
11
+ --task2-checkpoint mixture/runs/task2_moe/moe_checkpoint.pth \\
12
+ --input spectrograms/city_1_losangeles/LTE/snr_0/pedestrian/QPSK/fft_512_overlap_256/specs_0000.pkl \\
13
+ --index 0
14
+
15
+ Or run on a batch of samples:
16
+ python -m mixture.run_moe_inference \\
17
+ --task1-checkpoint mixture/runs/task1_moe/moe_checkpoint.pth \\
18
+ --task2-checkpoint mixture/runs/task2_moe/moe_checkpoint.pth \\
19
+ --input spectrograms/city_1_losangeles/LTE/snr_0/pedestrian/QPSK/fft_512_overlap_256/specs_0000.pkl \\
20
+ --batch-size 32
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import sys
28
+ from pathlib import Path
29
+ from typing import Optional
30
+
31
+ import numpy as np
32
+ import torch
33
+
34
+ REPO_ROOT = Path(__file__).resolve().parent.parent
35
+ sys.path.append(str(REPO_ROOT))
36
+
37
+ from mixture.train_embedding_router import MoEPredictor, load_all_samples # type: ignore
38
+
39
+
40
+ def load_spectrogram_sample(file_path: Path, index: Optional[int] = None) -> torch.Tensor:
41
+ """Load spectrogram(s) from pickle file.
42
+
43
+ Args:
44
+ file_path: Path to pickle file containing spectrograms
45
+ index: If specified, return single sample at this index. Otherwise return all.
46
+
47
+ Returns:
48
+ Tensor of shape [H, W] (single) or [N, H, W] (batch)
49
+ """
50
+ specs = load_all_samples(str(file_path))
51
+
52
+ if index is not None:
53
+ if index < 0 or index >= specs.shape[0]:
54
+ raise IndexError(f"Index {index} out of range for file with {specs.shape[0]} samples")
55
+ return torch.from_numpy(specs[index]).float()
56
+
57
+ return torch.from_numpy(specs).float()
58
+
59
+
60
+ def format_prediction_output(result: dict, task_name: str) -> str:
61
+ """Format prediction result for console output."""
62
+ lines = [f"\n{task_name} Prediction:"]
63
+ lines.append("-" * 60)
64
+
65
+ if "label" in result:
66
+ lines.append(f" Predicted: {result['label']}")
67
+ lines.append(f" Confidence: {result['confidence']:.4f}")
68
+ elif "labels" in result:
69
+ lines.append(f" Batch size: {len(result['labels'])}")
70
+ lines.append(f" Predictions: {result['labels'][:5]}{'...' if len(result['labels']) > 5 else ''}")
71
+ else:
72
+ lines.append(f" Predicted class: {result['predicted_class']}")
73
+ lines.append(f" Confidence: {result['confidence']:.4f}")
74
+
75
+ if "routing" in result and result["routing"]:
76
+ lines.append("\n Routing Weights:")
77
+ routing = result["routing"]
78
+ if isinstance(routing, list) and len(routing) > 0:
79
+ # Show routing for first sample in batch
80
+ if isinstance(routing[0], list):
81
+ routing = routing[0]
82
+ for expert_info in routing:
83
+ lines.append(f" {expert_info['expert']:20s} ({expert_info['comm']:4s}): {expert_info['weight']:.4f}")
84
+
85
+ return "\n".join(lines)
86
+
87
+
88
+ def main() -> None:
89
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
90
+ parser.add_argument(
91
+ "--task1-checkpoint",
92
+ type=Path,
93
+ required=True,
94
+ help="Path to Task 1 (modulation) MoE checkpoint",
95
+ )
96
+ parser.add_argument(
97
+ "--task2-checkpoint",
98
+ type=Path,
99
+ required=True,
100
+ help="Path to Task 2 (SNR/mobility) MoE checkpoint",
101
+ )
102
+ parser.add_argument(
103
+ "--input",
104
+ type=Path,
105
+ required=True,
106
+ help="Path to input spectrogram pickle file",
107
+ )
108
+ parser.add_argument(
109
+ "--index",
110
+ type=int,
111
+ default=None,
112
+ help="Index of sample to process (default: process all samples in file)",
113
+ )
114
+ parser.add_argument(
115
+ "--batch-size",
116
+ type=int,
117
+ default=None,
118
+ help="If processing multiple samples, batch size for inference (default: process all at once)",
119
+ )
120
+ parser.add_argument(
121
+ "--show-probabilities",
122
+ action="store_true",
123
+ help="Show full class probability distributions",
124
+ )
125
+ parser.add_argument(
126
+ "--show-routing",
127
+ action="store_true",
128
+ help="Show expert routing weights",
129
+ )
130
+ parser.add_argument(
131
+ "--output",
132
+ type=Path,
133
+ default=None,
134
+ help="Optional: save predictions to JSON file",
135
+ )
136
+ parser.add_argument(
137
+ "--device",
138
+ choices=["cuda", "cpu", "auto"],
139
+ default="auto",
140
+ help="Device to use for inference (default: auto-detect)",
141
+ )
142
+ args = parser.parse_args()
143
+
144
+ # Set device
145
+ if args.device == "auto":
146
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ else:
148
+ device = torch.device(args.device)
149
+
150
+ print(f"[INFO] Using device: {device}")
151
+
152
+ # Load MoE models
153
+ print(f"[INFO] Loading Task 1 MoE from {args.task1_checkpoint}")
154
+ task1_predictor = MoEPredictor.from_checkpoint(args.task1_checkpoint, device)
155
+
156
+ print(f"[INFO] Loading Task 2 MoE from {args.task2_checkpoint}")
157
+ task2_predictor = MoEPredictor.from_checkpoint(args.task2_checkpoint, device)
158
+
159
+ # Load input spectrogram(s)
160
+ print(f"[INFO] Loading spectrogram(s) from {args.input}")
161
+ spectrograms = load_spectrogram_sample(args.input, args.index)
162
+
163
+ if spectrograms.dim() == 2:
164
+ print(f"[INFO] Processing single spectrogram of shape {tuple(spectrograms.shape)}")
165
+ num_samples = 1
166
+ else:
167
+ print(f"[INFO] Processing {spectrograms.shape[0]} spectrograms")
168
+ num_samples = spectrograms.shape[0]
169
+
170
+ # Run inference
171
+ results = {"task1": [], "task2": []}
172
+
173
+ if num_samples == 1 or args.batch_size is None:
174
+ # Single inference call
175
+ print("\n" + "="*60)
176
+ print("RUNNING INFERENCE")
177
+ print("="*60)
178
+
179
+ task1_result = task1_predictor.predict(
180
+ spectrograms,
181
+ return_probabilities=args.show_probabilities,
182
+ return_routing=args.show_routing,
183
+ )
184
+ task2_result = task2_predictor.predict(
185
+ spectrograms,
186
+ return_probabilities=args.show_probabilities,
187
+ return_routing=args.show_routing,
188
+ )
189
+
190
+ results["task1"] = [task1_result] if num_samples == 1 else task1_result
191
+ results["task2"] = [task2_result] if num_samples == 1 else task2_result
192
+
193
+ # Print results
194
+ print(format_prediction_output(task1_result, "Task 1 (Modulation)"))
195
+ print(format_prediction_output(task2_result, "Task 2 (SNR/Mobility)"))
196
+
197
+ else:
198
+ # Batch processing
199
+ print("\n" + "="*60)
200
+ print(f"RUNNING BATCH INFERENCE ({args.batch_size} samples at a time)")
201
+ print("="*60)
202
+
203
+ num_batches = (num_samples + args.batch_size - 1) // args.batch_size
204
+
205
+ for batch_idx in range(num_batches):
206
+ start_idx = batch_idx * args.batch_size
207
+ end_idx = min(start_idx + args.batch_size, num_samples)
208
+ batch_specs = spectrograms[start_idx:end_idx]
209
+
210
+ print(f"\n[Batch {batch_idx+1}/{num_batches}] Processing samples {start_idx} to {end_idx-1}")
211
+
212
+ task1_batch_result = task1_predictor.predict(
213
+ batch_specs,
214
+ return_probabilities=args.show_probabilities,
215
+ return_routing=args.show_routing,
216
+ )
217
+ task2_batch_result = task2_predictor.predict(
218
+ batch_specs,
219
+ return_probabilities=args.show_probabilities,
220
+ return_routing=args.show_routing,
221
+ )
222
+
223
+ results["task1"].extend(
224
+ [task1_batch_result] if isinstance(task1_batch_result.get("predicted_class"), int)
225
+ else [{"predicted_class": task1_batch_result["predicted_class"][i],
226
+ "label": task1_batch_result.get("labels", [None])[i],
227
+ "confidence": task1_batch_result.get("confidence")[i] if isinstance(task1_batch_result.get("confidence"), list) else task1_batch_result.get("confidence")}
228
+ for i in range(len(task1_batch_result.get("labels", task1_batch_result.get("predicted_class", []))))]
229
+ )
230
+ results["task2"].extend(
231
+ [task2_batch_result] if isinstance(task2_batch_result.get("predicted_class"), int)
232
+ else [{"predicted_class": task2_batch_result["predicted_class"][i],
233
+ "label": task2_batch_result.get("labels", [None])[i],
234
+ "confidence": task2_batch_result.get("confidence")[i] if isinstance(task2_batch_result.get("confidence"), list) else task2_batch_result.get("confidence")}
235
+ for i in range(len(task2_batch_result.get("labels", task2_batch_result.get("predicted_class", []))))]
236
+ )
237
+
238
+ # Print summary
239
+ print("\n" + "="*60)
240
+ print("INFERENCE SUMMARY")
241
+ print("="*60)
242
+ print(f"Total samples processed: {num_samples}")
243
+
244
+ if results["task1"]:
245
+ task1_labels = [r.get("label", "Unknown") for r in results["task1"]]
246
+ print(f"\nTask 1 (Modulation) predictions:")
247
+ unique_labels = set(task1_labels)
248
+ for label in sorted(unique_labels):
249
+ count = task1_labels.count(label)
250
+ print(f" {label}: {count} samples ({count/num_samples*100:.1f}%)")
251
+
252
+ if results["task2"]:
253
+ task2_labels = [r.get("label", "Unknown") for r in results["task2"]]
254
+ print(f"\nTask 2 (SNR/Mobility) predictions:")
255
+ unique_labels = set(task2_labels)
256
+ for label in sorted(unique_labels):
257
+ count = task2_labels.count(label)
258
+ print(f" {label}: {count} samples ({count/num_samples*100:.1f}%)")
259
+
260
+ # Save results to file if requested
261
+ if args.output:
262
+ output_path = args.output.expanduser().resolve()
263
+ output_path.parent.mkdir(parents=True, exist_ok=True)
264
+
265
+ # Convert tensors to Python types for JSON serialization
266
+ def jsonify(obj):
267
+ if isinstance(obj, dict):
268
+ return {k: jsonify(v) for k, v in obj.items()}
269
+ elif isinstance(obj, (list, tuple)):
270
+ return [jsonify(x) for x in obj]
271
+ elif isinstance(obj, (torch.Tensor, np.ndarray)):
272
+ return obj.tolist()
273
+ elif isinstance(obj, (np.integer, np.floating)):
274
+ return obj.item()
275
+ return obj
276
+
277
+ output_data = {
278
+ "input_file": str(args.input),
279
+ "num_samples": num_samples,
280
+ "task1_predictions": jsonify(results["task1"]),
281
+ "task2_predictions": jsonify(results["task2"]),
282
+ }
283
+
284
+ with output_path.open("w", encoding="utf-8") as f:
285
+ json.dump(output_data, f, indent=2)
286
+
287
+ print(f"\n[INFO] Results saved to {output_path}")
288
+
289
+ print("\n" + "="*60)
290
+ print("INFERENCE COMPLETE")
291
+ print("="*60 + "\n")
292
+
293
+
294
+ if __name__ == "__main__":
295
+ main()
296
+