import os import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import matplotlib.ticker as mticker import cartopy.crs as ccrs import cartopy.feature as cfeature import cartopy.io.shapereader as shpreader from adjustText import adjust_text from .interpolation import interpolate_grid from .basemaps import draw_etopo_basemap import imageio.v2 as imageio import shutil class Plot_3DField_Data: """ A class for visualizing 3D spatiotemporal field data (e.g., ash concentration) across time and altitude levels. This class uses matplotlib and cartopy to create: - Animated GIFs of spatial fields at given altitudes - Vertical profile animations over time - Exported static frames with metadata annotations and zoomed views Parameters ---------- animator : object A container holding the dataset, including: - datasets: list of xarray-like DataArrays with 'ash_concentration' - lons, lats: 1D longitude and latitude arrays - lat_grid, lon_grid: 2D grid arrays for spatial mapping - levels: 1D array of vertical altitude levels (e.g., in km) output_dir : str Base directory for saving all outputs. Defaults to "plots". cmap : str Matplotlib colormap name. Defaults to "rainbow". fps : int Frames per second for GIFs. Defaults to 2. include_metadata : bool Whether to annotate each figure with simulation metadata. Defaults to True. threshold : float Value threshold (e.g., in g/m³) to highlight exceedances. Defaults to 0.1. zoom_width_deg : float Width of the zoomed-in region in degrees. Defaults to 6.0. zoom_height_deg : float Height of the zoomed-in region in degrees. Defaults to 6.0. zoom_level : int Zoom level passed to basemap drawing. Defaults to 7. basemap_type : str Type of basemap to draw (passed to draw_etopo_basemap). Defaults to "basemap". Methods ------- plot_single_z_level(z_km, filename) Generate animation over time at a specific altitude level. plot_vertical_profile_at_time(t_index, filename=None) Generate vertical profile GIF for a single timestep. animate_altitude(t_index, output_path) Animate altitude slices for one timestep. animate_all_altitude_profiles(output_folder='altitude_profiles') Generate vertical animations for all time steps. export_frames_as_jpgs(include_metadata=True) Export individual frames as static `.jpg` images with annotations. """ def __init__(self, animator, output_dir="plots", cmap="rainbow", fps=2, include_metadata=True, threshold=0.1, zoom_width_deg=6.0, zoom_height_deg=6.0, zoom_level=7, basemap_type="basemap"): self.animator = animator self.output_dir = os.path.abspath(os.path.join(os.getcwd(), output_dir)) os.makedirs(self.output_dir, exist_ok=True) self.cmap = cmap self.fps = fps self.include_metadata = include_metadata self.threshold = threshold self.zoom_width = zoom_width_deg self.zoom_height = zoom_height_deg shp = shpreader.natural_earth(resolution='110m', category='cultural', name='admin_0_countries') self.country_geoms = list(shpreader.Reader(shp).records()) self.zoom_level=zoom_level self.basemap_type=basemap_type #############3 # Load shapefile once countries_shp = shpreader.natural_earth( resolution='110m', category='cultural', name='admin_0_countries' ) self.country_geoms = list(shpreader.Reader(countries_shp).records()) # Cache extent bounds self.lon_min = np.min(self.animator.lons) self.lon_max = np.max(self.animator.lons) self.lat_min = np.min(self.animator.lats) self.lat_max = np.max(self.animator.lats) #####################3 def _make_dirs(self, path): path = os.path.abspath(os.path.join(os.getcwd(), os.path.dirname(path))) os.makedirs(path, exist_ok=True) def _get_zoom_indices(self, center_lat, center_lon): lon_min = center_lon - self.zoom_width / 2 lon_max = center_lon + self.zoom_width / 2 lat_min = center_lat - self.zoom_height / 2 lat_max = center_lat + self.zoom_height / 2 lat_idx = np.where((self.animator.lats >= lat_min) & (self.animator.lats <= lat_max))[0] lon_idx = np.where((self.animator.lons >= lon_min) & (self.animator.lons <= lon_max))[0] return lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max def _get_max_concentration_location(self): max_conc = -np.inf center_lat = center_lon = None for ds in self.animator.datasets: for z in range(len(self.animator.levels)): data = ds['ash_concentration'].values[z] if np.max(data) > max_conc: max_conc = np.max(data) max_idx = np.unravel_index(np.argmax(data), data.shape) center_lat = self.animator.lat_grid[max_idx] center_lon = self.animator.lon_grid[max_idx] return center_lat, center_lon def _add_country_labels(self, ax, extent): proj = ccrs.PlateCarree() texts = [] for country in self.country_geoms: name = country.attributes['NAME_LONG'] geom = country.geometry try: lon, lat = geom.centroid.x, geom.centroid.y if extent[0] <= lon <= extent[1] and extent[2] <= lat <= extent[3]: text = ax.text(lon, lat, name, fontsize=6, transform=proj, ha='center', va='center', color='white', bbox=dict(facecolor='black', alpha=0.5, linewidth=0)) texts.append(text) except: continue adjust_text(texts, ax=ax, only_move={'points': 'y', 'text': 'y'}, arrowprops=dict(arrowstyle="->", color='white', lw=0.5)) def _plot_frame(self, ax, data, lons, lats, title, levels, scale_label, proj): draw_etopo_basemap(ax, mode=self.basemap_type, zoom=self.zoom_level) c = ax.contourf(lons, lats, data, levels=levels, cmap=self.cmap, alpha=0.6, transform=proj) ax.contour(lons, lats, data, levels=levels, colors='black', linewidths=0.5, transform=proj) ax.set_title(title) ax.set_extent([lons.min(), lons.max(), lats.min(), lats.max()]) ax.coastlines() ax.add_feature(cfeature.BORDERS, linestyle=':') ax.add_feature(cfeature.LAND) ax.add_feature(cfeature.OCEAN) return c # metadata placement function and usage def _draw_metadata_sidebar(self, fig, meta_dict): lines = [ f"Run name: {meta_dict.get('run_name', 'N/A')}", f"Run time: {meta_dict.get('run_time', 'N/A')}", f"Met data: {meta_dict.get('met_data', 'N/A')}", f"Start release: {meta_dict.get('start_of_release', 'N/A')}", f"End release: {meta_dict.get('end_of_release', 'N/A')}", f"Source strength: {meta_dict.get('source_strength', 'N/A')} g/s", f"Release loc: {meta_dict.get('release_location', 'N/A')}", f"Release height: {meta_dict.get('release_height', 'N/A')} m asl", f"Run duration: {meta_dict.get('run_duration', 'N/A')}" ] full_text = "\n".join(lines) # ✅ actual newlines fig.text(0.1, 0.095, full_text, va='center', ha='left', fontsize=9, family='monospace', color='black', bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray')) def plot_single_z_level(self, z_km, filename="z_level.gif"): if z_km not in self.animator.levels: print(f"Z level {z_km} km not found.") return z_index = np.where(self.animator.levels == z_km)[0][0] output_path = os.path.join(self.output_dir, "z_levels", filename) fig = plt.figure(figsize=(16, 8)) proj = ccrs.PlateCarree() ax1 = fig.add_subplot(1, 2, 1, projection=proj) ax2 = fig.add_subplot(1, 2, 2, projection=proj) center_lat, center_lon = self._get_max_concentration_location() lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) lat_zoom = self.animator.lats[lat_idx] lon_zoom = self.animator.lons[lon_idx] lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom) meta = self.animator.datasets[0].attrs valid_frames = [] for t in range(len(self.animator.datasets)): interp = interpolate_grid(self.animator.datasets[t]['ash_concentration'].values[z_index], self.animator.lon_grid, self.animator.lat_grid) if np.isfinite(interp).sum() > 0: valid_frames.append(t) if not valid_frames: print(f"No valid frames for Z={z_km} km.") plt.close() return def update(t): ax1.clear() ax2.clear() data = self.animator.datasets[t]['ash_concentration'].values[z_index] interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) interp = np.where(interp < 0, np.nan, interp) zoom_plot = interp[np.ix_(lat_idx, lon_idx)] valid_vals = interp[np.isfinite(interp)] if valid_vals.size == 0: return [] min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) log_cutoff = 1e-3 use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 levels = ( np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) if use_log else np.linspace(0, max_val, 20) ) data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp scale_label = "Log" if use_log else "Linear" c = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, f"T{t+1} | Alt: {z_km} km (Full - {scale_label})", levels, scale_label, proj) self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, f"T{t} | Alt: {z_km} km (Zoom - {scale_label})", levels, scale_label, proj) self._add_country_labels(ax1, [self.animator.lons.min(), self.animator.lons.max(), self.animator.lats.min(), self.animator.lats.max()]) self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) if not hasattr(update, "colorbar"): update.colorbar = fig.colorbar(c, ax=[ax1, ax2], orientation='vertical', label="Ash concentration (g/m³)") formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') update.colorbar.ax.yaxis.set_major_formatter(formatter) # ✅ Draw threshold outline and label only if exceeded if np.nanmax(valid_vals) > self.threshold: ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.text(0.99, 0.01, f"⚠ Max Thresold Exceed: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", transform=ax2.transAxes, ha='right', va='bottom', fontsize=9, color='red', bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) return [] self._draw_metadata_sidebar(fig, meta) self._make_dirs(output_path) fig.tight_layout() ani = animation.FuncAnimation(fig, update, frames=valid_frames, blit=False, cache_frame_data =False) ani.save(output_path, writer='pillow', fps=self.fps, dpi=300) plt.close() print(f"✅ Saved Z-level animation to {output_path}") def plot_vertical_profile_at_time(self, t_index, filename=None): time_label = f"T{t_index+1}" for z_index, z_val in enumerate(self.animator.levels): filename = f"TimeSlices_Z{z_val:.1f}km.gif" self.plot_single_z_level(z_val, filename=os.path.join("vertical_profiles_timeSlice", filename)) ################################################ def animate_altitude(self, t_index: int, output_path: str): if not (0 <= t_index < len(self.animator.datasets)): print(f"Invalid time index {t_index}. Must be between 0 and {len(self.animator.datasets) - 1}.") ds = self.animator.datasets[t_index] fig = plt.figure(figsize=(18, 7)) proj = ccrs.PlateCarree() ax1 = fig.add_subplot(1, 2, 1, projection=proj) ax2 = fig.add_subplot(1, 2, 2, projection=proj) meta = ds.attrs center_lat, center_lon = self._get_max_concentration_location() if center_lat is None or center_lon is None: print(f"No valid data found for time T{t_index + 1}. Skipping...") plt.close() return lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) lat_zoom = self.animator.lats[lat_idx] lon_zoom = self.animator.lons[lon_idx] lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom) z_indices_with_data = [] for z_index in range(len(self.animator.levels)): data = ds['ash_concentration'].values[z_index] interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) if np.isfinite(interp).sum() > 0: z_indices_with_data.append(z_index) if not z_indices_with_data: print(f"No valid Z-levels at time T{t_index + 1}.") plt.close() return def update(z_index): ax1.clear() ax2.clear() data = ds['ash_concentration'].values[z_index] interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) interp = np.where(interp < 0, np.nan, interp) zoom_plot = interp[np.ix_(lat_idx, lon_idx)] valid_vals = interp[np.isfinite(interp)] if valid_vals.size == 0: return [] min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) log_cutoff = 1e-3 use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) if use_log else np.linspace(0, max_val, 20) data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp scale_label = "Log" if use_log else "Linear" title1 = f"T{t_index + 1} | Alt: {self.animator.levels[z_index]} km (Full - {scale_label})" title2 = f"T{t_index + 1} | Alt: {self.animator.levels[z_index]} km (Zoom - {scale_label})" c1 = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, title1, levels, scale_label, proj) self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, title2, levels, scale_label, proj) self._add_country_labels(ax1, [self.lon_min, self.lon_max, self.lat_min, self.lat_max]) self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) if self.include_metadata: self._draw_metadata_sidebar(fig, meta) if not hasattr(update, "colorbar"): update.colorbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical', label="Ash concentration (g/m³)", shrink=0.75) formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') update.colorbar.ax.yaxis.set_major_formatter(formatter) if np.nanmax(valid_vals) > self.threshold: ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.text(0.99, 0.01, f"⚠ Max Thresold Exceed: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", transform=ax2.transAxes, ha='right', va='bottom', fontsize=9, color='red', bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) return [] os.makedirs(os.path.dirname(output_path), exist_ok=True) #fig.set_size_inches(18, 7) fig.tight_layout(rect=[0.02, 0.02, 0.98, 0.98]) ani = animation.FuncAnimation(fig, update, frames=z_indices_with_data, blit=False, cache_frame_data =False) ani.save(output_path, writer='pillow', fps=self.fps, dpi=300) plt.close() print(f"✅ Saved vertical profile animation for T{t_index + 1} to {output_path}") def animate_all_altitude_profiles(self, output_folder='altitude_profiles'): output_folder = os.path.join(self.output_dir, "altitude_profiles") os.makedirs(output_folder, exist_ok=True) for t_index in range(len(self.animator.datasets)): output_path = os.path.join(output_folder, f"vertical_T{t_index + 1:02d}.gif") print(f"🔄 Generating vertical profile animation for T{t_index + 1}...") self.animate_altitude(t_index, output_path) def export_frames_as_jpgs(self, include_metadata: bool = True): output_folder = os.path.join(self.output_dir, "frames") os.makedirs(output_folder, exist_ok=True) meta = self.animator.datasets[0].attrs legend_text = "\\n".join([ f"Run name: {meta.get('run_name', 'N/A')}", f"Run time: {meta.get('run_time', 'N/A')}", f"Met data: {meta.get('met_data', 'N/A')}", f"Start release: {meta.get('start_of_release', 'N/A')}", f"End release: {meta.get('end_of_release', 'N/A')}", f"Strength: {meta.get('source_strength', 'N/A')} g/s", f"Location: {meta.get('release_location', 'N/A')}", f"Height: {meta.get('release_height', 'N/A')} m asl", f"Duration: {meta.get('run_duration', 'N/A')}" ]) for z_index, z_val in enumerate(self.animator.levels): z_dir = os.path.join(output_folder, f"Z{z_val:.1f}km") os.makedirs(z_dir, exist_ok=True) for t in range(len(self.animator.datasets)): data = self.animator.datasets[t]['ash_concentration'].values[z_index] interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) if not np.isfinite(interp).any(): continue fig = plt.figure(figsize=(16, 8)) proj = ccrs.PlateCarree() ax1 = fig.add_subplot(1, 2, 1, projection=proj) ax2 = fig.add_subplot(1, 2, 2, projection=proj) valid_vals = interp[np.isfinite(interp)] min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) log_cutoff = 1e-3 use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) if use_log else np.linspace(0, max_val, 20) data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp scale_label = "Log" if use_log else "Linear" center_lat, center_lon = self._get_max_concentration_location() lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) zoom_plot = interp[np.ix_(lat_idx, lon_idx)] lon_zoom = self.animator.lons[lon_idx] lat_zoom = self.animator.lats[lat_idx] c1 = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, f"T{t+1} | Alt: {z_val} km (Full - {scale_label})", levels, scale_label, proj) self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, f"T{t+1} | Alt: {z_val} km (Zoom - {scale_label})", levels, scale_label, proj) self._add_country_labels(ax1, [self.animator.lons.min(), self.animator.lons.max(), self.animator.lats.min(), self.animator.lats.max()]) self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) if np.nanmax(valid_vals) > self.threshold: ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], colors='red', linewidths=2, transform=proj) ax2.text(0.99, 0.01, f"⚠ Max: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", transform=ax2.transAxes, ha='right', va='bottom', fontsize=9, color='red', bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) if include_metadata: self._draw_metadata_sidebar(fig, meta) cbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical', shrink=0.75, pad=0.03) cbar.set_label("Ash concentration (g/m³)") formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') cbar.ax.yaxis.set_major_formatter(formatter) frame_path = os.path.join(z_dir, f"frame_{t+1:04d}.jpg") plt.savefig(frame_path, dpi=150, bbox_inches='tight') plt.close(fig) print(f"📸 Saved {frame_path}")