INKVISION / training /generate_dataset.py
Eli-Iustus's picture
Upload 320 files
3503b36 verified
import os
import cv2
import numpy as np
def generate_synthetic_data(num_samples=100, output_dir="dataset"):
"""
Generates a small synthetic dataset for training the LightCNN denoiser.
Creates pairs of (Clean Image, Noisy Image) with simulated shadows and pencil faded ink.
"""
clean_dir = os.path.join(output_dir, "clean")
noisy_dir = os.path.join(output_dir, "noisy")
os.makedirs(clean_dir, exist_ok=True)
os.makedirs(noisy_dir, exist_ok=True)
print(f"Generating {num_samples} synthetic training pairs...")
for i in range(num_samples):
# 1. Create a clean digital "handwritten" image
# White background
img = np.ones((128, 512, 3), dtype=np.uint8) * 255
# Draw some random text to simulate handwriting
text = f"Sample Text {np.random.randint(1000, 9999)}"
font = cv2.FONT_HERSHEY_SIMPLEX
thickness = np.random.randint(2, 5)
# Random position
x, y = np.random.randint(10, 50), np.random.randint(50, 90)
cv2.putText(img, text, (x, y), font, 1.5, (0, 0, 0), thickness, cv2.LINE_AA)
# Save the clean Ground Truth (y)
clean_path = os.path.join(clean_dir, f"{i:03d}.jpg")
cv2.imwrite(clean_path, img)
# 2. Add realistic noise to simulate a bad photo (x)
noisy = img.copy()
# Add a random gradient shadow
h, w = noisy.shape[:2]
gradient = np.zeros((h, w, 3), dtype=np.float32)
cv2.rectangle(gradient, (0, 0), (w, h), (np.random.randint(50, 150),)*3, -1)
gradient = cv2.GaussianBlur(gradient, (101, 101), 0)
noisy = cv2.addWeighted(noisy, 0.7, gradient.astype(np.uint8), 0.3, 0)
# Add salt and pepper noise
s_vs_p = 0.5
amount = 0.04
noisy_pixels = np.random.rand(h, w)
# Salt
noisy[noisy_pixels < amount * s_vs_p] = 255
# Pepper
noisy[noisy_pixels > 1 - amount * (1 - s_vs_p)] = 0
# Add a slight blur to simulate bad focus
if np.random.rand() > 0.5:
noisy = cv2.GaussianBlur(noisy, (5, 5), 0)
# Save the dirty input (x)
noisy_path = os.path.join(noisy_dir, f"{i:03d}.jpg")
cv2.imwrite(noisy_path, noisy)
print(f"Dataset generated in '{output_dir}'.")
print(f" Clean labels (y): {clean_dir}")
print(f" Noisy inputs (x): {noisy_dir}")
if __name__ == "__main__":
# Create 150 samples for a quick toy training run
generate_synthetic_data(num_samples=150)