Safetyinspector_AI / train_autoencoder.py
solfedge's picture
Upload 14 files
6340002 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Feature Extractor (ResNet)
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet.eval().to(device)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class ImageDataset(Dataset):
def __init__(self, img_dir):
self.img_dir = img_dir
self.img_names = [f for f in os.listdir(img_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if len(self.img_names) == 0:
raise ValueError(f"No images found in {img_dir}. Please upload images first!")
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_names[idx])
img = Image.open(img_path).convert("RGB")
return transform(img)
#Autoencoder Model
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1,output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1,output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1,output_padding=1),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
#Training
def train_autoencoder(data_dir="data/unlabeled", epochs=10, lr=1e-3):
try:
dataset = ImageDataset(data_dir)
except ValueError as e:
print(e)
print("Tip: Upload images to data/unlabeled/ using the Colab file uploader or download sample images.")
return
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
model = ConvAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
print(f"Training autoencoder on {len(dataset)} images...")
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
batch = batch.to(device)
recon = model(batch)
loss = criterion(recon, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(),"models/autoencoder.pth")
print("Autoencoder saved to models/autoencoder.pth")
if __name__ == "__main__":
train_autoencoder()