Senum2001 commited on
Commit
9cf599c
·
1 Parent(s): 8a7ddf0

Deploy Anomaly Detection API

Browse files
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Spaces
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for OpenCV
8
+ RUN apt-get update && apt-get install -y \
9
+ libgl1 \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements and install dependencies
14
+ COPY requirements.txt ./
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Download model checkpoint from Google Drive
18
+ RUN pip install gdown && \
19
+ gdown --id 1ftzxTJUnlxpQFqPlaUozG_JUbl1Qi5tQ -O /app/model_checkpoint.ckpt
20
+
21
+ # Copy all project files
22
+ COPY . .
23
+
24
+ # Expose port for Hugging Face Spaces
25
+ EXPOSE 7860
26
+
27
+ # Set environment variables
28
+ ENV PYTHONUNBUFFERED=1
29
+ ENV PORT=7860
30
+
31
+ # Start Flask app (direct JSON responses)
32
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,10 +1,98 @@
1
  ---
2
- title: Anomaly Detection Api
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Anomaly Detection API
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # 🔍 Anomaly Detection API
12
+
13
+ Real-time anomaly detection for electrical components using PatchCore + OpenCV classification.
14
+
15
+ ## 🚀 Quick Start
16
+
17
+ ### API Endpoint
18
+
19
+ **POST** `/infer`
20
+
21
+ **Request:**
22
+ ```json
23
+ {
24
+ "image_url": "https://example.com/your-image.jpg"
25
+ }
26
+ ```
27
+
28
+ **Response:**
29
+ ```json
30
+ {
31
+ "label": "Normal",
32
+ "boxed_url": "https://cloudinary.com/boxed_image.jpg",
33
+ "mask_url": "https://cloudinary.com/anomaly_mask.png",
34
+ "filtered_url": "https://cloudinary.com/filtered_anomalies.png",
35
+ "boxes": []
36
+ }
37
+ ```
38
+
39
+ ### Example Usage
40
+
41
+ ```bash
42
+ curl -X POST "https://YOUR_USERNAME-anomaly-detection-api.hf.space/infer" \
43
+ -H "Content-Type: application/json" \
44
+ -d '{"image_url": "https://example.com/test.jpg"}'
45
+ ```
46
+
47
+ ```python
48
+ import requests
49
+
50
+ response = requests.post(
51
+ "https://YOUR_USERNAME-anomaly-detection-api.hf.space/infer",
52
+ json={"image_url": "https://example.com/test.jpg"}
53
+ )
54
+
55
+ result = response.json()
56
+ print(f"Classification: {result['label']}")
57
+ print(f"Boxed Image: {result['boxed_url']}")
58
+ ```
59
+
60
+ ## 📋 Classification Labels
61
+
62
+ - **Normal** - No anomalies detected
63
+ - **Full Wire Overload** - Entire wire showing overload
64
+ - **Point Overload (Faulty)** - Localized overload points
65
+
66
+ ## 🔧 Technical Details
67
+
68
+ - **Model:** PatchCore (anomaly detection)
69
+ - **Classification:** OpenCV-based heuristics
70
+ - **Response Time:** ~5 seconds
71
+ - **Max Image Size:** Unlimited (auto-resized)
72
+
73
+ ## 🌐 Endpoints
74
+
75
+ | Endpoint | Method | Description |
76
+ |----------|--------|-------------|
77
+ | `/` | GET | API documentation |
78
+ | `/health` | GET | Health check |
79
+ | `/infer` | POST | Run inference |
80
+
81
+ ## 📦 Output Files
82
+
83
+ All processed images are uploaded to Cloudinary:
84
+ - **boxed_url:** Original image with bounding boxes
85
+ - **mask_url:** Grayscale anomaly heatmap
86
+ - **filtered_url:** Filtered image showing only anomalous regions
87
+
88
+ ## 🛠️ Built With
89
+
90
+ - PyTorch 2.4.1
91
+ - Anomalib (PatchCore)
92
+ - OpenCV
93
+ - Flask
94
+ - Cloudinary
95
+
96
+ ## 📄 License
97
+
98
+ MIT License
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces API wrapper
3
+ Provides direct JSON responses without job queues
4
+ """
5
+ from flask import Flask, request, jsonify
6
+ from inference_core import run_pipeline_for_image, download_image_from_url, upload_to_cloudinary
7
+ import os
8
+
9
+ app = Flask(__name__)
10
+
11
+
12
+ @app.route("/", methods=["GET"])
13
+ def home():
14
+ """Home page with API documentation"""
15
+ return jsonify({
16
+ "service": "Anomaly Detection API",
17
+ "version": "1.0",
18
+ "endpoints": {
19
+ "/health": "GET - Health check",
20
+ "/infer": "POST - Run inference on image URL"
21
+ },
22
+ "example_request": {
23
+ "method": "POST",
24
+ "url": "/infer",
25
+ "body": {
26
+ "image_url": "https://example.com/image.jpg"
27
+ }
28
+ }
29
+ })
30
+
31
+
32
+ @app.route("/health", methods=["GET"])
33
+ def health():
34
+ """Health check endpoint"""
35
+ return jsonify({"status": "healthy"}), 200
36
+
37
+
38
+ @app.route("/infer", methods=["POST"])
39
+ def infer():
40
+ """
41
+ Inference endpoint - returns direct JSON response
42
+ Request JSON: {"image_url": "https://..."}
43
+ """
44
+ try:
45
+ data = request.get_json()
46
+ if not data or "image_url" not in data:
47
+ return jsonify({"error": "Missing image_url"}), 400
48
+
49
+ image_url = data["image_url"]
50
+
51
+ # Download image
52
+ local_path = download_image_from_url(image_url)
53
+
54
+ # Run pipeline
55
+ results = run_pipeline_for_image(local_path)
56
+
57
+ # Upload outputs
58
+ boxed_url = upload_to_cloudinary(results["boxed_path"], folder="pipeline_outputs") if results["boxed_path"] else None
59
+ mask_url = upload_to_cloudinary(results["mask_path"], folder="pipeline_outputs") if results["mask_path"] else None
60
+ filtered_url = upload_to_cloudinary(results["filtered_path"], folder="pipeline_outputs") if results["filtered_path"] else None
61
+
62
+ # Clean up
63
+ if os.path.exists(local_path):
64
+ os.remove(local_path)
65
+
66
+ # Direct JSON response (no job queue wrapper)
67
+ return jsonify({
68
+ "label": results["label"],
69
+ "boxed_url": boxed_url,
70
+ "mask_url": mask_url,
71
+ "filtered_url": filtered_url,
72
+ "boxes": results.get("boxes", [])
73
+ })
74
+
75
+ except Exception as e:
76
+ return jsonify({"error": str(e)}), 500
77
+
78
+
79
+ if __name__ == "__main__":
80
+ # For Hugging Face Spaces, use port 7860
81
+ port = int(os.environ.get("PORT", 7860))
82
+ app.run(host="0.0.0.0", port=port, debug=False)
configs/patchcore_transformers.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ class_path: anomalib.models.Patchcore
3
+ init_args:
4
+ backbone: "wide_resnet50_2"
5
+ layers: ["layer2", "layer3"]
6
+ coreset_sampling_ratio: 0.1
7
+ num_neighbors: 9
8
+
9
+ data:
10
+ # You can also use: class_path: anomalib.data.Folder (re-export)
11
+ class_path: anomalib.data.datamodules.image.folder.Folder
12
+ init_args:
13
+ name: "transformers"
14
+ root: "./dataset"
15
+ normal_dir: "train/normal" # ONLY normal images here
16
+ abnormal_dir: "test/faulty" # faulty images for eval
17
+ normal_test_dir: "test/normal" # a few normals for eval
18
+ train_batch_size: 4
19
+ eval_batch_size: 4
20
+ num_workers: 4
21
+ # (optional) add augmentations later via train_augmentations/val_augmentations/test_augmentations
22
+
23
+ trainer:
24
+ max_epochs: 1
25
+ accelerator: "auto"
26
+ devices: 1
27
+ enable_checkpointing: true
inference_core.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core inference module - contains model loading and inference functions
3
+ Can be imported by both Flask app and RunPod handler
4
+ """
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import subprocess
11
+ import sys
12
+ import requests
13
+ import tempfile
14
+ import cloudinary
15
+ import cloudinary.uploader
16
+
17
+ # ---- Import your PatchCore API ----
18
+ from scripts.patchcore_api_inference import Patchcore, config, device
19
+
20
+ # ---- Output directories ----
21
+ OUT_MASK_DIR = "api_inference_pred_masks_pipeline"
22
+ OUT_FILTERED_DIR = "api_inference_filtered_pipeline"
23
+ OUT_BOXED_DIR = "api_inference_labeled_boxes_pipeline"
24
+
25
+ os.makedirs(OUT_MASK_DIR, exist_ok=True)
26
+ os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
27
+ os.makedirs(OUT_BOXED_DIR, exist_ok=True)
28
+
29
+ # ---- Cloudinary config ----
30
+ cloudinary.config(
31
+ cloud_name="dtyjmwyrp",
32
+ api_key="619824242791553",
33
+ api_secret="l8hHU1GIg1FJ8rDgvHd4Sf7BWMk"
34
+ )
35
+
36
+ # ---- Load model once ----
37
+ GDRIVE_URL = "1ftzxTJUnlxpQFqPlaUozG_JUbl1Qi5tQ"
38
+ MODEL_CKPT_PATH = os.path.abspath("model_checkpoint.ckpt")
39
+ try:
40
+ import gdown
41
+ except ImportError:
42
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
43
+ import gdown
44
+
45
+ if not os.path.exists(MODEL_CKPT_PATH):
46
+ raise FileNotFoundError(f"Model checkpoint not found at {MODEL_CKPT_PATH}. Please rebuild the Docker image to include the model.")
47
+ else:
48
+ print(f"[INFO] Model checkpoint already exists at {MODEL_CKPT_PATH}, skipping download.")
49
+
50
+ model = Patchcore.load_from_checkpoint(MODEL_CKPT_PATH, **config.model.init_args)
51
+ model.eval()
52
+ model = model.to(device)
53
+ print("[INFO] Model loaded and ready for inference")
54
+
55
+
56
+ def infer_single_image_with_patchcore(image_path: str):
57
+ """PatchCore inference on a single image"""
58
+ fixed_path = os.path.abspath(os.path.normpath(image_path))
59
+ orig_img = Image.open(fixed_path).convert("RGB")
60
+ orig_w, orig_h = orig_img.size
61
+
62
+ img_resized = orig_img.resize((256, 256))
63
+ img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).float() / 255.0
64
+ img_tensor = img_tensor.unsqueeze(0).to(device)
65
+
66
+ with torch.no_grad():
67
+ output = model(img_tensor)
68
+ if hasattr(output, "anomaly_map"):
69
+ anomaly_map = output.anomaly_map.squeeze().detach().cpu().numpy()
70
+ elif isinstance(output, (tuple, list)) and len(output) > 1:
71
+ anomaly_map = output[1].squeeze().detach().cpu().numpy()
72
+ else:
73
+ anomaly_map = None
74
+
75
+ base = os.path.splitext(os.path.basename(fixed_path))[0]
76
+ mask_path = None
77
+ filtered_path = None
78
+
79
+ if anomaly_map is not None:
80
+ norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
81
+ if norm_map.ndim > 2:
82
+ norm_map = np.squeeze(norm_map)
83
+ if norm_map.ndim > 2:
84
+ norm_map = norm_map[0]
85
+
86
+ mask_img_256 = Image.fromarray(norm_map)
87
+ mask_img = mask_img_256.resize((orig_w, orig_h), resample=Image.BILINEAR)
88
+
89
+ mask_path = os.path.join(OUT_MASK_DIR, f"{base}_mask.png")
90
+ mask_img.save(mask_path)
91
+
92
+ bin_mask = np.array(mask_img) > 128
93
+ orig_np = np.array(orig_img)
94
+ filtered_np = np.zeros_like(orig_np)
95
+ filtered_np[bin_mask] = orig_np[bin_mask]
96
+ filtered_img = Image.fromarray(filtered_np)
97
+
98
+ filtered_path = os.path.join(OUT_FILTERED_DIR, f"{base}_filtered.png")
99
+ filtered_img.save(filtered_path)
100
+
101
+ print(f"[PatchCore] Saved mask -> {mask_path}")
102
+ print(f"[PatchCore] Saved filtered -> {filtered_path}")
103
+ else:
104
+ print("[PatchCore] No anomaly_map produced by model.")
105
+
106
+ return {
107
+ "orig_path": fixed_path,
108
+ "mask_path": mask_path,
109
+ "filtered_path": filtered_path,
110
+ "orig_size": (orig_w, orig_h),
111
+ }
112
+
113
+
114
+ def classify_filtered_image(filtered_img_path: str):
115
+ """OpenCV heuristic classification on filtered image"""
116
+ img = cv2.imread(filtered_img_path)
117
+ if img is None:
118
+ raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
119
+
120
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
121
+
122
+ # Color masks
123
+ blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
124
+ black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
125
+ yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
126
+ orange_mask = cv2.inRange(hsv, (10, 100, 100), (25, 255, 255))
127
+ red_mask1 = cv2.inRange(hsv, (0, 100, 100), (10, 255, 255))
128
+ red_mask2 = cv2.inRange(hsv, (160, 100, 100), (180, 255, 255))
129
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
130
+
131
+ total = img.shape[0] * img.shape[1]
132
+ blue_count = np.sum(blue_mask > 0)
133
+ black_count = np.sum(black_mask > 0)
134
+ yellow_count = np.sum(yellow_mask > 0)
135
+ orange_count = np.sum(orange_mask > 0)
136
+ red_count = np.sum(red_mask > 0)
137
+
138
+ label = "Unknown"
139
+ box_list, label_list = [], []
140
+
141
+ # Simplified classification logic (keeping only essential parts)
142
+ if (blue_count + black_count) / total > 0.8:
143
+ label = "Normal"
144
+ elif (red_count + orange_count + yellow_count) / total > 0.7:
145
+ label = "Full Wire Overload"
146
+ box_list.append((0, 0, img.shape[1], img.shape[0]))
147
+ label_list.append(label)
148
+ else:
149
+ # Point overloads detection (simplified)
150
+ min_area_faulty = 120
151
+ max_area = 0.05 * total
152
+
153
+ contours, _ = cv2.findContours(red_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
154
+ for cnt in contours:
155
+ area = cv2.contourArea(cnt)
156
+ if min_area_faulty < area < max_area:
157
+ x, y, w, h = cv2.boundingRect(cnt)
158
+ box_list.append((x, y, w, h))
159
+ label_list.append("Point Overload (Faulty)")
160
+
161
+ return label, box_list, label_list, img
162
+
163
+
164
+ def run_pipeline_for_image(image_path: str):
165
+ """Complete pipeline: PatchCore + classification + drawing"""
166
+ # 1) PatchCore inference
167
+ pc_out = infer_single_image_with_patchcore(image_path)
168
+ filtered_path = pc_out["filtered_path"]
169
+ orig_path = pc_out["orig_path"]
170
+
171
+ if filtered_path is None:
172
+ filtered_path = orig_path
173
+
174
+ # 2) Classify
175
+ label, boxes, labels, _filtered_bgr = classify_filtered_image(filtered_path)
176
+
177
+ # 3) Draw boxes on original image
178
+ draw_img = cv2.imread(orig_path)
179
+ if draw_img is None:
180
+ raise FileNotFoundError(f"Could not read original image: {orig_path}")
181
+
182
+ for (x, y, w, h), l in zip(boxes, labels):
183
+ cv2.rectangle(draw_img, (x, y), (x + w, y + h), (0, 0, 255), 2)
184
+ cv2.putText(draw_img, l, (x, max(0, y - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
185
+
186
+ if not boxes:
187
+ cv2.putText(draw_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
188
+
189
+ base = os.path.splitext(os.path.basename(orig_path))[0]
190
+ ext = os.path.splitext(os.path.basename(orig_path))[1]
191
+ out_boxed_path = os.path.join(OUT_BOXED_DIR, f"{base}_boxed{ext if ext else '.png'}")
192
+ ok = cv2.imwrite(out_boxed_path, draw_img)
193
+ if not ok:
194
+ out_boxed_path = os.path.join(OUT_BOXED_DIR, f"{base}_boxed.png")
195
+ cv2.imwrite(out_boxed_path, draw_img)
196
+
197
+ print(f"[Pipeline] Classification label: {label}")
198
+ print(f"[Pipeline] Saved boxes-on-original -> {out_boxed_path}")
199
+
200
+ return {
201
+ "label": label,
202
+ "boxed_path": out_boxed_path,
203
+ "mask_path": pc_out["mask_path"],
204
+ "filtered_path": pc_out["filtered_path"],
205
+ "boxes": [
206
+ {"box": [int(x), int(y), int(w), int(h)], "type": l}
207
+ for (x, y, w, h), l in zip(boxes, labels)
208
+ ]
209
+ }
210
+
211
+
212
+ def download_image_from_url(url):
213
+ """Download image from URL to temp file"""
214
+ import requests
215
+ import tempfile
216
+ response = requests.get(url, stream=True)
217
+ if response.status_code != 200:
218
+ raise Exception(f"Failed to download image from {url}")
219
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
220
+ for chunk in response.iter_content(1024):
221
+ tmp.write(chunk)
222
+ tmp.close()
223
+ return tmp.name
224
+
225
+
226
+ def upload_to_cloudinary(file_path, folder=None):
227
+ """Upload file to Cloudinary"""
228
+ upload_opts = {"resource_type": "image"}
229
+ if folder:
230
+ upload_opts["folder"] = folder
231
+ result = cloudinary.uploader.upload(file_path, **upload_opts)
232
+ return result["secure_url"]
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anomalib[full]
2
+ cloudinary
3
+ flask
4
+ gdown
5
+ numpy
6
+ opencv-python
7
+ omegaconf
8
+ pillow
9
+ python-dotenv
10
+ pytorch-lightning==2.4.0
11
+ requests
12
+ torch==2.4.1
13
+ torchvision
14
+ gunicorn
15
+ runpod
scripts/__pycache__/patchcore_api_inference.cpython-312.pyc ADDED
Binary file (5.46 kB). View file
 
scripts/augment_patchcore_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageOps, ImageEnhance
3
+ import random
4
+
5
+ # Directory containing normal images for training
6
+ NORMAL_DIR = 'dataset/train/normal'
7
+ TARGET_COUNT = 100
8
+
9
+ # List all images in the directory
10
+ images = [f for f in os.listdir(NORMAL_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
11
+ current_count = len(images)
12
+ print(f"[INFO] Found {current_count} images in {NORMAL_DIR}")
13
+
14
+ if current_count == 0:
15
+ print("[ERROR] No images found to augment.")
16
+ exit(1)
17
+
18
+ # Augmentation functions
19
+ AUGS = [
20
+ lambda img: img.rotate(random.randint(-30, 30)),
21
+ lambda img: ImageOps.mirror(img),
22
+ lambda img: ImageOps.flip(img),
23
+ lambda img: ImageEnhance.Brightness(img).enhance(random.uniform(0.7, 1.3)),
24
+ lambda img: ImageEnhance.Contrast(img).enhance(random.uniform(0.7, 1.3)),
25
+ lambda img: ImageEnhance.Color(img).enhance(random.uniform(0.7, 1.3)),
26
+ ]
27
+
28
+ aug_idx = 0
29
+ while current_count < TARGET_COUNT:
30
+ for fname in images:
31
+ if current_count >= TARGET_COUNT:
32
+ break
33
+ img_path = os.path.join(NORMAL_DIR, fname)
34
+ img = Image.open(img_path).convert('RGB')
35
+ aug = random.choice(AUGS)
36
+ aug_img = aug(img)
37
+ aug_fname = f"aug_{aug_idx}_{fname}"
38
+ aug_img.save(os.path.join(NORMAL_DIR, aug_fname))
39
+ aug_idx += 1
40
+ current_count += 1
41
+ print(f"[AUG] Saved {aug_fname}")
42
+
43
+ print(f"[DONE] {NORMAL_DIR} now contains {current_count} images.")
44
+
45
+ # Repeat for faulty images in dataset/test/faulty
46
+ FAULTY_DIR = 'dataset/test/faulty'
47
+ TARGET_COUNT = 100
48
+ images = [f for f in os.listdir(FAULTY_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
49
+ current_count = len(images)
50
+ print(f"[INFO] Found {current_count} images in {FAULTY_DIR}")
51
+
52
+ if current_count == 0:
53
+ print("[ERROR] No faulty images found to augment.")
54
+ exit(1)
55
+
56
+ aug_idx = 0
57
+ while current_count < TARGET_COUNT:
58
+ for fname in images:
59
+ if current_count >= TARGET_COUNT:
60
+ break
61
+ img_path = os.path.join(FAULTY_DIR, fname)
62
+ img = Image.open(img_path).convert('RGB')
63
+ aug = random.choice(AUGS)
64
+ aug_img = aug(img)
65
+ aug_fname = f"aug_{aug_idx}_{fname}"
66
+ aug_img.save(os.path.join(FAULTY_DIR, aug_fname))
67
+ aug_idx += 1
68
+ current_count += 1
69
+ print(f"[AUG] Saved {aug_fname}")
70
+
71
+ print(f"[DONE] {FAULTY_DIR} now contains {current_count} images.")
scripts/balance_train_folders.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageOps, ImageEnhance
3
+ import random
4
+ import shutil
5
+
6
+ TRAIN_NORMAL = 'dataset/train/normal'
7
+ TRAIN_FAULTY = 'dataset/train/faulty'
8
+ TEST_NORMAL = 'dataset/test/normal'
9
+ TEST_FAULTY = 'dataset/test/faulty'
10
+ TARGET_COUNT = 100
11
+
12
+ os.makedirs(TRAIN_NORMAL, exist_ok=True)
13
+ os.makedirs(TRAIN_FAULTY, exist_ok=True)
14
+
15
+ # Helper: move images from test to train if needed
16
+ def move_images(src, dst, needed):
17
+ imgs = [f for f in os.listdir(src) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
18
+ moved = 0
19
+ for f in imgs:
20
+ if moved >= needed:
21
+ break
22
+ shutil.move(os.path.join(src, f), os.path.join(dst, f))
23
+ moved += 1
24
+ return moved
25
+
26
+ # 1. Move images from test to train if train folders have < TARGET_COUNT
27
+ normal_needed = TARGET_COUNT - len([f for f in os.listdir(TRAIN_NORMAL) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
28
+ faulty_needed = TARGET_COUNT - len([f for f in os.listdir(TRAIN_FAULTY) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
29
+ if normal_needed > 0:
30
+ move_images(TEST_NORMAL, TRAIN_NORMAL, normal_needed)
31
+ if faulty_needed > 0:
32
+ move_images(TEST_FAULTY, TRAIN_FAULTY, faulty_needed)
33
+
34
+ # 2. Augment if still not enough
35
+ AUGS = [
36
+ lambda img: img.rotate(random.randint(-30, 30)),
37
+ lambda img: ImageOps.mirror(img),
38
+ lambda img: ImageOps.flip(img),
39
+ lambda img: ImageEnhance.Brightness(img).enhance(random.uniform(0.7, 1.3)),
40
+ lambda img: ImageEnhance.Contrast(img).enhance(random.uniform(0.7, 1.3)),
41
+ lambda img: ImageEnhance.Color(img).enhance(random.uniform(0.7, 1.3)),
42
+ ]
43
+
44
+ def augment_to_count(folder, target):
45
+ images = [f for f in os.listdir(folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
46
+ current_count = len(images)
47
+ aug_idx = 0
48
+ while current_count < target:
49
+ for fname in images:
50
+ if current_count >= target:
51
+ break
52
+ img_path = os.path.join(folder, fname)
53
+ img = Image.open(img_path).convert('RGB')
54
+ aug = random.choice(AUGS)
55
+ aug_img = aug(img)
56
+ aug_fname = f"aug_{aug_idx}_{fname}"
57
+ aug_img.save(os.path.join(folder, aug_fname))
58
+ aug_idx += 1
59
+ current_count += 1
60
+ print(f"[AUG] Saved {aug_fname} in {folder}")
61
+
62
+ augment_to_count(TRAIN_NORMAL, TARGET_COUNT)
63
+ augment_to_count(TRAIN_FAULTY, TARGET_COUNT)
64
+
65
+ print(f"[DONE] {TRAIN_NORMAL} images: {len(os.listdir(TRAIN_NORMAL))}")
66
+ print(f"[DONE] {TRAIN_FAULTY} images: {len(os.listdir(TRAIN_FAULTY))}")
scripts/batch_patchcore_infer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from anomalib.models.image.patchcore import Patchcore
6
+ import pandas as pd
7
+
8
+ # Path to your trained checkpoint
9
+ CKPT_PATH = "results/Patchcore/transformers/v2/weights/lightning/model.ckpt"
10
+ # Folder with images to score
11
+ INPUT_DIR = "./dataset/test/faulty" # Change as needed
12
+ # Output CSV for scores
13
+ OUTPUT_CSV = "patchcore_batch_scores.csv"
14
+
15
+ # Load PatchCore model
16
+ model = Patchcore.load_from_checkpoint(CKPT_PATH)
17
+ model.eval()
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = model.to(device)
20
+ model.eval()
21
+
22
+ # Define image transforms (should match your training transforms)
23
+ transform = transforms.Compose([
24
+ transforms.Resize((256, 256)), # Change to your image_size if different
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ ])
28
+
29
+ def load_image(path):
30
+ img = Image.open(path).convert("RGB")
31
+ return transform(img)
32
+
33
+ # Score all images in the input directory
34
+ results = []
35
+ for fname in sorted(os.listdir(INPUT_DIR)):
36
+ if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
37
+ continue
38
+ fpath = os.path.join(INPUT_DIR, fname)
39
+ img_tensor = load_image(fpath).unsqueeze(0) # Add batch dim
40
+ with torch.no_grad():
41
+ output = model(img_tensor.to(device)) # Move to device
42
+ # PatchCore returns a namedtuple with anomaly_score
43
+ if hasattr(output, 'anomaly_score'):
44
+ score = output.anomaly_score.item()
45
+ elif isinstance(output, (tuple, list)):
46
+ score = float(output[0])
47
+ else:
48
+ raise RuntimeError("Unknown PatchCore output type: {}".format(type(output)))
49
+ results.append({"image": fname, "anomaly_score": score})
50
+ print(f"{fname}: {score:.4f}")
51
+
52
+ # Save results to CSV
53
+ pd.DataFrame(results).to_csv(OUTPUT_CSV, index=False)
54
+ print(f"Saved batch scores to {OUTPUT_CSV}")
scripts/classify_filtered_images_opencv.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+
5
+ # Directory containing filtered images
6
+ dir_path = 'api_inference_filtered'
7
+ output_dir = 'api_inference_labeled_boxes'
8
+ os.makedirs(output_dir, exist_ok=True)
9
+
10
+ # IOU function for non-max suppression
11
+ def iou(boxA, boxB):
12
+ xA = max(boxA[0], boxB[0])
13
+ yA = max(boxA[1], boxB[1])
14
+ xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
15
+ yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
16
+ interW = max(0, xB - xA)
17
+ interH = max(0, yB - yA)
18
+ interArea = interW * interH
19
+ boxAArea = boxA[2] * boxA[3]
20
+ boxBArea = boxB[2] * boxB[3]
21
+ iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)
22
+ return iou
23
+
24
+ # Merge close bounding boxes (same label, centers within dist_thresh)
25
+ def merge_close_boxes(boxes, labels, dist_thresh=20):
26
+ merged = []
27
+ merged_labels = []
28
+ used = [False]*len(boxes)
29
+ for i in range(len(boxes)):
30
+ if used[i]:
31
+ continue
32
+ x1, y1, w1, h1 = boxes[i]
33
+ label1 = labels[i]
34
+ x2, y2, w2, h2 = x1, y1, w1, h1
35
+ for j in range(i+1, len(boxes)):
36
+ if used[j]:
37
+ continue
38
+ bx, by, bw, bh = boxes[j]
39
+ # If boxes are close (distance between centers < dist_thresh)
40
+ cx1, cy1 = x1 + w1//2, y1 + h1//2
41
+ cx2, cy2 = bx + bw//2, by + bh//2
42
+ if abs(cx1-cx2) < dist_thresh and abs(cy1-cy2) < dist_thresh and label1 == labels[j]:
43
+ # Merge boxes
44
+ x2 = min(x2, bx)
45
+ y2 = min(y2, by)
46
+ w2 = max(x1+w1, bx+bw) - x2
47
+ h2 = max(y1+h1, by+bh) - y2
48
+ used[j] = True
49
+ merged.append((x2, y2, w2, h2))
50
+ merged_labels.append(label1)
51
+ used[i] = True
52
+ return merged, merged_labels
53
+
54
+ # Non-max suppression using IOU
55
+ def non_max_suppression_iou(boxes, labels, iou_thresh=0.4):
56
+ if len(boxes) == 0:
57
+ return [], []
58
+ idxs = np.argsort([w*h for (x, y, w, h) in boxes])[::-1]
59
+ keep = []
60
+ keep_labels = []
61
+ while len(idxs) > 0:
62
+ i = idxs[0]
63
+ keep.append(boxes[i])
64
+ keep_labels.append(labels[i])
65
+ remove = [0]
66
+ for j in range(1, len(idxs)):
67
+ if iou(boxes[i], boxes[idxs[j]]) > iou_thresh:
68
+ remove.append(j)
69
+ idxs = np.delete(idxs, remove)
70
+ return keep, keep_labels
71
+
72
+ # Filter out potential boxes that contain a faulty box inside
73
+ def filter_faulty_inside_potential(boxes, labels):
74
+ filtered_boxes = []
75
+ filtered_labels = []
76
+ for i, (box, label) in enumerate(zip(boxes, labels)):
77
+ if label == 'Point Overload (Potential)':
78
+ # Check if any faulty box is inside this potential box
79
+ keep = True
80
+ for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
81
+ if flabel == 'Point Overload (Faulty)':
82
+ # Check if faulty box is inside potential box
83
+ x, y, w, h = box
84
+ fx, fy, fw, fh = fbox
85
+ if fx >= x and fy >= y and fx+fw <= x+w and fy+fh <= y+h:
86
+ keep = False
87
+ break
88
+ if keep:
89
+ filtered_boxes.append(box)
90
+ filtered_labels.append(label)
91
+ else:
92
+ filtered_boxes.append(box)
93
+ filtered_labels.append(label)
94
+ return filtered_boxes, filtered_labels
95
+
96
+ # Remove potential boxes that overlap with a faulty box (not just inside)
97
+ def filter_faulty_overlapping_potential(boxes, labels):
98
+ # Remove potential boxes that overlap at all with a faulty box (any intersection)
99
+ filtered_boxes = []
100
+ filtered_labels = []
101
+ def is_overlapping(boxA, boxB):
102
+ xA = max(boxA[0], boxB[0])
103
+ yA = max(boxA[1], boxB[1])
104
+ xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
105
+ yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
106
+ return (xB > xA) and (yB > yA)
107
+ for i, (box, label) in enumerate(zip(boxes, labels)):
108
+ if label == 'Point Overload (Potential)':
109
+ keep = True
110
+ for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
111
+ if flabel == 'Point Overload (Faulty)':
112
+ if is_overlapping(box, fbox):
113
+ keep = False
114
+ break
115
+ if keep:
116
+ filtered_boxes.append(box)
117
+ filtered_labels.append(label)
118
+ else:
119
+ filtered_boxes.append(box)
120
+ filtered_labels.append(label)
121
+ return filtered_boxes, filtered_labels
122
+
123
+ # Heuristic classification function
124
+ def classify_image(img_path):
125
+ img = cv2.imread(img_path)
126
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
127
+
128
+ # Color masks
129
+ blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
130
+ black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
131
+ yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255)) # increased threshold
132
+ orange_mask = cv2.inRange(hsv, (10, 100, 100), (25, 255, 255))
133
+ red_mask1 = cv2.inRange(hsv, (0, 100, 100), (10, 255, 255))
134
+ red_mask2 = cv2.inRange(hsv, (160, 100, 100), (180, 255, 255))
135
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
136
+
137
+ total = img.shape[0] * img.shape[1]
138
+ blue_count = np.sum(blue_mask > 0)
139
+ black_count = np.sum(black_mask > 0)
140
+ yellow_count = np.sum(yellow_mask > 0)
141
+ orange_count = np.sum(orange_mask > 0)
142
+ red_count = np.sum(red_mask > 0)
143
+
144
+ label = 'Unknown'
145
+ box_list = []
146
+ label_list = []
147
+
148
+ # Full image checks
149
+ if (blue_count + black_count) / total > 0.8:
150
+ label = 'Normal'
151
+ elif (red_count + orange_count) / total > 0.5:
152
+ label = 'Full Wire Overload'
153
+ elif (yellow_count) / total > 0.5:
154
+ label = 'Full Wire Overload'
155
+ # Check for full wire overload (entire image reddish or yellowish)
156
+ full_wire_thresh = 0.7 # 70% of image is reddish or yellowish
157
+ if (red_count + orange_count + yellow_count) / total > full_wire_thresh:
158
+ label = 'Full Wire Overload'
159
+ # Add a box covering the whole image
160
+ box_list.append((0, 0, img.shape[1], img.shape[0]))
161
+ label_list.append(label)
162
+ else:
163
+ # Small spot checks (improved: filter tiny spots, merge overlapping boxes)
164
+ min_area_faulty = 120 # increased min area for red/orange (faulty)
165
+ min_area_potential = 1000 # much higher min area for yellow (potential)
166
+ max_area = 0.05 * total
167
+ # Faulty (red/orange) spots
168
+ for mask, spot_label, min_a in [
169
+ (red_mask, 'Point Overload (Faulty)', min_area_faulty),
170
+ (yellow_mask, 'Point Overload (Potential)', min_area_potential)
171
+ ]:
172
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
173
+ for cnt in contours:
174
+ area = cv2.contourArea(cnt)
175
+ if min_a < area < max_area:
176
+ x, y, w, h = cv2.boundingRect(cnt)
177
+ box_list.append((x, y, w, h))
178
+ label_list.append(spot_label)
179
+ # Middle area checks
180
+ h, w = img.shape[:2]
181
+ center = img[h//4:3*h//4, w//4:3*w//4]
182
+ center_hsv = cv2.cvtColor(center, cv2.COLOR_BGR2HSV)
183
+ center_yellow = cv2.inRange(center_hsv, (20, 130, 130), (35, 255, 255))
184
+ center_orange = cv2.inRange(center_hsv, (10, 100, 100), (25, 255, 255))
185
+ center_red1 = cv2.inRange(center_hsv, (0, 100, 100), (10, 255, 255))
186
+ center_red2 = cv2.inRange(center_hsv, (160, 100, 100), (180, 255, 255))
187
+ center_red = cv2.bitwise_or(center_red1, center_red2)
188
+ if np.sum(center_red > 0) + np.sum(center_orange > 0) > 0.1 * center.size:
189
+ label = 'Loose Joint (Faulty)'
190
+ box_list.append((w//4, h//4, w//2, h//2))
191
+ label_list.append(label)
192
+ elif np.sum(center_yellow > 0) > 0.1 * center.size:
193
+ label = 'Loose Joint (Potential)'
194
+ box_list.append((w//4, h//4, w//2, h//2))
195
+ label_list.append(label)
196
+ # Always check for tiny spots, even if image is labeled as Normal
197
+ min_area_tiny = 10
198
+ max_area_tiny = 30
199
+ for mask, spot_label in [
200
+ (red_mask, 'Tiny Faulty Spot'),
201
+ (yellow_mask, 'Tiny Potential Spot')
202
+ ]:
203
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
204
+ for cnt in contours:
205
+ area = cv2.contourArea(cnt)
206
+ if min_area_tiny < area < max_area_tiny:
207
+ x, y, w, h = cv2.boundingRect(cnt)
208
+ box_list.append((x, y, w, h))
209
+ label_list.append(spot_label)
210
+ # Detect wire-shaped (long, thin) regions for wire overloads only
211
+ aspect_ratio_thresh = 5
212
+ min_strip_area = 0.01 * total
213
+ wire_boxes = []
214
+ wire_labels = []
215
+ for mask, strip_label in [
216
+ (red_mask, 'Wire Overload (Red Strip)'),
217
+ (yellow_mask, 'Wire Overload (Yellow Strip)'),
218
+ (orange_mask, 'Wire Overload (Orange Strip)')
219
+ ]:
220
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
221
+ for cnt in contours:
222
+ area = cv2.contourArea(cnt)
223
+ if area > min_strip_area:
224
+ x, y, w, h = cv2.boundingRect(cnt)
225
+ aspect_ratio = max(w, h) / (min(w, h) + 1e-6)
226
+ if aspect_ratio > aspect_ratio_thresh:
227
+ wire_boxes.append((x, y, w, h))
228
+ wire_labels.append(strip_label)
229
+ # Add wire overloads to box_list/label_list
230
+ box_list = wire_boxes[:]
231
+ label_list = wire_labels[:]
232
+ # For point overloads, do not require wire shape
233
+ min_area_faulty = 120
234
+ min_area_potential = 1000
235
+ max_area = 0.05 * total
236
+ for mask, spot_label, min_a in [
237
+ (red_mask, 'Point Overload (Faulty)', min_area_faulty),
238
+ (yellow_mask, 'Point Overload (Potential)', min_area_potential)
239
+ ]:
240
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
241
+ for cnt in contours:
242
+ area = cv2.contourArea(cnt)
243
+ if min_a < area < max_area:
244
+ x, y, w, h = cv2.boundingRect(cnt)
245
+ box_list.append((x, y, w, h))
246
+ label_list.append(spot_label)
247
+ # Remove overlapping boxes using IOU
248
+ box_list, label_list = non_max_suppression_iou(box_list, label_list, iou_thresh=0.4)
249
+ box_list, label_list = filter_faulty_inside_potential(box_list, label_list)
250
+ box_list, label_list = filter_faulty_overlapping_potential(box_list, label_list)
251
+ box_list, label_list = merge_close_boxes(box_list, label_list, dist_thresh=100)
252
+ return label, box_list, label_list, img
253
+
254
+ # Batch process all images in the directory
255
+ for fname in os.listdir(dir_path):
256
+ if not fname.lower().endswith(('.jpg', '.jpeg', '.png')):
257
+ continue
258
+ label, box_list, label_list, img = classify_image(os.path.join(dir_path, fname))
259
+ # Load the original (unfiltered) image for drawing boxes
260
+ orig_dir = 'api_inference_pred_masks' # or the directory with original images
261
+ orig_img_path = os.path.join(orig_dir, fname)
262
+ if os.path.exists(orig_img_path):
263
+ draw_img = cv2.imread(orig_img_path)
264
+ if draw_img is None:
265
+ draw_img = img.copy()
266
+ else:
267
+ draw_img = img.copy()
268
+ # Draw bounding boxes and labels on the original image
269
+ for (x, y, w, h), l in zip(box_list, label_list):
270
+ cv2.rectangle(draw_img, (x, y), (x+w, y+h), (0, 0, 255), 2)
271
+ cv2.putText(draw_img, l, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
272
+ if not box_list:
273
+ cv2.putText(draw_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
274
+ out_path = os.path.join(output_dir, fname)
275
+ cv2.imwrite(out_path, draw_img)
276
+ print(f"{fname}: {label} (saved with boxes on original image)")
scripts/classify_single_image_opencv.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+
5
+ # Directory containing filtered images
6
+ dir_path = 'api_inference_filtered_pipeline'
7
+ output_dir = 'api_inference_labeled_boxes_pipeline'
8
+ os.makedirs(output_dir, exist_ok=True)
9
+
10
+ # IOU function for non-max suppression
11
+ def iou(boxA, boxB):
12
+ xA = max(boxA[0], boxB[0])
13
+ yA = max(boxA[1], boxB[1])
14
+ xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
15
+ yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
16
+ interW = max(0, xB - xA)
17
+ interH = max(0, yB - yA)
18
+ interArea = interW * interH
19
+ boxAArea = boxA[2] * boxA[3]
20
+ boxBArea = boxB[2] * boxB[3]
21
+ iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)
22
+ return iou
23
+
24
+ # Merge close bounding boxes (same label, centers within dist_thresh)
25
+ def merge_close_boxes(boxes, labels, dist_thresh=20):
26
+ merged = []
27
+ merged_labels = []
28
+ used = [False]*len(boxes)
29
+ for i in range(len(boxes)):
30
+ if used[i]:
31
+ continue
32
+ x1, y1, w1, h1 = boxes[i]
33
+ label1 = labels[i]
34
+ x2, y2, w2, h2 = x1, y1, w1, h1
35
+ for j in range(i+1, len(boxes)):
36
+ if used[j]:
37
+ continue
38
+ bx, by, bw, bh = boxes[j]
39
+ # If boxes are close (distance between centers < dist_thresh)
40
+ cx1, cy1 = x1 + w1//2, y1 + h1//2
41
+ cx2, cy2 = bx + bw//2, by + bh//2
42
+ if abs(cx1-cx2) < dist_thresh and abs(cy1-cy2) < dist_thresh and label1 == labels[j]:
43
+ # Merge boxes
44
+ x2 = min(x2, bx)
45
+ y2 = min(y2, by)
46
+ w2 = max(x1+w1, bx+bw) - x2
47
+ h2 = max(y1+h1, by+bh) - y2
48
+ used[j] = True
49
+ merged.append((x2, y2, w2, h2))
50
+ merged_labels.append(label1)
51
+ used[i] = True
52
+ return merged, merged_labels
53
+
54
+ # Non-max suppression using IOU
55
+ def non_max_suppression_iou(boxes, labels, iou_thresh=0.4):
56
+ if len(boxes) == 0:
57
+ return [], []
58
+ idxs = np.argsort([w*h for (x, y, w, h) in boxes])[::-1]
59
+ keep = []
60
+ keep_labels = []
61
+ while len(idxs) > 0:
62
+ i = idxs[0]
63
+ keep.append(boxes[i])
64
+ keep_labels.append(labels[i])
65
+ remove = [0]
66
+ for j in range(1, len(idxs)):
67
+ if iou(boxes[i], boxes[idxs[j]]) > iou_thresh:
68
+ remove.append(j)
69
+ idxs = np.delete(idxs, remove)
70
+ return keep, keep_labels
71
+
72
+ # Filter out potential boxes that contain a faulty box inside
73
+ def filter_faulty_inside_potential(boxes, labels):
74
+ filtered_boxes = []
75
+ filtered_labels = []
76
+ for i, (box, label) in enumerate(zip(boxes, labels)):
77
+ if label == 'Point Overload (Potential)':
78
+ # Check if any faulty box is inside this potential box
79
+ keep = True
80
+ for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
81
+ if flabel == 'Point Overload (Faulty)':
82
+ # Check if faulty box is inside potential box
83
+ x, y, w, h = box
84
+ fx, fy, fw, fh = fbox
85
+ if fx >= x and fy >= y and fx+fw <= x+w and fy+fh <= y+h:
86
+ keep = False
87
+ break
88
+ if keep:
89
+ filtered_boxes.append(box)
90
+ filtered_labels.append(label)
91
+ else:
92
+ filtered_boxes.append(box)
93
+ filtered_labels.append(label)
94
+ return filtered_boxes, filtered_labels
95
+
96
+ # Remove potential boxes that overlap with a faulty box (not just inside)
97
+ def filter_faulty_overlapping_potential(boxes, labels):
98
+ # Remove potential boxes that overlap at all with a faulty box (any intersection)
99
+ filtered_boxes = []
100
+ filtered_labels = []
101
+ def is_overlapping(boxA, boxB):
102
+ xA = max(boxA[0], boxB[0])
103
+ yA = max(boxA[1], boxB[1])
104
+ xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
105
+ yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
106
+ return (xB > xA) and (yB > yA)
107
+ for i, (box, label) in enumerate(zip(boxes, labels)):
108
+ if label == 'Point Overload (Potential)':
109
+ keep = True
110
+ for j, (fbox, flabel) in enumerate(zip(boxes, labels)):
111
+ if flabel == 'Point Overload (Faulty)':
112
+ if is_overlapping(box, fbox):
113
+ keep = False
114
+ break
115
+ if keep:
116
+ filtered_boxes.append(box)
117
+ filtered_labels.append(label)
118
+ else:
119
+ filtered_boxes.append(box)
120
+ filtered_labels.append(label)
121
+ return filtered_boxes, filtered_labels
122
+
123
+ # Heuristic classification function
124
+ def classify_image(img_path):
125
+ img = cv2.imread(img_path)
126
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
127
+
128
+ # Color masks
129
+ blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
130
+ black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
131
+ yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255)) # increased threshold
132
+ orange_mask = cv2.inRange(hsv, (10, 100, 100), (25, 255, 255))
133
+ red_mask1 = cv2.inRange(hsv, (0, 100, 100), (10, 255, 255))
134
+ red_mask2 = cv2.inRange(hsv, (160, 100, 100), (180, 255, 255))
135
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
136
+
137
+ total = img.shape[0] * img.shape[1]
138
+ blue_count = np.sum(blue_mask > 0)
139
+ black_count = np.sum(black_mask > 0)
140
+ yellow_count = np.sum(yellow_mask > 0)
141
+ orange_count = np.sum(orange_mask > 0)
142
+ red_count = np.sum(red_mask > 0)
143
+
144
+ label = 'Unknown'
145
+ box_list = []
146
+ label_list = []
147
+
148
+ # Full image checks
149
+ if (blue_count + black_count) / total > 0.8:
150
+ label = 'Normal'
151
+ elif (red_count + orange_count) / total > 0.5:
152
+ label = 'Full Wire Overload'
153
+ elif (yellow_count) / total > 0.5:
154
+ label = 'Full Wire Overload'
155
+ # Check for full wire overload (entire image reddish or yellowish)
156
+ full_wire_thresh = 0.7 # 70% of image is reddish or yellowish
157
+ if (red_count + orange_count + yellow_count) / total > full_wire_thresh:
158
+ label = 'Full Wire Overload'
159
+ # Add a box covering the whole image
160
+ box_list.append((0, 0, img.shape[1], img.shape[0]))
161
+ label_list.append(label)
162
+ else:
163
+ # Small spot checks (improved: filter tiny spots, merge overlapping boxes)
164
+ min_area_faulty = 120 # increased min area for red/orange (faulty)
165
+ min_area_potential = 1000 # much higher min area for yellow (potential)
166
+ max_area = 0.05 * total
167
+ # Faulty (red/orange) spots
168
+ for mask, spot_label, min_a in [
169
+ (red_mask, 'Point Overload (Faulty)', min_area_faulty),
170
+ (yellow_mask, 'Point Overload (Potential)', min_area_potential)
171
+ ]:
172
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
173
+ for cnt in contours:
174
+ area = cv2.contourArea(cnt)
175
+ if min_a < area < max_area:
176
+ x, y, w, h = cv2.boundingRect(cnt)
177
+ box_list.append((x, y, w, h))
178
+ label_list.append(spot_label)
179
+ # Middle area checks
180
+ h, w = img.shape[:2]
181
+ center = img[h//4:3*h//4, w//4:3*w//4]
182
+ center_hsv = cv2.cvtColor(center, cv2.COLOR_BGR2HSV)
183
+ center_yellow = cv2.inRange(center_hsv, (20, 130, 130), (35, 255, 255))
184
+ center_orange = cv2.inRange(center_hsv, (10, 100, 100), (25, 255, 255))
185
+ center_red1 = cv2.inRange(center_hsv, (0, 100, 100), (10, 255, 255))
186
+ center_red2 = cv2.inRange(center_hsv, (160, 100, 100), (180, 255, 255))
187
+ center_red = cv2.bitwise_or(center_red1, center_red2)
188
+ if np.sum(center_red > 0) + np.sum(center_orange > 0) > 0.1 * center.size:
189
+ label = 'Loose Joint (Faulty)'
190
+ box_list.append((w//4, h//4, w//2, h//2))
191
+ label_list.append(label)
192
+ elif np.sum(center_yellow > 0) > 0.1 * center.size:
193
+ label = 'Loose Joint (Potential)'
194
+ box_list.append((w//4, h//4, w//2, h//2))
195
+ label_list.append(label)
196
+ # Always check for tiny spots, even if image is labeled as Normal
197
+ min_area_tiny = 10
198
+ max_area_tiny = 30
199
+ for mask, spot_label in [
200
+ (red_mask, 'Tiny Faulty Spot'),
201
+ (yellow_mask, 'Tiny Potential Spot')
202
+ ]:
203
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
204
+ for cnt in contours:
205
+ area = cv2.contourArea(cnt)
206
+ if min_area_tiny < area < max_area_tiny:
207
+ x, y, w, h = cv2.boundingRect(cnt)
208
+ box_list.append((x, y, w, h))
209
+ label_list.append(spot_label)
210
+ # Detect wire-shaped (long, thin) regions for wire overloads only
211
+ aspect_ratio_thresh = 5
212
+ min_strip_area = 0.01 * total
213
+ wire_boxes = []
214
+ wire_labels = []
215
+ for mask, strip_label in [
216
+ (red_mask, 'Wire Overload (Red Strip)'),
217
+ (yellow_mask, 'Wire Overload (Yellow Strip)'),
218
+ (orange_mask, 'Wire Overload (Orange Strip)')
219
+ ]:
220
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
221
+ for cnt in contours:
222
+ area = cv2.contourArea(cnt)
223
+ if area > min_strip_area:
224
+ x, y, w, h = cv2.boundingRect(cnt)
225
+ aspect_ratio = max(w, h) / (min(w, h) + 1e-6)
226
+ if aspect_ratio > aspect_ratio_thresh:
227
+ wire_boxes.append((x, y, w, h))
228
+ wire_labels.append(strip_label)
229
+ # Add wire overloads to box_list/label_list
230
+ box_list = wire_boxes[:]
231
+ label_list = wire_labels[:]
232
+ # For point overloads, do not require wire shape
233
+ min_area_faulty = 120
234
+ min_area_potential = 1000
235
+ max_area = 0.05 * total
236
+ for mask, spot_label, min_a in [
237
+ (red_mask, 'Point Overload (Faulty)', min_area_faulty),
238
+ (yellow_mask, 'Point Overload (Potential)', min_area_potential)
239
+ ]:
240
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
241
+ for cnt in contours:
242
+ area = cv2.contourArea(cnt)
243
+ if min_a < area < max_area:
244
+ x, y, w, h = cv2.boundingRect(cnt)
245
+ box_list.append((x, y, w, h))
246
+ label_list.append(spot_label)
247
+ # Remove overlapping boxes using IOU
248
+ box_list, label_list = non_max_suppression_iou(box_list, label_list, iou_thresh=0.4)
249
+ box_list, label_list = filter_faulty_inside_potential(box_list, label_list)
250
+ box_list, label_list = filter_faulty_overlapping_potential(box_list, label_list)
251
+ box_list, label_list = merge_close_boxes(box_list, label_list, dist_thresh=100)
252
+ return label, box_list, label_list, img
253
+
254
+ # Batch process all images in the directory
255
+ for fname in os.listdir(dir_path):
256
+ if not fname.lower().endswith(('.jpg', '.jpeg', '.png')):
257
+ continue
258
+ label, box_list, label_list, img = classify_image(os.path.join(dir_path, fname))
259
+ # Load the original (unfiltered) image for drawing boxes
260
+ orig_dir = 'api_inference_pred_masks' # or the directory with original images
261
+ orig_img_path = os.path.join(orig_dir, fname)
262
+ if os.path.exists(orig_img_path):
263
+ draw_img = cv2.imread(orig_img_path)
264
+ if draw_img is None:
265
+ draw_img = img.copy()
266
+ else:
267
+ draw_img = img.copy()
268
+ # Draw bounding boxes and labels on the original image
269
+ for (x, y, w, h), l in zip(box_list, label_list):
270
+ cv2.rectangle(draw_img, (x, y), (x+w, y+h), (0, 0, 255), 2)
271
+ cv2.putText(draw_img, l, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
272
+ if not box_list:
273
+ cv2.putText(draw_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
274
+ out_path = os.path.join(output_dir, fname)
275
+ cv2.imwrite(out_path, draw_img)
276
+ print(f"{fname}: {label} (saved with boxes on original image)")
scripts/crop_train_images.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+
4
+ TRAIN_NORMAL = 'dataset/train/normal'
5
+ TRAIN_FAULTY = 'dataset/train/faulty'
6
+ CROP_PERCENT = 0.125 # 12.5%
7
+
8
+ for folder in [TRAIN_NORMAL, TRAIN_FAULTY]:
9
+ for fname in os.listdir(folder):
10
+ if not fname.lower().endswith(('.jpg', '.jpeg', '.png')):
11
+ continue
12
+ fpath = os.path.join(folder, fname)
13
+ img = Image.open(fpath)
14
+ w, h = img.size
15
+ crop_w = int(w * CROP_PERCENT)
16
+ # Crop 12.5% from left and right
17
+ cropped = img.crop((crop_w, 0, w - crop_w, h))
18
+ cropped.save(fpath)
19
+ print(f"[CROP] {fpath} -> size {cropped.size}")
20
+
21
+ print("[DONE] All training images cropped 12.5% from both sides.")
scripts/fill_test_from_TX.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+
5
+ TEST_NORMAL = 'dataset/test/normal'
6
+ TEST_FAULTY = 'dataset/test/faulty'
7
+ TX_DIR = 'TX'
8
+ TARGET_COUNT = 100
9
+
10
+ os.makedirs(TEST_NORMAL, exist_ok=True)
11
+ os.makedirs(TEST_FAULTY, exist_ok=True)
12
+
13
+ # Helper to collect images from TX folders
14
+ def collect_images(src_pattern, dst_folder, needed):
15
+ collected = 0
16
+ for root, dirs, files in os.walk(TX_DIR):
17
+ if src_pattern in root:
18
+ imgs = [f for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
19
+ random.shuffle(imgs)
20
+ for f in imgs:
21
+ if collected >= needed:
22
+ return
23
+ src_path = os.path.join(root, f)
24
+ dst_path = os.path.join(dst_folder, f)
25
+ if not os.path.exists(dst_path):
26
+ shutil.copyfile(src_path, dst_path)
27
+ collected += 1
28
+ print(f"[COPY] {src_path} -> {dst_path}")
29
+
30
+ # Fill test/normal and test/faulty to TARGET_COUNT from TX (do not touch train)
31
+ collect_images('normal', TEST_NORMAL, TARGET_COUNT)
32
+ collect_images('faulty', TEST_FAULTY, TARGET_COUNT)
33
+
34
+ print(f"[DONE] {TEST_NORMAL} images: {len(os.listdir(TEST_NORMAL))}")
35
+ print(f"[DONE] {TEST_FAULTY} images: {len(os.listdir(TEST_FAULTY))}")
scripts/patchcore_api_inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from omegaconf import OmegaConf
6
+ from anomalib.models import Patchcore
7
+ from anomalib.data import Folder
8
+ from pytorch_lightning import Trainer
9
+ # --- Load config ---
10
+ CONFIG_PATH = "configs/patchcore_transformers.yaml"
11
+ CKPT_PATH = "results/Patchcore/transformers/v7/weights/lightning/model.ckpt"
12
+
13
+ OUT_MASK_DIR = "api_inference_pred_pipeline"
14
+ OUT_FILTERED_DIR = "api_inference_filtered_pipeline"
15
+
16
+ os.makedirs(OUT_MASK_DIR, exist_ok=True)
17
+ os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ # Load config
22
+ config = OmegaConf.load(CONFIG_PATH)
23
+
24
+ if __name__ == "__main__":
25
+ os.makedirs(OUT_MASK_DIR, exist_ok=True)
26
+ os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
27
+ # Setup datamodule for prediction (use test set)
28
+ # Use arguments matching the YAML config and Folder datamodule signature
29
+ data_module = Folder(
30
+ name=config.data.init_args.name,
31
+ root=config.data.init_args.root,
32
+ normal_dir=config.data.init_args.normal_dir,
33
+ abnormal_dir=config.data.init_args.abnormal_dir,
34
+ normal_test_dir=config.data.init_args.normal_test_dir,
35
+ train_batch_size=config.data.init_args.train_batch_size,
36
+ eval_batch_size=config.data.init_args.eval_batch_size,
37
+ num_workers=config.data.init_args.num_workers,
38
+ )
39
+ data_module.setup()
40
+
41
+ # Load model
42
+ model = Patchcore.load_from_checkpoint(CKPT_PATH, **config.model.init_args)
43
+ model.eval()
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ model = model.to(device)
46
+
47
+ # Inference loop
48
+ for batch in data_module.test_dataloader():
49
+ img = batch.image.to(device)
50
+ fname = batch.image_path[0]
51
+ with torch.no_grad():
52
+ output = model(img)
53
+ # PatchCore returns (anomaly_score, anomaly_map, ...)
54
+ if hasattr(output, 'anomaly_map'):
55
+ anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
56
+ elif isinstance(output, (tuple, list)) and len(output) > 1:
57
+ anomaly_map = output[1].squeeze().cpu().numpy()
58
+ else:
59
+ anomaly_map = None
60
+ if anomaly_map is not None:
61
+ # Normalize to 0-255 for visualization
62
+ norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
63
+ # Ensure norm_map is 2D for PIL
64
+ if norm_map.ndim > 2:
65
+ norm_map = np.squeeze(norm_map)
66
+ if norm_map.ndim > 2:
67
+ norm_map = norm_map[0]
68
+ mask_img = Image.fromarray(norm_map)
69
+ out_name = os.path.splitext(os.path.basename(fname))[0] + "_mask.png"
70
+ mask_img.save(os.path.join(OUT_MASK_DIR, out_name))
71
+ print(f"Saved mask for {fname}")
72
+
73
+ # Save filtered (masked) part of the original transformer image
74
+ orig_img = Image.open(fname).convert("RGB")
75
+ # Resize mask to match original image size if needed
76
+ if mask_img.size != orig_img.size:
77
+ mask_img_resized = mask_img.resize(orig_img.size, resample=Image.BILINEAR)
78
+ else:
79
+ mask_img_resized = mask_img
80
+ # Binarize mask (threshold at 128)
81
+ bin_mask = np.array(mask_img_resized) > 128
82
+ # Apply mask to original image
83
+ orig_np = np.array(orig_img)
84
+ filtered_np = np.zeros_like(orig_np)
85
+ filtered_np[bin_mask] = orig_np[bin_mask]
86
+ filtered_img = Image.fromarray(filtered_np)
87
+ filtered_name = os.path.splitext(os.path.basename(fname))[0] + "_filtered.png"
88
+ filtered_img.save(os.path.join(OUT_FILTERED_DIR, filtered_name))
89
+ print(f"Saved filtered image for {fname}")
90
+ else:
91
+ print(f"No mask generated for {fname}")
92
+
93
+ print(f"All masks saved to {OUT_MASK_DIR}")
scripts/patchcore_inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ # PatchCore import for your config (update if needed for your Anomalib version)
9
+ from anomalib.models.image.patchcore import Patchcore
10
+
11
+ # Config from your YAML
12
+ CKPT_PATH = "results/Patchcore/transformers/v2/weights/lightning/model.ckpt"
13
+ IMAGE_SIZE = (256, 256) # Update if you used a different size in training
14
+
15
+ # Inference directory (change as needed)
16
+ INFER_DIR = "./dataset/test/faulty"
17
+
18
+ # Output directory for masks
19
+ OUT_MASK_DIR = "inference_masks"
20
+ os.makedirs(OUT_MASK_DIR, exist_ok=True)
21
+
22
+ # Load model
23
+ model = Patchcore.load_from_checkpoint(CKPT_PATH)
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # Transforms (should match your training config)
29
+ transform = transforms.Compose([
30
+ transforms.Resize(IMAGE_SIZE),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ # Inference loop
36
+ for fname in sorted(os.listdir(INFER_DIR)):
37
+ if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
38
+ continue
39
+ fpath = os.path.join(INFER_DIR, fname)
40
+ img = Image.open(fpath).convert("RGB")
41
+ img_tensor = transform(img).unsqueeze(0).to(device)
42
+ with torch.no_grad():
43
+ output = model(img_tensor)
44
+ # PatchCore returns (anomaly_score, anomaly_map, ...)
45
+ if hasattr(output, 'anomaly_map'):
46
+ anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
47
+ elif isinstance(output, (tuple, list)) and len(output) > 1:
48
+ anomaly_map = output[1].squeeze().cpu().numpy()
49
+ else:
50
+ anomaly_map = None
51
+ # Save mask as PNG
52
+ if anomaly_map is not None:
53
+ # Normalize to 0-255 for visualization
54
+ norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
55
+ mask_img = Image.fromarray(norm_map)
56
+ mask_img.save(os.path.join(OUT_MASK_DIR, f"{os.path.splitext(fname)[0]}_mask.png"))
57
+ print(f"Saved mask for {fname}")
58
+ else:
59
+ print(f"No mask generated for {fname}")
60
+
61
+ print(f"All masks saved to {OUT_MASK_DIR}")
scripts/patchcore_single_image.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from patchcore_api_inference import Patchcore, config, device
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+
8
+ # Output directories (should match those in patchcore_api_inference.py)
9
+ OUT_MASK_DIR = "api_inference_pred_masks_pipeline"
10
+ OUT_FILTERED_DIR = "api_inference_filtered_pipeline"
11
+ os.makedirs(OUT_MASK_DIR, exist_ok=True)
12
+ os.makedirs(OUT_FILTERED_DIR, exist_ok=True)
13
+
14
+ # Load model
15
+ model = Patchcore.load_from_checkpoint(
16
+ "results/Patchcore/transformers/v7/weights/lightning/model.ckpt",
17
+ **config.model.init_args
18
+ )
19
+ model.eval()
20
+ model = model.to(device)
21
+
22
+ def infer_single_image(image_path):
23
+ # Load and preprocess image
24
+ # Normalize path for cross-platform compatibility
25
+ import os
26
+ fixed_path = os.path.normpath(image_path)
27
+ fixed_path = os.path.abspath(fixed_path)
28
+ orig_img = Image.open(fixed_path).convert("RGB")
29
+ # Use the same transforms as in training (resize, normalize)
30
+ # If you have a transform pipeline, import and use it here
31
+ img_resized = orig_img.resize((256, 256)) # Change if your model uses a different size
32
+ img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).float() / 255.0
33
+ img_tensor = img_tensor.unsqueeze(0).to(device)
34
+
35
+ with torch.no_grad():
36
+ output = model(img_tensor)
37
+ if hasattr(output, 'anomaly_map'):
38
+ anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
39
+ elif isinstance(output, (tuple, list)) and len(output) > 1:
40
+ anomaly_map = output[1].squeeze().cpu().numpy()
41
+ else:
42
+ anomaly_map = None
43
+ if anomaly_map is not None:
44
+ norm_map = (255 * (anomaly_map - anomaly_map.min()) / (np.ptp(anomaly_map) + 1e-8)).astype(np.uint8)
45
+ if norm_map.ndim > 2:
46
+ norm_map = np.squeeze(norm_map)
47
+ if norm_map.ndim > 2:
48
+ norm_map = norm_map[0]
49
+ mask_img = Image.fromarray(norm_map)
50
+ out_name = os.path.splitext(os.path.basename(image_path))[0] + "_mask.png"
51
+ mask_img.save(os.path.join(OUT_MASK_DIR, out_name))
52
+ print(f"Saved mask for {image_path}")
53
+
54
+ # Resize mask to match original image size if needed
55
+ if mask_img.size != orig_img.size:
56
+ mask_img_resized = mask_img.resize(orig_img.size, resample=Image.BILINEAR)
57
+ else:
58
+ mask_img_resized = mask_img
59
+ bin_mask = np.array(mask_img_resized) > 128
60
+ orig_np = np.array(orig_img)
61
+ filtered_np = np.zeros_like(orig_np)
62
+ filtered_np[bin_mask] = orig_np[bin_mask]
63
+ filtered_img = Image.fromarray(filtered_np)
64
+ filtered_name = os.path.splitext(os.path.basename(image_path))[0] + "_filtered.png"
65
+ filtered_img.save(os.path.join(OUT_FILTERED_DIR, filtered_name))
66
+ print(f"Saved filtered image for {image_path}")
67
+ else:
68
+ print(f"No mask generated for {image_path}")
69
+
70
+ if __name__ == "__main__":
71
+ if len(sys.argv) < 2:
72
+ print("Usage: python patchcore_single_image.py <image_path>")
73
+ sys.exit(1)
74
+ infer_single_image(sys.argv[1])
scripts/prepare_patchcore_dataset.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # 1. Gather all normal images from TX folders into train/normal and test/normal
5
+ mkdir -p dataset/train/normal
6
+ mkdir -p dataset/test/normal
7
+
8
+ # Move/copy all normal images from TX folders to train/normal (for training)
9
+ find TX/ -type f -name '*.jpg' -path '*/normal/*' -exec cp {} dataset/train/normal/ \;
10
+
11
+ # Optionally, split some normal images into test/normal (for evaluation)
12
+ # Here, move 20% of images to test/normal (adjust as needed)
13
+ cd dataset/train/normal
14
+ mkdir -p ../test/normal
15
+ count=$(ls -1 | wc -l)
16
+ test_count=$((count / 5))
17
+ ls | shuf | head -n $test_count | xargs -I{} mv {} ../test/normal/
18
+ cd ../../..
19
+
20
+ # 2. Gather all faulty images from TX folders into test/faulty (for evaluation)
21
+ mkdir -p dataset/test/faulty
22
+ find TX/ -type f -name '*.jpg' -path '*/faulty/*' -exec cp {} dataset/test/faulty/ \;
23
+
24
+ # 3. (Optional) Remove duplicates between train/normal and test/normal
25
+ # This step assumes filenames are unique. If not, use a more robust deduplication method.
26
+ cd dataset/test/normal
27
+ for f in *; do
28
+ [ -e "../../train/normal/$f" ] && rm -f "../../train/normal/$f"
29
+ done
30
+ cd ../../..
31
+
32
+ # 4. Print summary
33
+ train_n=$(ls dataset/train/normal | wc -l)
34
+ test_n=$(ls dataset/test/normal | wc -l)
35
+ test_f=$(ls dataset/test/faulty | wc -l)
36
+ echo "[✓] Normal images in train: $train_n"
37
+ echo "[✓] Normal images in test: $test_n"
38
+ echo "[✓] Faulty images in test: $test_f"
scripts/recreate_train_from_TX.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+
5
+ TX_DIR = 'TX'
6
+ TRAIN_NORMAL = 'dataset/train/normal'
7
+ TRAIN_FAULTY = 'dataset/train/faulty'
8
+ TARGET_COUNT = 100
9
+
10
+ # Clean train folders
11
+ shutil.rmtree(TRAIN_NORMAL, ignore_errors=True)
12
+ os.makedirs(TRAIN_NORMAL, exist_ok=True)
13
+ shutil.rmtree(TRAIN_FAULTY, ignore_errors=True)
14
+ os.makedirs(TRAIN_FAULTY, exist_ok=True)
15
+
16
+ # Helper to collect images from TX folders
17
+ def collect_images(src_pattern, dst_folder, needed):
18
+ collected = 0
19
+ for root, dirs, files in os.walk(TX_DIR):
20
+ if src_pattern in root:
21
+ imgs = [f for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
22
+ random.shuffle(imgs)
23
+ for f in imgs:
24
+ if collected >= needed:
25
+ return
26
+ src_path = os.path.join(root, f)
27
+ dst_path = os.path.join(dst_folder, f)
28
+ if not os.path.exists(dst_path):
29
+ shutil.copyfile(src_path, dst_path)
30
+ collected += 1
31
+ print(f"[COPY] {src_path} -> {dst_path}")
32
+
33
+ # Fill train/normal and train/faulty to TARGET_COUNT from TX
34
+ collect_images('normal', TRAIN_NORMAL, TARGET_COUNT)
35
+ collect_images('faulty', TRAIN_FAULTY, TARGET_COUNT)
36
+
37
+ print(f"[DONE] {TRAIN_NORMAL} images: {len(os.listdir(TRAIN_NORMAL))}")
38
+ print(f"[DONE] {TRAIN_FAULTY} images: {len(os.listdir(TRAIN_FAULTY))}")
scripts/run_anomalib.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # 0) Environment
5
+ python3 -V || true
6
+ echo "[*] Installing python3-full and python3-venv if needed…"
7
+ sudo apt update && sudo apt install -y python3-full python3.12-venv
8
+ echo "[*] Creating virtual environment if not exists…"
9
+ if [ ! -d ~/anomalib_env ]; then
10
+ python3 -m venv ~/anomalib_env
11
+ source ~/anomalib_env/bin/activate
12
+ pip install -U pip
13
+ pip install "anomalib[full]" flask requests cloudinary pillow numpy opencv-python omegaconf torch
14
+ else
15
+ source ~/anomalib_env/bin/activate
16
+ pip install -U pip
17
+ pip install "anomalib[full]" flask requests cloudinary pillow numpy opencv-python omegaconf torch
18
+ fi
19
+
20
+
21
+ # # 1) Train (PatchCore builds the memory bank from normals)
22
+ # anomalib train \
23
+ # --config configs/patchcore_transformers.yaml
24
+ #
25
+ # CKPT=$(ls -1t results/transformers/patchcore/*/weights/*.ckpt | head -n 1)
26
+ # echo "[*] Using checkpoint: $CKPT"
27
+ #
28
+ # # 2) Test/Eval on test/{normal,faulty}
29
+ # anomalib test \
30
+ # --config configs/patchcore_transformers.yaml \
31
+ # --ckpt_path "$CKPT"
32
+
33
+ echo
34
+ echo "[✓] Done. Check:"
35
+ echo " • results/transformers/patchcore/**/images/ (heatmaps & overlays)"
36
+ echo " • results/transformers/patchcore/**/metrics.csv (AUROC/F1 etc.)"