mohammed-aljafry's picture
Final fix v9: Re-upload with fully self-contained and correct package structure
f08920f verified
raw
history blame
3.54 kB
# =========================================================
# الملف 3: simulation/data.py
# =========================================================
import torch
from torch.utils.data import Dataset
import cv2
import json
from pathlib import Path
from torchvision import transforms
import os
class LMDriveDataset(Dataset):
def __init__(self, data_dir, transform=None, lidar_transform=None):
self.data_dir = Path(data_dir)
self.transform = transform or self.get_default_transform()
self.lidar_transform = lidar_transform or self.get_default_lidar_transform()
self.samples = []
measurement_dir = self.data_dir / "measurements"
image_dir = self.data_dir / "rgb_full"
if not measurement_dir.exists() or not image_dir.exists():
print(f"Warning: Data directory {data_dir} not found or incomplete.")
return
measurement_files = sorted(measurement_dir.glob("*.json"))
for measurement_path in measurement_files:
frame_id = int(measurement_path.stem)
image_path = image_dir / f"{frame_id:04d}.jpg"
if not image_path.exists(): continue
with open(measurement_path, "r") as f: measurements_data = json.load(f)
self.samples.append({"image_path": str(image_path),"measurement_path": str(measurement_path),"frame_id": frame_id,"measurements": measurements_data})
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]; full_image = cv2.imread(sample["image_path"])
if full_image is None: raise ValueError(f"Failed to load image: {sample['image_path']}")
full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
front_image = full_image[:600, :800]; left_image = full_image[600:1200, :800]
right_image = full_image[1200:1800, :800]; center_image = full_image[1800:2400, :800]
front_tensor=self.transform(front_image);left_tensor=self.transform(left_image);right_tensor=self.transform(right_image);center_tensor=self.transform(center_image)
lidar_path = str(self.data_dir / "lidar" / f"{sample['frame_id']:04d}.png")
if os.path.exists(lidar_path):
lidar = cv2.imread(lidar_path)
if len(lidar.shape) == 2: lidar = cv2.cvtColor(lidar, cv2.COLOR_GRAY2BGR)
lidar = cv2.cvtColor(lidar, cv2.COLOR_BGR2RGB)
else: lidar = np.zeros((112, 112, 3), dtype=np.uint8)
lidar_tensor = self.lidar_transform(lidar)
m = sample["measurements"]
target_point = torch.tensor([m.get("x_command", 0.0), m.get("y_command", 0.0)], dtype=torch.float32)
measurements = torch.tensor([m.get("x",0.0),m.get("y",0.0),m.get("theta",0.0),m.get("speed",0.0),m.get("steer",0.0),m.get("throttle",0.0),int(m.get("brake",False)),m.get("command",0),int(m.get("is_junction",False)),int(m.get("should_brake",0))], dtype=torch.float32)
return {"rgb":front_tensor,"rgb_left":left_tensor,"rgb_right":right_tensor,"rgb_center":center_tensor,"lidar":lidar_tensor,"measurements":measurements,"target_point":target_point}
def get_default_transform(self):
return transforms.Compose([transforms.ToPILImage(),transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
def get_default_lidar_transform(self):
return transforms.Compose([transforms.ToPILImage(),transforms.Resize((112,112)),transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])