Upload folder using huggingface_hub
Browse files- .gitattributes +5 -35
- Dockerfile +12 -0
- README.md +0 -3
- requirements.txt +14 -0
- src/app.py +41 -0
- src/dashboard.py +56 -0
- src/gradcam.py +60 -0
- src/model.py +19 -0
- src/preprocess.py +35 -0
- src/train.py +74 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,5 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.
|
| 5 |
-
*.
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
COPY requirements.txt .
|
| 5 |
+
RUN pip install -r requirements.txt
|
| 6 |
+
|
| 7 |
+
COPY src/ ./src/
|
| 8 |
+
COPY models/ ./models/
|
| 9 |
+
|
| 10 |
+
EXPOSE 8000 8501
|
| 11 |
+
|
| 12 |
+
CMD ["sh", "-c", "uvicorn src.app:app --host 0.0.0.0 --port 8000 & streamlit run src/dashboard.py --server.port 8501 --server.address 0.0.0.0"]
|
README.md
CHANGED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
torch==2.0.1
|
| 3 |
+
torchvision==0.15.2
|
| 4 |
+
opencv-python==4.8.1.78
|
| 5 |
+
fastapi==0.104.1
|
| 6 |
+
uvicorn==0.24.0
|
| 7 |
+
streamlit==1.28.1
|
| 8 |
+
onnxruntime==1.16.1
|
| 9 |
+
numpy==1.24.3
|
| 10 |
+
pillow==10.0.1
|
| 11 |
+
transformers==4.35.0 # Optional, but useful for some models
|
| 12 |
+
scikit-learn==1.3.0
|
| 13 |
+
matplotlib==3.7.2
|
| 14 |
+
requests==2.31.0
|
src/app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
import numpy as np
|
| 5 |
+
from src.preprocess import preprocess_image
|
| 6 |
+
from src.gradcam import GradCAM # Note: GradCAM uses PyTorch, so for ONNX we approximate or skip full CAM
|
| 7 |
+
import cv2
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
app = FastAPI(title="AutoVision API")
|
| 12 |
+
|
| 13 |
+
# Load ONNX model
|
| 14 |
+
ort_session = ort.InferenceSession('../models/resnet18_anomaly.onnx')
|
| 15 |
+
classes = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled-in_scale', 'scratches']
|
| 16 |
+
|
| 17 |
+
@app.post("/predict")
|
| 18 |
+
async def predict(file: UploadFile = File(...)):
|
| 19 |
+
contents = await file.read()
|
| 20 |
+
image = Image.open(io.BytesIO(contents))
|
| 21 |
+
image_np = np.array(image)
|
| 22 |
+
|
| 23 |
+
input_data = preprocess_image(image_np)
|
| 24 |
+
|
| 25 |
+
ort_inputs = {ort_session.get_inputs()[0].name: input_data.astype(np.float32)}
|
| 26 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
| 27 |
+
pred = np.argmax(ort_outs[0])
|
| 28 |
+
confidence = np.max(ort_outs[0])
|
| 29 |
+
|
| 30 |
+
# For Grad-CAM, we'd need to load PyTorch model separately for explainability
|
| 31 |
+
# Here, return pred and conf; overlay handled in UI
|
| 32 |
+
|
| 33 |
+
return JSONResponse({
|
| 34 |
+
"prediction": classes[pred],
|
| 35 |
+
"confidence": float(confidence),
|
| 36 |
+
"class_id": int(pred)
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
import uvicorn
|
| 41 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/dashboard.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from src.preprocess import preprocess_image, camera_stream, overlay_heatmap
|
| 5 |
+
from src.gradcam import GradCAM
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
st.title("AutoVision: Real-Time Defect Detection")
|
| 11 |
+
|
| 12 |
+
# Load models
|
| 13 |
+
@st.cache_resource
|
| 14 |
+
def load_models():
|
| 15 |
+
gradcam = GradCAM('../models/resnet18_anomaly.pth')
|
| 16 |
+
return gradcam
|
| 17 |
+
|
| 18 |
+
gradcam = load_models()
|
| 19 |
+
classes = ['normal', 'crazing', 'inclusion', 'pits', 'pitted_surface', 'rolled-in_scale', 'scratches']
|
| 20 |
+
|
| 21 |
+
# Real-time camera feed
|
| 22 |
+
frame_placeholder = st.empty()
|
| 23 |
+
prediction_placeholder = st.empty()
|
| 24 |
+
|
| 25 |
+
# Optional: Use API for prediction (if backend running)
|
| 26 |
+
use_api = st.checkbox("Use FastAPI Backend for Inference")
|
| 27 |
+
|
| 28 |
+
for frame in camera_stream():
|
| 29 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 30 |
+
|
| 31 |
+
# Preprocess for inference
|
| 32 |
+
input_data = preprocess_image(rgb_frame)
|
| 33 |
+
input_tensor = torch.from_numpy(input_data).float()
|
| 34 |
+
|
| 35 |
+
if use_api:
|
| 36 |
+
# Upload to API (simplified; in practice, use multipart)
|
| 37 |
+
# For demo, use local PyTorch
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
# Local inference with PyTorch (for Grad-CAM compatibility)
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
output = gradcam.model(input_tensor)
|
| 43 |
+
pred = output.argmax().item()
|
| 44 |
+
confidence = torch.softmax(output, dim=1).max().item()
|
| 45 |
+
|
| 46 |
+
# Generate Grad-CAM
|
| 47 |
+
heatmap = gradcam.generate(input_tensor, pred)
|
| 48 |
+
|
| 49 |
+
# Overlay
|
| 50 |
+
overlaid = overlay_heatmap(rgb_frame, heatmap)
|
| 51 |
+
|
| 52 |
+
frame_placeholder.image(overlaid, channels="RGB")
|
| 53 |
+
|
| 54 |
+
prediction_placeholder.markdown(f"**Prediction:** {classes[pred]} ({confidence:.2%})")
|
| 55 |
+
|
| 56 |
+
st.info("Press Ctrl+C to stop camera.")
|
src/gradcam.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from src.model import get_model
|
| 6 |
+
|
| 7 |
+
class GradCAM:
|
| 8 |
+
def __init__(self, model_path, target_layer='layer4'):
|
| 9 |
+
self.model = get_model(pretrained=False)
|
| 10 |
+
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
| 11 |
+
self.model.eval()
|
| 12 |
+
self.target_layer = target_layer
|
| 13 |
+
self.gradients = None
|
| 14 |
+
self.activations = None
|
| 15 |
+
self.hooks = []
|
| 16 |
+
self._register_hooks()
|
| 17 |
+
|
| 18 |
+
def _register_hooks(self):
|
| 19 |
+
def backward_hook(module, grad_input, grad_output):
|
| 20 |
+
self.gradients = grad_output[0]
|
| 21 |
+
|
| 22 |
+
def forward_hook(module, input, output):
|
| 23 |
+
self.activations = output
|
| 24 |
+
|
| 25 |
+
for name, module in self.model.named_modules():
|
| 26 |
+
if target_layer in name:
|
| 27 |
+
self.hooks.append(module.register_forward_hook(forward_hook))
|
| 28 |
+
self.hooks.append(module.register_backward_hook(backward_hook))
|
| 29 |
+
|
| 30 |
+
def generate(self, input_tensor, class_idx=None):
|
| 31 |
+
self.model.zero_grad()
|
| 32 |
+
output = self.model(input_tensor)
|
| 33 |
+
if class_idx is None:
|
| 34 |
+
class_idx = output.argmax().item()
|
| 35 |
+
score = output[0, class_idx]
|
| 36 |
+
score.backward()
|
| 37 |
+
|
| 38 |
+
gradients = self.gradients[0]
|
| 39 |
+
activations = self.activations[0]
|
| 40 |
+
weights = torch.mean(gradients, dim=(1, 2), keepdim=True)
|
| 41 |
+
|
| 42 |
+
cam = torch.sum(weights * activations, dim=1, keepdim=True)
|
| 43 |
+
cam = F.relu(cam)
|
| 44 |
+
cam = cam.squeeze().detach().cpu().numpy()
|
| 45 |
+
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
|
| 46 |
+
|
| 47 |
+
# Remove hooks
|
| 48 |
+
for hook in self.hooks:
|
| 49 |
+
hook.remove()
|
| 50 |
+
|
| 51 |
+
return cam
|
| 52 |
+
|
| 53 |
+
def __del__(self):
|
| 54 |
+
for hook in self.hooks:
|
| 55 |
+
hook.remove()
|
| 56 |
+
|
| 57 |
+
# Usage example:
|
| 58 |
+
# gradcam = GradCAM('../models/resnet18_anomaly.pth')
|
| 59 |
+
# input_tensor = torch.randn(1, 3, 224, 224) # From preprocess
|
| 60 |
+
# heatmap = gradcam.generate(input_tensor)
|
src/model.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
NUM_CLASSES = 7 # 6 defects + normal
|
| 7 |
+
|
| 8 |
+
def get_model(pretrained=True):
|
| 9 |
+
model = models.resnet18(pretrained=pretrained)
|
| 10 |
+
num_ftrs = model.fc.in_features
|
| 11 |
+
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
|
| 12 |
+
return model
|
| 13 |
+
|
| 14 |
+
def get_transforms():
|
| 15 |
+
return transforms.Compose([
|
| 16 |
+
transforms.Resize((224, 224)),
|
| 17 |
+
transforms.ToTensor(),
|
| 18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 19 |
+
])
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from src.model import get_transforms
|
| 5 |
+
|
| 6 |
+
def preprocess_image(image_path_or_np):
|
| 7 |
+
"""Preprocess single image for inference."""
|
| 8 |
+
if isinstance(image_path_or_np, str):
|
| 9 |
+
image = cv2.imread(image_path_or_np)
|
| 10 |
+
else:
|
| 11 |
+
image = image_path_or_np
|
| 12 |
+
|
| 13 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 14 |
+
image = Image.fromarray(image)
|
| 15 |
+
transform = get_transforms()
|
| 16 |
+
return transform(image).unsqueeze(0).numpy() # To numpy for ONNX
|
| 17 |
+
|
| 18 |
+
def camera_stream():
|
| 19 |
+
"""Generator for real-time camera feed."""
|
| 20 |
+
cap = cv2.VideoCapture(0) # Use default camera
|
| 21 |
+
while True:
|
| 22 |
+
ret, frame = cap.read()
|
| 23 |
+
if ret:
|
| 24 |
+
yield frame
|
| 25 |
+
else:
|
| 26 |
+
break
|
| 27 |
+
cap.release()
|
| 28 |
+
|
| 29 |
+
def overlay_heatmap(frame, heatmap, alpha=0.4):
|
| 30 |
+
"""Overlay Grad-CAM heatmap on frame."""
|
| 31 |
+
heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
|
| 32 |
+
heatmap = np.uint8(255 * heatmap)
|
| 33 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 34 |
+
superimposed = cv2.addWeighted(frame, 1 - alpha, heatmap, alpha, 0)
|
| 35 |
+
return superimposed
|
src/train.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision import datasets, transforms
|
| 7 |
+
from src.model import get_model, get_transforms
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.metrics import accuracy_score
|
| 10 |
+
|
| 11 |
+
# Assume data/ has train/ and val/ folders with subfolders for classes: normal, crazing, inclusion, etc.
|
| 12 |
+
# https://www.kaggle.com/datasets/kaustubhdikshit/neu-surface-defect-database
|
| 13 |
+
DATA_DIR = '../data/neu_surface_defect_database'
|
| 14 |
+
BATCH_SIZE = 32
|
| 15 |
+
EPOCHS = 10
|
| 16 |
+
LEARNING_RATE = 0.001
|
| 17 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
transform = get_transforms()
|
| 21 |
+
|
| 22 |
+
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'train'), transform=transform)
|
| 23 |
+
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'val'), transform=transform)
|
| 24 |
+
|
| 25 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 26 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 27 |
+
|
| 28 |
+
model = get_model(pretrained=True).to(DEVICE)
|
| 29 |
+
criterion = nn.CrossEntropyLoss()
|
| 30 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 31 |
+
|
| 32 |
+
for epoch in range(EPOCHS):
|
| 33 |
+
model.train()
|
| 34 |
+
running_loss = 0.0
|
| 35 |
+
for inputs, labels in train_loader:
|
| 36 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
| 37 |
+
optimizer.zero_grad()
|
| 38 |
+
outputs = model(inputs)
|
| 39 |
+
loss = criterion(outputs, labels)
|
| 40 |
+
loss.backward()
|
| 41 |
+
optimizer.step()
|
| 42 |
+
running_loss += loss.item()
|
| 43 |
+
|
| 44 |
+
print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss / len(train_loader):.4f}')
|
| 45 |
+
|
| 46 |
+
# Validation
|
| 47 |
+
model.eval()
|
| 48 |
+
preds, trues = [], []
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
for inputs, labels in val_loader:
|
| 51 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
| 52 |
+
outputs = model(inputs)
|
| 53 |
+
_, predicted = torch.max(outputs, 1)
|
| 54 |
+
preds.extend(predicted.cpu().numpy())
|
| 55 |
+
trues.extend(labels.cpu().numpy())
|
| 56 |
+
|
| 57 |
+
acc = accuracy_score(trues, preds)
|
| 58 |
+
print(f'Validation Accuracy: {acc:.4f}')
|
| 59 |
+
|
| 60 |
+
# Save PyTorch model
|
| 61 |
+
torch.save(model.state_dict(), '../models/resnet18_anomaly.pth')
|
| 62 |
+
|
| 63 |
+
# Export to ONNX
|
| 64 |
+
model.eval()
|
| 65 |
+
dummy_input = torch.randn(1, 3, 224, 224).to(DEVICE)
|
| 66 |
+
torch.onnx.export(model, dummy_input, '../models/resnet18_anomaly.onnx',
|
| 67 |
+
export_params=True, opset_version=11,
|
| 68 |
+
do_constant_folding=True,
|
| 69 |
+
input_names=['input'], output_names=['output'])
|
| 70 |
+
|
| 71 |
+
print('Model trained and exported to ONNX!')
|
| 72 |
+
|
| 73 |
+
if __name__ == '__main__':
|
| 74 |
+
main()
|