import os import urllib.request import pandas as pd import numpy as np import matplotlib matplotlib.use('Agg') # Headless mode to avoid hangs import matplotlib.pyplot as plt import seaborn as sns import geopandas as gpd # Define directories PLOTS_DIR = 'plots' DATA_DIR = 'data' os.makedirs(PLOTS_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True) # 1. Download world countries map GeoJSON if not present geojson_path = os.path.join(DATA_DIR, 'countries.geojson') geojson_url = 'https://raw.githubusercontent.com/datasets/geo-boundaries-world-110m/master/countries.geojson' if not os.path.exists(geojson_path): print("Downloading world boundaries GeoJSON...") try: urllib.request.urlretrieve(geojson_url, geojson_path) print("Download complete.") except Exception as e: print(f"Error downloading GeoJSON: {e}") # Load world boundaries world = gpd.read_file(geojson_path) if os.path.exists(geojson_path) else None # Load iNaturalist unified master metadata print("Loading master dataset...") df = pd.read_parquet('metadata/inat_world_model_master.parquet') df['date_dt'] = pd.to_datetime(df['date'].astype(str)) df['year'] = df['date_dt'].dt.year df['month'] = df['date_dt'].dt.month # Set style sns.set_theme(style="whitegrid") plt.rcParams.update({ 'font.size': 11, 'axes.labelsize': 12, 'axes.titlesize': 14, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'figure.titlesize': 16, 'figure.dpi': 150 }) # Color palette helper palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] # ----------------- PLOT 1: Global Geographic Observations Map ----------------- print("Generating Plot 1: Global Geographic Map...") fig, ax = plt.subplots(figsize=(14, 8)) if world is not None: world.plot(ax=ax, color='#eef0f2', edgecolor='#bdc3c7', linewidth=0.5) # Plot observations colored by elevation sc = ax.scatter( df['longitude'], df['latitude'], c=df['elevation'].clip(0, 3000), # Clip for better visual contrast cmap='terrain', alpha=0.3, s=1.5, label='Observation' ) cbar = plt.colorbar(sc, ax=ax, shrink=0.6, pad=0.02) cbar.set_label('Elevation (m)', rotation=270, labelpad=15) ax.set_title("Global iNaturalist Observations Distribution (Colored by Elevation)") ax.set_xlabel("Longitude") ax.set_ylabel("Latitude") ax.set_xlim(-180, 180) ax.set_ylim(-60, 85) plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'geographic_map_global.png'), bbox_inches='tight', dpi=200) plt.close(fig) # ----------------- PLOT 2: Geographical Spread Comparison ----------------- # Target top-2 and bottom-2 geographical spread species # Top: Pacific Golden-Plover (Pluvialis fulva), Chestnut Munia (Lonchura atricapilla) # Bottom: Taiwan Barbet (Psilopogon nuchalis), Florida Scrub Jay (Aphelocoma coerulescens) print("Generating Plot 2: Geographic Spread Comparison...") target_geo_species = [ ("Pluvialis fulva", "Pacific Golden-Plover (Top 1 Spread)", "#1f77b4"), ("Lonchura atricapilla", "Chestnut Munia (Top 2 Spread)", "#e74c3c"), ("Psilopogon nuchalis", "Taiwan Barbet (Bottom 1 Spread)", "#2ecc71"), ("Aphelocoma coerulescens", "Florida Scrub Jay (Bottom 2 Spread)", "#9b59b6") ] fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=False, sharey=False) axes = axes.flatten() for i, (sci_name, label, color) in enumerate(target_geo_species): ax = axes[i] if world is not None: world.plot(ax=ax, color='#eef0f2', edgecolor='#bdc3c7', linewidth=0.5) sp_df = df[df['name'] == sci_name] ax.scatter( sp_df['longitude'], sp_df['latitude'], color=color, alpha=0.6, s=15, label=label ) # Calculate geographical box zoom around the observations for clarity lon_min, lon_max = sp_df['longitude'].min(), sp_df['longitude'].max() lat_min, lat_max = sp_df['latitude'].min(), sp_df['latitude'].max() lon_pad = max(5, (lon_max - lon_min) * 0.2) lat_pad = max(5, (lat_max - lat_min) * 0.2) # If wide spread, set to world view, else zoom to region if (lon_max - lon_min) > 120 or (lat_max - lat_min) > 80: ax.set_xlim(-180, 180) ax.set_ylim(-60, 85) else: ax.set_xlim(lon_min - lon_pad, lon_max + lon_pad) ax.set_ylim(lat_min - lat_pad, lat_max + lat_pad) ax.set_title(f"{label}\n({sci_name})") ax.set_xlabel("Longitude") ax.set_ylabel("Latitude") plt.suptitle("Geographical Spread Comparison: Migratory vs. Localized Species", y=0.98) plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'geographic_spread_comparison.png'), bbox_inches='tight', dpi=200) plt.close(fig) # ----------------- PLOT 3: Global Temporal & Seasonality ----------------- print("Generating Plot 3: Global Temporal & Seasonality...") fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) # Timeline (Observations by Year) year_counts = df['year'].value_counts().sort_index() # Filter to recent 25 years to focus on major timeline data year_counts = year_counts[year_counts.index >= 2000] sns.barplot(x=year_counts.index, y=year_counts.values, ax=ax1, color='#34495e') ax1.set_title("Timeline: Global Observations per Year (>= 2000)") ax1.set_xlabel("Year") ax1.set_ylabel("Observation Count") ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45) # Seasonality (Observations by Month) month_counts = df['month'].value_counts().sort_index() month_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] sns.barplot(x=month_labels, y=month_counts.values, ax=ax2, color='#2980b9') ax2.set_title("Seasonality: Global Observations by Month") ax2.set_xlabel("Month") ax2.set_ylabel("Observation Count") plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'temporal_global.png'), bbox_inches='tight', dpi=150) plt.close(fig) # ----------------- PLOT 4: Temporal Spread Comparison ----------------- # Target top-2 and bottom-2 temporal spread species # Top: Marabou Stork (Leptoptilos crumenifer), Golden-cheeked Warbler (Setophaga chrysoparia) # Bottom: Great-tailed Grackle (Quiscalus mexicanus), Booted Warbler (Iduna caligata) print("Generating Plot 4: Temporal Spread Comparison...") target_temp_species = [ ("Leptoptilos crumenifer", "Marabou Stork (Top 1 Temporal Spread)", "#d35400"), ("Setophaga chrysoparia", "Golden-cheeked Warbler (Top 2 Temporal Spread)", "#27ae60"), ("Quiscalus mexicanus", "Great-tailed Grackle (Bottom 1 Temporal Spread)", "#2980b9"), ("Iduna caligata", "Booted Warbler (Bottom 2 Temporal Spread)", "#8e44ad") ] fig, axes = plt.subplots(2, 2, figsize=(16, 10)) axes = axes.flatten() for i, (sci_name, label, color) in enumerate(target_temp_species): ax = axes[i] sp_df = df[df['name'] == sci_name] # Group by year-month to plot observations timeline sp_df_monthly = sp_df.groupby([sp_df['date_dt'].dt.year, sp_df['date_dt'].dt.month]).size() index_dates = [pd.Timestamp(year=y, month=m, day=1) for y, m in sp_df_monthly.index] ax.plot(index_dates, sp_df_monthly.values, marker='o', linestyle='-', color=color, linewidth=1.5, markersize=3) ax.set_title(f"{label}\n({sci_name})") ax.set_xlabel("Timeline") ax.set_ylabel("Observation Count per Month") ax.grid(True, linestyle='--', alpha=0.5) plt.suptitle("Temporal Spread Comparison: Broad Timeline vs. Clustered Observations", y=0.98) plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'temporal_spread_comparison.png'), bbox_inches='tight', dpi=200) plt.close(fig) # ----------------- PLOT 5: AVONET Trait Space ----------------- print("Generating Plot 5: AVONET Trait Space...") fig, ax = plt.subplots(figsize=(11, 7)) # Filter out records with null lifestyles or traits clean_traits = df.dropna(subset=['Mass', 'Hand-Wing.Index', 'Primary.Lifestyle']) # Limit data to speed up plotting and improve visibility clean_sample = clean_traits.sample(n=min(len(clean_traits), 20000), random_state=42) sns.scatterplot( data=clean_sample, x='Mass', y='Hand-Wing.Index', hue='Primary.Lifestyle', palette='Set2', alpha=0.6, s=15, ax=ax ) ax.set_xscale('log') ax.set_title("AVONET Trait Space: Flight Capability vs. Body Mass") ax.set_xlabel("Body Mass (g, Log Scale)") ax.set_ylabel("Hand-Wing Index (Flight Efficiency / Dispersal)") ax.legend(title='Primary Lifestyle', bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'avonet_traits.png'), bbox_inches='tight', dpi=150) plt.close(fig) # ----------------- PLOT 6: Species Image Grid (Old vs. New & Geographic) ----------------- print("Generating Plot 6: Species Image Grid...") images_to_plot = { "Pluvialis fulva (Pacific Golden-Plover)": [ ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/660d715f-871a-47e5-862f-dcdb86f4afdd.jpg", "Oldest (1980)\nHawaii (21.3, -158.1)"), ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/a0acb86a-bdbc-422a-83ea-3bf6b76aeaaa.jpg", "Old (1998)\nMidway Atoll (28.2, -177.4)"), ("data/images/val/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/c460631c-3000-4984-b9bd-96dd167855b5.jpg", "New (2019)\nSiberia, RU (70.9, 73.9)"), ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/18deb1d6-b0c8-44d1-84e8-ba8791b586af.jpg", "New (2019)\nMassachusetts, US (41.6, -70.0)") ], "Lonchura atricapilla (Chestnut Munia)": [ ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/b7ab52d0-9d7e-49ba-abf6-407eb568385a.jpg", "Oldest (1979)\nTaiwan (24.7, 121.8)"), ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/7f28e644-19f5-4420-b59f-6f71ca9c552c.jpg", "Old (2003)\nPortugal (41.9, -8.8)"), ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/570e0a9f-7c7e-4e08-8059-7c1aabb4f176.jpg", "New (2019)\nSulawesi, ID (0.5, 122.0)"), ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/6cc4541f-7289-4710-b062-e4779e061f6b.jpg", "New (2019)\nSabah, MY (6.0, 116.1)") ] } fig, axes = plt.subplots(2, 4, figsize=(18, 10)) for r_idx, (species_name, items) in enumerate(images_to_plot.items()): for c_idx, (img_path, label) in enumerate(items): ax = axes[r_idx, c_idx] if os.path.exists(img_path): img = plt.imread(img_path) ax.imshow(img) else: ax.text(0.5, 0.5, "Image Missing", ha='center', va='center') ax.set_title(label, fontsize=12) ax.axis('off') # Add row label on the far left column if c_idx == 0: ax.text(-0.2, 0.5, species_name, rotation=90, va='center', ha='right', transform=ax.transAxes, fontsize=14, fontweight='bold') plt.suptitle("Photo Quality Evolution & Geographical Differences for Widely Spread Species", y=0.98, fontsize=16, fontweight='bold') plt.tight_layout() fig.savefig(os.path.join(PLOTS_DIR, 'species_image_grid.png'), bbox_inches='tight', dpi=200) plt.close(fig) print("All plots (including species image grid) generated successfully in the 'plots/' directory!")