Vision / training /train_denoiser.py
Eli-Iustus's picture
Upload 321 files
2013cf0 verified
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
# Import the architecture we defined in the pipeline
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pipeline.preprocessor import LightCNN_Denoiser
class DenoiserDataset(Dataset):
"""
Loads pairs of (Noisy Input -> Clean Output) from the synthetic dataset.
"""
def __init__(self, dataset_dir="dataset"):
self.clean_dir = os.path.join(dataset_dir, "clean")
self.noisy_dir = os.path.join(dataset_dir, "noisy")
self.image_files = os.listdir(self.clean_dir)
self.transform = transforms.Compose([
# Resize for consistent CNN batching
transforms.Resize((64, 256)),
transforms.ToTensor()
])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
filename = self.image_files[idx]
clean_img = Image.open(os.path.join(self.clean_dir, filename)).convert("RGB")
noisy_img = Image.open(os.path.join(self.noisy_dir, filename)).convert("RGB")
clean_tensor = self.transform(clean_img)
noisy_tensor = self.transform(noisy_img)
return noisy_tensor, clean_tensor
def train_model():
print("==================================================")
print("Initializing LightCNN Denoising Training Showcase")
print("==================================================")
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 1. Load Data
print("Loading synthetic dataset...")
try:
dataset = DenoiserDataset()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
except FileNotFoundError:
print("ERROR: Dataset not found. Please run generate_dataset.py first!")
return
# 2. Initialize Model
model = LightCNN_Denoiser().to(device)
criterion = nn.MSELoss() # Measure the difference between pixels
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5
# 3. Training Loop
print(f"Starting training for {epochs} epochs...")
model.train()
for epoch in range(epochs):
running_loss = 0.0
for i, (noisy_inputs, clean_targets) in enumerate(dataloader):
noisy_inputs = noisy_inputs.to(device)
clean_targets = clean_targets.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(noisy_inputs)
# Calculate pixel error
loss = criterion(outputs, clean_targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch+1}/{epochs}] - Loss: {running_loss/len(dataloader):.4f}")
# 4. Save Weights
os.makedirs("weights", exist_ok=True)
save_path = "weights/lightcnn_weights.pth"
torch.save(model.state_dict(), save_path)
print("==================================================")
print(f"Training Complete. Showcase weights saved to: {save_path}")
print("To use this in production, set use_dl_cnn=True in pipeline/preprocessor.py")
if __name__ == "__main__":
train_model()