Spaces:
Paused
Paused
| """Script to get BEV images from a dataset of locations. | |
| Example usage: | |
| python3.9 -m mia.bev.get_bev | |
| """ | |
| import argparse | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| import io | |
| import os | |
| import requests | |
| import contextlib | |
| import traceback | |
| import colour | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| import pandas as pd | |
| import geopandas as gpd | |
| import torch.nn as nn | |
| import torch | |
| from tqdm import tqdm | |
| from filelock import FileLock | |
| from math import sqrt, ceil | |
| import svgwrite | |
| import cairosvg | |
| from PIL import Image | |
| from xml.etree import ElementTree as ET | |
| from pyproj.transformer import Transformer | |
| from shapely.geometry import box | |
| from omegaconf import OmegaConf | |
| import urllib3 | |
| from map_machine.map_configuration import MapConfiguration | |
| from map_machine.scheme import Scheme | |
| from map_machine.geometry.boundary_box import BoundaryBox | |
| from map_machine.osm.osm_getter import NetworkError | |
| from map_machine.osm.osm_reader import OSMData | |
| from map_machine.geometry.flinger import MercatorFlinger | |
| from map_machine.pictogram.icon import ShapeExtractor | |
| from map_machine.workspace import workspace | |
| from map_machine.mapper import Map | |
| from map_machine.constructor import Constructor | |
| from .. import logger | |
| from .image import center_crop_to_size, center_pad | |
| # MUST match colors from map rendering style | |
| COLORS = { | |
| "road": "#000", | |
| "crossing": "#F00", | |
| "explicit_pedestrian": "#FF0", | |
| "park": "#0F0", | |
| "building": "#F0F", | |
| "water": "#00F", | |
| "terrain": "#0FF", | |
| "parking": "#AAA", | |
| "train": "#555" | |
| } | |
| # While the color mapping above must match what is in the | |
| # rendering style, the pretty colors below are just for visualization | |
| # purposes and can easily be changed below without worrying. | |
| # Colors set to None will not be rendered in rendered masks | |
| PRETTY_COLORS = { | |
| "road": "#444", | |
| "crossing": "#F4A261", | |
| "explicit_pedestrian": "#E9C46A", | |
| "park": None, | |
| "building": "#E76F51", | |
| "water": None, | |
| "terrain": "#2A9D8F", | |
| "parking": "#CCC", | |
| "train": None | |
| } | |
| # Better order for visualization | |
| VIS_ORDER = ["terrain", "water", "park", "parking", "train", | |
| "road", "explicit_pedestrian", "crossing", "building"] | |
| def checkColor(code): | |
| def check_ele(ele): | |
| isColor = False | |
| if "stroke" in ele.attribs: | |
| if ele.attribs["stroke"] != "none": | |
| color = colour.Color(ele.attribs["stroke"]) | |
| isColor |= color == colour.Color(code) | |
| if "fill" in ele.attribs: | |
| if ele.attribs["fill"] != "none": | |
| color = colour.Color(ele.attribs["fill"]) | |
| isColor |= color == colour.Color(code) | |
| return isColor | |
| return check_ele | |
| def hex2rgb(hex_str): | |
| hex_str = hex_str.lstrip('#') | |
| if len(hex_str) == 3: | |
| hex_str = "".join([hex_str[i//2] for i in range(6)]) | |
| return tuple(int(hex_str[i:i+2], 16) for i in (0, 2, 4)) | |
| def mask2rgb(mask, pretty=True): | |
| H,W,N = mask.shape | |
| rgb = np.ones((H,W,3), dtype=np.uint8)*255 | |
| cmap = PRETTY_COLORS if pretty else COLORS | |
| key2mask_i = dict(zip(cmap.keys(), range(N))) | |
| for k in VIS_ORDER: | |
| if cmap[k]: | |
| rgb[mask[:,:, key2mask_i[k]]>0.5] = (np.array(hex2rgb(cmap[k]))) | |
| return rgb | |
| def draw_bev(bbox: BoundaryBox, osm_data: OSMData, | |
| configuration: MapConfiguration, meters_per_pixel: float, heading: float): | |
| """Rasterize OSM data as a BEV image""" | |
| lat = bbox.center()[0] | |
| # Equation rearranged from https://wiki.openstreetmap.org/wiki/Zoom_levels | |
| # To get zoom level given meters_per_pixel | |
| z = np.log2(np.abs(osm_data.equator_length*np.cos(np.deg2rad(lat))/meters_per_pixel/256)) | |
| flinger = MercatorFlinger(bbox, z, osm_data.equator_length) | |
| size = flinger.size | |
| svg: svgwrite.Drawing = svgwrite.Drawing(None, size) # None since we are not saving an svg file | |
| icon_extractor: ShapeExtractor = ShapeExtractor( | |
| workspace.ICONS_PATH, workspace.ICONS_CONFIG_PATH | |
| ) | |
| constructor: Constructor = Constructor( | |
| osm_data=osm_data, | |
| flinger=flinger, | |
| extractor=icon_extractor, | |
| configuration=configuration, | |
| ) | |
| constructor.construct() | |
| map_: Map = Map(flinger=flinger, svg=svg, configuration=configuration) | |
| try: | |
| imgs = [] | |
| map_.draw(constructor) | |
| # svg.defs.add(svgwrite.container.Style(f"transform: rotate({str(heading)}deg)")) | |
| for ele in svg.elements: | |
| ele.rotate(360 - heading, (size[0]/2, size[1]/2)) | |
| for k, v in COLORS.items(): | |
| svg_new = svg.copy() | |
| svg_new.elements = list(filter(checkColor(v), svg_new.elements)) | |
| png_byte_string = cairosvg.svg2png(bytestring=svg_new.tostring(), | |
| output_width=size[0], | |
| output_height=size[1]) # convert svg to png | |
| img = Image.open(io.BytesIO(png_byte_string)) | |
| imgs.append(img) | |
| except Exception as e: | |
| # Prepare the stack trace | |
| stack_trace = traceback.format_exc() | |
| logger.error(f"Failed to render BEV for bbox {bbox.get_format()}. Exception: {repr(e)}. Skipping.. Stack trace: {stack_trace}") | |
| return None, None | |
| return imgs, svg | |
| def process_img(img, num_pixels, heading=None): | |
| """Rotate + Crop to correct for heading and ensure correct dimensions""" | |
| img = center_pad(img, num_pixels, num_pixels) | |
| s = min(img.size) | |
| squared_img = center_crop_to_size(img, s, s) # Ensure it is square before rotating (Perhaps not needed) | |
| if heading: | |
| squared_img = squared_img.rotate(heading, expand=False, resample=Image.Resampling.BILINEAR) | |
| center_cropped_bev_img = center_crop_to_size(squared_img, num_pixels, num_pixels) | |
| # robot_cropped_bev_img = center_cropped_bev_img.crop((0, 0, num_pixels, num_pixels/2)) # left, upper, right, lower | |
| return center_cropped_bev_img | |
| def get_satellite_from_bbox(bbox, output_fp, num_pixels, heading): | |
| # TODO: This method does not always produce a full satellite image. | |
| # We need something more consistent like mapbox but free. | |
| region = ee.Geometry.Rectangle(bbox, proj="EPSG:4326", geodesic=False) | |
| # Load a satellite image collection, filter it by date and region, then select the first image | |
| image = ee.ImageCollection('USDA/NAIP/DOQQ') \ | |
| .filterBounds(region) \ | |
| .filterDate('2022-01-01', '2022-12-31') \ | |
| .sort('CLOUDY_PIXEL_PERCENTAGE') \ | |
| .first().select(['R', 'G', 'B']) | |
| # Reproject the image to a common projection (e.g., EPSG:4326) | |
| image = image.reproject(crs='EPSG:4326', scale=0.5) | |
| # Get the image URL | |
| url = image.getThumbURL({'min': 0, 'max': 255, 'region': region.getInfo()['coordinates']}) | |
| # Download the image to your desktop | |
| response = requests.get(url) | |
| img = Image.open(io.BytesIO(response.content)) | |
| robot_cropped_bev_img = process_img(img, num_pixels, heading) | |
| robot_cropped_bev_img.save(output_fp) | |
| def get_data(address: str, parameters: dict[str, str]) -> bytes: | |
| """ | |
| Construct Internet page URL and get its descriptor. | |
| :param address: URL without parameters | |
| :param parameters: URL parameters | |
| :return: connection descriptor | |
| """ | |
| for _ in range(50): | |
| http = urllib3.PoolManager() | |
| urllib3.disable_warnings() | |
| try: | |
| result = http.request("GET", address, fields=parameters) | |
| except urllib3.exceptions.MaxRetryError: | |
| continue | |
| if result.status == 200: | |
| break | |
| else: | |
| print(result.data) | |
| raise NetworkError(f"Cannot download data: {result.status} {result.reason}") | |
| http.clear() | |
| return result.data | |
| def get_osm_data(bbox: BoundaryBox, osm_output_fp: Path, | |
| overwrite=False, use_lock=False) -> OSMData: | |
| """ | |
| Get OSM data within bounding box from usingoverpass APIs and | |
| write data to osm_output_fp. | |
| """ | |
| OVERPASS_ENDPOINTS = [ | |
| "http://overpass-api.de/api/map", | |
| "http://overpass.kumi.systems/api/map", | |
| "http://maps.mail.ru/osm/tools/overpass/api/map" | |
| ] | |
| RETRIES = 10 | |
| osm_data = None | |
| overpass_endpoints_i = 0 | |
| for retry in range(RETRIES): | |
| try: | |
| # fetch or load from cache | |
| # A lock is needed if we are using multiple processes without store_osm_per_id | |
| # Since multiple workers may share the same cached OSM file. | |
| # Note: Can optimize locking further by implementing a readers-writer lock scheme | |
| if use_lock: | |
| lock_fp = osm_output_fp.parent.parent / (osm_output_fp.parent.name + "_tmp_locks") / (osm_output_fp.name + ".lock") | |
| lock = FileLock(lock_fp) | |
| else: | |
| lock = contextlib.nullcontext() | |
| with lock: | |
| if not overwrite and osm_output_fp.is_file(): | |
| with osm_output_fp.open(encoding="utf-8") as output_file: | |
| xml_str = output_file.read() | |
| else: | |
| content: bytes = get_data( | |
| address=OVERPASS_ENDPOINTS[overpass_endpoints_i], | |
| parameters={"bbox": bbox.get_format()} | |
| ) | |
| xml_str = content.decode("utf-8") | |
| if not content.startswith(b"<?xml"): | |
| raise Exception(f"Invalid content received: '{xml_str}'") | |
| with osm_output_fp.open("bw+") as output_file: | |
| output_file.write(content) | |
| # parse OSM xml string | |
| tree = ET.fromstring(xml_str) | |
| osm_data = OSMData() | |
| osm_data.parse_osm(tree, parse_nodes=True, | |
| parse_relations=True, parse_ways=True) | |
| break | |
| except Exception as e: | |
| msg = f"Error: Unable to fetch OSM data for bbox {bbox.get_format()} "\ | |
| f"for file {osm_output_fp} after {retry+1}/{RETRIES} attempts. Exception: {repr(e)}." | |
| if retry < RETRIES-1: | |
| overpass_endpoints_i = (overpass_endpoints_i+1) % len(OVERPASS_ENDPOINTS) | |
| logger.error(f"{msg}. Retrying with {OVERPASS_ENDPOINTS[overpass_endpoints_i]} endpoint..") | |
| continue | |
| else: | |
| logger.error(f"{msg}. Skipping..") | |
| break | |
| return osm_data, retry+1 | |
| def get_bev_from_bbox( | |
| bbox: BoundaryBox, | |
| num_pixels: int, | |
| meters_per_pixel: float, | |
| configuration: MapConfiguration, | |
| osm_output_fp: Path, | |
| bev_output_fp: Path, | |
| mask_output_fp: Path, | |
| rendered_mask_output_fp: Path, | |
| osm_data: OSMData=None, | |
| heading: float=0, | |
| final_downsample: int=1, | |
| download_osm_only: bool=False, | |
| use_osm_cache_lock: bool=False, | |
| ) -> None: | |
| """Get BEV image from a boundary box. Optionally rotate, crop and save the extracted semantic mask.""" | |
| if osm_data is None: | |
| if osm_output_fp.is_file(): | |
| # Load from cache | |
| try: | |
| osm_data = OSMData() | |
| with osm_output_fp.open(encoding="utf-8") as output_file: | |
| xml_str = output_file.read() | |
| tree = ET.fromstring(xml_str) | |
| osm_data.parse_osm(tree, parse_nodes=True, | |
| parse_relations=True, parse_ways=True) | |
| except Exception as e: | |
| osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock) | |
| else: | |
| # No local osm planet dump file. Need to download or read from cache | |
| osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock) | |
| if osm_data is None: | |
| return | |
| if download_osm_only: | |
| return | |
| imgs, svg = draw_bev(bbox, osm_data, configuration, meters_per_pixel, heading) | |
| if imgs is None: | |
| return | |
| if bev_output_fp: | |
| svg.saveas(bev_output_fp) | |
| cropped_imgs = [] | |
| for img in imgs: | |
| # Set heading to None because we already rotated in draw_bev | |
| cropped_imgs.append(process_img(img, num_pixels, heading=None)) | |
| masks = [] | |
| for img in cropped_imgs: | |
| arr = np.array(img) | |
| masks.append(arr[..., -1] != 0) | |
| extracted_mask = np.stack(masks, axis=0) | |
| extracted_mask[2][extracted_mask[0]] = 0 | |
| if final_downsample > 1: | |
| max_pool_layer = nn.MaxPool2d(kernel_size=final_downsample, stride=final_downsample) | |
| # Apply max pooling | |
| mask_tensor = torch.tensor(extracted_mask, dtype=torch.float32).unsqueeze(0) | |
| max_pool_tensor = max_pool_layer(mask_tensor) | |
| # Remove the batch dimension and permute back to original dimension order, then convert to numpy | |
| multilabel_mask_downsampled = max_pool_tensor.squeeze(0).permute(1, 2, 0).numpy() | |
| else: | |
| multilabel_mask_downsampled = extracted_mask.transpose(1, 2, 0) | |
| # Save npz files for semantic masks | |
| if mask_output_fp: | |
| np.savez_compressed(mask_output_fp, multilabel_mask_downsampled) | |
| # Save rendered BEV map if we want for visualization | |
| if rendered_mask_output_fp: | |
| rgb = mask2rgb(multilabel_mask_downsampled) | |
| plt.imsave(rendered_mask_output_fp.with_suffix('.png'), rgb) | |
| def get_bev_from_bbox_worker_init(osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, | |
| scheme_path, map_length, meters_per_pixel, | |
| osm_data, redownload, download_osm_only, store_osm_per_id, | |
| use_osm_cache_lock, final_downsample): | |
| global worker_kwargs | |
| worker_kwargs=locals() | |
| # MapConfiguration is not picklable so we have to initialize it for each worker | |
| scheme = Scheme.from_file(Path(scheme_path)) | |
| configuration = MapConfiguration(scheme) | |
| configuration.show_credit = False | |
| worker_kwargs["configuration"] = configuration | |
| logger.info(f"Worker {os.getpid()} started.") | |
| def get_bev_from_bbox_worker(job_dict): | |
| id = job_dict['id'] | |
| bbox = job_dict['bbox_formatted'] | |
| bbox = BoundaryBox.from_text(bbox) | |
| heading = job_dict['computed_compass_angle'] | |
| # Setting a path to None disables storing that file | |
| bev_fp = worker_kwargs["bev_dir"] | |
| if bev_fp: | |
| bev_fp = bev_fp / f"{id}.svg" | |
| semantic_mask_fp = worker_kwargs["semantic_mask_dir"] | |
| if semantic_mask_fp: | |
| semantic_mask_fp = semantic_mask_fp / f"{id}.npz" | |
| rendered_mask_fp = worker_kwargs["rendered_mask_dir"] | |
| if rendered_mask_fp: | |
| rendered_mask_fp = rendered_mask_fp / f"{id}.png" | |
| if worker_kwargs["store_osm_per_id"]: | |
| osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{id}.osm" | |
| else: | |
| osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{bbox.get_format()}.osm" | |
| if ( (bev_fp is None or bev_fp.exists() ) # Bev exists or we don't want to save it | |
| and (semantic_mask_fp is None or semantic_mask_fp.exists()) # ... | |
| and (rendered_mask_fp is None or rendered_mask_fp.exists()) # ... | |
| and not worker_kwargs["redownload"]): | |
| return | |
| get_bev_from_bbox(bbox=bbox, | |
| num_pixels=worker_kwargs["map_length"], | |
| meters_per_pixel=worker_kwargs["meters_per_pixel"], | |
| configuration=worker_kwargs["configuration"], | |
| osm_output_fp=osm_output_fp, | |
| bev_output_fp=bev_fp, | |
| mask_output_fp=semantic_mask_fp, | |
| rendered_mask_output_fp=rendered_mask_fp, | |
| osm_data=worker_kwargs["osm_data"], | |
| heading=heading, | |
| final_downsample=worker_kwargs["final_downsample"], | |
| download_osm_only=worker_kwargs["download_osm_only"], | |
| use_osm_cache_lock=worker_kwargs["use_osm_cache_lock"]) | |
| def main(dataset_dir, locations, args): | |
| # setup directory paths | |
| dataset_dir = Path(dataset_dir) | |
| for loc in locations: | |
| loc_name = loc["name"].lower().replace(" ", "_") | |
| location_dir = dataset_dir / loc_name | |
| osm_cache_dir = location_dir / "osm_cache" | |
| bev_dir = location_dir / "bev_raw" if args.store_all_steps else None | |
| semantic_mask_dir = location_dir / "semantic_masks" | |
| rendered_mask_dir = location_dir / "rendered_semantic_masks" if args.store_all_steps else None | |
| for d in [location_dir, osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir]: | |
| if d: | |
| d.mkdir(parents=True, exist_ok=True) | |
| # read the parquet file | |
| parquet_fp = location_dir / f"image_metadata_filtered_processed.parquet" | |
| logger.info(f"Reading parquet file from {parquet_fp}.") | |
| df = pd.read_parquet(parquet_fp) | |
| if args.n_samples > 0:# If -1, use all samples | |
| logger.info(f"Sampling {args.n_samples} rows.") | |
| df = df.sample(args.n_samples, replace=False, random_state=1) | |
| df.reset_index(drop=True, inplace=True) | |
| logger.info(f"Read {len(df)} rows from the parquet file.") | |
| # convert pandas dataframe to geopandas dataframe | |
| gdf = gpd.GeoDataFrame(df, | |
| geometry=gpd.points_from_xy( | |
| df['computed_geometry.long'], | |
| df['computed_geometry.lat']), | |
| crs=4326) | |
| # convert the geopandas dataframe to UTM | |
| utm_crs = gdf.estimate_utm_crs() | |
| gdf_utm = gdf.to_crs(utm_crs) | |
| transformer = Transformer.from_crs(utm_crs, 4326) | |
| logger.info(f"UTM zone for {loc_name} is {utm_crs.to_epsg()}.") | |
| # load OSM data, if available | |
| padding = args.padding | |
| # calculate the required distance from the center to the edge of the image | |
| # so that the image will not be out of bounds when we rotate it | |
| map_length = args.map_length | |
| map_length = ceil(sqrt(map_length**2 + map_length**2)) | |
| distance = map_length * args.meters_per_pixel / 2 | |
| logger.info(f"Each image will be {map_length:.2f} x {map_length:.2f} pixels. The distance from the center to the edge is {distance:.2f} meters.") | |
| osm_data = None | |
| if args.osm_fp: | |
| logger.info(f"Loading OSM data from {args.osm_fp}.") | |
| osm_fp = Path(args.osm_fp) | |
| osm_data = OSMData() | |
| if osm_fp.suffix == '.osm': | |
| osm_data.parse_osm_file(osm_fp) | |
| elif osm_fp.suffix == '.json': | |
| osm_data.parse_overpass(osm_fp) | |
| else: | |
| raise ValueError(f"OSM file format {osm_fp.suffix} is not supported.") | |
| # make sure that the loaded osm data at least covers some points in the dataframe | |
| bbox = osm_data.boundary_box | |
| shapely_bbox = box(bbox.left, bbox.bottom, bbox.right, bbox.top) | |
| logger.warning(f"Clipping the geopandas dataframe to the OSM boundary box. May result in loss of points.") | |
| gdf = gpd.clip(gdf, shapely_bbox) | |
| if gdf.empty: | |
| raise ValueError("Clipped geopandas dataframe is empty. Exiting.") | |
| logger.info(f"Clipped geopandas dataframe is left with {len(gdf)} points.") | |
| elif args.one_big_osm: | |
| osm_fp = location_dir / "one_big_map.osm" | |
| min_long = gdf_utm.geometry.x.min() - distance - padding | |
| max_long = gdf_utm.geometry.x.max() + distance + padding | |
| min_lat = gdf_utm.geometry.y.min() - distance - padding | |
| max_lat = gdf_utm.geometry.y.max() + distance + padding | |
| padding = 0 | |
| big_bbox = transformer.transform_bounds(left=min_long, bottom=min_lat, right=max_long, top=max_lat) | |
| # TODO: Check why transformer is flipping lat long | |
| big_bbox = (big_bbox[1], big_bbox[0], big_bbox[3], big_bbox[2]) | |
| big_bbox_fmt = ",".join([str(x) for x in big_bbox]) | |
| logger.info(f"Fetching one big osm file using coordinates {big_bbox_fmt}.") | |
| big_bbox = BoundaryBox.from_text(big_bbox_fmt) | |
| osm_data, retries = get_osm_data(big_bbox, osm_fp, overwrite=args.redownload) | |
| # create bounding boxes for each point | |
| gdf_utm['bounding_box_utm_p1'] = gdf_utm.apply(lambda row: ( | |
| row.geometry.x - distance - padding, | |
| row.geometry.y - distance - padding, | |
| ), axis=1) | |
| gdf_utm['bounding_box_utm_p2'] = gdf_utm.apply(lambda row: ( | |
| row.geometry.x + distance + padding, | |
| row.geometry.y + distance + padding, | |
| ), axis=1) | |
| # convert the bounding box back to lat, long | |
| gdf_utm['bounding_box_lat_long_p1'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p1']), axis=1) | |
| gdf_utm['bounding_box_lat_long_p2'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p2']), axis=1) | |
| gdf_utm['bbox_min_lat'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[0]) | |
| gdf_utm['bbox_min_long'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[1]) | |
| gdf_utm['bbox_max_lat'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[0]) | |
| gdf_utm['bbox_max_long'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[1]) | |
| gdf_utm['bbox_formatted'] = gdf_utm.apply(lambda row: f"{row['bbox_min_long']},{row['bbox_min_lat']},{row['bbox_max_long']},{row['bbox_max_lat']}", axis=1) | |
| # iterate over the dataframe and get BEV images | |
| jobs = gdf_utm[['id', 'bbox_formatted', 'computed_compass_angle']] # only need the id and bbox_formatted columns for the jobs | |
| jobs = jobs.to_dict(orient='records').copy() | |
| use_osm_cache_lock = args.n_workers > 0 and not args.store_osm_per_id | |
| if use_osm_cache_lock: | |
| logger.info("Using osm cache locks to prevent race conditions since number of workers > 0 and store_osm_per_id is false") | |
| init_args = [osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, | |
| args.map_machine_scheme, | |
| args.map_length, args.meters_per_pixel, | |
| osm_data, args.redownload, args.download_osm_only, | |
| args.store_osm_per_id, use_osm_cache_lock, args.final_downsample] | |
| if args.n_workers > 0: | |
| logger.info(f"Launching {args.n_workers} workers to fetch BEVs for {len(jobs)} bounding boxes.") | |
| with mp.Pool(args.n_workers, | |
| initializer=get_bev_from_bbox_worker_init, | |
| initargs=init_args) as pool: | |
| for _ in tqdm(pool.imap_unordered(get_bev_from_bbox_worker, jobs, chunksize=16), | |
| total=len(jobs), desc="Getting BEV images"): | |
| pass | |
| else: | |
| get_bev_from_bbox_worker_init(*init_args) | |
| pbar = tqdm(jobs, desc="Getting BEV images") | |
| for job_dict in pbar: | |
| get_bev_from_bbox_worker(job_dict) | |
| # Download sattelite images if needed | |
| if args.store_sat: | |
| logger.info("Downloading sattelite images.") | |
| sat_dir = location_dir / "sattelite" | |
| sat_dir.mkdir(parents=True, exist_ok=True) | |
| pbar = tqdm(jobs, desc="Getting Sattelite images") | |
| for job_dict in pbar: | |
| id = job_dict['id'] | |
| sat_fp = sat_dir / f"{id}.png" | |
| if sat_fp.exists() and not args.redownload: | |
| continue | |
| bbox = [float(x) for x in job_dict['bbox_formatted'].split(",")] | |
| try: | |
| get_satellite_from_bbox(bbox, sat_fp, heading=job_dict['computed_compass_angle'], num_pixels=args.map_length) | |
| except Exception as e: | |
| logger.error(f"Failed to get sattelite image for bbox {job_dict['bbox_formatted']}. Exception {repr(e)}") | |
| # TODO: Post BEV retireval filtering | |
| # df.to_parquet(location_dir / "image_metadata_bev_processed.parquet") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Get BEV images from a dataset of locations using MapMachine.") | |
| parser.add_argument("--cfg", type=str, default="mia/conf/example.yaml", help="Path to config yaml file.") | |
| args = parser.parse_args() | |
| cfgs = OmegaConf.load(args.cfg) | |
| if cfgs.bev_options.store_sat: | |
| if cfgs.bev_options.n_workers > 0: | |
| logger.fatal("Satellite download is not multiprocessed yet !!") | |
| import ee | |
| ee.Initialize() | |
| logger.info("="*80) | |
| logger.info("Running get_bev.py") | |
| logger.info("Arguments:") | |
| for arg in vars(args): | |
| logger.info(f"- {arg}: {getattr(args, arg)}") | |
| logger.info("="*80) | |
| main(cfgs.dataset_dir, cfgs.cities, cfgs.bev_options) |