|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import cv2 |
|
|
import json |
|
|
from pathlib import Path |
|
|
from torchvision import transforms |
|
|
import os |
|
|
|
|
|
class LMDriveDataset(Dataset): |
|
|
def __init__(self, data_dir, transform=None, lidar_transform=None): |
|
|
self.data_dir = Path(data_dir) |
|
|
self.transform = transform or self.get_default_transform() |
|
|
self.lidar_transform = lidar_transform or self.get_default_lidar_transform() |
|
|
self.samples = [] |
|
|
measurement_dir = self.data_dir / "measurements" |
|
|
image_dir = self.data_dir / "rgb_full" |
|
|
if not measurement_dir.exists() or not image_dir.exists(): |
|
|
print(f"Warning: Data directory {data_dir} not found or incomplete.") |
|
|
return |
|
|
measurement_files = sorted(measurement_dir.glob("*.json")) |
|
|
for measurement_path in measurement_files: |
|
|
frame_id = int(measurement_path.stem) |
|
|
image_path = image_dir / f"{frame_id:04d}.jpg" |
|
|
if not image_path.exists(): continue |
|
|
with open(measurement_path, "r") as f: measurements_data = json.load(f) |
|
|
self.samples.append({"image_path": str(image_path),"measurement_path": str(measurement_path),"frame_id": frame_id,"measurements": measurements_data}) |
|
|
def __len__(self): return len(self.samples) |
|
|
def __getitem__(self, idx): |
|
|
sample = self.samples[idx]; full_image = cv2.imread(sample["image_path"]) |
|
|
if full_image is None: raise ValueError(f"Failed to load image: {sample['image_path']}") |
|
|
full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB) |
|
|
front_image = full_image[:600, :800]; left_image = full_image[600:1200, :800] |
|
|
right_image = full_image[1200:1800, :800]; center_image = full_image[1800:2400, :800] |
|
|
front_tensor=self.transform(front_image);left_tensor=self.transform(left_image);right_tensor=self.transform(right_image);center_tensor=self.transform(center_image) |
|
|
lidar_path = str(self.data_dir / "lidar" / f"{sample['frame_id']:04d}.png") |
|
|
if os.path.exists(lidar_path): |
|
|
lidar = cv2.imread(lidar_path) |
|
|
if len(lidar.shape) == 2: lidar = cv2.cvtColor(lidar, cv2.COLOR_GRAY2BGR) |
|
|
lidar = cv2.cvtColor(lidar, cv2.COLOR_BGR2RGB) |
|
|
else: lidar = np.zeros((112, 112, 3), dtype=np.uint8) |
|
|
lidar_tensor = self.lidar_transform(lidar) |
|
|
m = sample["measurements"] |
|
|
target_point = torch.tensor([m.get("x_command", 0.0), m.get("y_command", 0.0)], dtype=torch.float32) |
|
|
measurements = torch.tensor([m.get("x",0.0),m.get("y",0.0),m.get("theta",0.0),m.get("speed",0.0),m.get("steer",0.0),m.get("throttle",0.0),int(m.get("brake",False)),m.get("command",0),int(m.get("is_junction",False)),int(m.get("should_brake",0))], dtype=torch.float32) |
|
|
return {"rgb":front_tensor,"rgb_left":left_tensor,"rgb_right":right_tensor,"rgb_center":center_tensor,"lidar":lidar_tensor,"measurements":measurements,"target_point":target_point} |
|
|
def get_default_transform(self): |
|
|
return transforms.Compose([transforms.ToPILImage(),transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) |
|
|
def get_default_lidar_transform(self): |
|
|
return transforms.Compose([transforms.ToPILImage(),transforms.Resize((112,112)),transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])]) |
|
|
|