Ameya729 commited on
Commit
9602c20
·
verified ·
1 Parent(s): 3d21b4a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +144 -144
inference.py CHANGED
@@ -1,144 +1,144 @@
1
- """
2
- Standalone inference script for single image prediction
3
- """
4
-
5
- import torch
6
- import numpy as np
7
- from PIL import Image
8
- import argparse
9
- from pathlib import Path
10
- import sys
11
-
12
- sys.path.append(str(Path(__file__).parent))
13
-
14
- import config
15
- from src.feature_extractor import FeatureExtractor, extract_embeddings
16
- from src.padim import PaDiM
17
- from src.visualize import save_prediction
18
-
19
-
20
- def predict_single_image(image_path: str,
21
- model_path: str = None,
22
- threshold: float = 0.5,
23
- save_result: bool = True) -> dict:
24
- """
25
- Run inference on a single image
26
-
27
- Args:
28
- image_path: Path to input image
29
- model_path: Path to trained PaDiM model (default: models/padim_model.pkl)
30
- threshold: Anomaly threshold
31
- save_result: Whether to save visualization
32
-
33
- Returns:
34
- Dictionary with prediction results
35
- """
36
- if model_path is None:
37
- model_path = config.MODEL_DIR / "padim_model.pkl"
38
-
39
- # Check files exist
40
- if not Path(image_path).exists():
41
- raise FileNotFoundError(f"Image not found: {image_path}")
42
-
43
- if not Path(model_path).exists():
44
- raise FileNotFoundError(f"Model not found: {model_path}. Run train.py first.")
45
-
46
- # Set device
47
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
- print(f"Using device: {device}")
49
-
50
- # Load model
51
- print("Loading model...")
52
- padim_model = PaDiM()
53
- padim_model.load(model_path)
54
-
55
- # Load feature extractor
56
- print("Loading feature extractor...")
57
- extractor = FeatureExtractor(
58
- backbone=config.BACKBONE,
59
- layers=config.FEATURE_LAYERS
60
- ).to(device)
61
-
62
- # Load and preprocess image
63
- print(f"Processing image: {image_path}")
64
- image = Image.open(image_path).convert("RGB")
65
-
66
- from src.data_loader import load_single_image
67
- img_tensor, original = load_single_image(image_path)
68
- img_tensor = img_tensor.to(device)
69
-
70
- # Extract features
71
- print("Extracting features...")
72
- with torch.no_grad():
73
- embeddings = extract_embeddings(extractor, img_tensor)
74
-
75
- # Predict
76
- print("Computing anomaly score...")
77
- embeddings_np = embeddings.cpu().numpy()
78
- anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
79
-
80
- # Make decision
81
- is_defective = anomaly_score > threshold
82
- prediction = "DEFECTIVE" if is_defective else "NORMAL"
83
-
84
- # Print results
85
- print("\n" + "=" * 60)
86
- print(f"PREDICTION: {prediction}")
87
- print(f"Anomaly Score: {anomaly_score:.4f}")
88
- print(f"Threshold: {threshold:.4f}")
89
- print("=" * 60)
90
-
91
- # Save visualization
92
- if save_result:
93
- output_path = config.RESULTS_DIR / f"prediction_{Path(image_path).stem}.png"
94
- save_prediction(image, anomaly_score, anomaly_map, str(output_path), threshold)
95
- print(f"\nResult saved to: {output_path}")
96
-
97
- return {
98
- 'image_path': str(image_path),
99
- 'prediction': prediction,
100
- 'anomaly_score': float(anomaly_score),
101
- 'threshold': threshold,
102
- 'is_defective': is_defective
103
- }
104
-
105
-
106
- def main():
107
- parser = argparse.ArgumentParser(
108
- description="Run inference on a single tablet image"
109
- )
110
- parser.add_argument(
111
- 'image_path',
112
- type=str,
113
- help='Path to input image'
114
- )
115
- parser.add_argument(
116
- '--model',
117
- type=str,
118
- default=None,
119
- help='Path to trained model (default: models/padim_model.pkl)'
120
- )
121
- parser.add_argument(
122
- '--threshold',
123
- type=float,
124
- default=0.5,
125
- help='Anomaly threshold (default: 0.5)'
126
- )
127
- parser.add_argument(
128
- '--no-save',
129
- action='store_true',
130
- help='Do not save result visualization'
131
- )
132
-
133
- args = parser.parse_args()
134
-
135
- predict_single_image(
136
- image_path=args.image_path,
137
- model_path=args.model,
138
- threshold=args.threshold,
139
- save_result=not args.no_save
140
- )
141
-
142
-
143
- if __name__ == "__main__":
144
- main()
 
1
+ """
2
+ Standalone inference script for single image prediction
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ import argparse
9
+ from pathlib import Path
10
+ import sys
11
+
12
+ sys.path.append(str(Path(__file__).parent))
13
+
14
+ import config
15
+ from src.feature_extractor import FeatureExtractor, extract_embeddings
16
+ from src.padim import PaDiM
17
+ from src.visualize import save_prediction
18
+
19
+
20
+ def predict_single_image(image_path: str,
21
+ model_path: str = None,
22
+ threshold: float = 15.0,
23
+ save_result: bool = True) -> dict:
24
+ """
25
+ Run inference on a single image
26
+
27
+ Args:
28
+ image_path: Path to input image
29
+ model_path: Path to trained PaDiM model (default: models/padim_model.pkl)
30
+ threshold: Anomaly threshold
31
+ save_result: Whether to save visualization
32
+
33
+ Returns:
34
+ Dictionary with prediction results
35
+ """
36
+ if model_path is None:
37
+ model_path = config.MODEL_DIR / "padim_model.pkl"
38
+
39
+ # Check files exist
40
+ if not Path(image_path).exists():
41
+ raise FileNotFoundError(f"Image not found: {image_path}")
42
+
43
+ if not Path(model_path).exists():
44
+ raise FileNotFoundError(f"Model not found: {model_path}. Run train.py first.")
45
+
46
+ # Set device
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ print(f"Using device: {device}")
49
+
50
+ # Load model
51
+ print("Loading model...")
52
+ padim_model = PaDiM()
53
+ padim_model.load(model_path)
54
+
55
+ # Load feature extractor
56
+ print("Loading feature extractor...")
57
+ extractor = FeatureExtractor(
58
+ backbone=config.BACKBONE,
59
+ layers=config.FEATURE_LAYERS
60
+ ).to(device)
61
+
62
+ # Load and preprocess image
63
+ print(f"Processing image: {image_path}")
64
+ image = Image.open(image_path).convert("RGB")
65
+
66
+ from src.data_loader import load_single_image
67
+ img_tensor, original = load_single_image(image_path)
68
+ img_tensor = img_tensor.to(device)
69
+
70
+ # Extract features
71
+ print("Extracting features...")
72
+ with torch.no_grad():
73
+ embeddings = extract_embeddings(extractor, img_tensor)
74
+
75
+ # Predict
76
+ print("Computing anomaly score...")
77
+ embeddings_np = embeddings.cpu().numpy()
78
+ anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
79
+
80
+ # Make decision
81
+ is_defective = anomaly_score > threshold
82
+ prediction = "DEFECTIVE" if is_defective else "NORMAL"
83
+
84
+ # Print results
85
+ print("\n" + "=" * 60)
86
+ print(f"PREDICTION: {prediction}")
87
+ print(f"Anomaly Score: {anomaly_score:.4f}")
88
+ print(f"Threshold: {threshold:.4f}")
89
+ print("=" * 60)
90
+
91
+ # Save visualization
92
+ if save_result:
93
+ output_path = config.RESULTS_DIR / f"prediction_{Path(image_path).stem}.png"
94
+ save_prediction(image, anomaly_score, anomaly_map, str(output_path), threshold)
95
+ print(f"\nResult saved to: {output_path}")
96
+
97
+ return {
98
+ 'image_path': str(image_path),
99
+ 'prediction': prediction,
100
+ 'anomaly_score': float(anomaly_score),
101
+ 'threshold': threshold,
102
+ 'is_defective': is_defective
103
+ }
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser(
108
+ description="Run inference on a single tablet image"
109
+ )
110
+ parser.add_argument(
111
+ 'image_path',
112
+ type=str,
113
+ help='Path to input image'
114
+ )
115
+ parser.add_argument(
116
+ '--model',
117
+ type=str,
118
+ default=None,
119
+ help='Path to trained model (default: models/padim_model.pkl)'
120
+ )
121
+ parser.add_argument(
122
+ '--threshold',
123
+ type=float,
124
+ default=15.0,
125
+ help='Anomaly threshold for Mahalanobis distance (default: 15.0)'
126
+ )
127
+ parser.add_argument(
128
+ '--no-save',
129
+ action='store_true',
130
+ help='Do not save result visualization'
131
+ )
132
+
133
+ args = parser.parse_args()
134
+
135
+ predict_single_image(
136
+ image_path=args.image_path,
137
+ model_path=args.model,
138
+ threshold=args.threshold,
139
+ save_result=not args.no_save
140
+ )
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()