File size: 4,069 Bytes
9602c20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Standalone inference script for single image prediction
"""

import torch
import numpy as np
from PIL import Image
import argparse
from pathlib import Path
import sys

sys.path.append(str(Path(__file__).parent))

import config
from src.feature_extractor import FeatureExtractor, extract_embeddings
from src.padim import PaDiM
from src.visualize import save_prediction


def predict_single_image(image_path: str, 
                         model_path: str = None,
                         threshold: float = 15.0,
                         save_result: bool = True) -> dict:
    """
    Run inference on a single image
    
    Args:
        image_path: Path to input image
        model_path: Path to trained PaDiM model (default: models/padim_model.pkl)
        threshold: Anomaly threshold
        save_result: Whether to save visualization
        
    Returns:
        Dictionary with prediction results
    """
    if model_path is None:
        model_path = config.MODEL_DIR / "padim_model.pkl"
    
    # Check files exist
    if not Path(image_path).exists():
        raise FileNotFoundError(f"Image not found: {image_path}")
    
    if not Path(model_path).exists():
        raise FileNotFoundError(f"Model not found: {model_path}. Run train.py first.")
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    print("Loading model...")
    padim_model = PaDiM()
    padim_model.load(model_path)
    
    # Load feature extractor
    print("Loading feature extractor...")
    extractor = FeatureExtractor(
        backbone=config.BACKBONE,
        layers=config.FEATURE_LAYERS
    ).to(device)
    
    # Load and preprocess image
    print(f"Processing image: {image_path}")
    image = Image.open(image_path).convert("RGB")
    
    from src.data_loader import load_single_image
    img_tensor, original = load_single_image(image_path)
    img_tensor = img_tensor.to(device)
    
    # Extract features
    print("Extracting features...")
    with torch.no_grad():
        embeddings = extract_embeddings(extractor, img_tensor)
    
    # Predict
    print("Computing anomaly score...")
    embeddings_np = embeddings.cpu().numpy()
    anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
    
    # Make decision
    is_defective = anomaly_score > threshold
    prediction = "DEFECTIVE" if is_defective else "NORMAL"
    
    # Print results
    print("\n" + "=" * 60)
    print(f"PREDICTION: {prediction}")
    print(f"Anomaly Score: {anomaly_score:.4f}")
    print(f"Threshold: {threshold:.4f}")
    print("=" * 60)
    
    # Save visualization
    if save_result:
        output_path = config.RESULTS_DIR / f"prediction_{Path(image_path).stem}.png"
        save_prediction(image, anomaly_score, anomaly_map, str(output_path), threshold)
        print(f"\nResult saved to: {output_path}")
    
    return {
        'image_path': str(image_path),
        'prediction': prediction,
        'anomaly_score': float(anomaly_score),
        'threshold': threshold,
        'is_defective': is_defective
    }


def main():
    parser = argparse.ArgumentParser(
        description="Run inference on a single tablet image"
    )
    parser.add_argument(
        'image_path',
        type=str,
        help='Path to input image'
    )
    parser.add_argument(
        '--model',
        type=str,
        default=None,
        help='Path to trained model (default: models/padim_model.pkl)'
    )
    parser.add_argument(
        '--threshold',
        type=float,
        default=15.0,
        help='Anomaly threshold for Mahalanobis distance (default: 15.0)'
    )
    parser.add_argument(
        '--no-save',
        action='store_true',
        help='Do not save result visualization'
    )
    
    args = parser.parse_args()
    
    predict_single_image(
        image_path=args.image_path,
        model_path=args.model,
        threshold=args.threshold,
        save_result=not args.no_save
    )


if __name__ == "__main__":
    main()