#!/usr/bin/env python3 """Quick spectrogram visualiser for lwm-spectro datasets.""" from __future__ import annotations import argparse import sys from pathlib import Path import pickle import matplotlib.pyplot as plt import numpy as np from fractions import Fraction try: from core.paths import get_spectrogram_base_dir except Exception: # pragma: no cover - fallback when module unavailable get_spectrogram_base_dir = None # type: ignore SCRIPT_DIR = Path(__file__).resolve().parent DEFAULT_BASE_CANDIDATES = [ SCRIPT_DIR / 'spectrograms', SCRIPT_DIR.parent / 'spectrograms', Path('D:/Namhyun/lwm_data'), Path('/mnt/d/Namhyun/lwm_data'), ] def resolve_base_dir() -> Path: if get_spectrogram_base_dir is not None: base = Path(get_spectrogram_base_dir()) if base.exists(): return base for cand in DEFAULT_BASE_CANDIDATES: if cand.exists(): return cand return Path.cwd() def detect_existing(route: Path, base_dir: Path) -> Path | None: if route.is_absolute() and route.exists(): return route.resolve() rel_candidate = (base_dir / route).resolve() if rel_candidate.exists(): return rel_candidate return None def tokens_to_path(tokens: list[str], base_dir: Path) -> Path: if not tokens: raise ValueError('At least one token or path fragment is required') candidate = Path(*tokens) existing = detect_existing(candidate, base_dir) if existing is not None: return existing current = base_dir for tok in tokens: matches = sorted(current.glob(f'**/{tok}')) matches = [m for m in matches if m.is_dir()] if not matches: raise FileNotFoundError(f'Cannot match token "{tok}" under {current}') current = matches[-1] return current def resolve_city_dirs(city_ids: list[int], base_dir: Path) -> list[Path]: roots: list[Path] = [] for cid in city_ids: pattern = f'city_{cid}_*' matches = sorted(p for p in base_dir.glob(pattern) if p.is_dir()) if not matches: raise FileNotFoundError(f'No directory matching {pattern} under {base_dir}') roots.extend(matches) return roots def is_within(path: Path, root: Path) -> bool: try: path.resolve().relative_to(root.resolve()) return True except ValueError: return False def extract_city_token(path: Path) -> str | None: for part in path.resolve().parts: if part.startswith('city_'): return part return None def format_city_token(city_token: str) -> str: parts = city_token.split('_') if len(parts) >= 3 and parts[0].lower() == 'city': prefix = parts[0].capitalize() identifier = parts[1] remainder = ' '.join(p.capitalize() for p in parts[2:]) return ' '.join(part for part in (prefix, identifier, remainder) if part) return city_token.replace('_', ' ').title() def find_pickle(path: Path) -> Path: if path.is_file(): return path matches = sorted(path.rglob('*.pkl')) if not matches: raise FileNotFoundError(f"No .pkl file found under {path}") return matches[0] def load_spectrogram(path: Path, index: int, average: bool) -> tuple[np.ndarray, dict]: with path.open('rb') as f: payload = pickle.load(f) specs = np.asarray(payload['spectrograms']) cfg = payload.get('configuration', {}) if average: img = specs.mean(axis=0) label = 'Mean' else: if not (0 <= index < specs.shape[0]): raise IndexError(f'Index {index} out of range (0..{specs.shape[0]-1})') img = specs[index] label = f'Index {index}' return img, cfg | {'label': label} def resolve_route(tokens: list[str], base_dir: Path, city_dirs: list[Path]) -> Path: base_candidate: Path | None = None try: base_candidate = tokens_to_path(tokens, base_dir) except (FileNotFoundError, ValueError): base_candidate = None if city_dirs: if base_candidate is not None and any(is_within(base_candidate, city) for city in city_dirs): return base_candidate for root in city_dirs: try: return tokens_to_path(tokens, root) except (FileNotFoundError, ValueError): continue raise FileNotFoundError('Failed to resolve route within the selected city directories.') if base_candidate is not None: return base_candidate raise FileNotFoundError('Failed to resolve route within the base directory.') def main() -> None: parser = argparse.ArgumentParser(description='Plot spectrogram pickle outputs for lwm-spectro.') parser.add_argument('--route', nargs='+', required=True, help='Path fragments (e.g., LTE QAM16 rate3-4 SNR5dB static) leading to the target pickle.') parser.add_argument('--city', type=int, nargs='+', help='City indices to search (e.g., --city 0 1). Defaults to searching the entire base directory.') parser.add_argument('--index', type=int, default=0, help='Sample index inside pickle (default: 0).') parser.add_argument('--average', action='store_true', help='Plot the mean over all samples instead of a single index.') parser.add_argument('--save', type=Path, help='Optional output path. Defaults to current directory with an auto-generated filename.') parser.add_argument('--no-show', action='store_true', help='Skip opening an interactive window (image is still saved).') args = parser.parse_args() base_dir = resolve_base_dir() city_dirs: list[Path] = [] if args.city: try: city_dirs = resolve_city_dirs(args.city, base_dir) except FileNotFoundError as err: print(err, file=sys.stderr) sys.exit(1) try: target_path = resolve_route(args.route, base_dir, city_dirs) except FileNotFoundError as err: print(err, file=sys.stderr) sys.exit(1) try: pkl_path = find_pickle(target_path) except FileNotFoundError as err: print(err, file=sys.stderr) sys.exit(1) city_token = extract_city_token(pkl_path) try: img, meta = load_spectrogram(pkl_path, args.index, args.average) except (IndexError, KeyError) as err: print(err, file=sys.stderr) sys.exit(1) # Values already in dBm (10*log10(P[W]/1 mW)) img_dbm = np.asarray(img, dtype=np.float64) freq_res = meta.get('freq_resolution_hz') time_res = meta.get('time_resolution_ms') sample_rate = meta.get('sample_rate') nperseg = meta.get('nperseg') noverlap = meta.get('noverlap') hop = None if isinstance(nperseg, (int, float)) and isinstance(noverlap, (int, float)) and isinstance(sample_rate, (int, float)) and sample_rate > 0: hop_samples = max(int(nperseg - noverlap), 1) hop = hop_samples / sample_rate snr = meta.get('snr') plt.figure(figsize=(6.4, 5.4)) extent = None if hop is not None and isinstance(freq_res, (int, float)): height, width = img_dbm.shape times = [0, hop * width] freqs = [-(height // 2) * freq_res, (height - height // 2) * freq_res] extent = [times[0] * 1e6, times[1] * 1e6, freqs[0] / 1e6, freqs[1] / 1e6] im = plt.imshow(img_dbm, aspect='auto', origin='lower', cmap='viridis', extent=extent) plt.colorbar(im, fraction=0.046, pad=0.04, label='Power (dBm)') title_tokens = [] if city_token: city_display = format_city_token(city_token) title_tokens.append(city_display) if meta.get('standard'): title_tokens.append(str(meta['standard'])) if meta.get('modulation'): title_tokens.append(str(meta['modulation'])) code_rate = meta.get('code_rate') if isinstance(code_rate, (int, float)): try: frac = Fraction(code_rate).limit_denominator(16) title_tokens.append(f'rate {frac.numerator}/{frac.denominator}') except Exception: title_tokens.append(f'rate {code_rate}') if isinstance(snr, (int, float)): snr_display = int(round(snr)) if abs(snr - round(snr)) < 1e-6 else snr title_tokens.append(f'SNR {snr_display} dB') speed = meta.get('speed') or meta.get('speed_name') if speed: title_tokens.append(str(speed)) plt.title(' | '.join(title_tokens) if title_tokens else pkl_path.stem) if extent is not None: xlabel = 'Time (µs)' ylabel = 'Frequency (MHz)' else: xlabel = 'Time bins' if isinstance(time_res, (int, float)): xlabel += f' (~{time_res:.3f} ms hop)' ylabel = 'Frequency bins' if isinstance(freq_res, (int, float)): ylabel += f' (~{freq_res/1e3:.1f} kHz/bin)' plt.xlabel(xlabel) plt.ylabel(ylabel) if args.save is not None: out_path = args.save else: def sanitize(token: str) -> str: return token.replace(' ', '_').replace('/', '_') tokens = [pkl_path.stem] if city_token: tokens.append(city_token) for key in ("standard", "modulation"): value = meta.get(key) if value: tokens.append(str(value)) code_rate = meta.get('code_rate') if isinstance(code_rate, (int, float)): try: frac = Fraction(code_rate).limit_denominator(16) tokens.append(f'rate{frac.numerator}-{frac.denominator}') except Exception: tokens.append(f'rate{code_rate}') snr_val = meta.get('snr') if isinstance(snr_val, (int, float)): snr_display = int(round(snr_val)) if abs(snr_val - round(snr_val)) < 1e-6 else snr_val tokens.append(f'SNR{snr_display}dB') speed = meta.get('speed') or meta.get('speed_name') if speed: tokens.append(str(speed)) tokens.append(meta['label']) out_name = '_'.join(sanitize(str(tok)) for tok in tokens) + '.png' out_path = Path.cwd() / out_name out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, dpi=200) if not args.no_show: plt.show() if __name__ == '__main__': main()