cat-detection / train.py
Isa0's picture
feat: add training code
df0b32f
import os
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
# 1. Dataset Definition
class CatLandmarkDataset(Dataset):
def __init__(self, root_dirs, img_size=224):
self.img_size = img_size
self.image_paths = []
self.label_paths = []
for folder in root_dirs:
if not os.path.exists(folder):
continue
jpg_pattern = os.path.join(folder, "*.jpg")
for img_path in glob.glob(jpg_pattern):
cat_path = img_path + ".cat"
if os.path.exists(cat_path):
self.image_paths.append(img_path)
self.label_paths.append(cat_path)
print(f"[DATA] Total matching cat images: {len(self.image_paths)}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Read image and convert to RGB
img = cv2.imread(self.image_paths[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
orig_h, orig_w, _ = img.shape
# Read coordinates from .cat file
with open(self.label_paths[idx], 'r') as f:
data = f.read().split()
landmarks = np.array([float(x) for x in data[1:]], dtype=np.float32)
landmarks = landmarks.reshape(-1, 2)
# Resize image to 224x224
img_resized = cv2.resize(img, (self.img_size, self.img_size))
# Scale coordinates to new size and normalize between 0-1
landmarks[:, 0] = (landmarks[:, 0] * (self.img_size / orig_w)) / self.img_size
landmarks[:, 1] = (landmarks[:, 1] * (self.img_size / orig_h)) / self.img_size
# Convert to PyTorch format (C, H, W)
img_tensor = torch.tensor(img_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
landmarks_tensor = torch.tensor(landmarks.flatten(), dtype=torch.float32)
return img_tensor, landmarks_tensor
# 2. Model Architecture (MobileNetV3 Small)
def get_model():
# Lightest and optimized architecture for low-end devices
# Load pre-trained weights with MobileNet_V3_Small_Weights.DEFAULT
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
# Modify the final classification layer of the model.
# We will predict 18 coordinate values (9 points x 2) instead of classification (Regression).
in_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(in_features, 18)
return model
# 3. Training Function
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, device="cpu"):
model = model.to(device)
criterion = nn.MSELoss() # Mean Squared Error is used for coordinate predictions
optimizer = optim.Adam(model.parameters(), lr=lr)
print(f"\n[TRAINING] Starting... Device: {device}")
for epoch in range(epochs):
model.train()
train_loss = 0.0
for images, landmarks in train_loader:
images = images.to(device)
landmarks = landmarks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, landmarks)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
# Validation Phase
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, landmarks in val_loader:
images = images.to(device)
landmarks = landmarks.to(device)
outputs = model(images)
loss = criterion(outputs, landmarks)
val_loss += loss.item() * images.size(0)
val_loss /= len(val_loader.dataset)
print(f"Epoch [{epoch+1}/{epochs}] -> Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
return model
# 4. Export to ONNX Format
def export_to_onnx(model, save_path="cat_landmark_model.onnx"):
model.eval()
# Dummy input to pass through the model (Batch_size=1, Channel=3, H=224, W=224)
dummy_input = torch.randn(1, 3, 224, 224).to(next(model.parameters()).device)
print(f"\n[ONNX] Converting model to ONNX format...")
torch.onnx.export(
model,
dummy_input,
save_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output']
)
print(f"[ONNX] Successfully saved: {save_path}")
# Main Execution
if __name__ == "__main__":
# Folder paths (You can update this according to your file structure)
data_dirs = ['/content/CAT_00', '/content/CAT_01', '/content/CAT_02',
'/content/CAT_03', '/content/CAT_04', '/content/CAT_05', '/content/CAT_06']
# Device Selection (GPU if CUDA is available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Load Data
full_dataset = CatLandmarkDataset(root_dirs=data_dirs, img_size=224)
if len(full_dataset) == 0:
print("[ERROR] No data found in the specified folders! Please check file paths.")
else:
# Split data into 90% Training - 10% Validation
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 2. Get Model
cat_model = get_model()
# 3. Train Model (Set to 5 epochs for quick Colab execution, increase if desired)
trained_model = train_model(cat_model, train_loader, val_loader, epochs=5, lr=0.001, device=device)
# 4. Save PyTorch model (As backup)
torch.save(trained_model.state_dict(), "cat_landmark_model.pth")
print("\n[SAVE] PyTorch weights saved (cat_landmark_model.pth)")
# 5. Convert to ONNX format for running on low-end devices
export_to_onnx(trained_model, save_path="cat_landmark_model.onnx")