Spaces:
Running
Running
| import os | |
| import urllib.request | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') # Headless mode to avoid hangs | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import geopandas as gpd | |
| # Define directories | |
| PLOTS_DIR = 'plots' | |
| DATA_DIR = 'data' | |
| os.makedirs(PLOTS_DIR, exist_ok=True) | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| # 1. Download world countries map GeoJSON if not present | |
| geojson_path = os.path.join(DATA_DIR, 'countries.geojson') | |
| geojson_url = 'https://raw.githubusercontent.com/datasets/geo-boundaries-world-110m/master/countries.geojson' | |
| if not os.path.exists(geojson_path): | |
| print("Downloading world boundaries GeoJSON...") | |
| try: | |
| urllib.request.urlretrieve(geojson_url, geojson_path) | |
| print("Download complete.") | |
| except Exception as e: | |
| print(f"Error downloading GeoJSON: {e}") | |
| # Load world boundaries | |
| world = gpd.read_file(geojson_path) if os.path.exists(geojson_path) else None | |
| # Load iNaturalist unified master metadata | |
| print("Loading master dataset...") | |
| df = pd.read_parquet('metadata/inat_world_model_master.parquet') | |
| df['date_dt'] = pd.to_datetime(df['date'].astype(str)) | |
| df['year'] = df['date_dt'].dt.year | |
| df['month'] = df['date_dt'].dt.month | |
| # Set style | |
| sns.set_theme(style="whitegrid") | |
| plt.rcParams.update({ | |
| 'font.size': 11, | |
| 'axes.labelsize': 12, | |
| 'axes.titlesize': 14, | |
| 'xtick.labelsize': 10, | |
| 'ytick.labelsize': 10, | |
| 'figure.titlesize': 16, | |
| 'figure.dpi': 150 | |
| }) | |
| # Color palette helper | |
| palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] | |
| # ----------------- PLOT 1: Global Geographic Observations Map ----------------- | |
| print("Generating Plot 1: Global Geographic Map...") | |
| fig, ax = plt.subplots(figsize=(14, 8)) | |
| if world is not None: | |
| world.plot(ax=ax, color='#eef0f2', edgecolor='#bdc3c7', linewidth=0.5) | |
| # Plot observations colored by elevation | |
| sc = ax.scatter( | |
| df['longitude'], df['latitude'], | |
| c=df['elevation'].clip(0, 3000), # Clip for better visual contrast | |
| cmap='terrain', alpha=0.3, s=1.5, label='Observation' | |
| ) | |
| cbar = plt.colorbar(sc, ax=ax, shrink=0.6, pad=0.02) | |
| cbar.set_label('Elevation (m)', rotation=270, labelpad=15) | |
| ax.set_title("Global iNaturalist Observations Distribution (Colored by Elevation)") | |
| ax.set_xlabel("Longitude") | |
| ax.set_ylabel("Latitude") | |
| ax.set_xlim(-180, 180) | |
| ax.set_ylim(-60, 85) | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'geographic_map_global.png'), bbox_inches='tight', dpi=200) | |
| plt.close(fig) | |
| # ----------------- PLOT 2: Geographical Spread Comparison ----------------- | |
| # Target top-2 and bottom-2 geographical spread species | |
| # Top: Pacific Golden-Plover (Pluvialis fulva), Chestnut Munia (Lonchura atricapilla) | |
| # Bottom: Taiwan Barbet (Psilopogon nuchalis), Florida Scrub Jay (Aphelocoma coerulescens) | |
| print("Generating Plot 2: Geographic Spread Comparison...") | |
| target_geo_species = [ | |
| ("Pluvialis fulva", "Pacific Golden-Plover (Top 1 Spread)", "#1f77b4"), | |
| ("Lonchura atricapilla", "Chestnut Munia (Top 2 Spread)", "#e74c3c"), | |
| ("Psilopogon nuchalis", "Taiwan Barbet (Bottom 1 Spread)", "#2ecc71"), | |
| ("Aphelocoma coerulescens", "Florida Scrub Jay (Bottom 2 Spread)", "#9b59b6") | |
| ] | |
| fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=False, sharey=False) | |
| axes = axes.flatten() | |
| for i, (sci_name, label, color) in enumerate(target_geo_species): | |
| ax = axes[i] | |
| if world is not None: | |
| world.plot(ax=ax, color='#eef0f2', edgecolor='#bdc3c7', linewidth=0.5) | |
| sp_df = df[df['name'] == sci_name] | |
| ax.scatter( | |
| sp_df['longitude'], sp_df['latitude'], | |
| color=color, alpha=0.6, s=15, label=label | |
| ) | |
| # Calculate geographical box zoom around the observations for clarity | |
| lon_min, lon_max = sp_df['longitude'].min(), sp_df['longitude'].max() | |
| lat_min, lat_max = sp_df['latitude'].min(), sp_df['latitude'].max() | |
| lon_pad = max(5, (lon_max - lon_min) * 0.2) | |
| lat_pad = max(5, (lat_max - lat_min) * 0.2) | |
| # If wide spread, set to world view, else zoom to region | |
| if (lon_max - lon_min) > 120 or (lat_max - lat_min) > 80: | |
| ax.set_xlim(-180, 180) | |
| ax.set_ylim(-60, 85) | |
| else: | |
| ax.set_xlim(lon_min - lon_pad, lon_max + lon_pad) | |
| ax.set_ylim(lat_min - lat_pad, lat_max + lat_pad) | |
| ax.set_title(f"{label}\n({sci_name})") | |
| ax.set_xlabel("Longitude") | |
| ax.set_ylabel("Latitude") | |
| plt.suptitle("Geographical Spread Comparison: Migratory vs. Localized Species", y=0.98) | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'geographic_spread_comparison.png'), bbox_inches='tight', dpi=200) | |
| plt.close(fig) | |
| # ----------------- PLOT 3: Global Temporal & Seasonality ----------------- | |
| print("Generating Plot 3: Global Temporal & Seasonality...") | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) | |
| # Timeline (Observations by Year) | |
| year_counts = df['year'].value_counts().sort_index() | |
| # Filter to recent 25 years to focus on major timeline data | |
| year_counts = year_counts[year_counts.index >= 2000] | |
| sns.barplot(x=year_counts.index, y=year_counts.values, ax=ax1, color='#34495e') | |
| ax1.set_title("Timeline: Global Observations per Year (>= 2000)") | |
| ax1.set_xlabel("Year") | |
| ax1.set_ylabel("Observation Count") | |
| ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45) | |
| # Seasonality (Observations by Month) | |
| month_counts = df['month'].value_counts().sort_index() | |
| month_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] | |
| sns.barplot(x=month_labels, y=month_counts.values, ax=ax2, color='#2980b9') | |
| ax2.set_title("Seasonality: Global Observations by Month") | |
| ax2.set_xlabel("Month") | |
| ax2.set_ylabel("Observation Count") | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'temporal_global.png'), bbox_inches='tight', dpi=150) | |
| plt.close(fig) | |
| # ----------------- PLOT 4: Temporal Spread Comparison ----------------- | |
| # Target top-2 and bottom-2 temporal spread species | |
| # Top: Marabou Stork (Leptoptilos crumenifer), Golden-cheeked Warbler (Setophaga chrysoparia) | |
| # Bottom: Great-tailed Grackle (Quiscalus mexicanus), Booted Warbler (Iduna caligata) | |
| print("Generating Plot 4: Temporal Spread Comparison...") | |
| target_temp_species = [ | |
| ("Leptoptilos crumenifer", "Marabou Stork (Top 1 Temporal Spread)", "#d35400"), | |
| ("Setophaga chrysoparia", "Golden-cheeked Warbler (Top 2 Temporal Spread)", "#27ae60"), | |
| ("Quiscalus mexicanus", "Great-tailed Grackle (Bottom 1 Temporal Spread)", "#2980b9"), | |
| ("Iduna caligata", "Booted Warbler (Bottom 2 Temporal Spread)", "#8e44ad") | |
| ] | |
| fig, axes = plt.subplots(2, 2, figsize=(16, 10)) | |
| axes = axes.flatten() | |
| for i, (sci_name, label, color) in enumerate(target_temp_species): | |
| ax = axes[i] | |
| sp_df = df[df['name'] == sci_name] | |
| # Group by year-month to plot observations timeline | |
| sp_df_monthly = sp_df.groupby([sp_df['date_dt'].dt.year, sp_df['date_dt'].dt.month]).size() | |
| index_dates = [pd.Timestamp(year=y, month=m, day=1) for y, m in sp_df_monthly.index] | |
| ax.plot(index_dates, sp_df_monthly.values, marker='o', linestyle='-', color=color, linewidth=1.5, markersize=3) | |
| ax.set_title(f"{label}\n({sci_name})") | |
| ax.set_xlabel("Timeline") | |
| ax.set_ylabel("Observation Count per Month") | |
| ax.grid(True, linestyle='--', alpha=0.5) | |
| plt.suptitle("Temporal Spread Comparison: Broad Timeline vs. Clustered Observations", y=0.98) | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'temporal_spread_comparison.png'), bbox_inches='tight', dpi=200) | |
| plt.close(fig) | |
| # ----------------- PLOT 5: AVONET Trait Space ----------------- | |
| print("Generating Plot 5: AVONET Trait Space...") | |
| fig, ax = plt.subplots(figsize=(11, 7)) | |
| # Filter out records with null lifestyles or traits | |
| clean_traits = df.dropna(subset=['Mass', 'Hand-Wing.Index', 'Primary.Lifestyle']) | |
| # Limit data to speed up plotting and improve visibility | |
| clean_sample = clean_traits.sample(n=min(len(clean_traits), 20000), random_state=42) | |
| sns.scatterplot( | |
| data=clean_sample, | |
| x='Mass', y='Hand-Wing.Index', | |
| hue='Primary.Lifestyle', | |
| palette='Set2', | |
| alpha=0.6, | |
| s=15, | |
| ax=ax | |
| ) | |
| ax.set_xscale('log') | |
| ax.set_title("AVONET Trait Space: Flight Capability vs. Body Mass") | |
| ax.set_xlabel("Body Mass (g, Log Scale)") | |
| ax.set_ylabel("Hand-Wing Index (Flight Efficiency / Dispersal)") | |
| ax.legend(title='Primary Lifestyle', bbox_to_anchor=(1.05, 1), loc='upper left') | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'avonet_traits.png'), bbox_inches='tight', dpi=150) | |
| plt.close(fig) | |
| # ----------------- PLOT 6: Species Image Grid (Old vs. New & Geographic) ----------------- | |
| print("Generating Plot 6: Species Image Grid...") | |
| images_to_plot = { | |
| "Pluvialis fulva (Pacific Golden-Plover)": [ | |
| ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/660d715f-871a-47e5-862f-dcdb86f4afdd.jpg", "Oldest (1980)\nHawaii (21.3, -158.1)"), | |
| ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/a0acb86a-bdbc-422a-83ea-3bf6b76aeaaa.jpg", "Old (1998)\nMidway Atoll (28.2, -177.4)"), | |
| ("data/images/val/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/c460631c-3000-4984-b9bd-96dd167855b5.jpg", "New (2019)\nSiberia, RU (70.9, 73.9)"), | |
| ("data/images/train/03368_Animalia_Chordata_Aves_Charadriiformes_Charadriidae_Pluvialis_fulva/18deb1d6-b0c8-44d1-84e8-ba8791b586af.jpg", "New (2019)\nMassachusetts, US (41.6, -70.0)") | |
| ], | |
| "Lonchura atricapilla (Chestnut Munia)": [ | |
| ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/b7ab52d0-9d7e-49ba-abf6-407eb568385a.jpg", "Oldest (1979)\nTaiwan (24.7, 121.8)"), | |
| ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/7f28e644-19f5-4420-b59f-6f71ca9c552c.jpg", "Old (2003)\nPortugal (41.9, -8.8)"), | |
| ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/570e0a9f-7c7e-4e08-8059-7c1aabb4f176.jpg", "New (2019)\nSulawesi, ID (0.5, 122.0)"), | |
| ("data/images/train/03799_Animalia_Chordata_Aves_Passeriformes_Estrildidae_Lonchura_atricapilla/6cc4541f-7289-4710-b062-e4779e061f6b.jpg", "New (2019)\nSabah, MY (6.0, 116.1)") | |
| ] | |
| } | |
| fig, axes = plt.subplots(2, 4, figsize=(18, 10)) | |
| for r_idx, (species_name, items) in enumerate(images_to_plot.items()): | |
| for c_idx, (img_path, label) in enumerate(items): | |
| ax = axes[r_idx, c_idx] | |
| if os.path.exists(img_path): | |
| img = plt.imread(img_path) | |
| ax.imshow(img) | |
| else: | |
| ax.text(0.5, 0.5, "Image Missing", ha='center', va='center') | |
| ax.set_title(label, fontsize=12) | |
| ax.axis('off') | |
| # Add row label on the far left column | |
| if c_idx == 0: | |
| ax.text(-0.2, 0.5, species_name, rotation=90, va='center', ha='right', | |
| transform=ax.transAxes, fontsize=14, fontweight='bold') | |
| plt.suptitle("Photo Quality Evolution & Geographical Differences for Widely Spread Species", y=0.98, fontsize=16, fontweight='bold') | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'species_image_grid.png'), bbox_inches='tight', dpi=200) | |
| plt.close(fig) | |
| print("All plots (including species image grid) generated successfully in the 'plots/' directory!") | |