File size: 3,538 Bytes
f08920f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# =========================================================
# الملف 3: simulation/data.py
# =========================================================
import torch
from torch.utils.data import Dataset
import cv2
import json
from pathlib import Path
from torchvision import transforms
import os

class LMDriveDataset(Dataset):
    def __init__(self, data_dir, transform=None, lidar_transform=None):
        self.data_dir = Path(data_dir)
        self.transform = transform or self.get_default_transform()
        self.lidar_transform = lidar_transform or self.get_default_lidar_transform()
        self.samples = []
        measurement_dir = self.data_dir / "measurements"
        image_dir = self.data_dir / "rgb_full"
        if not measurement_dir.exists() or not image_dir.exists():
            print(f"Warning: Data directory {data_dir} not found or incomplete.")
            return
        measurement_files = sorted(measurement_dir.glob("*.json"))
        for measurement_path in measurement_files:
            frame_id = int(measurement_path.stem)
            image_path = image_dir / f"{frame_id:04d}.jpg"
            if not image_path.exists(): continue
            with open(measurement_path, "r") as f: measurements_data = json.load(f)
            self.samples.append({"image_path": str(image_path),"measurement_path": str(measurement_path),"frame_id": frame_id,"measurements": measurements_data})
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]; full_image = cv2.imread(sample["image_path"])
        if full_image is None: raise ValueError(f"Failed to load image: {sample['image_path']}")
        full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
        front_image = full_image[:600, :800]; left_image = full_image[600:1200, :800]
        right_image = full_image[1200:1800, :800]; center_image = full_image[1800:2400, :800]
        front_tensor=self.transform(front_image);left_tensor=self.transform(left_image);right_tensor=self.transform(right_image);center_tensor=self.transform(center_image)
        lidar_path = str(self.data_dir / "lidar" / f"{sample['frame_id']:04d}.png")
        if os.path.exists(lidar_path):
            lidar = cv2.imread(lidar_path)
            if len(lidar.shape) == 2: lidar = cv2.cvtColor(lidar, cv2.COLOR_GRAY2BGR)
            lidar = cv2.cvtColor(lidar, cv2.COLOR_BGR2RGB)
        else: lidar = np.zeros((112, 112, 3), dtype=np.uint8)
        lidar_tensor = self.lidar_transform(lidar)
        m = sample["measurements"]
        target_point = torch.tensor([m.get("x_command", 0.0), m.get("y_command", 0.0)], dtype=torch.float32)
        measurements = torch.tensor([m.get("x",0.0),m.get("y",0.0),m.get("theta",0.0),m.get("speed",0.0),m.get("steer",0.0),m.get("throttle",0.0),int(m.get("brake",False)),m.get("command",0),int(m.get("is_junction",False)),int(m.get("should_brake",0))], dtype=torch.float32)
        return {"rgb":front_tensor,"rgb_left":left_tensor,"rgb_right":right_tensor,"rgb_center":center_tensor,"lidar":lidar_tensor,"measurements":measurements,"target_point":target_point}
    def get_default_transform(self):
        return transforms.Compose([transforms.ToPILImage(),transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
    def get_default_lidar_transform(self):
        return transforms.Compose([transforms.ToPILImage(),transforms.Resize((112,112)),transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])