EarthExplorer / data_utils.py
VoyagerXvoyagerx's picture
sync from hf
29fab93
import fsspec
import pyarrow.parquet as pq
import numpy as np
from PIL import Image
from io import BytesIO
from rasterio.io import MemoryFile
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
from matplotlib.patches import Rectangle
import math
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
def crop_center(img_array, cropx, cropy):
y, x, c = img_array.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img_array[starty:starty+cropy, startx:startx+cropx]
def read_tif_bytes(tif_bytes):
with MemoryFile(tif_bytes) as mem_f:
with mem_f.open(driver='GTiff') as f:
return f.read().squeeze()
def read_row_memory(row_dict, columns=["thumbnail"]):
url = row_dict['parquet_url']
row_idx = row_dict['parquet_row']
fs_options = {
"cache_type": "readahead",
"block_size": 5 * 1024 * 1024
}
with fsspec.open(url, mode='rb', **fs_options) as f:
with pq.ParquetFile(f) as pf:
table = pf.read_row_group(row_idx, columns=columns)
row_output = {}
for col in columns:
col_data = table[col][0].as_py()
if col != 'thumbnail':
row_output[col] = read_tif_bytes(col_data)
else:
stream = BytesIO(col_data)
row_output[col] = Image.open(stream)
return row_output
def download_and_process_image(product_id, df_source=None, verbose=True):
if df_source is None:
if verbose: print("❌ Error: No DataFrame provided.")
return None, None
row_subset = df_source[df_source['product_id'] == product_id]
if len(row_subset) == 0:
if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
return None, None
row_dict = row_subset.iloc[0].to_dict()
if 'parquet_url' in row_dict:
url = row_dict['parquet_url']
if 'huggingface.co' in url:
row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
elif 'hf-mirror.com' in url:
row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
else:
if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
return None, None
if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
try:
bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
return None, None
rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
if verbose:
print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
# Check if data is already 0-255 or 0-1
if rgb_img.max() <= 255:
# Assume it might be uint8 or scaled
pass
rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
if verbose:
print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
img_full = Image.fromarray(rgb_uint8)
if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
cropped_array = crop_center(rgb_uint8, 384, 384)
img_384 = Image.fromarray(cropped_array)
else:
if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
img_384 = img_full.resize((384, 384))
if verbose: print(f"✅ Successfully processed {product_id}")
return img_384, img_full
except Exception as e:
if verbose: print(f"❌ Error processing {product_id}: {e}")
import traceback
traceback.print_exc()
return None, None
# Define Esri Imagery Class
class EsriImagery(cimgt.GoogleTiles):
def _image_url(self, tile):
x, y, z = tile
return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
from PIL import Image, ImageDraw, ImageFont
def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
img = Image.new('RGB', size, color=(200, 200, 200))
d = ImageDraw.Draw(img)
try:
# Try to load a default font
font = ImageFont.load_default()
except:
font = None
# Draw text in center (rough approximation)
# For better centering we would need font metrics, but simple is fine here
d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
return img
def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
"""
Generates a satellite image visualization using Esri World Imagery via Cartopy.
Matches the style of the provided notebook.
Uses OO Matplotlib API for thread safety.
"""
try:
imagery = EsriImagery()
# Create figure using OO API
fig = Figure(figsize=(5, 5), dpi=100)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
# Set extent to approx 10km x 10km around the point
extent_deg = 0.05
ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
# Add the imagery
ax.add_image(imagery, 14)
# Add a marker for the center
ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
# Add Bounding Box (3840m x 3840m)
box_size_m = 384 * 10 # 3840m
# Convert meters to degrees (approx)
# 1 deg lat = 111320m
# 1 deg lon = 111320m * cos(lat)
dlat = (box_size_m / 111320)
dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
# Bottom-Left corner
rect_lon = lon - dlon / 2
rect_lat = lat - dlat / 2
# Add Rectangle
rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)
# Title
title_parts = []
if query: title_parts.append(f"{query}")
if rank is not None: title_parts.append(f"Rank {rank}")
if score is not None: title_parts.append(f"Score: {score:.4f}")
ax.set_title("\n".join(title_parts), fontsize=10)
# Save to buffer
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
except Exception as e:
# Suppress full traceback for network errors to avoid log spam
error_msg = str(e)
if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
else:
print(f"Error generating Esri image for {lat}, {lon}: {e}")
# Only print traceback for non-network errors
# import traceback
# traceback.print_exc()
# Return a placeholder image with text
return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
def get_esri_satellite_image_url(lat, lon, zoom=14):
"""
Returns the URL for the Esri World Imagery tile at the given location.
"""
try:
imagery = EsriImagery()
# Calculate tile coordinates
# This is a simplification, cimgt handles this internally usually
# But for direct URL we might need more logic or just use the static map approach above
# For now, let's stick to the static map generation which works
pass
except:
pass
return None