lwm-spectro / plot /plot_spectrogram.py
Namhyun Kim
Sync local development code into HF repo
eaaeb1b
#!/usr/bin/env python3
"""Quick spectrogram visualiser for lwm-spectro datasets."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import pickle
import matplotlib.pyplot as plt
import numpy as np
from fractions import Fraction
try:
from core.paths import get_spectrogram_base_dir
except Exception: # pragma: no cover - fallback when module unavailable
get_spectrogram_base_dir = None # type: ignore
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_BASE_CANDIDATES = [
SCRIPT_DIR / 'spectrograms',
SCRIPT_DIR.parent / 'spectrograms',
Path('D:/Namhyun/lwm_data'),
Path('/mnt/d/Namhyun/lwm_data'),
]
def resolve_base_dir() -> Path:
if get_spectrogram_base_dir is not None:
base = Path(get_spectrogram_base_dir())
if base.exists():
return base
for cand in DEFAULT_BASE_CANDIDATES:
if cand.exists():
return cand
return Path.cwd()
def detect_existing(route: Path, base_dir: Path) -> Path | None:
if route.is_absolute() and route.exists():
return route.resolve()
rel_candidate = (base_dir / route).resolve()
if rel_candidate.exists():
return rel_candidate
return None
def tokens_to_path(tokens: list[str], base_dir: Path) -> Path:
if not tokens:
raise ValueError('At least one token or path fragment is required')
candidate = Path(*tokens)
existing = detect_existing(candidate, base_dir)
if existing is not None:
return existing
current = base_dir
for tok in tokens:
matches = sorted(current.glob(f'**/{tok}'))
matches = [m for m in matches if m.is_dir()]
if not matches:
raise FileNotFoundError(f'Cannot match token "{tok}" under {current}')
current = matches[-1]
return current
def resolve_city_dirs(city_ids: list[int], base_dir: Path) -> list[Path]:
roots: list[Path] = []
for cid in city_ids:
pattern = f'city_{cid}_*'
matches = sorted(p for p in base_dir.glob(pattern) if p.is_dir())
if not matches:
raise FileNotFoundError(f'No directory matching {pattern} under {base_dir}')
roots.extend(matches)
return roots
def is_within(path: Path, root: Path) -> bool:
try:
path.resolve().relative_to(root.resolve())
return True
except ValueError:
return False
def extract_city_token(path: Path) -> str | None:
for part in path.resolve().parts:
if part.startswith('city_'):
return part
return None
def format_city_token(city_token: str) -> str:
parts = city_token.split('_')
if len(parts) >= 3 and parts[0].lower() == 'city':
prefix = parts[0].capitalize()
identifier = parts[1]
remainder = ' '.join(p.capitalize() for p in parts[2:])
return ' '.join(part for part in (prefix, identifier, remainder) if part)
return city_token.replace('_', ' ').title()
def find_pickle(path: Path) -> Path:
if path.is_file():
return path
matches = sorted(path.rglob('*.pkl'))
if not matches:
raise FileNotFoundError(f"No .pkl file found under {path}")
return matches[0]
def load_spectrogram(path: Path, index: int, average: bool) -> tuple[np.ndarray, dict]:
with path.open('rb') as f:
payload = pickle.load(f)
specs = np.asarray(payload['spectrograms'])
cfg = payload.get('configuration', {})
if average:
img = specs.mean(axis=0)
label = 'Mean'
else:
if not (0 <= index < specs.shape[0]):
raise IndexError(f'Index {index} out of range (0..{specs.shape[0]-1})')
img = specs[index]
label = f'Index {index}'
return img, cfg | {'label': label}
def resolve_route(tokens: list[str], base_dir: Path, city_dirs: list[Path]) -> Path:
base_candidate: Path | None = None
try:
base_candidate = tokens_to_path(tokens, base_dir)
except (FileNotFoundError, ValueError):
base_candidate = None
if city_dirs:
if base_candidate is not None and any(is_within(base_candidate, city) for city in city_dirs):
return base_candidate
for root in city_dirs:
try:
return tokens_to_path(tokens, root)
except (FileNotFoundError, ValueError):
continue
raise FileNotFoundError('Failed to resolve route within the selected city directories.')
if base_candidate is not None:
return base_candidate
raise FileNotFoundError('Failed to resolve route within the base directory.')
def main() -> None:
parser = argparse.ArgumentParser(description='Plot spectrogram pickle outputs for lwm-spectro.')
parser.add_argument('--route', nargs='+', required=True,
help='Path fragments (e.g., LTE QAM16 rate3-4 SNR5dB static) leading to the target pickle.')
parser.add_argument('--city', type=int, nargs='+',
help='City indices to search (e.g., --city 0 1). Defaults to searching the entire base directory.')
parser.add_argument('--index', type=int, default=0, help='Sample index inside pickle (default: 0).')
parser.add_argument('--average', action='store_true', help='Plot the mean over all samples instead of a single index.')
parser.add_argument('--save', type=Path,
help='Optional output path. Defaults to current directory with an auto-generated filename.')
parser.add_argument('--no-show', action='store_true', help='Skip opening an interactive window (image is still saved).')
args = parser.parse_args()
base_dir = resolve_base_dir()
city_dirs: list[Path] = []
if args.city:
try:
city_dirs = resolve_city_dirs(args.city, base_dir)
except FileNotFoundError as err:
print(err, file=sys.stderr)
sys.exit(1)
try:
target_path = resolve_route(args.route, base_dir, city_dirs)
except FileNotFoundError as err:
print(err, file=sys.stderr)
sys.exit(1)
try:
pkl_path = find_pickle(target_path)
except FileNotFoundError as err:
print(err, file=sys.stderr)
sys.exit(1)
city_token = extract_city_token(pkl_path)
try:
img, meta = load_spectrogram(pkl_path, args.index, args.average)
except (IndexError, KeyError) as err:
print(err, file=sys.stderr)
sys.exit(1)
# Values already in dBm (10*log10(P[W]/1 mW))
img_dbm = np.asarray(img, dtype=np.float64)
freq_res = meta.get('freq_resolution_hz')
time_res = meta.get('time_resolution_ms')
sample_rate = meta.get('sample_rate')
nperseg = meta.get('nperseg')
noverlap = meta.get('noverlap')
hop = None
if isinstance(nperseg, (int, float)) and isinstance(noverlap, (int, float)) and isinstance(sample_rate, (int, float)) and sample_rate > 0:
hop_samples = max(int(nperseg - noverlap), 1)
hop = hop_samples / sample_rate
snr = meta.get('snr')
plt.figure(figsize=(6.4, 5.4))
extent = None
if hop is not None and isinstance(freq_res, (int, float)):
height, width = img_dbm.shape
times = [0, hop * width]
freqs = [-(height // 2) * freq_res, (height - height // 2) * freq_res]
extent = [times[0] * 1e6, times[1] * 1e6, freqs[0] / 1e6, freqs[1] / 1e6]
im = plt.imshow(img_dbm, aspect='auto', origin='lower', cmap='viridis', extent=extent)
plt.colorbar(im, fraction=0.046, pad=0.04, label='Power (dBm)')
title_tokens = []
if city_token:
city_display = format_city_token(city_token)
title_tokens.append(city_display)
if meta.get('standard'):
title_tokens.append(str(meta['standard']))
if meta.get('modulation'):
title_tokens.append(str(meta['modulation']))
code_rate = meta.get('code_rate')
if isinstance(code_rate, (int, float)):
try:
frac = Fraction(code_rate).limit_denominator(16)
title_tokens.append(f'rate {frac.numerator}/{frac.denominator}')
except Exception:
title_tokens.append(f'rate {code_rate}')
if isinstance(snr, (int, float)):
snr_display = int(round(snr)) if abs(snr - round(snr)) < 1e-6 else snr
title_tokens.append(f'SNR {snr_display} dB')
speed = meta.get('speed') or meta.get('speed_name')
if speed:
title_tokens.append(str(speed))
plt.title(' | '.join(title_tokens) if title_tokens else pkl_path.stem)
if extent is not None:
xlabel = 'Time (µs)'
ylabel = 'Frequency (MHz)'
else:
xlabel = 'Time bins'
if isinstance(time_res, (int, float)):
xlabel += f' (~{time_res:.3f} ms hop)'
ylabel = 'Frequency bins'
if isinstance(freq_res, (int, float)):
ylabel += f' (~{freq_res/1e3:.1f} kHz/bin)'
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if args.save is not None:
out_path = args.save
else:
def sanitize(token: str) -> str:
return token.replace(' ', '_').replace('/', '_')
tokens = [pkl_path.stem]
if city_token:
tokens.append(city_token)
for key in ("standard", "modulation"):
value = meta.get(key)
if value:
tokens.append(str(value))
code_rate = meta.get('code_rate')
if isinstance(code_rate, (int, float)):
try:
frac = Fraction(code_rate).limit_denominator(16)
tokens.append(f'rate{frac.numerator}-{frac.denominator}')
except Exception:
tokens.append(f'rate{code_rate}')
snr_val = meta.get('snr')
if isinstance(snr_val, (int, float)):
snr_display = int(round(snr_val)) if abs(snr_val - round(snr_val)) < 1e-6 else snr_val
tokens.append(f'SNR{snr_display}dB')
speed = meta.get('speed') or meta.get('speed_name')
if speed:
tokens.append(str(speed))
tokens.append(meta['label'])
out_name = '_'.join(sanitize(str(tok)) for tok in tokens) + '.png'
out_path = Path.cwd() / out_name
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, dpi=200)
if not args.no_show:
plt.show()
if __name__ == '__main__':
main()