Spaces:
Runtime error
Runtime error
File size: 8,342 Bytes
29fab93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import fsspec
import pyarrow.parquet as pq
import numpy as np
from PIL import Image
from io import BytesIO
from rasterio.io import MemoryFile
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
from matplotlib.patches import Rectangle
import math
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
def crop_center(img_array, cropx, cropy):
y, x, c = img_array.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img_array[starty:starty+cropy, startx:startx+cropx]
def read_tif_bytes(tif_bytes):
with MemoryFile(tif_bytes) as mem_f:
with mem_f.open(driver='GTiff') as f:
return f.read().squeeze()
def read_row_memory(row_dict, columns=["thumbnail"]):
url = row_dict['parquet_url']
row_idx = row_dict['parquet_row']
fs_options = {
"cache_type": "readahead",
"block_size": 5 * 1024 * 1024
}
with fsspec.open(url, mode='rb', **fs_options) as f:
with pq.ParquetFile(f) as pf:
table = pf.read_row_group(row_idx, columns=columns)
row_output = {}
for col in columns:
col_data = table[col][0].as_py()
if col != 'thumbnail':
row_output[col] = read_tif_bytes(col_data)
else:
stream = BytesIO(col_data)
row_output[col] = Image.open(stream)
return row_output
def download_and_process_image(product_id, df_source=None, verbose=True):
if df_source is None:
if verbose: print("❌ Error: No DataFrame provided.")
return None, None
row_subset = df_source[df_source['product_id'] == product_id]
if len(row_subset) == 0:
if verbose: print(f"❌ Error: Product ID {product_id} not found in DataFrame.")
return None, None
row_dict = row_subset.iloc[0].to_dict()
if 'parquet_url' in row_dict:
url = row_dict['parquet_url']
if 'huggingface.co' in url:
row_dict['parquet_url'] = url.replace('https://huggingface.co', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
elif 'hf-mirror.com' in url:
row_dict['parquet_url'] = url.replace('https://hf-mirror.com', 'https://modelscope.cn').replace('resolve/main', 'resolve/master')
else:
if verbose: print("❌ Error: 'parquet_url' missing in metadata.")
return None, None
if verbose: print(f"⬇️ Fetching data for {product_id} from {row_dict['parquet_url']}...")
try:
bands_data = read_row_memory(row_dict, columns=['B04', 'B03', 'B02'])
if not all(b in bands_data for b in ['B04', 'B03', 'B02']):
if verbose: print(f"❌ Error: Missing bands in fetched data for {product_id}")
return None, None
rgb_img = np.stack([bands_data['B04'], bands_data['B03'], bands_data['B02']], axis=-1)
if verbose:
print(f"Raw RGB stats: Min={rgb_img.min()}, Max={rgb_img.max()}, Mean={rgb_img.mean()}, Dtype={rgb_img.dtype}")
# Check if data is already 0-255 or 0-1
if rgb_img.max() <= 255:
# Assume it might be uint8 or scaled
pass
rgb_norm = (2.5 * (rgb_img.astype(float) / 10000.0)).clip(0, 1)
rgb_uint8 = (rgb_norm * 255).astype(np.uint8)
if verbose:
print(f"Processed RGB stats: Min={rgb_uint8.min()}, Max={rgb_uint8.max()}, Mean={rgb_uint8.mean()}")
img_full = Image.fromarray(rgb_uint8)
if rgb_uint8.shape[0] >= 384 and rgb_uint8.shape[1] >= 384:
cropped_array = crop_center(rgb_uint8, 384, 384)
img_384 = Image.fromarray(cropped_array)
else:
if verbose: print(f"⚠️ Image too small {rgb_uint8.shape}, resizing to 384x384.")
img_384 = img_full.resize((384, 384))
if verbose: print(f"✅ Successfully processed {product_id}")
return img_384, img_full
except Exception as e:
if verbose: print(f"❌ Error processing {product_id}: {e}")
import traceback
traceback.print_exc()
return None, None
# Define Esri Imagery Class
class EsriImagery(cimgt.GoogleTiles):
def _image_url(self, tile):
x, y, z = tile
return f'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
from PIL import Image, ImageDraw, ImageFont
def get_placeholder_image(text="Image Unavailable", size=(384, 384)):
img = Image.new('RGB', size, color=(200, 200, 200))
d = ImageDraw.Draw(img)
try:
# Try to load a default font
font = ImageFont.load_default()
except:
font = None
# Draw text in center (rough approximation)
# For better centering we would need font metrics, but simple is fine here
d.text((20, size[1]//2), text, fill=(0, 0, 0), font=font)
return img
def get_esri_satellite_image(lat, lon, score=None, rank=None, query=None):
"""
Generates a satellite image visualization using Esri World Imagery via Cartopy.
Matches the style of the provided notebook.
Uses OO Matplotlib API for thread safety.
"""
try:
imagery = EsriImagery()
# Create figure using OO API
fig = Figure(figsize=(5, 5), dpi=100)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(1, 1, 1, projection=imagery.crs)
# Set extent to approx 10km x 10km around the point
extent_deg = 0.05
ax.set_extent([lon - extent_deg, lon + extent_deg, lat - extent_deg, lat + extent_deg], crs=ccrs.PlateCarree())
# Add the imagery
ax.add_image(imagery, 14)
# Add a marker for the center
ax.plot(lon, lat, marker='+', color='yellow', markersize=12, markeredgewidth=2, transform=ccrs.PlateCarree())
# Add Bounding Box (3840m x 3840m)
box_size_m = 384 * 10 # 3840m
# Convert meters to degrees (approx)
# 1 deg lat = 111320m
# 1 deg lon = 111320m * cos(lat)
dlat = (box_size_m / 111320)
dlon = (box_size_m / (111320 * math.cos(math.radians(lat))))
# Bottom-Left corner
rect_lon = lon - dlon / 2
rect_lat = lat - dlat / 2
# Add Rectangle
rect = Rectangle((rect_lon, rect_lat), dlon, dlat,
linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree())
ax.add_patch(rect)
# Title
title_parts = []
if query: title_parts.append(f"{query}")
if rank is not None: title_parts.append(f"Rank {rank}")
if score is not None: title_parts.append(f"Score: {score:.4f}")
ax.set_title("\n".join(title_parts), fontsize=10)
# Save to buffer
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
except Exception as e:
# Suppress full traceback for network errors to avoid log spam
error_msg = str(e)
if "Connection reset by peer" in error_msg or "Network is unreachable" in error_msg or "urlopen error" in error_msg:
print(f"⚠️ Network warning: Could not fetch Esri satellite map for ({lat:.4f}, {lon:.4f}). Server might be offline.")
else:
print(f"Error generating Esri image for {lat}, {lon}: {e}")
# Only print traceback for non-network errors
# import traceback
# traceback.print_exc()
# Return a placeholder image with text
return get_placeholder_image(f"Map Unavailable\n({lat:.2f}, {lon:.2f})")
def get_esri_satellite_image_url(lat, lon, zoom=14):
"""
Returns the URL for the Esri World Imagery tile at the given location.
"""
try:
imagery = EsriImagery()
# Calculate tile coordinates
# This is a simplification, cimgt handles this internally usually
# But for direct URL we might need more logic or just use the static map approach above
# For now, let's stick to the static map generation which works
pass
except:
pass
return None
|