Spaces:
Runtime error
Runtime error
| import random | |
| from operator import itemgetter | |
| from data_enrich import DataEnrich | |
| class DataLoader: | |
| label_mapping = { | |
| 'car': 0, | |
| 'walk': 1, | |
| 'bus': 2, | |
| 'train': 3, | |
| 'subway': 4, | |
| 'bike': 5, | |
| 'run': 6, | |
| 'boat': 7, | |
| 'airplane': 8, | |
| 'motorcycle': 9, | |
| 'taxi': 10 | |
| } | |
| fields_to_feed = ["dist", "speed", "accel", "timedelta", "jerk", "bearing", "bearing_rate"] | |
| labels_to_remove = ["boat", "motorcycle", "airplane", "run", "bike"] | |
| def __init__(self, test_ratio=0.2, val_ratio=0.1, batchsize=4, read_from_pickle=True): | |
| de = DataEnrich() | |
| self._raw = de.get_enriched_data(read_from_pickle) | |
| self._test_ratio = test_ratio | |
| self._val_ratio = val_ratio | |
| self._batchsize = batchsize | |
| def _remove_traj_containing_labels(self): | |
| cleaned = [] | |
| for elem in self._raw: | |
| if len(elem) == 0: | |
| continue | |
| if all(x not in list(elem["label"]) for x in self.labels_to_remove): | |
| cleaned.append(elem) | |
| self._raw = cleaned | |
| def _merge_labels(self, target_label, label_to_remove): | |
| for elem in self._raw: | |
| if label_to_remove in list(elem["label"]): | |
| elem["label"] = elem["label"].replace(to_replace=label_to_remove, value=target_label) | |
| def _labels_to_int_repr(self): | |
| for elem in self._raw: | |
| elem["label"] = elem["label"].apply(lambda x: self.label_mapping[x]) | |
| def _get_split_indices(self, traj): | |
| train_size = int((1 - self._test_ratio) * len(traj)) | |
| val_size = len(traj) - int((1 - self._val_ratio) * len(traj)) | |
| indices = [x for x in range(len(traj))] | |
| indices_for_training = random.sample(indices, train_size) | |
| indices_for_validation = random.sample(indices_for_training, val_size) | |
| indices_for_training = set(indices_for_training) - set(indices_for_validation) | |
| indices_for_testing = set(indices) - indices_for_training | |
| indices_for_testing = list(indices_for_testing) | |
| return list(indices_for_training), list(indices_for_testing), list(indices_for_validation) | |
| def _set_splitted_data(self, traj, labels): | |
| i_train, i_test, i_val = self._get_split_indices(traj) | |
| random.shuffle(i_train) | |
| self.test_data = list(itemgetter(*i_test)(traj)) | |
| self.val_data = list(itemgetter(*i_val)(traj)) | |
| self.train_data = list(itemgetter(*i_train)(traj)) | |
| self.test_labels = list(itemgetter(*i_test)(labels)) | |
| self.val_labels = list(itemgetter(*i_val)(labels)) | |
| self.train_labels = list(itemgetter(*i_train)(labels)) | |
| def _split_too_long_traj(self, traj, labels, max_points): | |
| if len(traj) > max_points*2: | |
| splitted_traj, splitted_labels = [],[] | |
| num_subsets = len(traj) // max_points | |
| print("Splitting trajectory with length ", len(traj), "in ", num_subsets, "trajectories") | |
| for i in range(num_subsets): | |
| end_pointer = len(traj)-1 if ((i+1)*max_points)+max_points > len(traj) else (i*max_points)+max_points | |
| traj_subset = traj[i*max_points:end_pointer] | |
| labels_subset = labels[i*max_points:end_pointer] | |
| assert len(traj_subset) == len(labels_subset) | |
| splitted_traj.append(traj_subset) | |
| splitted_labels.append(labels_subset) | |
| return splitted_traj, splitted_labels | |
| return [traj], [labels] | |
| def prepare_data(self): | |
| trajs = [] | |
| labels = [] | |
| self._remove_traj_containing_labels() | |
| self._merge_labels("car", "taxi") | |
| self._labels_to_int_repr() | |
| for elem in self._raw: | |
| assert len(elem) > 0 | |
| data_ = elem[self.fields_to_feed].values.tolist() | |
| label_ = elem["label"].values.tolist() | |
| data_, label_ = self._split_too_long_traj(data_, label_, 350) | |
| trajs.extend(data_) | |
| labels.extend(label_) | |
| self._set_splitted_data(trajs, labels) | |
| def batches(self): | |
| for i in range(0, len(self.train_data), self._batchsize): | |
| if len(self.train_data[i:i + self._batchsize]) < self._batchsize: | |
| break # drop last incomplete batch | |
| labels_sorted = sorted(self.train_labels[i:i + self._batchsize:], key=len, reverse=True) | |
| train_sorted = sorted(self.train_data[i:i + self._batchsize:], key=len, reverse=True) | |
| for p in range(len(labels_sorted)): | |
| assert len(labels_sorted[p]) == len(train_sorted[p]) | |
| yield train_sorted, labels_sorted | |
| def val_batches(self): | |
| for i in range(0, len(self.val_data), self._batchsize): | |
| if len(self.val_data[i:i + self._batchsize]) < self._batchsize: | |
| break # drop last incomplete batch | |
| labels_sorted = sorted(self.val_labels[i:i + self._batchsize:], key=len, reverse=True) | |
| val_sorted = sorted(self.val_data[i:i + self._batchsize:], key=len, reverse=True) | |
| for p in range(len(labels_sorted)): | |
| assert len(labels_sorted[p]) == len(val_sorted[p]) | |
| yield val_sorted, labels_sorted | |
| def test_batches(self): | |
| for i in range(0, len(self.test_data), self._batchsize): | |
| if len(self.test_data[i:i + self._batchsize]) < self._batchsize: | |
| break # drop last incomplete batch | |
| labels_sorted = sorted(self.test_labels[i:i + self._batchsize:], key=len, reverse=True) | |
| test_sorted = sorted(self.test_data[i:i + self._batchsize:], key=len, reverse=True) | |
| for p in range(len(labels_sorted)): | |
| assert len(labels_sorted[p]) == len(test_sorted[p]) | |
| yield test_sorted, labels_sorted | |
| def get_train_size(self): | |
| return len(self.train_data) | |
| def get_val_size(self): | |
| return len(self.val_data) | |
| def get_test_size(self): | |
| return len(self.test_data) |