Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| # Import the architecture we defined in the pipeline | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from pipeline.preprocessor import LightCNN_Denoiser | |
| class DenoiserDataset(Dataset): | |
| """ | |
| Loads pairs of (Noisy Input -> Clean Output) from the synthetic dataset. | |
| """ | |
| def __init__(self, dataset_dir="dataset"): | |
| self.clean_dir = os.path.join(dataset_dir, "clean") | |
| self.noisy_dir = os.path.join(dataset_dir, "noisy") | |
| self.image_files = os.listdir(self.clean_dir) | |
| self.transform = transforms.Compose([ | |
| # Resize for consistent CNN batching | |
| transforms.Resize((64, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| filename = self.image_files[idx] | |
| clean_img = Image.open(os.path.join(self.clean_dir, filename)).convert("RGB") | |
| noisy_img = Image.open(os.path.join(self.noisy_dir, filename)).convert("RGB") | |
| clean_tensor = self.transform(clean_img) | |
| noisy_tensor = self.transform(noisy_img) | |
| return noisy_tensor, clean_tensor | |
| def train_model(): | |
| print("==================================================") | |
| print("Initializing LightCNN Denoising Training Showcase") | |
| print("==================================================") | |
| # Check for GPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # 1. Load Data | |
| print("Loading synthetic dataset...") | |
| try: | |
| dataset = DenoiserDataset() | |
| dataloader = DataLoader(dataset, batch_size=16, shuffle=True) | |
| except FileNotFoundError: | |
| print("ERROR: Dataset not found. Please run generate_dataset.py first!") | |
| return | |
| # 2. Initialize Model | |
| model = LightCNN_Denoiser().to(device) | |
| criterion = nn.MSELoss() # Measure the difference between pixels | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| epochs = 5 | |
| # 3. Training Loop | |
| print(f"Starting training for {epochs} epochs...") | |
| model.train() | |
| for epoch in range(epochs): | |
| running_loss = 0.0 | |
| for i, (noisy_inputs, clean_targets) in enumerate(dataloader): | |
| noisy_inputs = noisy_inputs.to(device) | |
| clean_targets = clean_targets.to(device) | |
| # Zero gradients | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs = model(noisy_inputs) | |
| # Calculate pixel error | |
| loss = criterion(outputs, clean_targets) | |
| # Backward pass and optimize | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| print(f"Epoch [{epoch+1}/{epochs}] - Loss: {running_loss/len(dataloader):.4f}") | |
| # 4. Save Weights | |
| os.makedirs("weights", exist_ok=True) | |
| save_path = "weights/lightcnn_weights.pth" | |
| torch.save(model.state_dict(), save_path) | |
| print("==================================================") | |
| print(f"Training Complete. Showcase weights saved to: {save_path}") | |
| print("To use this in production, set use_dl_cnn=True in pipeline/preprocessor.py") | |
| if __name__ == "__main__": | |
| train_model() | |