Glas42 commited on
Commit
dac17cf
·
verified ·
1 Parent(s): f3441f6

Upload NavigationDetectionAI-Train.py

Browse files
Files changed (1) hide show
  1. other/NavigationDetectionAI-Train.py +174 -0
other/NavigationDetectionAI-Train.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("\n------------------------------------\n\nImporting libraries...")
2
+
3
+ from torchvision.transforms.functional import to_pil_image
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ import torch.optim as optim
7
+ import multiprocessing
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+ import numpy as np
11
+ import datetime
12
+ import torch
13
+ import time
14
+ import cv2
15
+ import os
16
+
17
+ # Constants
18
+ SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
19
+ DATA_PATH = "C:/Users/olefr/Downloads/AIDATA"
20
+ MODEL_PATH = SCRIPT_PATH
21
+ IMG_HEIGHT = 220
22
+ IMG_WIDTH = 420
23
+ NUM_EPOCHS = 50
24
+ BATCH_SIZE = 64
25
+ OUTPUTS = 8
26
+
27
+ print("\n------------------------------------\n")
28
+
29
+ print(f"CUDA available: {torch.cuda.is_available()}")
30
+
31
+ # Check for CUDA availability
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ print(f"Using {device} for training")
34
+
35
+ # Determine the number of CPU cores
36
+ num_cpu_cores = multiprocessing.cpu_count()
37
+ print('Number of CPU cores:', num_cpu_cores)
38
+
39
+ image_count = 0
40
+ for file in os.listdir(DATA_PATH):
41
+ if file.endswith(".png"):
42
+ image_count += 1
43
+
44
+ print("\nTraining settings:")
45
+ print("> Epochs:", NUM_EPOCHS)
46
+ print("> Batch size:", BATCH_SIZE)
47
+ print("> Image width:", IMG_WIDTH)
48
+ print("> Image height:", IMG_HEIGHT)
49
+ print("> Images:", image_count)
50
+
51
+ print("\n------------------------------------\n")
52
+
53
+ print("Loading...")
54
+
55
+ # Define custom dataset
56
+ class CustomDataset(Dataset):
57
+ def __init__(self, data_path, transform=None):
58
+ self.data_path = data_path
59
+ self.transform = transform
60
+ self.images, self.user_inputs = self.load_data(data_path)
61
+
62
+ def load_data(self, data_path):
63
+ images = []
64
+ user_inputs = []
65
+ for file in os.listdir(data_path):
66
+ if file.endswith(".png"):
67
+ # Load image
68
+ img = Image.open(os.path.join(data_path, file))
69
+ img = np.array(img)
70
+ img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
71
+ img_array = np.array(img) / 255.0
72
+
73
+ # Load steering angle if corresponding file exists
74
+ user_inputs_file = os.path.join(data_path, file.replace(".png", ".txt"))
75
+ if os.path.exists(user_inputs_file):
76
+ with open(user_inputs_file, 'r') as f:
77
+ user_input = [float(val if type(val) != str else (1 if val == "True" else 0)) for val in f.read().strip().split(',')]
78
+ images.append(img_array)
79
+ user_inputs.append(user_input)
80
+ else:
81
+ pass
82
+
83
+ return np.array(images), np.array(user_inputs)
84
+
85
+ def __len__(self):
86
+ return len(self.images)
87
+
88
+ def __getitem__(self, idx):
89
+ image = self.images[idx]
90
+ user_input = self.user_inputs[idx]
91
+ if self.transform:
92
+ image = self.transform(image)
93
+ return image, user_input
94
+
95
+ # Define transformation
96
+ transform = transforms.Compose([
97
+ transforms.Lambda(lambda x: to_pil_image(x)), # Convert to PIL Image
98
+ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
99
+ transforms.Lambda(lambda x: x.convert("L")), # Convert to grayscale
100
+ transforms.Lambda(lambda x: x.point(lambda p: p > 128 and 255)), # Convert to binary
101
+ transforms.ToTensor()
102
+ ])
103
+
104
+ # Load data
105
+ dataset = CustomDataset(DATA_PATH, transform=transform)
106
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
107
+
108
+ # Define model
109
+ class Net(nn.Module):
110
+ def __init__(self):
111
+ super(Net, self).__init__()
112
+ self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) # Adjust input channels to 1
113
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
114
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
115
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
116
+ self.fc_input_size = self._get_fc_input_size()
117
+ self.fc1 = nn.Linear(self.fc_input_size, 64)
118
+ self.fc2 = nn.Linear(64, OUTPUTS)
119
+
120
+ def _get_fc_input_size(self):
121
+ # Create a sample tensor and propagate it through the network to get the output shape
122
+ with torch.no_grad():
123
+ sample_tensor = torch.zeros(1, 1, IMG_HEIGHT, IMG_WIDTH)
124
+ sample_tensor = self.pool(torch.relu(self.conv1(sample_tensor)))
125
+ sample_tensor = self.pool(torch.relu(self.conv2(sample_tensor)))
126
+ sample_tensor = self.pool(torch.relu(self.conv3(sample_tensor)))
127
+ return sample_tensor.view(1, -1).shape[1]
128
+
129
+ def forward(self, x):
130
+ x = self.pool(torch.relu(self.conv1(x)))
131
+ x = self.pool(torch.relu(self.conv2(x)))
132
+ x = self.pool(torch.relu(self.conv3(x)))
133
+ x = x.view(-1, self.fc_input_size)
134
+ x = torch.relu(self.fc1(x))
135
+ x = self.fc2(x)
136
+ return x
137
+
138
+ model = Net().to(device) # Move model to GPU if available
139
+
140
+ # Define loss function and optimizer
141
+ criterion = nn.MSELoss()
142
+ optimizer = optim.Adam(model.parameters())
143
+
144
+ print("Starting training...")
145
+ print("\n--------------------------------------------------------------\n")
146
+ start_time = time.time()
147
+ update_time = start_time
148
+
149
+ # Train model
150
+ for epoch in range(NUM_EPOCHS):
151
+ running_loss = 0.0
152
+ for i, data in enumerate(dataloader, 0):
153
+ inputs, labels = data
154
+ inputs, labels = inputs.to(device), labels.to(device)
155
+ # Explicitly convert inputs and labels to torch.float32
156
+ inputs = inputs.float()
157
+ labels = labels.float()
158
+ optimizer.zero_grad()
159
+ outputs = model(inputs) # No need to call .float() here
160
+ loss = criterion(outputs, labels)
161
+ loss.backward()
162
+ optimizer.step()
163
+ running_loss += loss.item()
164
+ print(f"\rEpoch {epoch+1}, Loss: {running_loss / len(dataloader)}, {round((time.time() - update_time) if time.time() - update_time > 1 else (time.time() - update_time) * 1000, 2)}{'s' if time.time() - update_time > 1 else 'ms'}/Epoch, ETA: {time.strftime('%H:%M:%S', time.gmtime(round((time.time() - start_time) / (epoch + 1) * NUM_EPOCHS - (time.time() - start_time), 2)))} " + "\n\n--------------------------------------------------------------", end='', flush=True)
165
+ update_time = time.time()
166
+
167
+ print("\n\nTraining completed in " + time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)))
168
+
169
+ # Save model
170
+ print("Saving model...")
171
+ torch.save(model.state_dict(), os.path.join(MODEL_PATH, f"EPOCHS-{NUM_EPOCHS}_BATCH-{BATCH_SIZE}_RES-{IMG_WIDTH}x{IMG_HEIGHT}_IMAGES-{len(dataset)}_TRAININGTIME-{time.strftime('%H-%M-%S', time.gmtime(time.time() - start_time))}_DATE-{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.pt"))
172
+ print("Model saved successfully.")
173
+
174
+ print("\n------------------------------------\n")