arunimas1107 commited on
Commit
8ca8664
·
verified ·
1 Parent(s): 34db880

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +121 -3
  2. casting_autoencoder.onnx +3 -0
  3. casting_autoencoder.pth +3 -0
  4. config.yaml +30 -0
  5. model.bin +3 -0
  6. model.xml +885 -0
  7. requirements.txt +3 -0
  8. train_model.py +97 -0
README.md CHANGED
@@ -1,3 +1,121 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Anomaly Detection Model – Edge AI for Casting Defect Inspection
2
+
3
+ ## Overview
4
+ The **Anomaly Detection Model** is an **autoencoder-based anomaly detection system** fine-tuned for industrial **casting defect inspection**. It identifies whether a metal casting image is *normal (OK)* or *defective* by reconstructing input images and analyzing reconstruction errors.
5
+
6
+ This model is designed for **Edge AI deployment**, optimized via **ONNX** and **OpenVINO IR** formats to run efficiently on low-power Intel edge devices.
7
+
8
+ ---
9
+
10
+ ## Model Details
11
+ - **Architecture:** Convolutional Autoencoder
12
+ - **Framework:** PyTorch
13
+ - **Training Objective:** Minimize reconstruction loss (MSE) for normal samples
14
+ - **Optimization:** ONNX and OpenVINO IR export for edge inference
15
+ - **Task:** Unsupervised anomaly detection
16
+ - **Domain:** Industrial visual inspection
17
+
18
+ ---
19
+
20
+ ## Repository Structure
21
+ ```
22
+ ├── config.yaml # Configuration file for training
23
+ ├── train_model.py # Training script
24
+ ├── casting_autoencoder.pth # Trained PyTorch model
25
+ ├── casting_autoencoder.onnx # ONNX export
26
+ ├── model.bin # OpenVINO IR model (bin)
27
+ ├── model.xml # OpenVINO IR model (xml)
28
+ ├── requirements.txt # Dependencies
29
+ └── README.md # Model card (this file)
30
+ ```
31
+
32
+ ---
33
+
34
+ ## Dataset
35
+ **Dataset:** Casting Product Image Dataset (Kaggle)
36
+ - **Classes:** Defective / Normal
37
+ - **Modality:** Grayscale industrial images
38
+ - **Training Strategy:** Only *normal* samples used for training the autoencoder.
39
+
40
+ ---
41
+
42
+ ## Training Configuration
43
+ | Parameter | Value |
44
+ |------------|--------|
45
+ | Batch Size | 32 |
46
+ | Epochs | 50 |
47
+ | Optimizer | Adam |
48
+ | Learning Rate | 1e-3 |
49
+ | Loss Function | MSELoss |
50
+
51
+ ---
52
+
53
+ ## Export & Deployment
54
+ | Format | Purpose |
55
+ |---------|----------|
56
+ | `.pth` | Original PyTorch model |
57
+ | `.onnx` | Framework-independent inference |
58
+ | `.xml` / `.bin` | OpenVINO IR format for edge devices |
59
+
60
+ **Edge Optimization:** Model converted and optimized using `openvino.convert_model()`.
61
+
62
+ ---
63
+
64
+ ## Inference Example
65
+ ```python
66
+ from openvino.runtime import Core
67
+ import cv2
68
+ import numpy as np
69
+
70
+ ie = Core()
71
+ model = ie.read_model(model="casting_ir/model.xml")
72
+ compiled_model = ie.compile_model(model=model, device_name="CPU")
73
+
74
+ # Load and preprocess image
75
+ img = cv2.imread('sample_casting.png', cv2.IMREAD_GRAYSCALE)
76
+ img = cv2.resize(img, (128, 128)) / 255.0
77
+ img = np.expand_dims(img, (0,1)).astype(np.float32)
78
+
79
+ # Run inference
80
+ infer_request = compiled_model.create_infer_request()
81
+ result = infer_request.infer(inputs={compiled_model.inputs[0]: img})
82
+
83
+ reconstructed = result[compiled_model.outputs[0]]
84
+ error = np.mean((img - reconstructed)**2)
85
+ if error > 0.01:
86
+ print("Defective Casting Detected")
87
+ else:
88
+ print("Casting OK")
89
+ ```
90
+
91
+ ---
92
+
93
+ ## Intended Use
94
+ - Automated visual inspection for manufacturing/QA systems.
95
+ - Real-time edge deployment in industrial environments.
96
+
97
+ **Not recommended for:**
98
+ - Non-industrial datasets.
99
+ - Scenarios with significant domain drift (e.g., lighting changes or non-casting objects).
100
+
101
+ ---
102
+
103
+ ## Limitations
104
+ - Accuracy depends on lighting and background consistency.
105
+ - Model trained primarily on grayscale casting images.
106
+ - Thresholds for anomaly detection must be tuned for specific deployment environments.
107
+
108
+ ---
109
+
110
+ ## License
111
+ This project is released under the [MIT License](LICENSE).
112
+
113
+ ---
114
+
115
+ ## Author
116
+ **Arunima Surendran**
117
+ Applied AI Engineer
118
+ [GitHub Repository](https://github.com/arunimakanavu/anomalydetectionmodel)
119
+
120
+ ---
121
+
casting_autoencoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac75b27b96e87240a5a516f1a745cdf615e53ebd682c9cea9b8d620483ef6bc
3
+ size 191021
casting_autoencoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aed04db97f4fb56e7e71f8cabdfdbaf8a7153a6f79f0e2f4d91aebd2082aa37
3
+ size 193559
config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ seed_everything: 42
3
+
4
+ data:
5
+ class_path: anomalib.data.Folder
6
+ init_args:
7
+ root: ./casting_data/train
8
+ normal_dir: ok_front
9
+ abnormal_dir: def_front
10
+ task: classification
11
+ image_size: [256, 256]
12
+ train_batch_size: 32
13
+ eval_batch_size: 32
14
+ num_workers: 4
15
+
16
+ model:
17
+ class_path: anomalib.models.Patchcore
18
+ init_args:
19
+ backbone: resnet18
20
+ layers:
21
+ - layer2
22
+ - layer3
23
+
24
+ trainer:
25
+ accelerator: auto
26
+ devices: 1
27
+ max_epochs: 1
28
+
29
+ logging:
30
+ log_graph: false
model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd301efb30159cd017a7a1d04c43b79a4ffcd43050b17b231c7067a39b1de28c
3
+ size 94214
model.xml ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0"?>
2
+ <net name="main_graph" version="11">
3
+ <layers>
4
+ <layer id="0" name="input" type="Parameter" version="opset1">
5
+ <data shape="1,3,304,304" element_type="f32" />
6
+ <output>
7
+ <port id="0" precision="FP32" names="input">
8
+ <dim>1</dim>
9
+ <dim>3</dim>
10
+ <dim>304</dim>
11
+ <dim>304</dim>
12
+ </port>
13
+ </output>
14
+ </layer>
15
+ <layer id="1" name="encoder.0.weight_compressed" type="Const" version="opset1">
16
+ <data element_type="f16" shape="16, 3, 3, 3" offset="0" size="864" />
17
+ <output>
18
+ <port id="0" precision="FP16">
19
+ <dim>16</dim>
20
+ <dim>3</dim>
21
+ <dim>3</dim>
22
+ <dim>3</dim>
23
+ </port>
24
+ </output>
25
+ </layer>
26
+ <layer id="2" name="encoder.0.weight" type="Convert" version="opset1">
27
+ <data destination_type="f32" />
28
+ <rt_info>
29
+ <attribute name="decompression" version="0" />
30
+ </rt_info>
31
+ <input>
32
+ <port id="0" precision="FP16">
33
+ <dim>16</dim>
34
+ <dim>3</dim>
35
+ <dim>3</dim>
36
+ <dim>3</dim>
37
+ </port>
38
+ </input>
39
+ <output>
40
+ <port id="1" precision="FP32" names="encoder.0.weight">
41
+ <dim>16</dim>
42
+ <dim>3</dim>
43
+ <dim>3</dim>
44
+ <dim>3</dim>
45
+ </port>
46
+ </output>
47
+ </layer>
48
+ <layer id="3" name="/encoder/encoder.0/Conv/WithoutBiases" type="Convolution" version="opset1">
49
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" />
50
+ <input>
51
+ <port id="0" precision="FP32">
52
+ <dim>1</dim>
53
+ <dim>3</dim>
54
+ <dim>304</dim>
55
+ <dim>304</dim>
56
+ </port>
57
+ <port id="1" precision="FP32">
58
+ <dim>16</dim>
59
+ <dim>3</dim>
60
+ <dim>3</dim>
61
+ <dim>3</dim>
62
+ </port>
63
+ </input>
64
+ <output>
65
+ <port id="2" precision="FP32">
66
+ <dim>1</dim>
67
+ <dim>16</dim>
68
+ <dim>152</dim>
69
+ <dim>152</dim>
70
+ </port>
71
+ </output>
72
+ </layer>
73
+ <layer id="4" name="Reshape_25_compressed" type="Const" version="opset1">
74
+ <data element_type="f16" shape="1, 16, 1, 1" offset="864" size="32" />
75
+ <output>
76
+ <port id="0" precision="FP16">
77
+ <dim>1</dim>
78
+ <dim>16</dim>
79
+ <dim>1</dim>
80
+ <dim>1</dim>
81
+ </port>
82
+ </output>
83
+ </layer>
84
+ <layer id="5" name="Reshape_25" type="Convert" version="opset1">
85
+ <data destination_type="f32" />
86
+ <rt_info>
87
+ <attribute name="decompression" version="0" />
88
+ </rt_info>
89
+ <input>
90
+ <port id="0" precision="FP16">
91
+ <dim>1</dim>
92
+ <dim>16</dim>
93
+ <dim>1</dim>
94
+ <dim>1</dim>
95
+ </port>
96
+ </input>
97
+ <output>
98
+ <port id="1" precision="FP32">
99
+ <dim>1</dim>
100
+ <dim>16</dim>
101
+ <dim>1</dim>
102
+ <dim>1</dim>
103
+ </port>
104
+ </output>
105
+ </layer>
106
+ <layer id="6" name="/encoder/encoder.0/Conv" type="Add" version="opset1">
107
+ <data auto_broadcast="numpy" />
108
+ <input>
109
+ <port id="0" precision="FP32">
110
+ <dim>1</dim>
111
+ <dim>16</dim>
112
+ <dim>152</dim>
113
+ <dim>152</dim>
114
+ </port>
115
+ <port id="1" precision="FP32">
116
+ <dim>1</dim>
117
+ <dim>16</dim>
118
+ <dim>1</dim>
119
+ <dim>1</dim>
120
+ </port>
121
+ </input>
122
+ <output>
123
+ <port id="2" precision="FP32" names="/encoder/encoder.0/Conv_output_0">
124
+ <dim>1</dim>
125
+ <dim>16</dim>
126
+ <dim>152</dim>
127
+ <dim>152</dim>
128
+ </port>
129
+ </output>
130
+ </layer>
131
+ <layer id="7" name="/encoder/encoder.1/Relu" type="ReLU" version="opset1">
132
+ <input>
133
+ <port id="0" precision="FP32">
134
+ <dim>1</dim>
135
+ <dim>16</dim>
136
+ <dim>152</dim>
137
+ <dim>152</dim>
138
+ </port>
139
+ </input>
140
+ <output>
141
+ <port id="1" precision="FP32" names="/encoder/encoder.1/Relu_output_0">
142
+ <dim>1</dim>
143
+ <dim>16</dim>
144
+ <dim>152</dim>
145
+ <dim>152</dim>
146
+ </port>
147
+ </output>
148
+ </layer>
149
+ <layer id="8" name="encoder.2.weight_compressed" type="Const" version="opset1">
150
+ <data element_type="f16" shape="32, 16, 3, 3" offset="896" size="9216" />
151
+ <output>
152
+ <port id="0" precision="FP16">
153
+ <dim>32</dim>
154
+ <dim>16</dim>
155
+ <dim>3</dim>
156
+ <dim>3</dim>
157
+ </port>
158
+ </output>
159
+ </layer>
160
+ <layer id="9" name="encoder.2.weight" type="Convert" version="opset1">
161
+ <data destination_type="f32" />
162
+ <rt_info>
163
+ <attribute name="decompression" version="0" />
164
+ </rt_info>
165
+ <input>
166
+ <port id="0" precision="FP16">
167
+ <dim>32</dim>
168
+ <dim>16</dim>
169
+ <dim>3</dim>
170
+ <dim>3</dim>
171
+ </port>
172
+ </input>
173
+ <output>
174
+ <port id="1" precision="FP32" names="encoder.2.weight">
175
+ <dim>32</dim>
176
+ <dim>16</dim>
177
+ <dim>3</dim>
178
+ <dim>3</dim>
179
+ </port>
180
+ </output>
181
+ </layer>
182
+ <layer id="10" name="/encoder/encoder.2/Conv/WithoutBiases" type="Convolution" version="opset1">
183
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" />
184
+ <input>
185
+ <port id="0" precision="FP32">
186
+ <dim>1</dim>
187
+ <dim>16</dim>
188
+ <dim>152</dim>
189
+ <dim>152</dim>
190
+ </port>
191
+ <port id="1" precision="FP32">
192
+ <dim>32</dim>
193
+ <dim>16</dim>
194
+ <dim>3</dim>
195
+ <dim>3</dim>
196
+ </port>
197
+ </input>
198
+ <output>
199
+ <port id="2" precision="FP32">
200
+ <dim>1</dim>
201
+ <dim>32</dim>
202
+ <dim>76</dim>
203
+ <dim>76</dim>
204
+ </port>
205
+ </output>
206
+ </layer>
207
+ <layer id="11" name="Reshape_39_compressed" type="Const" version="opset1">
208
+ <data element_type="f16" shape="1, 32, 1, 1" offset="10112" size="64" />
209
+ <output>
210
+ <port id="0" precision="FP16">
211
+ <dim>1</dim>
212
+ <dim>32</dim>
213
+ <dim>1</dim>
214
+ <dim>1</dim>
215
+ </port>
216
+ </output>
217
+ </layer>
218
+ <layer id="12" name="Reshape_39" type="Convert" version="opset1">
219
+ <data destination_type="f32" />
220
+ <rt_info>
221
+ <attribute name="decompression" version="0" />
222
+ </rt_info>
223
+ <input>
224
+ <port id="0" precision="FP16">
225
+ <dim>1</dim>
226
+ <dim>32</dim>
227
+ <dim>1</dim>
228
+ <dim>1</dim>
229
+ </port>
230
+ </input>
231
+ <output>
232
+ <port id="1" precision="FP32">
233
+ <dim>1</dim>
234
+ <dim>32</dim>
235
+ <dim>1</dim>
236
+ <dim>1</dim>
237
+ </port>
238
+ </output>
239
+ </layer>
240
+ <layer id="13" name="/encoder/encoder.2/Conv" type="Add" version="opset1">
241
+ <data auto_broadcast="numpy" />
242
+ <input>
243
+ <port id="0" precision="FP32">
244
+ <dim>1</dim>
245
+ <dim>32</dim>
246
+ <dim>76</dim>
247
+ <dim>76</dim>
248
+ </port>
249
+ <port id="1" precision="FP32">
250
+ <dim>1</dim>
251
+ <dim>32</dim>
252
+ <dim>1</dim>
253
+ <dim>1</dim>
254
+ </port>
255
+ </input>
256
+ <output>
257
+ <port id="2" precision="FP32" names="/encoder/encoder.2/Conv_output_0">
258
+ <dim>1</dim>
259
+ <dim>32</dim>
260
+ <dim>76</dim>
261
+ <dim>76</dim>
262
+ </port>
263
+ </output>
264
+ </layer>
265
+ <layer id="14" name="/encoder/encoder.3/Relu" type="ReLU" version="opset1">
266
+ <input>
267
+ <port id="0" precision="FP32">
268
+ <dim>1</dim>
269
+ <dim>32</dim>
270
+ <dim>76</dim>
271
+ <dim>76</dim>
272
+ </port>
273
+ </input>
274
+ <output>
275
+ <port id="1" precision="FP32" names="/encoder/encoder.3/Relu_output_0">
276
+ <dim>1</dim>
277
+ <dim>32</dim>
278
+ <dim>76</dim>
279
+ <dim>76</dim>
280
+ </port>
281
+ </output>
282
+ </layer>
283
+ <layer id="15" name="encoder.4.weight_compressed" type="Const" version="opset1">
284
+ <data element_type="f16" shape="64, 32, 3, 3" offset="10176" size="36864" />
285
+ <output>
286
+ <port id="0" precision="FP16">
287
+ <dim>64</dim>
288
+ <dim>32</dim>
289
+ <dim>3</dim>
290
+ <dim>3</dim>
291
+ </port>
292
+ </output>
293
+ </layer>
294
+ <layer id="16" name="encoder.4.weight" type="Convert" version="opset1">
295
+ <data destination_type="f32" />
296
+ <rt_info>
297
+ <attribute name="decompression" version="0" />
298
+ </rt_info>
299
+ <input>
300
+ <port id="0" precision="FP16">
301
+ <dim>64</dim>
302
+ <dim>32</dim>
303
+ <dim>3</dim>
304
+ <dim>3</dim>
305
+ </port>
306
+ </input>
307
+ <output>
308
+ <port id="1" precision="FP32" names="encoder.4.weight">
309
+ <dim>64</dim>
310
+ <dim>32</dim>
311
+ <dim>3</dim>
312
+ <dim>3</dim>
313
+ </port>
314
+ </output>
315
+ </layer>
316
+ <layer id="17" name="/encoder/encoder.4/Conv/WithoutBiases" type="Convolution" version="opset1">
317
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" />
318
+ <input>
319
+ <port id="0" precision="FP32">
320
+ <dim>1</dim>
321
+ <dim>32</dim>
322
+ <dim>76</dim>
323
+ <dim>76</dim>
324
+ </port>
325
+ <port id="1" precision="FP32">
326
+ <dim>64</dim>
327
+ <dim>32</dim>
328
+ <dim>3</dim>
329
+ <dim>3</dim>
330
+ </port>
331
+ </input>
332
+ <output>
333
+ <port id="2" precision="FP32">
334
+ <dim>1</dim>
335
+ <dim>64</dim>
336
+ <dim>38</dim>
337
+ <dim>38</dim>
338
+ </port>
339
+ </output>
340
+ </layer>
341
+ <layer id="18" name="Reshape_53_compressed" type="Const" version="opset1">
342
+ <data element_type="f16" shape="1, 64, 1, 1" offset="47040" size="128" />
343
+ <output>
344
+ <port id="0" precision="FP16">
345
+ <dim>1</dim>
346
+ <dim>64</dim>
347
+ <dim>1</dim>
348
+ <dim>1</dim>
349
+ </port>
350
+ </output>
351
+ </layer>
352
+ <layer id="19" name="Reshape_53" type="Convert" version="opset1">
353
+ <data destination_type="f32" />
354
+ <rt_info>
355
+ <attribute name="decompression" version="0" />
356
+ </rt_info>
357
+ <input>
358
+ <port id="0" precision="FP16">
359
+ <dim>1</dim>
360
+ <dim>64</dim>
361
+ <dim>1</dim>
362
+ <dim>1</dim>
363
+ </port>
364
+ </input>
365
+ <output>
366
+ <port id="1" precision="FP32">
367
+ <dim>1</dim>
368
+ <dim>64</dim>
369
+ <dim>1</dim>
370
+ <dim>1</dim>
371
+ </port>
372
+ </output>
373
+ </layer>
374
+ <layer id="20" name="/encoder/encoder.4/Conv" type="Add" version="opset1">
375
+ <data auto_broadcast="numpy" />
376
+ <input>
377
+ <port id="0" precision="FP32">
378
+ <dim>1</dim>
379
+ <dim>64</dim>
380
+ <dim>38</dim>
381
+ <dim>38</dim>
382
+ </port>
383
+ <port id="1" precision="FP32">
384
+ <dim>1</dim>
385
+ <dim>64</dim>
386
+ <dim>1</dim>
387
+ <dim>1</dim>
388
+ </port>
389
+ </input>
390
+ <output>
391
+ <port id="2" precision="FP32" names="/encoder/encoder.4/Conv_output_0">
392
+ <dim>1</dim>
393
+ <dim>64</dim>
394
+ <dim>38</dim>
395
+ <dim>38</dim>
396
+ </port>
397
+ </output>
398
+ </layer>
399
+ <layer id="21" name="/encoder/encoder.5/Relu" type="ReLU" version="opset1">
400
+ <input>
401
+ <port id="0" precision="FP32">
402
+ <dim>1</dim>
403
+ <dim>64</dim>
404
+ <dim>38</dim>
405
+ <dim>38</dim>
406
+ </port>
407
+ </input>
408
+ <output>
409
+ <port id="1" precision="FP32" names="/encoder/encoder.5/Relu_output_0">
410
+ <dim>1</dim>
411
+ <dim>64</dim>
412
+ <dim>38</dim>
413
+ <dim>38</dim>
414
+ </port>
415
+ </output>
416
+ </layer>
417
+ <layer id="22" name="decoder.0.weight_compressed" type="Const" version="opset1">
418
+ <data element_type="f16" shape="64, 32, 3, 3" offset="47168" size="36864" />
419
+ <output>
420
+ <port id="0" precision="FP16">
421
+ <dim>64</dim>
422
+ <dim>32</dim>
423
+ <dim>3</dim>
424
+ <dim>3</dim>
425
+ </port>
426
+ </output>
427
+ </layer>
428
+ <layer id="23" name="decoder.0.weight" type="Convert" version="opset1">
429
+ <data destination_type="f32" />
430
+ <rt_info>
431
+ <attribute name="decompression" version="0" />
432
+ </rt_info>
433
+ <input>
434
+ <port id="0" precision="FP16">
435
+ <dim>64</dim>
436
+ <dim>32</dim>
437
+ <dim>3</dim>
438
+ <dim>3</dim>
439
+ </port>
440
+ </input>
441
+ <output>
442
+ <port id="1" precision="FP32" names="decoder.0.weight">
443
+ <dim>64</dim>
444
+ <dim>32</dim>
445
+ <dim>3</dim>
446
+ <dim>3</dim>
447
+ </port>
448
+ </output>
449
+ </layer>
450
+ <layer id="24" name="ConvolutionBackpropData_56" type="ConvolutionBackpropData" version="opset1">
451
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" output_padding="1, 1" />
452
+ <input>
453
+ <port id="0" precision="FP32">
454
+ <dim>1</dim>
455
+ <dim>64</dim>
456
+ <dim>38</dim>
457
+ <dim>38</dim>
458
+ </port>
459
+ <port id="1" precision="FP32">
460
+ <dim>64</dim>
461
+ <dim>32</dim>
462
+ <dim>3</dim>
463
+ <dim>3</dim>
464
+ </port>
465
+ </input>
466
+ <output>
467
+ <port id="2" precision="FP32">
468
+ <dim>1</dim>
469
+ <dim>32</dim>
470
+ <dim>76</dim>
471
+ <dim>76</dim>
472
+ </port>
473
+ </output>
474
+ </layer>
475
+ <layer id="25" name="Reshape_58_compressed" type="Const" version="opset1">
476
+ <data element_type="f16" shape="1, 32, 1, 1" offset="84032" size="64" />
477
+ <output>
478
+ <port id="0" precision="FP16">
479
+ <dim>1</dim>
480
+ <dim>32</dim>
481
+ <dim>1</dim>
482
+ <dim>1</dim>
483
+ </port>
484
+ </output>
485
+ </layer>
486
+ <layer id="26" name="Reshape_58" type="Convert" version="opset1">
487
+ <data destination_type="f32" />
488
+ <rt_info>
489
+ <attribute name="decompression" version="0" />
490
+ </rt_info>
491
+ <input>
492
+ <port id="0" precision="FP16">
493
+ <dim>1</dim>
494
+ <dim>32</dim>
495
+ <dim>1</dim>
496
+ <dim>1</dim>
497
+ </port>
498
+ </input>
499
+ <output>
500
+ <port id="1" precision="FP32">
501
+ <dim>1</dim>
502
+ <dim>32</dim>
503
+ <dim>1</dim>
504
+ <dim>1</dim>
505
+ </port>
506
+ </output>
507
+ </layer>
508
+ <layer id="27" name="/decoder/decoder.0/ConvTranspose" type="Add" version="opset1">
509
+ <data auto_broadcast="numpy" />
510
+ <input>
511
+ <port id="0" precision="FP32">
512
+ <dim>1</dim>
513
+ <dim>32</dim>
514
+ <dim>76</dim>
515
+ <dim>76</dim>
516
+ </port>
517
+ <port id="1" precision="FP32">
518
+ <dim>1</dim>
519
+ <dim>32</dim>
520
+ <dim>1</dim>
521
+ <dim>1</dim>
522
+ </port>
523
+ </input>
524
+ <output>
525
+ <port id="2" precision="FP32" names="/decoder/decoder.0/ConvTranspose_output_0">
526
+ <dim>1</dim>
527
+ <dim>32</dim>
528
+ <dim>76</dim>
529
+ <dim>76</dim>
530
+ </port>
531
+ </output>
532
+ </layer>
533
+ <layer id="28" name="/decoder/decoder.1/Relu" type="ReLU" version="opset1">
534
+ <input>
535
+ <port id="0" precision="FP32">
536
+ <dim>1</dim>
537
+ <dim>32</dim>
538
+ <dim>76</dim>
539
+ <dim>76</dim>
540
+ </port>
541
+ </input>
542
+ <output>
543
+ <port id="1" precision="FP32" names="/decoder/decoder.1/Relu_output_0">
544
+ <dim>1</dim>
545
+ <dim>32</dim>
546
+ <dim>76</dim>
547
+ <dim>76</dim>
548
+ </port>
549
+ </output>
550
+ </layer>
551
+ <layer id="29" name="decoder.2.weight_compressed" type="Const" version="opset1">
552
+ <data element_type="f16" shape="32, 16, 3, 3" offset="84096" size="9216" />
553
+ <output>
554
+ <port id="0" precision="FP16">
555
+ <dim>32</dim>
556
+ <dim>16</dim>
557
+ <dim>3</dim>
558
+ <dim>3</dim>
559
+ </port>
560
+ </output>
561
+ </layer>
562
+ <layer id="30" name="decoder.2.weight" type="Convert" version="opset1">
563
+ <data destination_type="f32" />
564
+ <rt_info>
565
+ <attribute name="decompression" version="0" />
566
+ </rt_info>
567
+ <input>
568
+ <port id="0" precision="FP16">
569
+ <dim>32</dim>
570
+ <dim>16</dim>
571
+ <dim>3</dim>
572
+ <dim>3</dim>
573
+ </port>
574
+ </input>
575
+ <output>
576
+ <port id="1" precision="FP32" names="decoder.2.weight">
577
+ <dim>32</dim>
578
+ <dim>16</dim>
579
+ <dim>3</dim>
580
+ <dim>3</dim>
581
+ </port>
582
+ </output>
583
+ </layer>
584
+ <layer id="31" name="ConvolutionBackpropData_61" type="ConvolutionBackpropData" version="opset1">
585
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" output_padding="1, 1" />
586
+ <input>
587
+ <port id="0" precision="FP32">
588
+ <dim>1</dim>
589
+ <dim>32</dim>
590
+ <dim>76</dim>
591
+ <dim>76</dim>
592
+ </port>
593
+ <port id="1" precision="FP32">
594
+ <dim>32</dim>
595
+ <dim>16</dim>
596
+ <dim>3</dim>
597
+ <dim>3</dim>
598
+ </port>
599
+ </input>
600
+ <output>
601
+ <port id="2" precision="FP32">
602
+ <dim>1</dim>
603
+ <dim>16</dim>
604
+ <dim>152</dim>
605
+ <dim>152</dim>
606
+ </port>
607
+ </output>
608
+ </layer>
609
+ <layer id="32" name="Reshape_63_compressed" type="Const" version="opset1">
610
+ <data element_type="f16" shape="1, 16, 1, 1" offset="93312" size="32" />
611
+ <output>
612
+ <port id="0" precision="FP16">
613
+ <dim>1</dim>
614
+ <dim>16</dim>
615
+ <dim>1</dim>
616
+ <dim>1</dim>
617
+ </port>
618
+ </output>
619
+ </layer>
620
+ <layer id="33" name="Reshape_63" type="Convert" version="opset1">
621
+ <data destination_type="f32" />
622
+ <rt_info>
623
+ <attribute name="decompression" version="0" />
624
+ </rt_info>
625
+ <input>
626
+ <port id="0" precision="FP16">
627
+ <dim>1</dim>
628
+ <dim>16</dim>
629
+ <dim>1</dim>
630
+ <dim>1</dim>
631
+ </port>
632
+ </input>
633
+ <output>
634
+ <port id="1" precision="FP32">
635
+ <dim>1</dim>
636
+ <dim>16</dim>
637
+ <dim>1</dim>
638
+ <dim>1</dim>
639
+ </port>
640
+ </output>
641
+ </layer>
642
+ <layer id="34" name="/decoder/decoder.2/ConvTranspose" type="Add" version="opset1">
643
+ <data auto_broadcast="numpy" />
644
+ <input>
645
+ <port id="0" precision="FP32">
646
+ <dim>1</dim>
647
+ <dim>16</dim>
648
+ <dim>152</dim>
649
+ <dim>152</dim>
650
+ </port>
651
+ <port id="1" precision="FP32">
652
+ <dim>1</dim>
653
+ <dim>16</dim>
654
+ <dim>1</dim>
655
+ <dim>1</dim>
656
+ </port>
657
+ </input>
658
+ <output>
659
+ <port id="2" precision="FP32" names="/decoder/decoder.2/ConvTranspose_output_0">
660
+ <dim>1</dim>
661
+ <dim>16</dim>
662
+ <dim>152</dim>
663
+ <dim>152</dim>
664
+ </port>
665
+ </output>
666
+ </layer>
667
+ <layer id="35" name="/decoder/decoder.3/Relu" type="ReLU" version="opset1">
668
+ <input>
669
+ <port id="0" precision="FP32">
670
+ <dim>1</dim>
671
+ <dim>16</dim>
672
+ <dim>152</dim>
673
+ <dim>152</dim>
674
+ </port>
675
+ </input>
676
+ <output>
677
+ <port id="1" precision="FP32" names="/decoder/decoder.3/Relu_output_0">
678
+ <dim>1</dim>
679
+ <dim>16</dim>
680
+ <dim>152</dim>
681
+ <dim>152</dim>
682
+ </port>
683
+ </output>
684
+ </layer>
685
+ <layer id="36" name="decoder.4.weight_compressed" type="Const" version="opset1">
686
+ <data element_type="f16" shape="16, 3, 3, 3" offset="93344" size="864" />
687
+ <output>
688
+ <port id="0" precision="FP16">
689
+ <dim>16</dim>
690
+ <dim>3</dim>
691
+ <dim>3</dim>
692
+ <dim>3</dim>
693
+ </port>
694
+ </output>
695
+ </layer>
696
+ <layer id="37" name="decoder.4.weight" type="Convert" version="opset1">
697
+ <data destination_type="f32" />
698
+ <rt_info>
699
+ <attribute name="decompression" version="0" />
700
+ </rt_info>
701
+ <input>
702
+ <port id="0" precision="FP16">
703
+ <dim>16</dim>
704
+ <dim>3</dim>
705
+ <dim>3</dim>
706
+ <dim>3</dim>
707
+ </port>
708
+ </input>
709
+ <output>
710
+ <port id="1" precision="FP32" names="decoder.4.weight">
711
+ <dim>16</dim>
712
+ <dim>3</dim>
713
+ <dim>3</dim>
714
+ <dim>3</dim>
715
+ </port>
716
+ </output>
717
+ </layer>
718
+ <layer id="38" name="ConvolutionBackpropData_66" type="ConvolutionBackpropData" version="opset1">
719
+ <data strides="2, 2" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="explicit" output_padding="1, 1" />
720
+ <input>
721
+ <port id="0" precision="FP32">
722
+ <dim>1</dim>
723
+ <dim>16</dim>
724
+ <dim>152</dim>
725
+ <dim>152</dim>
726
+ </port>
727
+ <port id="1" precision="FP32">
728
+ <dim>16</dim>
729
+ <dim>3</dim>
730
+ <dim>3</dim>
731
+ <dim>3</dim>
732
+ </port>
733
+ </input>
734
+ <output>
735
+ <port id="2" precision="FP32">
736
+ <dim>1</dim>
737
+ <dim>3</dim>
738
+ <dim>304</dim>
739
+ <dim>304</dim>
740
+ </port>
741
+ </output>
742
+ </layer>
743
+ <layer id="39" name="Reshape_68_compressed" type="Const" version="opset1">
744
+ <data element_type="f16" shape="1, 3, 1, 1" offset="94208" size="6" />
745
+ <output>
746
+ <port id="0" precision="FP16">
747
+ <dim>1</dim>
748
+ <dim>3</dim>
749
+ <dim>1</dim>
750
+ <dim>1</dim>
751
+ </port>
752
+ </output>
753
+ </layer>
754
+ <layer id="40" name="Reshape_68" type="Convert" version="opset1">
755
+ <data destination_type="f32" />
756
+ <rt_info>
757
+ <attribute name="decompression" version="0" />
758
+ </rt_info>
759
+ <input>
760
+ <port id="0" precision="FP16">
761
+ <dim>1</dim>
762
+ <dim>3</dim>
763
+ <dim>1</dim>
764
+ <dim>1</dim>
765
+ </port>
766
+ </input>
767
+ <output>
768
+ <port id="1" precision="FP32">
769
+ <dim>1</dim>
770
+ <dim>3</dim>
771
+ <dim>1</dim>
772
+ <dim>1</dim>
773
+ </port>
774
+ </output>
775
+ </layer>
776
+ <layer id="41" name="/decoder/decoder.4/ConvTranspose" type="Add" version="opset1">
777
+ <data auto_broadcast="numpy" />
778
+ <input>
779
+ <port id="0" precision="FP32">
780
+ <dim>1</dim>
781
+ <dim>3</dim>
782
+ <dim>304</dim>
783
+ <dim>304</dim>
784
+ </port>
785
+ <port id="1" precision="FP32">
786
+ <dim>1</dim>
787
+ <dim>3</dim>
788
+ <dim>1</dim>
789
+ <dim>1</dim>
790
+ </port>
791
+ </input>
792
+ <output>
793
+ <port id="2" precision="FP32" names="/decoder/decoder.4/ConvTranspose_output_0">
794
+ <dim>1</dim>
795
+ <dim>3</dim>
796
+ <dim>304</dim>
797
+ <dim>304</dim>
798
+ </port>
799
+ </output>
800
+ </layer>
801
+ <layer id="42" name="output" type="Sigmoid" version="opset1">
802
+ <input>
803
+ <port id="0" precision="FP32">
804
+ <dim>1</dim>
805
+ <dim>3</dim>
806
+ <dim>304</dim>
807
+ <dim>304</dim>
808
+ </port>
809
+ </input>
810
+ <output>
811
+ <port id="1" precision="FP32" names="output">
812
+ <dim>1</dim>
813
+ <dim>3</dim>
814
+ <dim>304</dim>
815
+ <dim>304</dim>
816
+ </port>
817
+ </output>
818
+ </layer>
819
+ <layer id="43" name="output/sink_port_0" type="Result" version="opset1">
820
+ <input>
821
+ <port id="0" precision="FP32">
822
+ <dim>1</dim>
823
+ <dim>3</dim>
824
+ <dim>304</dim>
825
+ <dim>304</dim>
826
+ </port>
827
+ </input>
828
+ </layer>
829
+ </layers>
830
+ <edges>
831
+ <edge from-layer="0" from-port="0" to-layer="3" to-port="0" />
832
+ <edge from-layer="1" from-port="0" to-layer="2" to-port="0" />
833
+ <edge from-layer="2" from-port="1" to-layer="3" to-port="1" />
834
+ <edge from-layer="3" from-port="2" to-layer="6" to-port="0" />
835
+ <edge from-layer="4" from-port="0" to-layer="5" to-port="0" />
836
+ <edge from-layer="5" from-port="1" to-layer="6" to-port="1" />
837
+ <edge from-layer="6" from-port="2" to-layer="7" to-port="0" />
838
+ <edge from-layer="7" from-port="1" to-layer="10" to-port="0" />
839
+ <edge from-layer="8" from-port="0" to-layer="9" to-port="0" />
840
+ <edge from-layer="9" from-port="1" to-layer="10" to-port="1" />
841
+ <edge from-layer="10" from-port="2" to-layer="13" to-port="0" />
842
+ <edge from-layer="11" from-port="0" to-layer="12" to-port="0" />
843
+ <edge from-layer="12" from-port="1" to-layer="13" to-port="1" />
844
+ <edge from-layer="13" from-port="2" to-layer="14" to-port="0" />
845
+ <edge from-layer="14" from-port="1" to-layer="17" to-port="0" />
846
+ <edge from-layer="15" from-port="0" to-layer="16" to-port="0" />
847
+ <edge from-layer="16" from-port="1" to-layer="17" to-port="1" />
848
+ <edge from-layer="17" from-port="2" to-layer="20" to-port="0" />
849
+ <edge from-layer="18" from-port="0" to-layer="19" to-port="0" />
850
+ <edge from-layer="19" from-port="1" to-layer="20" to-port="1" />
851
+ <edge from-layer="20" from-port="2" to-layer="21" to-port="0" />
852
+ <edge from-layer="21" from-port="1" to-layer="24" to-port="0" />
853
+ <edge from-layer="22" from-port="0" to-layer="23" to-port="0" />
854
+ <edge from-layer="23" from-port="1" to-layer="24" to-port="1" />
855
+ <edge from-layer="24" from-port="2" to-layer="27" to-port="0" />
856
+ <edge from-layer="25" from-port="0" to-layer="26" to-port="0" />
857
+ <edge from-layer="26" from-port="1" to-layer="27" to-port="1" />
858
+ <edge from-layer="27" from-port="2" to-layer="28" to-port="0" />
859
+ <edge from-layer="28" from-port="1" to-layer="31" to-port="0" />
860
+ <edge from-layer="29" from-port="0" to-layer="30" to-port="0" />
861
+ <edge from-layer="30" from-port="1" to-layer="31" to-port="1" />
862
+ <edge from-layer="31" from-port="2" to-layer="34" to-port="0" />
863
+ <edge from-layer="32" from-port="0" to-layer="33" to-port="0" />
864
+ <edge from-layer="33" from-port="1" to-layer="34" to-port="1" />
865
+ <edge from-layer="34" from-port="2" to-layer="35" to-port="0" />
866
+ <edge from-layer="35" from-port="1" to-layer="38" to-port="0" />
867
+ <edge from-layer="36" from-port="0" to-layer="37" to-port="0" />
868
+ <edge from-layer="37" from-port="1" to-layer="38" to-port="1" />
869
+ <edge from-layer="38" from-port="2" to-layer="41" to-port="0" />
870
+ <edge from-layer="39" from-port="0" to-layer="40" to-port="0" />
871
+ <edge from-layer="40" from-port="1" to-layer="41" to-port="1" />
872
+ <edge from-layer="41" from-port="2" to-layer="42" to-port="0" />
873
+ <edge from-layer="42" from-port="1" to-layer="43" to-port="0" />
874
+ </edges>
875
+ <rt_info>
876
+ <MO_version value="2024.6.0-17404-4c0f47d2335-releases/2024/6" />
877
+ <Runtime_version value="2024.6.0-17404-4c0f47d2335-releases/2024/6" />
878
+ <conversion_parameters>
879
+ <input_model value="DIR/casting_autoencoder.onnx" />
880
+ <is_python_api_used value="False" />
881
+ <output_dir value="/home/arunima/intel/casting_data/./casting_ir" />
882
+ </conversion_parameters>
883
+ <legacy_frontend value="False" />
884
+ </rt_info>
885
+ </net>
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
train_model.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from torchvision import datasets, transforms
5
+ import numpy as np
6
+ import os
7
+
8
+ # =============== 1. CONFIG =================
9
+ IMG_SIZE = 304
10
+ BATCH_SIZE = 32
11
+ EPOCHS = 10
12
+ LR = 1e-3
13
+ MODEL_PATH = "casting_autoencoder.pth"
14
+ ONNX_PATH = "casting_autoencoder.onnx"
15
+
16
+ TRAIN_DIR = "casting_data/train" # only OK parts
17
+ TEST_DEFECT_DIR = "casting_data/test" # defects for thresholding
18
+
19
+ # =============== 2. DATA PIPELINE =================
20
+ transform = transforms.Compose([
21
+ transforms.Grayscale(num_output_channels=3),
22
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
23
+ transforms.ToTensor()
24
+ ])
25
+
26
+ train_data = datasets.ImageFolder(root=TRAIN_DIR, transform=transform)
27
+ train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
28
+
29
+ # =============== 3. MODEL =================
30
+ class Autoencoder(nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.encoder = nn.Sequential(
34
+ nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.ReLU(),
35
+ nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(),
36
+ nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
37
+ )
38
+ self.decoder = nn.Sequential(
39
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
40
+ nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
41
+ nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid()
42
+ )
43
+
44
+ def forward(self, x):
45
+ x = self.encoder(x)
46
+ x = self.decoder(x)
47
+ return x
48
+
49
+ # =============== 4. TRAINING LOOP =================
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ model = Autoencoder().to(device)
52
+ criterion = nn.MSELoss()
53
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
54
+
55
+ print(" Training started...")
56
+ for epoch in range(EPOCHS):
57
+ total_loss = 0
58
+ for imgs, _ in train_loader:
59
+ imgs = imgs.to(device)
60
+ output = model(imgs)
61
+ loss = criterion(output, imgs)
62
+ optimizer.zero_grad()
63
+ loss.backward()
64
+ optimizer.step()
65
+ total_loss += loss.item()
66
+ print(f"Epoch [{epoch+1}/{EPOCHS}] - Loss: {total_loss/len(train_loader):.4f}")
67
+
68
+ torch.save(model.state_dict(), MODEL_PATH)
69
+ print(f" Model saved to {MODEL_PATH}")
70
+
71
+ # =============== 5. THRESHOLD CALIBRATION =================
72
+ defect_data = datasets.ImageFolder(root=TEST_DEFECT_DIR, transform=transform)
73
+ defect_loader = DataLoader(defect_data, batch_size=1)
74
+
75
+ model.eval()
76
+ errors = []
77
+ with torch.no_grad():
78
+ for img, _ in defect_loader:
79
+ img = img.to(device)
80
+ out = model(img)
81
+ err = criterion(out, img).item()
82
+ errors.append(err)
83
+
84
+ threshold = np.mean(errors) * 0.8
85
+ print(f"⚡ Suggested anomaly threshold: {threshold:.4f}")
86
+
87
+ # =============== 6. EXPORT TO ONNX =================
88
+ dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
89
+ torch.onnx.export(
90
+ model,
91
+ dummy,
92
+ ONNX_PATH,
93
+ input_names=["input"],
94
+ output_names=["output"],
95
+ opset_version=12
96
+ )
97
+ print(f" ONNX model exported to {ONNX_PATH}")