| | """Script to get BEV images from a dataset of locations. |
| | |
| | Example usage: |
| | python3.9 -m mia.bev.get_bev |
| | """ |
| |
|
| | import argparse |
| | import multiprocessing as mp |
| | from pathlib import Path |
| | import io |
| | import os |
| | import requests |
| | import contextlib |
| | import traceback |
| | import colour |
| |
|
| | import numpy as np |
| | from matplotlib import pyplot as plt |
| | import pandas as pd |
| | import geopandas as gpd |
| | import torch.nn as nn |
| | import torch |
| | from tqdm import tqdm |
| | from filelock import FileLock |
| | from math import sqrt, ceil |
| | import svgwrite |
| | import cairosvg |
| | from PIL import Image |
| | from xml.etree import ElementTree as ET |
| | from pyproj.transformer import Transformer |
| | from shapely.geometry import box |
| | from omegaconf import OmegaConf |
| | import urllib3 |
| |
|
| | from map_machine.map_configuration import MapConfiguration |
| | from map_machine.scheme import Scheme |
| | from map_machine.geometry.boundary_box import BoundaryBox |
| | from map_machine.osm.osm_getter import NetworkError |
| | from map_machine.osm.osm_reader import OSMData |
| | from map_machine.geometry.flinger import MercatorFlinger |
| | from map_machine.pictogram.icon import ShapeExtractor |
| | from map_machine.workspace import workspace |
| | from map_machine.mapper import Map |
| | from map_machine.constructor import Constructor |
| |
|
| | from .. import logger |
| | from .image import center_crop_to_size, center_pad |
| |
|
| | |
| | COLORS = { |
| | "road": "#000", |
| | "crossing": "#F00", |
| | "explicit_pedestrian": "#FF0", |
| | "park": "#0F0", |
| | "building": "#F0F", |
| | "water": "#00F", |
| | "terrain": "#0FF", |
| | "parking": "#AAA", |
| | "train": "#555" |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | PRETTY_COLORS = { |
| | "road": "#444", |
| | "crossing": "#F4A261", |
| | "explicit_pedestrian": "#E9C46A", |
| | "park": None, |
| | "building": "#E76F51", |
| | "water": None, |
| | "terrain": "#2A9D8F", |
| | "parking": "#CCC", |
| | "train": None |
| | } |
| |
|
| | |
| | VIS_ORDER = ["terrain", "water", "park", "parking", "train", |
| | "road", "explicit_pedestrian", "crossing", "building"] |
| |
|
| | def checkColor(code): |
| |
|
| | def check_ele(ele): |
| | isColor = False |
| | if "stroke" in ele.attribs: |
| | if ele.attribs["stroke"] != "none": |
| | color = colour.Color(ele.attribs["stroke"]) |
| | isColor |= color == colour.Color(code) |
| | |
| | if "fill" in ele.attribs: |
| | if ele.attribs["fill"] != "none": |
| | color = colour.Color(ele.attribs["fill"]) |
| | isColor |= color == colour.Color(code) |
| |
|
| | return isColor |
| | |
| | return check_ele |
| |
|
| | def hex2rgb(hex_str): |
| | hex_str = hex_str.lstrip('#') |
| | if len(hex_str) == 3: |
| | hex_str = "".join([hex_str[i//2] for i in range(6)]) |
| | return tuple(int(hex_str[i:i+2], 16) for i in (0, 2, 4)) |
| |
|
| | def mask2rgb(mask, pretty=True): |
| | H,W,N = mask.shape |
| | rgb = np.ones((H,W,3), dtype=np.uint8)*255 |
| | cmap = PRETTY_COLORS if pretty else COLORS |
| | key2mask_i = dict(zip(cmap.keys(), range(N))) |
| | for k in VIS_ORDER: |
| | if cmap[k]: |
| | rgb[mask[:,:, key2mask_i[k]]>0.5] = (np.array(hex2rgb(cmap[k]))) |
| | |
| | return rgb |
| |
|
| | def draw_bev(bbox: BoundaryBox, osm_data: OSMData, |
| | configuration: MapConfiguration, meters_per_pixel: float, heading: float): |
| | """Rasterize OSM data as a BEV image""" |
| | lat = bbox.center()[0] |
| | |
| | |
| | z = np.log2(np.abs(osm_data.equator_length*np.cos(np.deg2rad(lat))/meters_per_pixel/256)) |
| | flinger = MercatorFlinger(bbox, z, osm_data.equator_length) |
| |
|
| | size = flinger.size |
| | svg: svgwrite.Drawing = svgwrite.Drawing(None, size) |
| | |
| | icon_extractor: ShapeExtractor = ShapeExtractor( |
| | workspace.ICONS_PATH, workspace.ICONS_CONFIG_PATH |
| | ) |
| | constructor: Constructor = Constructor( |
| | osm_data=osm_data, |
| | flinger=flinger, |
| | extractor=icon_extractor, |
| | configuration=configuration, |
| | ) |
| | constructor.construct() |
| | map_: Map = Map(flinger=flinger, svg=svg, configuration=configuration) |
| | try: |
| | imgs = [] |
| |
|
| | map_.draw(constructor) |
| |
|
| | |
| | for ele in svg.elements: |
| | ele.rotate(360 - heading, (size[0]/2, size[1]/2)) |
| |
|
| | for k, v in COLORS.items(): |
| | svg_new = svg.copy() |
| | svg_new.elements = list(filter(checkColor(v), svg_new.elements)) |
| | |
| | png_byte_string = cairosvg.svg2png(bytestring=svg_new.tostring(), |
| | output_width=size[0], |
| | output_height=size[1]) |
| | img = Image.open(io.BytesIO(png_byte_string)) |
| |
|
| | imgs.append(img) |
| |
|
| | except Exception as e: |
| | |
| | stack_trace = traceback.format_exc() |
| | logger.error(f"Failed to render BEV for bbox {bbox.get_format()}. Exception: {repr(e)}. Skipping.. Stack trace: {stack_trace}") |
| | return None, None |
| |
|
| | return imgs, svg |
| |
|
| |
|
| | def process_img(img, num_pixels, heading=None): |
| | """Rotate + Crop to correct for heading and ensure correct dimensions""" |
| |
|
| | img = center_pad(img, num_pixels, num_pixels) |
| | s = min(img.size) |
| | squared_img = center_crop_to_size(img, s, s) |
| | if heading: |
| | squared_img = squared_img.rotate(heading, expand=False, resample=Image.Resampling.BILINEAR) |
| | center_cropped_bev_img = center_crop_to_size(squared_img, num_pixels, num_pixels) |
| | |
| | return center_cropped_bev_img |
| |
|
| |
|
| | def get_satellite_from_bbox(bbox, output_fp, num_pixels, heading): |
| | |
| | |
| |
|
| | region = ee.Geometry.Rectangle(bbox, proj="EPSG:4326", geodesic=False) |
| | |
| | image = ee.ImageCollection('USDA/NAIP/DOQQ') \ |
| | .filterBounds(region) \ |
| | .filterDate('2022-01-01', '2022-12-31') \ |
| | .sort('CLOUDY_PIXEL_PERCENTAGE') \ |
| | .first().select(['R', 'G', 'B']) |
| |
|
| | |
| | image = image.reproject(crs='EPSG:4326', scale=0.5) |
| |
|
| | |
| | url = image.getThumbURL({'min': 0, 'max': 255, 'region': region.getInfo()['coordinates']}) |
| |
|
| | |
| | response = requests.get(url) |
| | img = Image.open(io.BytesIO(response.content)) |
| | robot_cropped_bev_img = process_img(img, num_pixels, heading) |
| | robot_cropped_bev_img.save(output_fp) |
| |
|
| |
|
| | def get_data(address: str, parameters: dict[str, str]) -> bytes: |
| | """ |
| | Construct Internet page URL and get its descriptor. |
| | |
| | :param address: URL without parameters |
| | :param parameters: URL parameters |
| | :return: connection descriptor |
| | """ |
| | for _ in range(50): |
| | http = urllib3.PoolManager() |
| |
|
| | urllib3.disable_warnings() |
| |
|
| | try: |
| | result = http.request("GET", address, fields=parameters) |
| | except urllib3.exceptions.MaxRetryError: |
| | continue |
| | |
| | if result.status == 200: |
| | break |
| | else: |
| | print(result.data) |
| | raise NetworkError(f"Cannot download data: {result.status} {result.reason}") |
| | |
| | http.clear() |
| | return result.data |
| |
|
| |
|
| | def get_osm_data(bbox: BoundaryBox, osm_output_fp: Path, |
| | overwrite=False, use_lock=False) -> OSMData: |
| | """ |
| | Get OSM data within bounding box from usingoverpass APIs and |
| | write data to osm_output_fp. |
| | """ |
| |
|
| | OVERPASS_ENDPOINTS = [ |
| | "http://overpass-api.de/api/map", |
| | "http://overpass.kumi.systems/api/map", |
| | "http://maps.mail.ru/osm/tools/overpass/api/map" |
| | ] |
| |
|
| | RETRIES = 10 |
| | osm_data = None |
| | overpass_endpoints_i = 0 |
| |
|
| | for retry in range(RETRIES): |
| | try: |
| | |
| | |
| | |
| | |
| | if use_lock: |
| | lock_fp = osm_output_fp.parent.parent / (osm_output_fp.parent.name + "_tmp_locks") / (osm_output_fp.name + ".lock") |
| | lock = FileLock(lock_fp) |
| | else: |
| | lock = contextlib.nullcontext() |
| | |
| | with lock: |
| | if not overwrite and osm_output_fp.is_file(): |
| | with osm_output_fp.open(encoding="utf-8") as output_file: |
| | xml_str = output_file.read() |
| | else: |
| | content: bytes = get_data( |
| | address=OVERPASS_ENDPOINTS[overpass_endpoints_i], |
| | parameters={"bbox": bbox.get_format()} |
| | ) |
| |
|
| | xml_str = content.decode("utf-8") |
| |
|
| | if not content.startswith(b"<?xml"): |
| | raise Exception(f"Invalid content received: '{xml_str}'") |
| |
|
| | with osm_output_fp.open("bw+") as output_file: |
| | output_file.write(content) |
| |
|
| | |
| | tree = ET.fromstring(xml_str) |
| | osm_data = OSMData() |
| | osm_data.parse_osm(tree, parse_nodes=True, |
| | parse_relations=True, parse_ways=True) |
| | break |
| | |
| | except Exception as e: |
| | msg = f"Error: Unable to fetch OSM data for bbox {bbox.get_format()} "\ |
| | f"for file {osm_output_fp} after {retry+1}/{RETRIES} attempts. Exception: {repr(e)}." |
| | |
| | if retry < RETRIES-1: |
| | overpass_endpoints_i = (overpass_endpoints_i+1) % len(OVERPASS_ENDPOINTS) |
| | logger.error(f"{msg}. Retrying with {OVERPASS_ENDPOINTS[overpass_endpoints_i]} endpoint..") |
| | continue |
| | else: |
| | logger.error(f"{msg}. Skipping..") |
| | break |
| |
|
| | return osm_data, retry+1 |
| |
|
| | def get_bev_from_bbox( |
| | bbox: BoundaryBox, |
| | |
| | num_pixels: int, |
| | meters_per_pixel: float, |
| | configuration: MapConfiguration, |
| | |
| | osm_output_fp: Path, |
| | bev_output_fp: Path, |
| | mask_output_fp: Path, |
| | rendered_mask_output_fp: Path, |
| | |
| | osm_data: OSMData=None, |
| | heading: float=0, |
| | final_downsample: int=1, |
| | |
| | download_osm_only: bool=False, |
| | use_osm_cache_lock: bool=False, |
| | ) -> None: |
| | """Get BEV image from a boundary box. Optionally rotate, crop and save the extracted semantic mask.""" |
| |
|
| | if osm_data is None: |
| | if osm_output_fp.is_file(): |
| | |
| | try: |
| | osm_data = OSMData() |
| | with osm_output_fp.open(encoding="utf-8") as output_file: |
| | xml_str = output_file.read() |
| | tree = ET.fromstring(xml_str) |
| | osm_data.parse_osm(tree, parse_nodes=True, |
| | parse_relations=True, parse_ways=True) |
| | except Exception as e: |
| | osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock) |
| | else: |
| | |
| | osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock) |
| | |
| | if osm_data is None: |
| | return |
| |
|
| | if download_osm_only: |
| | return |
| | |
| | imgs, svg = draw_bev(bbox, osm_data, configuration, meters_per_pixel, heading) |
| | if imgs is None: |
| | return |
| | |
| | if bev_output_fp: |
| | svg.saveas(bev_output_fp) |
| |
|
| | cropped_imgs = [] |
| | for img in imgs: |
| | |
| | cropped_imgs.append(process_img(img, num_pixels, heading=None)) |
| |
|
| | masks = [] |
| | for img in cropped_imgs: |
| | arr = np.array(img) |
| | masks.append(arr[..., -1] != 0) |
| | |
| | extracted_mask = np.stack(masks, axis=0) |
| | extracted_mask[2][extracted_mask[0]] = 0 |
| |
|
| | if final_downsample > 1: |
| | max_pool_layer = nn.MaxPool2d(kernel_size=final_downsample, stride=final_downsample) |
| | |
| | mask_tensor = torch.tensor(extracted_mask, dtype=torch.float32).unsqueeze(0) |
| | max_pool_tensor = max_pool_layer(mask_tensor) |
| | |
| | multilabel_mask_downsampled = max_pool_tensor.squeeze(0).permute(1, 2, 0).numpy() |
| | else: |
| | multilabel_mask_downsampled = extracted_mask.transpose(1, 2, 0) |
| |
|
| |
|
| | |
| | if mask_output_fp: |
| | np.savez_compressed(mask_output_fp, multilabel_mask_downsampled) |
| |
|
| | |
| | if rendered_mask_output_fp: |
| | rgb = mask2rgb(multilabel_mask_downsampled) |
| | plt.imsave(rendered_mask_output_fp.with_suffix('.png'), rgb) |
| |
|
| |
|
| | def get_bev_from_bbox_worker_init(osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, |
| | scheme_path, map_length, meters_per_pixel, |
| | osm_data, redownload, download_osm_only, store_osm_per_id, |
| | use_osm_cache_lock, final_downsample): |
| | global worker_kwargs |
| | worker_kwargs=locals() |
| | |
| | scheme = Scheme.from_file(Path(scheme_path)) |
| | configuration = MapConfiguration(scheme) |
| | configuration.show_credit = False |
| | worker_kwargs["configuration"] = configuration |
| | logger.info(f"Worker {os.getpid()} started.") |
| | |
| |
|
| | def get_bev_from_bbox_worker(job_dict): |
| | id = job_dict['id'] |
| | bbox = job_dict['bbox_formatted'] |
| | bbox = BoundaryBox.from_text(bbox) |
| | heading = job_dict['computed_compass_angle'] |
| |
|
| | |
| | bev_fp = worker_kwargs["bev_dir"] |
| | if bev_fp: |
| | bev_fp = bev_fp / f"{id}.svg" |
| | |
| | semantic_mask_fp = worker_kwargs["semantic_mask_dir"] |
| | if semantic_mask_fp: |
| | semantic_mask_fp = semantic_mask_fp / f"{id}.npz" |
| | |
| | rendered_mask_fp = worker_kwargs["rendered_mask_dir"] |
| | if rendered_mask_fp: |
| | rendered_mask_fp = rendered_mask_fp / f"{id}.png" |
| |
|
| | if worker_kwargs["store_osm_per_id"]: |
| | osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{id}.osm" |
| | else: |
| | osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{bbox.get_format()}.osm" |
| |
|
| | |
| | if ( (bev_fp is None or bev_fp.exists() ) |
| | and (semantic_mask_fp is None or semantic_mask_fp.exists()) |
| | and (rendered_mask_fp is None or rendered_mask_fp.exists()) |
| | and not worker_kwargs["redownload"]): |
| | return |
| | |
| | get_bev_from_bbox(bbox=bbox, |
| | num_pixels=worker_kwargs["map_length"], |
| | meters_per_pixel=worker_kwargs["meters_per_pixel"], |
| | configuration=worker_kwargs["configuration"], |
| | osm_output_fp=osm_output_fp, |
| | bev_output_fp=bev_fp, |
| | mask_output_fp=semantic_mask_fp, |
| | rendered_mask_output_fp=rendered_mask_fp, |
| | osm_data=worker_kwargs["osm_data"], |
| | heading=heading, |
| | final_downsample=worker_kwargs["final_downsample"], |
| | download_osm_only=worker_kwargs["download_osm_only"], |
| | use_osm_cache_lock=worker_kwargs["use_osm_cache_lock"]) |
| |
|
| | def main(dataset_dir, locations, args): |
| | |
| | dataset_dir = Path(dataset_dir) |
| |
|
| | for loc in locations: |
| | loc_name = loc["name"].lower().replace(" ", "_") |
| | location_dir = dataset_dir / loc_name |
| | osm_cache_dir = location_dir / "osm_cache" |
| | bev_dir = location_dir / "bev_raw" if args.store_all_steps else None |
| | semantic_mask_dir = location_dir / "semantic_masks" |
| | rendered_mask_dir = location_dir / "rendered_semantic_masks" if args.store_all_steps else None |
| |
|
| | for d in [location_dir, osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir]: |
| | if d: |
| | d.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | parquet_fp = location_dir / f"image_metadata_filtered_processed.parquet" |
| | logger.info(f"Reading parquet file from {parquet_fp}.") |
| | df = pd.read_parquet(parquet_fp) |
| |
|
| | if args.n_samples > 0: |
| | logger.info(f"Sampling {args.n_samples} rows.") |
| | df = df.sample(args.n_samples, replace=False, random_state=1) |
| |
|
| | df.reset_index(drop=True, inplace=True) |
| | logger.info(f"Read {len(df)} rows from the parquet file.") |
| | |
| | |
| | gdf = gpd.GeoDataFrame(df, |
| | geometry=gpd.points_from_xy( |
| | df['computed_geometry.long'], |
| | df['computed_geometry.lat']), |
| | crs=4326) |
| | |
| | |
| | utm_crs = gdf.estimate_utm_crs() |
| | gdf_utm = gdf.to_crs(utm_crs) |
| | transformer = Transformer.from_crs(utm_crs, 4326) |
| | logger.info(f"UTM zone for {loc_name} is {utm_crs.to_epsg()}.") |
| |
|
| | |
| | padding = args.padding |
| | |
| | |
| | map_length = args.map_length |
| | map_length = ceil(sqrt(map_length**2 + map_length**2)) |
| | distance = map_length * args.meters_per_pixel / 2 |
| | logger.info(f"Each image will be {map_length:.2f} x {map_length:.2f} pixels. The distance from the center to the edge is {distance:.2f} meters.") |
| |
|
| | osm_data = None |
| | if args.osm_fp: |
| | logger.info(f"Loading OSM data from {args.osm_fp}.") |
| | osm_fp = Path(args.osm_fp) |
| | osm_data = OSMData() |
| | if osm_fp.suffix == '.osm': |
| | osm_data.parse_osm_file(osm_fp) |
| | elif osm_fp.suffix == '.json': |
| | osm_data.parse_overpass(osm_fp) |
| | else: |
| | raise ValueError(f"OSM file format {osm_fp.suffix} is not supported.") |
| | |
| | bbox = osm_data.boundary_box |
| | shapely_bbox = box(bbox.left, bbox.bottom, bbox.right, bbox.top) |
| | logger.warning(f"Clipping the geopandas dataframe to the OSM boundary box. May result in loss of points.") |
| | gdf = gpd.clip(gdf, shapely_bbox) |
| | if gdf.empty: |
| | raise ValueError("Clipped geopandas dataframe is empty. Exiting.") |
| | logger.info(f"Clipped geopandas dataframe is left with {len(gdf)} points.") |
| |
|
| | elif args.one_big_osm: |
| | osm_fp = location_dir / "one_big_map.osm" |
| | min_long = gdf_utm.geometry.x.min() - distance - padding |
| | max_long = gdf_utm.geometry.x.max() + distance + padding |
| | min_lat = gdf_utm.geometry.y.min() - distance - padding |
| | max_lat = gdf_utm.geometry.y.max() + distance + padding |
| | padding = 0 |
| | big_bbox = transformer.transform_bounds(left=min_long, bottom=min_lat, right=max_long, top=max_lat) |
| | |
| | big_bbox = (big_bbox[1], big_bbox[0], big_bbox[3], big_bbox[2]) |
| | big_bbox_fmt = ",".join([str(x) for x in big_bbox]) |
| | logger.info(f"Fetching one big osm file using coordinates {big_bbox_fmt}.") |
| | big_bbox = BoundaryBox.from_text(big_bbox_fmt) |
| | osm_data, retries = get_osm_data(big_bbox, osm_fp, overwrite=args.redownload) |
| |
|
| | |
| | gdf_utm['bounding_box_utm_p1'] = gdf_utm.apply(lambda row: ( |
| | row.geometry.x - distance - padding, |
| | row.geometry.y - distance - padding, |
| | ), axis=1) |
| |
|
| | gdf_utm['bounding_box_utm_p2'] = gdf_utm.apply(lambda row: ( |
| | row.geometry.x + distance + padding, |
| | row.geometry.y + distance + padding, |
| | ), axis=1) |
| |
|
| | |
| | gdf_utm['bounding_box_lat_long_p1'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p1']), axis=1) |
| | gdf_utm['bounding_box_lat_long_p2'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p2']), axis=1) |
| | gdf_utm['bbox_min_lat'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[0]) |
| | gdf_utm['bbox_min_long'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[1]) |
| | gdf_utm['bbox_max_lat'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[0]) |
| | gdf_utm['bbox_max_long'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[1]) |
| | gdf_utm['bbox_formatted'] = gdf_utm.apply(lambda row: f"{row['bbox_min_long']},{row['bbox_min_lat']},{row['bbox_max_long']},{row['bbox_max_lat']}", axis=1) |
| |
|
| | |
| | jobs = gdf_utm[['id', 'bbox_formatted', 'computed_compass_angle']] |
| | jobs = jobs.to_dict(orient='records').copy() |
| |
|
| | use_osm_cache_lock = args.n_workers > 0 and not args.store_osm_per_id |
| | if use_osm_cache_lock: |
| | logger.info("Using osm cache locks to prevent race conditions since number of workers > 0 and store_osm_per_id is false") |
| | |
| | init_args = [osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, |
| | args.map_machine_scheme, |
| | args.map_length, args.meters_per_pixel, |
| | osm_data, args.redownload, args.download_osm_only, |
| | args.store_osm_per_id, use_osm_cache_lock, args.final_downsample] |
| | |
| | if args.n_workers > 0: |
| | logger.info(f"Launching {args.n_workers} workers to fetch BEVs for {len(jobs)} bounding boxes.") |
| | with mp.Pool(args.n_workers, |
| | initializer=get_bev_from_bbox_worker_init, |
| | initargs=init_args) as pool: |
| | for _ in tqdm(pool.imap_unordered(get_bev_from_bbox_worker, jobs, chunksize=16), |
| | total=len(jobs), desc="Getting BEV images"): |
| | pass |
| | else: |
| | get_bev_from_bbox_worker_init(*init_args) |
| | pbar = tqdm(jobs, desc="Getting BEV images") |
| | for job_dict in pbar: |
| | get_bev_from_bbox_worker(job_dict) |
| | |
| | |
| | if args.store_sat: |
| | logger.info("Downloading sattelite images.") |
| | sat_dir = location_dir / "sattelite" |
| | sat_dir.mkdir(parents=True, exist_ok=True) |
| | pbar = tqdm(jobs, desc="Getting Sattelite images") |
| | for job_dict in pbar: |
| | id = job_dict['id'] |
| | sat_fp = sat_dir / f"{id}.png" |
| | if sat_fp.exists() and not args.redownload: |
| | continue |
| | bbox = [float(x) for x in job_dict['bbox_formatted'].split(",")] |
| | try: |
| | get_satellite_from_bbox(bbox, sat_fp, heading=job_dict['computed_compass_angle'], num_pixels=args.map_length) |
| | except Exception as e: |
| | logger.error(f"Failed to get sattelite image for bbox {job_dict['bbox_formatted']}. Exception {repr(e)}") |
| |
|
| | |
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Get BEV images from a dataset of locations using MapMachine.") |
| | parser.add_argument("--cfg", type=str, default="mia/conf/example.yaml", help="Path to config yaml file.") |
| | args = parser.parse_args() |
| |
|
| | cfgs = OmegaConf.load(args.cfg) |
| |
|
| | if cfgs.bev_options.store_sat: |
| | if cfgs.bev_options.n_workers > 0: |
| | logger.fatal("Satellite download is not multiprocessed yet !!") |
| | import ee |
| | ee.Initialize() |
| |
|
| | logger.info("="*80) |
| | logger.info("Running get_bev.py") |
| | logger.info("Arguments:") |
| | for arg in vars(args): |
| | logger.info(f"- {arg}: {getattr(args, arg)}") |
| | logger.info("="*80) |
| | main(cfgs.dataset_dir, cfgs.cities, cfgs.bev_options) |