| import os |
| import glob |
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import torchvision.models as models |
|
|
| |
| class CatLandmarkDataset(Dataset): |
| def __init__(self, root_dirs, img_size=224): |
| self.img_size = img_size |
| self.image_paths = [] |
| self.label_paths = [] |
| |
| for folder in root_dirs: |
| if not os.path.exists(folder): |
| continue |
| jpg_pattern = os.path.join(folder, "*.jpg") |
| for img_path in glob.glob(jpg_pattern): |
| cat_path = img_path + ".cat" |
| if os.path.exists(cat_path): |
| self.image_paths.append(img_path) |
| self.label_paths.append(cat_path) |
| |
| print(f"[DATA] Total matching cat images: {len(self.image_paths)}") |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| |
| img = cv2.imread(self.image_paths[idx]) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| orig_h, orig_w, _ = img.shape |
| |
| |
| with open(self.label_paths[idx], 'r') as f: |
| data = f.read().split() |
| landmarks = np.array([float(x) for x in data[1:]], dtype=np.float32) |
| landmarks = landmarks.reshape(-1, 2) |
| |
| |
| img_resized = cv2.resize(img, (self.img_size, self.img_size)) |
| |
| |
| landmarks[:, 0] = (landmarks[:, 0] * (self.img_size / orig_w)) / self.img_size |
| landmarks[:, 1] = (landmarks[:, 1] * (self.img_size / orig_h)) / self.img_size |
| |
| |
| img_tensor = torch.tensor(img_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0 |
| landmarks_tensor = torch.tensor(landmarks.flatten(), dtype=torch.float32) |
| |
| return img_tensor, landmarks_tensor |
|
|
| |
| def get_model(): |
| |
| |
| model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) |
| |
| |
| |
| in_features = model.classifier[3].in_features |
| model.classifier[3] = nn.Linear(in_features, 18) |
| |
| return model |
|
|
| |
| def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, device="cpu"): |
| model = model.to(device) |
| criterion = nn.MSELoss() |
| optimizer = optim.Adam(model.parameters(), lr=lr) |
| |
| print(f"\n[TRAINING] Starting... Device: {device}") |
| |
| for epoch in range(epochs): |
| model.train() |
| train_loss = 0.0 |
| |
| for images, landmarks in train_loader: |
| images = images.to(device) |
| landmarks = landmarks.to(device) |
| |
| optimizer.zero_grad() |
| outputs = model(images) |
| loss = criterion(outputs, landmarks) |
| loss.backward() |
| optimizer.step() |
| |
| train_loss += loss.item() * images.size(0) |
| |
| train_loss /= len(train_loader.dataset) |
| |
| |
| model.eval() |
| val_loss = 0.0 |
| with torch.no_grad(): |
| for images, landmarks in val_loader: |
| images = images.to(device) |
| landmarks = landmarks.to(device) |
| outputs = model(images) |
| loss = criterion(outputs, landmarks) |
| val_loss += loss.item() * images.size(0) |
| val_loss /= len(val_loader.dataset) |
| |
| print(f"Epoch [{epoch+1}/{epochs}] -> Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}") |
| |
| return model |
|
|
| |
| def export_to_onnx(model, save_path="cat_landmark_model.onnx"): |
| model.eval() |
| |
| dummy_input = torch.randn(1, 3, 224, 224).to(next(model.parameters()).device) |
| |
| print(f"\n[ONNX] Converting model to ONNX format...") |
| torch.onnx.export( |
| model, |
| dummy_input, |
| save_path, |
| export_params=True, |
| opset_version=11, |
| do_constant_folding=True, |
| input_names=['input'], |
| output_names=['output'] |
| ) |
| print(f"[ONNX] Successfully saved: {save_path}") |
|
|
| |
| if __name__ == "__main__": |
| |
| data_dirs = ['/content/CAT_00', '/content/CAT_01', '/content/CAT_02', |
| '/content/CAT_03', '/content/CAT_04', '/content/CAT_05', '/content/CAT_06'] |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| full_dataset = CatLandmarkDataset(root_dirs=data_dirs, img_size=224) |
| |
| if len(full_dataset) == 0: |
| print("[ERROR] No data found in the specified folders! Please check file paths.") |
| else: |
| |
| train_size = int(0.9 * len(full_dataset)) |
| val_size = len(full_dataset) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size]) |
| |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) |
| |
| |
| cat_model = get_model() |
| |
| |
| trained_model = train_model(cat_model, train_loader, val_loader, epochs=5, lr=0.001, device=device) |
| |
| |
| torch.save(trained_model.state_dict(), "cat_landmark_model.pth") |
| print("\n[SAVE] PyTorch weights saved (cat_landmark_model.pth)") |
| |
| |
| export_to_onnx(trained_model, save_path="cat_landmark_model.onnx") |
|
|