Spaces:
Sleeping
Sleeping
File size: 10,876 Bytes
cfea739 c110c9c cfea739 868635f cfea739 868635f cfea739 868635f cfea739 868635f cfea739 1265936 cfea739 1265936 cfea739 4f0125c cfea739 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# plot_generator.py
# Generate air pollution maps for India using GeoPandas for the map outline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for web apps
import geopandas as gpd
from pathlib import Path
from datetime import datetime
from constants import INDIA_BOUNDS, COLOR_THEMES
import warnings
warnings.filterwarnings('ignore')
class IndiaMapPlotter:
def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"):
"""
Initialize the map plotter
Parameters:
plots_dir (str): Directory to save plots
shapefile_path (str): Path to India boundary shapefile
"""
self.plots_dir = Path(plots_dir)
self.plots_dir.mkdir(exist_ok=True)
try:
self.india_map = gpd.read_file(shapefile_path)
# Ensure it's in lat/lon (WGS84)
if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326:
self.india_map = self.india_map.to_crs(epsg=4326)
except Exception as e:
raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. "
f"Please ensure the file exists. Error: {e}")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['font.size'] = 10
def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None):
"""
Create air pollution map over India
"""
try:
# Metadata extraction remains the same
lats = metadata['lats']
lons = metadata['lons']
var_name = metadata['variable_name']
display_name = metadata['display_name']
units = metadata['units']
pressure_level = metadata.get('pressure_level')
time_stamp = metadata.get('timestamp_str')
# Color theme logic remains the same
if color_theme is None:
from constants import AIR_POLLUTION_VARIABLES
color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis')
if color_theme not in COLOR_THEMES:
print(f"Warning: Color theme '{color_theme}' not found, using 'viridis'")
color_theme = 'viridis'
# Create figure and axes - match interactive plot proportions (1400x1000 = 1.4:1 ratio)
fig = plt.figure(figsize=(16, 10)) # Wider to match interactive plot
ax = fig.add_subplot(1, 1, 1)
# Set map extent
ax.set_xlim(INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max'])
ax.set_ylim(INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max'])
# --- KEY CHANGE: PLOT ORDER & ZORDER ---
# 1. Plot the pollution data in the background (lower zorder) - pixel-wise like interactive plots
if lons.ndim == 1 and lats.ndim == 1:
lon_grid, lat_grid = np.meshgrid(lons, lats)
else:
lon_grid, lat_grid = lons, lats
valid_data = data_values[~np.isnan(data_values)]
if len(valid_data) == 0:
raise ValueError("All data values are NaN - cannot create plot")
from constants import AIR_POLLUTION_VARIABLES
vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90)
vmin = np.nanpercentile(valid_data, 5)
vmax = np.nanpercentile(valid_data, vmax_percentile)
if vmax <= vmin:
vmax = vmin + 1.0
# Use imshow for pixel-wise display - matches interactive plot orientation
extent = [lons.min(), lons.max(), lats.min(), lats.max()]
# Handle latitude order for proper orientation
# NetCDF files often have descending latitudes, but imshow with origin='lower' expects ascending
lat_ascending = lats[0] < lats[-1] if len(lats) > 1 else True
if lat_ascending:
# Lats are ascending (good for origin='lower')
plot_data = data_values
else:
# Lats are descending, flip to match origin='lower'
plot_data = np.flipud(data_values)
im = ax.imshow(plot_data, cmap=color_theme, vmin=vmin, vmax=vmax,
extent=extent, origin='lower', aspect='auto', # Changed to 'auto' to match interactive plot
interpolation='nearest', zorder=1)
# Auto-adjust bounds if INDIA_BOUNDS is too small or wrong
xmin, ymin, xmax, ymax = self.india_map.total_bounds
if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max'] and INDIA_BOUNDS['lon_min'] <= xmax <= INDIA_BOUNDS['lon_max']):
print("⚠️ Warning: Using shapefile's actual bounds instead of INDIA_BOUNDS.")
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
# 2. Plot the India map outlines on top of the data (higher zorder)
self.india_map.plot(ax=ax, edgecolor='black', facecolor='none',
linewidth=0.8, zorder=2) # <-- CHANGED: Set zorder=2 (foreground)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, shrink=0.6, pad=0.02, aspect=30)
cbar_label = f"{display_name}" + (f" ({units})" if units else "")
cbar.set_label(cbar_label, fontsize=12, labelpad=15)
# Add gridlines and labels
ax.grid(True, linestyle='--', alpha=0.6, color='gray', zorder=3)
ax.set_xlabel("Longitude", fontsize=10)
ax.set_ylabel("Latitude", fontsize=10)
ax.tick_params(axis='both', which='major', labelsize=10)
# Title creation logic - include pressure level and plot type
if custom_title:
title = custom_title
else:
title = f'{display_name} Concentration over India (Static)'
if pressure_level:
title += f' at {pressure_level} hPa'
title += f' on {time_stamp}'
plt.title(title, fontsize=14, pad=20, weight='bold')
# Statistics and theme info boxes remain the same
stats_text = self._create_stats_text(valid_data, units)
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.9),
verticalalignment='top', fontsize=10, zorder=4)
theme_text = f"Color Theme: {COLOR_THEMES[color_theme]}"
ax.text(0.98, 0.02, theme_text, transform=ax.transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8),
horizontalalignment='right', verticalalignment='bottom', fontsize=9, zorder=4)
plt.tight_layout()
plot_path = None
if save_plot:
plot_path = self._save_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp)
plt.close(fig)
return plot_path
except Exception as e:
plt.close('all')
raise Exception(f"Error creating map: {str(e)}")
# All other helper methods (_create_stats_text, _save_plot, etc.) are unchanged.
# The `create_comparison_plot` method is also left out for brevity but would need the same zorder fix.
# The full, unchanged code for the helper methods from the previous answer is still valid.
def _create_stats_text(self, data, units):
units_str = f" {units}" if units else ""
stats = {'Min': np.nanmin(data), 'Max': np.nanmax(data), 'Mean': np.nanmean(data), 'Median': np.nanmedian(data), 'Std': np.nanstd(data)}
def format_number(val):
if abs(val) >= 1000: return f"{val:.0f}"
elif abs(val) >= 10: return f"{val:.1f}"
else: return f"{val:.2f}"
stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()]
return "\n".join(stats_lines)
def _save_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp):
# Handle None values with fallbacks
display_name = display_name or var_name or 'Unknown'
time_stamp = time_stamp or 'Unknown_Time'
safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('₂', '2').replace('₃', '3').replace('.', '_')
safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_')
filename_parts = [f"{safe_display_name}_India"]
if pressure_level:
filename_parts.append(f"{int(pressure_level)}hPa")
filename_parts.extend([color_theme, safe_time_stamp])
filename = "_".join(filename_parts) + ".png"
plot_path = self.plots_dir / filename
fig.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
print(f"Plot saved: {plot_path}")
return str(plot_path)
def list_available_themes(self):
return COLOR_THEMES
def test_plot_generator():
print("Testing plot generator with GeoPandas and zorder fix...")
lats, lons = np.linspace(6, 38, 50), np.linspace(68, 98, 60)
lon_grid, lat_grid = np.meshgrid(lons, lats)
data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50
data += np.random.normal(0, 10, data.shape)
metadata = {
'variable_name': 'pm25', 'display_name': 'PM2.5', 'units': 'µg/m³',
'lats': lats, 'lons': lons, 'pressure_level': None,
'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
shapefile_path = "shapefiles/India_State_Boundary.shp"
if not Path(shapefile_path).exists():
print(f"❌ Test failed: Shapefile not found at '{shapefile_path}'.")
print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.")
return False
plotter = IndiaMapPlotter(shapefile_path=shapefile_path)
try:
plot_path = plotter.create_india_map(data, metadata, color_theme='YlOrRd')
print(f"✅ Test plot created successfully: {plot_path}")
return True
except Exception as e:
print(f"❌ Test failed: {str(e)}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
test_plot_generator() |