""" Training script for PaDiM anomaly detection model """ import torch import numpy as np from tqdm import tqdm from pathlib import Path import sys # Add parent directory to path sys.path.append(str(Path(__file__).parent.parent)) import config from src.data_loader import get_dataloader from src.feature_extractor import FeatureExtractor, extract_embeddings from src.padim import PaDiM def train_padim(): """Train PaDiM model on normal training data""" print("=" * 60) print("AUTOMATED TABLET DEFECT DETECTION - TRAINING") print("=" * 60) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Initialize feature extractor print("\nInitializing feature extractor...") extractor = FeatureExtractor( backbone=config.BACKBONE, layers=config.FEATURE_LAYERS ).to(device) # Display feature dimensions dims = extractor.get_feature_dimensions() print("\nFeature dimensions:") for layer, dim_info in dims.items(): print(f" {layer}: {dim_info}") # Load training data (only good samples) print(f"\nLoading training data from {config.TRAIN_DIR}...") train_loader = get_dataloader( config.TRAIN_DIR, batch_size=config.BATCH_SIZE, shuffle=False ) print(f"Training samples: {len(train_loader.dataset)}") # Extract embeddings from all training samples print("\nExtracting features from training data...") all_embeddings = [] with torch.no_grad(): for batch_idx, (images, paths, _) in enumerate(tqdm(train_loader)): images = images.to(device) # Extract multi-scale embeddings embeddings = extract_embeddings(extractor, images) all_embeddings.append(embeddings.cpu().numpy()) # Concatenate all embeddings all_embeddings = np.concatenate(all_embeddings, axis=0) print(f"Embeddings shape: {all_embeddings.shape}") # Train PaDiM model print("\nTraining PaDiM model...") padim_model = PaDiM( reduce_dim=config.REDUCE_DIM, epsilon=config.EPSILON ) padim_model.fit(all_embeddings) # Save model model_path = config.MODEL_DIR / "padim_model.pkl" padim_model.save(model_path) print("\n" + "=" * 60) print("TRAINING COMPLETED SUCCESSFULLY!") print("=" * 60) print(f"Model saved to: {model_path}") return padim_model, extractor if __name__ == "__main__": train_padim()