File size: 5,286 Bytes
363abf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """
Frame rendering helpers for episode replay GIFs.
"""
from __future__ import annotations
from typing import List
import numpy as np
def render_frame(state: dict, step: int, stats: dict | None = None) -> np.ndarray:
"""
Render a ground-truth state dict into an RGB uint8 array (H_px, W_px, 3).
The figure is 8x8 inches at 100 dpi = 800x800 px.
Main panel (top 85%): grid. Bottom strip: stats bar.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyArrow
import io
grid = state["grid"]
rows = len(grid)
cols = len(grid[0]) if rows > 0 else 1
fig = plt.figure(figsize=(8, 8), dpi=100)
# Main panel
ax = fig.add_axes([0.02, 0.15, 0.96, 0.83])
# Stats strip
ax_bar = fig.add_axes([0.02, 0.01, 0.96, 0.12])
ax_bar.axis("off")
# ββ Build colour grid ββ
rgb = np.ones((rows, cols, 3))
for r in range(rows):
for c in range(cols):
cell = grid[r][c]
fs = cell["fire_state"]
intensity = cell.get("fire_intensity", 0.0)
if fs == "burning":
sat = 0.4 + 0.6 * intensity
rgb[r, c] = [1.0, 1.0 - sat * 0.8, 0.0]
elif fs == "ember":
rgb[r, c] = [0.9, 0.4, 0.0]
elif fs == "burned_out":
rgb[r, c] = [0.25, 0.22, 0.20]
elif fs == "firebreak":
rgb[r, c] = [0.55, 0.35, 0.15]
elif fs == "suppressed":
rgb[r, c] = [0.6, 0.8, 0.6]
else:
# Unburned: shade by fuel
fuel = cell.get("fuel_type", "grass")
if fuel == "water":
rgb[r, c] = [0.3, 0.5, 0.9]
elif fuel == "road":
rgb[r, c] = [0.7, 0.7, 0.7]
elif fuel == "timber":
rgb[r, c] = [0.1, 0.45, 0.1]
elif fuel == "shrub":
rgb[r, c] = [0.5, 0.7, 0.2]
elif fuel == "urban":
rgb[r, c] = [0.8, 0.75, 0.7]
else:
rgb[r, c] = [0.7, 0.85, 0.4]
ax.imshow(rgb, origin="upper", aspect="auto", interpolation="nearest")
# ββ Populated cell outlines ββ
for r in range(rows):
for c in range(cols):
if grid[r][c].get("is_populated"):
rect = mpatches.Rectangle(
(c - 0.5, r - 0.5), 1, 1,
linewidth=1.5, edgecolor="blue", facecolor="none"
)
ax.add_patch(rect)
# ββ Crew markers ββ
resources = state.get("resources", {})
for crew in resources.get("crews", []):
if not crew.get("is_deployed") or not crew.get("is_active", True):
continue
cr, cc = crew["row"], crew["col"]
ax.plot(cc, cr, "o", color="lime", markersize=7, markeredgecolor="black", markeredgewidth=0.8)
ax.text(cc, cr - 0.6, crew["crew_id"].replace("crew_", "c"),
ha="center", va="bottom", fontsize=5, color="white",
fontweight="bold")
ax.set_xlim(-0.5, cols - 0.5)
ax.set_ylim(rows - 0.5, -0.5)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(f"Step {step}", fontsize=9, pad=2)
# ββ Stats strip ββ
weather = state.get("weather", {})
wind_spd = weather.get("wind_speed_kmh", 0)
wind_dir = weather.get("wind_direction_deg", 0)
cells_burning = state.get("cells_burning", 0) if stats is None else stats.get("cells_burning", 0)
containment = state.get("containment_pct", 0) if stats is None else stats.get("containment_pct", 0)
pop_lost = state.get("population_lost", 0) if stats is None else stats.get("population_lost", 0)
# Fallback: compute from grid if not in state root
if cells_burning == 0:
cells_burning = sum(1 for r in grid for c in r if c["fire_state"] == "burning")
strip_text = (
f"Step {step} | Burning: {cells_burning} | Containment: {containment:.1f}% | "
f"Pop lost: {pop_lost} | Wind: {wind_spd:.0f} km/h"
)
ax_bar.text(0.5, 0.5, strip_text, ha="center", va="center",
fontsize=8, transform=ax_bar.transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="#f0f0f0", edgecolor="gray"))
# Wind arrow
import math
rad = math.radians(wind_dir)
dx, dy = math.sin(rad) * 0.08, -math.cos(rad) * 0.08
ax_bar.annotate("", xy=(0.92 + dx, 0.5 + dy), xytext=(0.92 - dx, 0.5 - dy),
xycoords="axes fraction",
arrowprops=dict(arrowstyle="->", color="darkred", lw=1.5))
# Convert figure to RGB array
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100)
plt.close(fig)
buf.seek(0)
import imageio.v3 as iio
img = iio.imread(buf, extension=".png")
return img[:, :, :3].astype(np.uint8)
def render_episode_gif(frames: List[np.ndarray], output_path: str, fps: int = 5) -> None:
"""Stitch RGB frames into an animated GIF at the given fps."""
import imageio.v3 as iio
iio.imwrite(output_path, frames, extension=".gif", loop=0,
duration=int(1000 / fps))
|