Isa0 commited on
Commit
df0b32f
·
1 Parent(s): a2bf021

feat: add training code

Browse files
Files changed (1) hide show
  1. train.py +166 -0
train.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import torchvision.models as models
10
+
11
+ # 1. Dataset Definition
12
+ class CatLandmarkDataset(Dataset):
13
+ def __init__(self, root_dirs, img_size=224):
14
+ self.img_size = img_size
15
+ self.image_paths = []
16
+ self.label_paths = []
17
+
18
+ for folder in root_dirs:
19
+ if not os.path.exists(folder):
20
+ continue
21
+ jpg_pattern = os.path.join(folder, "*.jpg")
22
+ for img_path in glob.glob(jpg_pattern):
23
+ cat_path = img_path + ".cat"
24
+ if os.path.exists(cat_path):
25
+ self.image_paths.append(img_path)
26
+ self.label_paths.append(cat_path)
27
+
28
+ print(f"[DATA] Total matching cat images: {len(self.image_paths)}")
29
+
30
+ def __len__(self):
31
+ return len(self.image_paths)
32
+
33
+ def __getitem__(self, idx):
34
+ # Read image and convert to RGB
35
+ img = cv2.imread(self.image_paths[idx])
36
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
37
+ orig_h, orig_w, _ = img.shape
38
+
39
+ # Read coordinates from .cat file
40
+ with open(self.label_paths[idx], 'r') as f:
41
+ data = f.read().split()
42
+ landmarks = np.array([float(x) for x in data[1:]], dtype=np.float32)
43
+ landmarks = landmarks.reshape(-1, 2)
44
+
45
+ # Resize image to 224x224
46
+ img_resized = cv2.resize(img, (self.img_size, self.img_size))
47
+
48
+ # Scale coordinates to new size and normalize between 0-1
49
+ landmarks[:, 0] = (landmarks[:, 0] * (self.img_size / orig_w)) / self.img_size
50
+ landmarks[:, 1] = (landmarks[:, 1] * (self.img_size / orig_h)) / self.img_size
51
+
52
+ # Convert to PyTorch format (C, H, W)
53
+ img_tensor = torch.tensor(img_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
54
+ landmarks_tensor = torch.tensor(landmarks.flatten(), dtype=torch.float32)
55
+
56
+ return img_tensor, landmarks_tensor
57
+
58
+ # 2. Model Architecture (MobileNetV3 Small)
59
+ def get_model():
60
+ # Lightest and optimized architecture for low-end devices
61
+ # Load pre-trained weights with MobileNet_V3_Small_Weights.DEFAULT
62
+ model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
63
+
64
+ # Modify the final classification layer of the model.
65
+ # We will predict 18 coordinate values (9 points x 2) instead of classification (Regression).
66
+ in_features = model.classifier[3].in_features
67
+ model.classifier[3] = nn.Linear(in_features, 18)
68
+
69
+ return model
70
+
71
+ # 3. Training Function
72
+ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, device="cpu"):
73
+ model = model.to(device)
74
+ criterion = nn.MSELoss() # Mean Squared Error is used for coordinate predictions
75
+ optimizer = optim.Adam(model.parameters(), lr=lr)
76
+
77
+ print(f"\n[TRAINING] Starting... Device: {device}")
78
+
79
+ for epoch in range(epochs):
80
+ model.train()
81
+ train_loss = 0.0
82
+
83
+ for images, landmarks in train_loader:
84
+ images = images.to(device)
85
+ landmarks = landmarks.to(device)
86
+
87
+ optimizer.zero_grad()
88
+ outputs = model(images)
89
+ loss = criterion(outputs, landmarks)
90
+ loss.backward()
91
+ optimizer.step()
92
+
93
+ train_loss += loss.item() * images.size(0)
94
+
95
+ train_loss /= len(train_loader.dataset)
96
+
97
+ # Validation Phase
98
+ model.eval()
99
+ val_loss = 0.0
100
+ with torch.no_grad():
101
+ for images, landmarks in val_loader:
102
+ images = images.to(device)
103
+ landmarks = landmarks.to(device)
104
+ outputs = model(images)
105
+ loss = criterion(outputs, landmarks)
106
+ val_loss += loss.item() * images.size(0)
107
+ val_loss /= len(val_loader.dataset)
108
+
109
+ print(f"Epoch [{epoch+1}/{epochs}] -> Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
110
+
111
+ return model
112
+
113
+ # 4. Export to ONNX Format
114
+ def export_to_onnx(model, save_path="cat_landmark_model.onnx"):
115
+ model.eval()
116
+ # Dummy input to pass through the model (Batch_size=1, Channel=3, H=224, W=224)
117
+ dummy_input = torch.randn(1, 3, 224, 224).to(next(model.parameters()).device)
118
+
119
+ print(f"\n[ONNX] Converting model to ONNX format...")
120
+ torch.onnx.export(
121
+ model,
122
+ dummy_input,
123
+ save_path,
124
+ export_params=True,
125
+ opset_version=11,
126
+ do_constant_folding=True,
127
+ input_names=['input'],
128
+ output_names=['output']
129
+ )
130
+ print(f"[ONNX] Successfully saved: {save_path}")
131
+
132
+ # Main Execution
133
+ if __name__ == "__main__":
134
+ # Folder paths (You can update this according to your file structure)
135
+ data_dirs = ['/content/CAT_00', '/content/CAT_01', '/content/CAT_02',
136
+ '/content/CAT_03', '/content/CAT_04', '/content/CAT_05', '/content/CAT_06']
137
+
138
+ # Device Selection (GPU if CUDA is available, otherwise CPU)
139
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+
141
+ # 1. Load Data
142
+ full_dataset = CatLandmarkDataset(root_dirs=data_dirs, img_size=224)
143
+
144
+ if len(full_dataset) == 0:
145
+ print("[ERROR] No data found in the specified folders! Please check file paths.")
146
+ else:
147
+ # Split data into 90% Training - 10% Validation
148
+ train_size = int(0.9 * len(full_dataset))
149
+ val_size = len(full_dataset) - train_size
150
+ train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
151
+
152
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
153
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
154
+
155
+ # 2. Get Model
156
+ cat_model = get_model()
157
+
158
+ # 3. Train Model (Set to 5 epochs for quick Colab execution, increase if desired)
159
+ trained_model = train_model(cat_model, train_loader, val_loader, epochs=5, lr=0.001, device=device)
160
+
161
+ # 4. Save PyTorch model (As backup)
162
+ torch.save(trained_model.state_dict(), "cat_landmark_model.pth")
163
+ print("\n[SAVE] PyTorch weights saved (cat_landmark_model.pth)")
164
+
165
+ # 5. Convert to ONNX format for running on low-end devices
166
+ export_to_onnx(trained_model, save_path="cat_landmark_model.onnx")