| | import argparse |
| | import json |
| | import os |
| | import sys |
| | import time |
| | from functools import partial |
| | from multiprocessing import Pool |
| | from queue import Queue |
| | from threading import Thread |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import xarray as xr |
| | from tqdm import tqdm |
| |
|
| | START = time.time() |
| | SEED = None |
| |
|
| | |
| | parser = argparse.ArgumentParser(description="Importance sampler of the valid datacubes (after the nan filtering)") |
| | parser.add_argument("zarr_path", help="Path to the Zarr dataset") |
| | parser.add_argument( |
| | "csv_path", help="Path to the CSV with the valid datacube coordinates (created by the nan filtering)" |
| | ) |
| | parser.add_argument("--q_min", type=float, default=1e-4, help="Minimum selection probability (default 1e-4)") |
| | parser.add_argument("--s", type=float, default=1, help="Denominator in the exponential") |
| | parser.add_argument("--m", type=float, default=0.1, help="Factor weighting the mean rescaled rain rate (dafault 0.1)") |
| | parser.add_argument("--n_workers", type=int, default=8, help="Number of parallel workers (default 8)") |
| | parser.add_argument("--n_rand", type=int, default=1, help="Number of random sampling of each datacube (dafaut 1)") |
| | args = parser.parse_args() |
| |
|
| | |
| | s = args.s |
| | qmin = args.q_min |
| | m = args.m |
| |
|
| | n_workers = args.n_workers |
| | N_rand = args.n_rand |
| | chunksize = 16000 |
| |
|
| | |
| | name_arr = args.csv_path.split("_") |
| | dates = name_arr[2] |
| | start_date = "-".join(dates.split("-")[0:3]) |
| | end_date = "-".join(dates.split("-")[3:]) |
| | Dt, w, h = name_arr[3].split("x") |
| | step_T, step_X, step_Y = name_arr[4].split("x") |
| | N_nan = name_arr[5][:-4] |
| |
|
| | |
| | Dt, w, h = int(Dt), int(w), int(h) |
| | step_T, step_X, step_Y = int(step_T), int(step_X), int(step_Y) |
| | N_nan = int(N_nan) |
| |
|
| |
|
| | |
| | def acceptance_probability(data): |
| | """ |
| | Calculate the acceptance probability for importance sampling. |
| | |
| | The probability is ``min(1, q_min + m * mean(data))``, where ``q_min`` |
| | and ``m`` are module-level parameters. |
| | |
| | Parameters |
| | ---------- |
| | data : np.ndarray |
| | Rescaled rain rate data for a single datacube. |
| | |
| | Returns |
| | ------- |
| | q : float |
| | Acceptance probability in ``[q_min, 1]``. |
| | """ |
| | return min(1.0, qmin + m * np.nanmean(data)) |
| |
|
| |
|
| | def process_datacube(coord, RR, N_rand, seed, acceptance_probability): |
| | """ |
| | Process a single space-time region for importance sampling. |
| | |
| | Loads the datacube, rescales rain rate, computes an acceptance |
| | probability, and performs ``N_rand`` random acceptance trials. |
| | |
| | Parameters |
| | ---------- |
| | coord : array-like of int |
| | Three-element sequence ``(it, ix, iy)`` specifying the datacube |
| | origin. |
| | RR : xr.DataArray |
| | Rain rate data array from the Zarr dataset. |
| | N_rand : int |
| | Number of random acceptance trials per datacube. |
| | seed : int or None |
| | Random seed for reproducibility. If ``None``, non-deterministic. |
| | acceptance_probability : callable |
| | Function that takes a data array and returns a probability in |
| | ``[0, 1]``. |
| | |
| | Returns |
| | ------- |
| | hits : list of tuple of int |
| | List of accepted ``(it, ix, iy)`` tuples (may contain duplicates |
| | if accepted multiple times). |
| | """ |
| |
|
| | try: |
| | it, ix, iy = coord |
| | time_slice = slice(it, it + Dt) |
| | x_slice = slice(ix, ix + w) |
| | y_slice = slice(iy, iy + h) |
| |
|
| | |
| | data = RR[time_slice, x_slice, y_slice] |
| | data = 1 - np.exp(-data / s) |
| |
|
| | |
| | q = acceptance_probability(data) |
| |
|
| | |
| | rng = np.random.default_rng(seed) |
| | random_numbers = rng.random(N_rand) |
| | accepted_count = np.sum(random_numbers <= q) |
| |
|
| | |
| | hits = [(it, ix, iy)] * accepted_count |
| | return hits |
| | except Exception as e: |
| | print(f"Error processing region ({it}, {ix}, {iy}): {e}", file=sys.stderr) |
| | return [] |
| |
|
| |
|
| | def file_writer(output_queue, filename, batch_size=1000): |
| | """ |
| | Dedicated writer thread that flushes results to a CSV file in batches. |
| | |
| | Reads lists of ``(t, x, y)`` tuples from the queue and writes them as |
| | CSV rows. Stops when a ``None`` sentinel is received. |
| | |
| | Parameters |
| | ---------- |
| | output_queue : queue.Queue |
| | Thread-safe queue providing lists of ``(t, x, y)`` tuples. |
| | filename : str |
| | Path to the output CSV file. |
| | batch_size : int, optional |
| | Number of rows to buffer before flushing to disk. Default is |
| | ``1000``. |
| | """ |
| | with open(filename, "w") as f: |
| | f.write("t,x,y\n") |
| | batch = [] |
| |
|
| | while True: |
| | item = output_queue.get() |
| |
|
| | if item is None: |
| | |
| | for t, x, y in batch: |
| | f.write(f"{t},{x},{y}\n") |
| | break |
| |
|
| | batch.extend(item) |
| |
|
| | if len(batch) >= batch_size: |
| | for t, x, y in batch: |
| | f.write(f"{t},{x},{y}\n") |
| | f.flush() |
| | batch = [] |
| |
|
| | print(f"Results saved to {filename}") |
| |
|
| |
|
| | |
| | print(f"Opening Zarr dataset: {args.zarr_path}") |
| | try: |
| | zg = xr.open_zarr(args.zarr_path, mode="r") |
| | RR = zg["RR"] |
| | except Exception as e: |
| | print(f"Error loading Zarr dataset: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | output_file = f"sampled_datacubes_{start_date}-{end_date}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{N_nan}.csv" |
| | if os.path.exists(output_file): |
| | response = input(f"File {output_file} already exists. Overwrite? (y/n): ") |
| | if response.lower() != "y": |
| | print("Exiting without overwriting.") |
| | sys.exit(0) |
| | else: |
| | print(f"Overwriting {output_file}...") |
| |
|
| | |
| | output_queue = Queue(maxsize=100) |
| | writer_thread = Thread(target=file_writer, args=(output_queue, output_file, 1000)) |
| | writer_thread.daemon = False |
| | writer_thread.start() |
| |
|
| | |
| | metadata = { |
| | "csv": args.csv_path, |
| | "zarr": args.zarr_path, |
| | "file": output_file, |
| | "start_date": start_date, |
| | "end_date": end_date, |
| | "Dt": Dt, |
| | "w": w, |
| | "h": h, |
| | "step_T": step_T, |
| | "step_X": step_X, |
| | "step_Y": step_Y, |
| | "N_nan": N_nan, |
| | "N_rand": N_rand, |
| | "n_workers": n_workers, |
| | "qmin": qmin, |
| | "m": m, |
| | "s": s, |
| | "seed": SEED, |
| | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| | } |
| | metadata_filename = output_file.replace(".csv", "_metadata.json") |
| | with open(metadata_filename, "w") as f: |
| | json.dump(metadata, f, indent=2) |
| | print(f"Saved run metadata to {metadata_filename}") |
| |
|
| |
|
| | |
| | |
| | process_datacube_partial = partial( |
| | process_datacube, RR=RR, N_rand=N_rand, seed=SEED, acceptance_probability=acceptance_probability |
| | ) |
| |
|
| | pool_chunksize = max(1, chunksize // n_workers) |
| |
|
| | with Pool(n_workers) as pool: |
| | pbar = tqdm(desc="Processing CSV chunks") |
| |
|
| | |
| | for chunk in pd.read_csv( |
| | args.csv_path, |
| | usecols=["t", "x", "y"], |
| | dtype={"t": "int32", "x": "int32", "y": "int32"}, |
| | engine="c", |
| | chunksize=chunksize, |
| | ): |
| | for hits in pool.imap(process_datacube_partial, chunk.values, chunksize=pool_chunksize): |
| | if hits: |
| | output_queue.put(hits) |
| | pbar.update(1) |
| |
|
| | pbar.close() |
| |
|
| | |
| | output_queue.put(None) |
| | writer_thread.join() |
| |
|
| | print(f"Done in {time.time() - START}s.") |
| | sys.exit(0) |
| |
|