|
|
"""Animation functionality for creating MP4 videos from multi-dimensional data.""" |
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
import subprocess |
|
|
from typing import Optional, Callable, List |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.animation import FuncAnimation |
|
|
import xarray as xr |
|
|
|
|
|
from .plot import plot_1d, plot_2d, plot_map, setup_matplotlib |
|
|
from .utils import identify_coordinates, format_value |
|
|
|
|
|
|
|
|
def check_ffmpeg(): |
|
|
"""Check if FFmpeg is available.""" |
|
|
try: |
|
|
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True) |
|
|
return True |
|
|
except (subprocess.CalledProcessError, FileNotFoundError): |
|
|
return False |
|
|
|
|
|
|
|
|
def animate_over_dim(da: xr.DataArray, dim: str, plot_func: Callable = None, |
|
|
fps: int = 10, out: str = "animation.mp4", |
|
|
figsize: tuple = (10, 8), **plot_kwargs) -> str: |
|
|
""" |
|
|
Create an animation over a specified dimension. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
dim: Dimension to animate over |
|
|
plot_func: Plotting function to use (auto-detected if None) |
|
|
fps: Frames per second |
|
|
out: Output file path |
|
|
figsize: Figure size |
|
|
**plot_kwargs: Additional plotting parameters |
|
|
|
|
|
Returns: |
|
|
Path to the created animation file |
|
|
""" |
|
|
if not check_ffmpeg(): |
|
|
raise RuntimeError("FFmpeg is required for creating MP4 animations") |
|
|
|
|
|
if dim not in da.dims: |
|
|
raise ValueError(f"Dimension '{dim}' not found in DataArray") |
|
|
|
|
|
setup_matplotlib() |
|
|
|
|
|
|
|
|
coord_vals = da.coords[dim].values |
|
|
n_frames = len(coord_vals) |
|
|
|
|
|
if n_frames < 2: |
|
|
raise ValueError(f"Need at least 2 frames for animation, got {n_frames}") |
|
|
|
|
|
|
|
|
if plot_func is None: |
|
|
remaining_dims = [d for d in da.dims if d != dim] |
|
|
n_remaining = len(remaining_dims) |
|
|
|
|
|
|
|
|
coords = identify_coordinates(da) |
|
|
has_geo = 'X' in coords and 'Y' in coords |
|
|
|
|
|
if n_remaining == 1: |
|
|
plot_func = plot_1d |
|
|
elif n_remaining == 2 and has_geo: |
|
|
plot_func = plot_map |
|
|
elif n_remaining == 2: |
|
|
plot_func = plot_2d |
|
|
else: |
|
|
raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data") |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=figsize) |
|
|
|
|
|
|
|
|
initial_frame = da.isel({dim: 0}) |
|
|
|
|
|
|
|
|
if 'vmin' not in plot_kwargs: |
|
|
plot_kwargs['vmin'] = float(da.min().values) |
|
|
if 'vmax' not in plot_kwargs: |
|
|
plot_kwargs['vmax'] = float(da.max().values) |
|
|
|
|
|
|
|
|
if plot_func == plot_1d: |
|
|
line, = ax.plot([], []) |
|
|
ax.set_xlim(float(initial_frame.coords[initial_frame.dims[0]].min()), |
|
|
float(initial_frame.coords[initial_frame.dims[0]].max())) |
|
|
ax.set_ylim(plot_kwargs['vmin'], plot_kwargs['vmax']) |
|
|
|
|
|
|
|
|
x_dim = initial_frame.dims[0] |
|
|
ax.set_xlabel(f"{x_dim} ({initial_frame.coords[x_dim].attrs.get('units', '')})") |
|
|
ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") |
|
|
|
|
|
def animate(frame_idx): |
|
|
frame_data = da.isel({dim: frame_idx}) |
|
|
x_data = frame_data.coords[x_dim] |
|
|
line.set_data(x_data, frame_data) |
|
|
|
|
|
|
|
|
coord_val = coord_vals[frame_idx] |
|
|
coord_str = format_value(coord_val, dim) |
|
|
title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}" |
|
|
ax.set_title(title) |
|
|
|
|
|
return line, |
|
|
|
|
|
elif plot_func in [plot_2d, plot_map]: |
|
|
|
|
|
def animate(frame_idx): |
|
|
ax.clear() |
|
|
frame_data = da.isel({dim: frame_idx}) |
|
|
|
|
|
|
|
|
if plot_func == plot_map: |
|
|
|
|
|
import cartopy.crs as ccrs |
|
|
import cartopy.feature as cfeature |
|
|
|
|
|
proj = plot_kwargs.get('proj', 'PlateCarree') |
|
|
proj_map = { |
|
|
'PlateCarree': ccrs.PlateCarree(), |
|
|
'Robinson': ccrs.Robinson(), |
|
|
'Mollweide': ccrs.Mollweide() |
|
|
} |
|
|
projection = proj_map.get(proj, ccrs.PlateCarree()) |
|
|
|
|
|
coords = identify_coordinates(frame_data) |
|
|
lon_dim = coords['X'] |
|
|
lat_dim = coords['Y'] |
|
|
|
|
|
lons = frame_data.coords[lon_dim].values |
|
|
lats = frame_data.coords[lat_dim].values |
|
|
|
|
|
|
|
|
cmap = plot_kwargs.get('cmap', 'viridis') |
|
|
im = ax.pcolormesh(lons, lats, frame_data.transpose(lat_dim, lon_dim).values, |
|
|
cmap=cmap, vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'], |
|
|
transform=ccrs.PlateCarree(), shading='auto') |
|
|
|
|
|
|
|
|
if plot_kwargs.get('coastlines', True): |
|
|
ax.coastlines(resolution='50m', color='black', linewidth=0.5) |
|
|
if plot_kwargs.get('gridlines', True): |
|
|
ax.gridlines(alpha=0.5) |
|
|
|
|
|
ax.set_global() |
|
|
|
|
|
else: |
|
|
|
|
|
coords = identify_coordinates(frame_data) |
|
|
x_dim = coords.get('X', frame_data.dims[-1]) |
|
|
y_dim = coords.get('Y', frame_data.dims[-2]) |
|
|
|
|
|
frame_plot = frame_data.transpose(y_dim, x_dim) |
|
|
x_coord = frame_data.coords[x_dim] |
|
|
y_coord = frame_data.coords[y_dim] |
|
|
|
|
|
im = ax.imshow(frame_plot.values, |
|
|
extent=[float(x_coord.min()), float(x_coord.max()), |
|
|
float(y_coord.min()), float(y_coord.max())], |
|
|
aspect='auto', origin='lower', |
|
|
cmap=plot_kwargs.get('cmap', 'viridis'), |
|
|
vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax']) |
|
|
|
|
|
ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})") |
|
|
ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})") |
|
|
|
|
|
|
|
|
coord_val = coord_vals[frame_idx] |
|
|
coord_str = format_value(coord_val, dim) |
|
|
title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}" |
|
|
ax.set_title(title) |
|
|
|
|
|
return [im] if 'im' in locals() else [] |
|
|
|
|
|
|
|
|
anim = FuncAnimation(fig, animate, frames=n_frames, interval=1000//fps, blit=False) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
Writer = plt.matplotlib.animation.writers['ffmpeg'] |
|
|
writer = Writer(fps=fps, metadata=dict(artist='TensorView'), bitrate=1800) |
|
|
anim.save(out, writer=writer) |
|
|
|
|
|
plt.close(fig) |
|
|
return out |
|
|
|
|
|
except Exception as e: |
|
|
plt.close(fig) |
|
|
raise RuntimeError(f"Failed to create animation: {str(e)}") |
|
|
|
|
|
|
|
|
def create_frame_sequence(da: xr.DataArray, dim: str, plot_func: Callable = None, |
|
|
output_dir: str = "frames", **plot_kwargs) -> List[str]: |
|
|
""" |
|
|
Create a sequence of individual frame images. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
dim: Dimension to animate over |
|
|
plot_func: Plotting function to use |
|
|
output_dir: Directory to save frames |
|
|
**plot_kwargs: Additional plotting parameters |
|
|
|
|
|
Returns: |
|
|
List of frame file paths |
|
|
""" |
|
|
if dim not in da.dims: |
|
|
raise ValueError(f"Dimension '{dim}' not found in DataArray") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
coord_vals = da.coords[dim].values |
|
|
frame_paths = [] |
|
|
|
|
|
|
|
|
if plot_func is None: |
|
|
remaining_dims = [d for d in da.dims if d != dim] |
|
|
n_remaining = len(remaining_dims) |
|
|
|
|
|
coords = identify_coordinates(da) |
|
|
has_geo = 'X' in coords and 'Y' in coords |
|
|
|
|
|
if n_remaining == 1: |
|
|
plot_func = plot_1d |
|
|
elif n_remaining == 2 and has_geo: |
|
|
plot_func = plot_map |
|
|
elif n_remaining == 2: |
|
|
plot_func = plot_2d |
|
|
else: |
|
|
raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data") |
|
|
|
|
|
|
|
|
if 'vmin' not in plot_kwargs: |
|
|
plot_kwargs['vmin'] = float(da.min().values) |
|
|
if 'vmax' not in plot_kwargs: |
|
|
plot_kwargs['vmax'] = float(da.max().values) |
|
|
|
|
|
|
|
|
for i, coord_val in enumerate(coord_vals): |
|
|
frame_data = da.isel({dim: i}) |
|
|
|
|
|
|
|
|
fig = plot_func(frame_data, **plot_kwargs) |
|
|
|
|
|
|
|
|
coord_str = format_value(coord_val, dim) |
|
|
fig.suptitle(f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}") |
|
|
|
|
|
|
|
|
frame_path = os.path.join(output_dir, f"frame_{i:04d}.png") |
|
|
fig.savefig(frame_path, dpi=150, bbox_inches='tight') |
|
|
frame_paths.append(frame_path) |
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
return frame_paths |
|
|
|
|
|
|
|
|
def frames_to_mp4(frame_dir: str, output_path: str, fps: int = 10, cleanup: bool = True) -> str: |
|
|
""" |
|
|
Convert a directory of frame images to MP4 video. |
|
|
|
|
|
Args: |
|
|
frame_dir: Directory containing frame images |
|
|
output_path: Output MP4 file path |
|
|
fps: Frames per second |
|
|
cleanup: Whether to delete frame files after conversion |
|
|
|
|
|
Returns: |
|
|
Path to created MP4 file |
|
|
""" |
|
|
if not check_ffmpeg(): |
|
|
raise RuntimeError("FFmpeg is required for MP4 conversion") |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
'ffmpeg', '-y', |
|
|
'-framerate', str(fps), |
|
|
'-pattern_type', 'glob', |
|
|
'-i', os.path.join(frame_dir, 'frame_*.png'), |
|
|
'-c:v', 'libx264', |
|
|
'-pix_fmt', 'yuv420p', |
|
|
'-crf', '18', |
|
|
output_path |
|
|
] |
|
|
|
|
|
try: |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
|
|
|
|
|
|
if cleanup: |
|
|
import glob |
|
|
for frame_file in glob.glob(os.path.join(frame_dir, 'frame_*.png')): |
|
|
os.remove(frame_file) |
|
|
|
|
|
|
|
|
try: |
|
|
os.rmdir(frame_dir) |
|
|
except OSError: |
|
|
pass |
|
|
|
|
|
return output_path |
|
|
|
|
|
except subprocess.CalledProcessError as e: |
|
|
raise RuntimeError(f"FFmpeg failed: {e.stderr.decode()}") |
|
|
|
|
|
|
|
|
def create_gif(da: xr.DataArray, dim: str, output_path: str = "animation.gif", |
|
|
duration: int = 200, plot_func: Callable = None, **plot_kwargs) -> str: |
|
|
""" |
|
|
Create an animated GIF. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
dim: Dimension to animate over |
|
|
output_path: Output GIF file path |
|
|
duration: Duration per frame in milliseconds |
|
|
plot_func: Plotting function to use |
|
|
**plot_kwargs: Additional plotting parameters |
|
|
|
|
|
Returns: |
|
|
Path to created GIF file |
|
|
""" |
|
|
try: |
|
|
from PIL import Image |
|
|
except ImportError: |
|
|
raise ImportError("Pillow is required for GIF creation") |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
frame_paths = create_frame_sequence(da, dim, plot_func, temp_dir, **plot_kwargs) |
|
|
|
|
|
|
|
|
images = [] |
|
|
for frame_path in frame_paths: |
|
|
img = Image.open(frame_path) |
|
|
images.append(img) |
|
|
|
|
|
|
|
|
images[0].save(output_path, save_all=True, append_images=images[1:], |
|
|
duration=duration, loop=0) |
|
|
|
|
|
return output_path |