Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| import os | |
| import cv2 | |
| # @Time : 2023-02-13 22:56 | |
| # @Author : Wang Zhen | |
| # @Email : frozenzhencola@163.com | |
| # @File : SatelliteTool.py | |
| # @Project : TGRS_seqmatch_2023_1 | |
| import numpy as np | |
| import random | |
| from utils.geo import BoundaryBox, Projection | |
| from osm.tiling import TileManager,MapTileManager | |
| from pathlib import Path | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| class UavMapPair(Dataset): | |
| def __init__( | |
| self, | |
| root: Path, | |
| city:str, | |
| training:bool, | |
| transform | |
| ): | |
| super().__init__() | |
| # self.root = root | |
| # city = 'Manhattan' | |
| # root = '/root/DATASET/CrossModel/' | |
| # root=Path(root) | |
| self.uav_image_path = root/city/'uav' | |
| self.map_path = root/city/'map' | |
| self.map_vis = root / city / 'map_vis' | |
| info_path = root / city / 'info.csv' | |
| self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) | |
| self.transform=transform | |
| self.training=training | |
| def random_center_crop(self,image): | |
| height, width = image.shape[:2] | |
| # 随机生成剪裁尺寸 | |
| crop_size = random.randint(min(height, width) // 2, min(height, width)) | |
| # 计算剪裁的起始坐标 | |
| start_x = (width - crop_size) // 2 | |
| start_y = (height - crop_size) // 2 | |
| # 进行剪裁 | |
| cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size] | |
| return cropped_image | |
| def __getitem__(self, index: int): | |
| id, uav_name, map_name, \ | |
| uav_long, uav_lat, \ | |
| map_long, map_lat, \ | |
| tile_size_meters, pixel_per_meter, \ | |
| u, v, yaw,dis=self.info[index] | |
| uav_image=cv2.imread(str(self.uav_image_path/uav_name)) | |
| if self.training: | |
| uav_image =self.random_center_crop(uav_image) | |
| uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB) | |
| if self.transform: | |
| uav_image=self.transform(uav_image) | |
| map=np.load(str(self.map_path/map_name)) | |
| return { | |
| 'map':torch.from_numpy(np.ascontiguousarray(map)).long(), | |
| 'image':torch.tensor(uav_image), | |
| 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(), | |
| 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(), | |
| "uv":torch.tensor([float(u), float(v)]).float(), | |
| } | |
| def __len__(self): | |
| return len(self.info) | |
| if __name__ == '__main__': | |
| root=Path('/root/DATASET/OrienterNet/UavMap/') | |
| city='NewYork' | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize(256), | |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| ]) | |
| dataset=UavMapPair( | |
| root=root, | |
| city=city, | |
| transform=transform | |
| ) | |
| datasetloder = DataLoader(dataset, batch_size=3) | |
| for batch, i in enumerate(datasetloder): | |
| pass | |
| # 将PyTorch张量转换为PIL图像 | |
| # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy()) | |
| # 显示图像 | |
| # 将PyTorch张量转换为NumPy数组 | |
| # numpy_array = i['uav_image'][0].numpy() | |
| # | |
| # # 显示图像 | |
| # plt.imshow(numpy_array.transpose(1, 2, 0)) | |
| # plt.axis('off') | |
| # plt.show() | |
| # | |
| # map_viz, label = Colormap.apply(i['map'][0]) | |
| # map_viz = map_viz * 255 | |
| # map_viz = map_viz.astype(np.uint8) | |
| # plot_images([map_viz], titles=["OpenStreetMap raster"]) | |