File size: 6,478 Bytes
df0b32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# 1. Dataset Definition
class CatLandmarkDataset(Dataset):
    def __init__(self, root_dirs, img_size=224):
        self.img_size = img_size
        self.image_paths = []
        self.label_paths = []
        
        for folder in root_dirs:
            if not os.path.exists(folder):
                continue
            jpg_pattern = os.path.join(folder, "*.jpg")
            for img_path in glob.glob(jpg_pattern):
                cat_path = img_path + ".cat"
                if os.path.exists(cat_path):
                    self.image_paths.append(img_path)
                    self.label_paths.append(cat_path)
                    
        print(f"[DATA] Total matching cat images: {len(self.image_paths)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Read image and convert to RGB
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        orig_h, orig_w, _ = img.shape
        
        # Read coordinates from .cat file
        with open(self.label_paths[idx], 'r') as f:
            data = f.read().split()
            landmarks = np.array([float(x) for x in data[1:]], dtype=np.float32)
            landmarks = landmarks.reshape(-1, 2)
            
        # Resize image to 224x224
        img_resized = cv2.resize(img, (self.img_size, self.img_size))
        
        # Scale coordinates to new size and normalize between 0-1
        landmarks[:, 0] = (landmarks[:, 0] * (self.img_size / orig_w)) / self.img_size
        landmarks[:, 1] = (landmarks[:, 1] * (self.img_size / orig_h)) / self.img_size
        
        # Convert to PyTorch format (C, H, W)
        img_tensor = torch.tensor(img_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
        landmarks_tensor = torch.tensor(landmarks.flatten(), dtype=torch.float32)
        
        return img_tensor, landmarks_tensor

# 2. Model Architecture (MobileNetV3 Small)
def get_model():
    # Lightest and optimized architecture for low-end devices
    # Load pre-trained weights with MobileNet_V3_Small_Weights.DEFAULT
    model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
    
    # Modify the final classification layer of the model.
    # We will predict 18 coordinate values (9 points x 2) instead of classification (Regression).
    in_features = model.classifier[3].in_features
    model.classifier[3] = nn.Linear(in_features, 18)
    
    return model

# 3. Training Function
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, device="cpu"):
    model = model.to(device)
    criterion = nn.MSELoss() # Mean Squared Error is used for coordinate predictions
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    print(f"\n[TRAINING] Starting... Device: {device}")
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for images, landmarks in train_loader:
            images = images.to(device)
            landmarks = landmarks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, landmarks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            
        train_loss /= len(train_loader.dataset)
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, landmarks in val_loader:
                images = images.to(device)
                landmarks = landmarks.to(device)
                outputs = model(images)
                loss = criterion(outputs, landmarks)
                val_loss += loss.item() * images.size(0)
        val_loss /= len(val_loader.dataset)
        
        print(f"Epoch [{epoch+1}/{epochs}] -> Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
        
    return model

# 4. Export to ONNX Format
def export_to_onnx(model, save_path="cat_landmark_model.onnx"):
    model.eval()
    # Dummy input to pass through the model (Batch_size=1, Channel=3, H=224, W=224)
    dummy_input = torch.randn(1, 3, 224, 224).to(next(model.parameters()).device)
    
    print(f"\n[ONNX] Converting model to ONNX format...")
    torch.onnx.export(
        model, 
        dummy_input, 
        save_path, 
        export_params=True, 
        opset_version=11, 
        do_constant_folding=True,
        input_names=['input'], 
        output_names=['output']
    )
    print(f"[ONNX] Successfully saved: {save_path}")

# Main Execution
if __name__ == "__main__":
    # Folder paths (You can update this according to your file structure)
    data_dirs = ['/content/CAT_00', '/content/CAT_01', '/content/CAT_02', 
                 '/content/CAT_03', '/content/CAT_04', '/content/CAT_05', '/content/CAT_06']
    
    # Device Selection (GPU if CUDA is available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. Load Data
    full_dataset = CatLandmarkDataset(root_dirs=data_dirs, img_size=224)
    
    if len(full_dataset) == 0:
        print("[ERROR] No data found in the specified folders! Please check file paths.")
    else:
        # Split data into 90% Training - 10% Validation
        train_size = int(0.9 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
        
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
        
        # 2. Get Model
        cat_model = get_model()
        
        # 3. Train Model (Set to 5 epochs for quick Colab execution, increase if desired)
        trained_model = train_model(cat_model, train_loader, val_loader, epochs=5, lr=0.001, device=device)
        
        # 4. Save PyTorch model (As backup)
        torch.save(trained_model.state_dict(), "cat_landmark_model.pth")
        print("\n[SAVE] PyTorch weights saved (cat_landmark_model.pth)")
        
        # 5. Convert to ONNX format for running on low-end devices
        export_to_onnx(trained_model, save_path="cat_landmark_model.onnx")