zyyyy commited on
Commit
f755531
·
verified ·
1 Parent(s): 468f9cd

Delete src/photosketch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/photosketch/README.md +0 -57
  2. src/photosketch/__init__.py +0 -0
  3. src/photosketch/data/__init__.py +0 -0
  4. src/photosketch/data/__pycache__/__init__.cpython-312.pyc +0 -0
  5. src/photosketch/data/__pycache__/data_loader.cpython-312.pyc +0 -0
  6. src/photosketch/data/base_data_loader.py +0 -14
  7. src/photosketch/data/base_dataset.py +0 -45
  8. src/photosketch/data/custom_dataset_data_loader.py +0 -44
  9. src/photosketch/data/data_loader.py +0 -7
  10. src/photosketch/data/image_folder.py +0 -75
  11. src/photosketch/data/paired_1_to_n_dataset.py +0 -99
  12. src/photosketch/data/rotate_and_crop.py +0 -63
  13. src/photosketch/data/test_dir_dataset.py +0 -38
  14. src/photosketch/examples/00000932.jpg +0 -0
  15. src/photosketch/examples/bedroom.jpg +0 -0
  16. src/photosketch/examples/input_image_vermeer.png +0 -3
  17. src/photosketch/inference_utils.py +0 -109
  18. src/photosketch/models/__init__.py +0 -0
  19. src/photosketch/models/__pycache__/__init__.cpython-312.pyc +0 -0
  20. src/photosketch/models/__pycache__/base_model.cpython-312.pyc +0 -0
  21. src/photosketch/models/__pycache__/models.cpython-312.pyc +0 -0
  22. src/photosketch/models/__pycache__/networks.cpython-312.pyc +0 -0
  23. src/photosketch/models/__pycache__/pix2pix_model.cpython-312.pyc +0 -0
  24. src/photosketch/models/base_model.py +0 -69
  25. src/photosketch/models/models.py +0 -12
  26. src/photosketch/models/networks.py +0 -557
  27. src/photosketch/models/pix2pix_model.py +0 -174
  28. src/photosketch/options/__init__.py +0 -0
  29. src/photosketch/options/__pycache__/__init__.cpython-312.pyc +0 -0
  30. src/photosketch/options/__pycache__/base_options.cpython-312.pyc +0 -0
  31. src/photosketch/options/__pycache__/inference_options.cpython-312.pyc +0 -0
  32. src/photosketch/options/base_options.py +0 -83
  33. src/photosketch/options/inference_options.py +0 -43
  34. src/photosketch/options/opt.json +0 -56
  35. src/photosketch/options/test_options.py +0 -15
  36. src/photosketch/options/train_options.py +0 -31
  37. src/photosketch/pretrained/inference_config.txt +0 -56
  38. src/photosketch/pretrained/latest_net_D.pth +0 -3
  39. src/photosketch/pretrained/latest_net_G.pth +0 -3
  40. src/photosketch/scripts/test.cmd +0 -19
  41. src/photosketch/scripts/test.sh +0 -19
  42. src/photosketch/scripts/test_pretrained.cmd +0 -15
  43. src/photosketch/scripts/test_pretrained.sh +0 -15
  44. src/photosketch/scripts/train.cmd +0 -22
  45. src/photosketch/scripts/train.sh +0 -22
  46. src/photosketch/test.py +0 -37
  47. src/photosketch/test_pretrained.py +0 -29
  48. src/photosketch/train.py +0 -61
  49. src/photosketch/util/.DS_Store +0 -0
  50. src/photosketch/util/__init__.py +0 -0
src/photosketch/README.md DELETED
@@ -1,57 +0,0 @@
1
- # Photo-Sketching: Inferring Contour Drawings from Images
2
-
3
- <p align="center"><img alt="Teaser" src="doc/teaser.jpg"></p>
4
-
5
- This repo contains the training & testing code for our sketch generator. We also provide a [[pre-trained model]](https://drive.google.com/file/d/1TQf-LyS8rRDDapdcTnEgWzYJllPgiXdj/view).
6
-
7
- For technical details and the dataset, please refer to the [**[paper]**](https://arxiv.org/abs/1901.00542) and the [**[project page]**](http://www.cs.cmu.edu/~mengtial/proj/sketch/).
8
-
9
- # Setting up
10
-
11
- The code is now updated to use PyTorch 0.4 and runs on Windows, Mac and Linux. For the obsolete version with PyTorch 0.3 (Linux only), please check out the branch [pytorch-0.3-obsolete](../../tree/pytorch-0.3-obsolete).
12
-
13
- Windows users should find the corresponding `.cmd` files instead of `.sh` files mentioned below.
14
-
15
- ## One-line installation (with Conda environments)
16
- `conda env create -f environment.yml`
17
-
18
- Then activate the environment (sketch) and you are ready to go!
19
-
20
- See [here](https://conda.io/docs/user-guide/tasks/manage-environments.html) for more information about conda environments.
21
-
22
- ## Manual installation
23
- See `environment.yml` for a list of dependencies.
24
-
25
- # Using the pre-trained model
26
-
27
- - Download the [pre-trained model](https://drive.google.com/file/d/1TQf-LyS8rRDDapdcTnEgWzYJllPgiXdj/view)
28
- - Modify the path in `scripts/test_pretrained.sh`
29
- - From the repo's **root directory**, run `scripts/test_pretrained.sh`
30
-
31
- It supports a folder of images as input.
32
-
33
- # Train & test on our contour drawing dataset
34
-
35
- - Download the images and the rendered sketch from the [project page](http://www.cs.cmu.edu/~mengtial/proj/sketch/)
36
- - Unzip and organize them into the following structure:
37
- <p align="center"><img alt="File structure" src="doc/file_structure.png"></p>
38
-
39
- - Modify the path in `scripts/train.sh` and `scripts/test.sh`
40
- - From the repo's **root directory**, run `scripts/train.sh` to train the model
41
- - From the repo's **root directory**, run `scripts/test.sh` to test on the val set or the test set (specified by the phase flag)
42
-
43
- ## Citation
44
- If you use the code or the data for your research, please cite the paper:
45
-
46
- ```
47
- @article{LIPS2019,
48
- title={Photo-Sketching: Inferring Contour Drawings from Images},
49
- author={Li, Mengtian and Lin, Zhe and M\v ech, Radom\'ir and and Yumer, Ersin and Ramanan, Deva},
50
- journal={WACV},
51
- year={2019}
52
- }
53
- ```
54
-
55
- ## Acknowledgement
56
- This code is based on an old version of [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/).
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/__init__.py DELETED
File without changes
src/photosketch/data/__init__.py DELETED
File without changes
src/photosketch/data/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (184 Bytes)
 
src/photosketch/data/__pycache__/data_loader.cpython-312.pyc DELETED
Binary file (567 Bytes)
 
src/photosketch/data/base_data_loader.py DELETED
@@ -1,14 +0,0 @@
1
-
2
- class BaseDataLoader():
3
- def __init__(self):
4
- pass
5
-
6
- def initialize(self, opt):
7
- self.opt = opt
8
- pass
9
-
10
- def load_data():
11
- return None
12
-
13
-
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/base_dataset.py DELETED
@@ -1,45 +0,0 @@
1
- import torch.utils.data as data
2
- from PIL import Image
3
- import torchvision.transforms as transforms
4
-
5
- class BaseDataset(data.Dataset):
6
- def __init__(self):
7
- super(BaseDataset, self).__init__()
8
-
9
- def name(self):
10
- return 'BaseDataset'
11
-
12
- def initialize(self, opt):
13
- pass
14
-
15
- def get_transform(opt):
16
- transform_list = []
17
- if opt.resize_or_crop == 'resize_and_crop':
18
- osize = [opt.loadSize, opt.loadSize]
19
- transform_list.append(transforms.Scale(osize, Image.BICUBIC))
20
- transform_list.append(transforms.RandomCrop(opt.fineSize))
21
- elif opt.resize_or_crop == 'crop':
22
- transform_list.append(transforms.RandomCrop(opt.fineSize))
23
- elif opt.resize_or_crop == 'scale_width':
24
- transform_list.append(transforms.Lambda(
25
- lambda img: __scale_width(img, opt.fineSize)))
26
- elif opt.resize_or_crop == 'scale_width_and_crop':
27
- transform_list.append(transforms.Lambda(
28
- lambda img: __scale_width(img, opt.loadSize)))
29
- transform_list.append(transforms.RandomCrop(opt.fineSize))
30
-
31
- if opt.isTrain and not opt.no_flip:
32
- transform_list.append(transforms.RandomHorizontalFlip())
33
-
34
- transform_list += [transforms.ToTensor(),
35
- transforms.Normalize((0.5, 0.5, 0.5),
36
- (0.5, 0.5, 0.5))]
37
- return transforms.Compose(transform_list)
38
-
39
- def __scale_width(img, target_width):
40
- ow, oh = img.size
41
- if (ow == target_width):
42
- return img
43
- w = target_width
44
- h = int(target_width * oh / ow)
45
- return img.resize((w, h), Image.BICUBIC)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/custom_dataset_data_loader.py DELETED
@@ -1,44 +0,0 @@
1
- import torch.utils.data
2
- from data.base_data_loader import BaseDataLoader
3
-
4
-
5
- def CreateDataset(opt):
6
- dataset = None
7
- if opt.dataset_mode == '1_to_n':
8
- from data.paired_1_to_n_dataset import Paired1ToNDataset
9
- dataset = Paired1ToNDataset()
10
- elif opt.dataset_mode == 'test_dir':
11
- from data.test_dir_dataset import TestDirDataset
12
- dataset = TestDirDataset()
13
- else:
14
- raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
15
-
16
- print("dataset [%s] was created" % (dataset.name()))
17
- dataset.initialize(opt)
18
- return dataset
19
-
20
-
21
- class CustomDatasetDataLoader(BaseDataLoader):
22
- def name(self):
23
- return 'CustomDatasetDataLoader'
24
-
25
- def initialize(self, opt):
26
- BaseDataLoader.initialize(self, opt)
27
- self.dataset = CreateDataset(opt)
28
- self.dataloader = torch.utils.data.DataLoader(
29
- self.dataset,
30
- batch_size=opt.batchSize,
31
- shuffle=not opt.serial_batches,
32
- num_workers=int(opt.nThreads))
33
-
34
- def load_data(self):
35
- return self
36
-
37
- def __len__(self):
38
- return min(len(self.dataset), self.opt.max_dataset_size)
39
-
40
- def __iter__(self):
41
- for i, data in enumerate(self.dataloader):
42
- if i >= self.opt.max_dataset_size:
43
- break
44
- yield data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/data_loader.py DELETED
@@ -1,7 +0,0 @@
1
-
2
- def CreateDataLoader(opt):
3
- from data.custom_dataset_data_loader import CustomDatasetDataLoader
4
- data_loader = CustomDatasetDataLoader()
5
- print(data_loader.name())
6
- data_loader.initialize(opt)
7
- return data_loader
 
 
 
 
 
 
 
 
src/photosketch/data/image_folder.py DELETED
@@ -1,75 +0,0 @@
1
- ###############################################################################
2
- # Code from
3
- # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
- # Modified the original code so that it also loads images from the current
5
- # directory as well as the subdirectories
6
- ###############################################################################
7
-
8
- import torch.utils.data as data
9
-
10
- from PIL import Image
11
- import os
12
- import os.path
13
-
14
- import platform
15
-
16
- if platform.system() == 'Windows':
17
- IMG_EXTENSIONS = [
18
- '.jpg', '.jpeg', '.png', '.ppm', '.bmp',
19
- ]
20
- else:
21
- IMG_EXTENSIONS = [
22
- '.jpg', '.JPG', '.jpeg', '.JPEG',
23
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
24
- ]
25
-
26
-
27
- def is_image_file(filename):
28
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
29
-
30
-
31
- def make_dataset(dir):
32
- images = []
33
- assert os.path.isdir(dir), '%s is not a valid directory' % dir
34
-
35
- for root, _, fnames in sorted(os.walk(dir)):
36
- for fname in fnames:
37
- if is_image_file(fname):
38
- path = os.path.join(root, fname)
39
- images.append(path)
40
-
41
- return images
42
-
43
-
44
- def default_loader(path):
45
- return Image.open(path).convert('RGB')
46
-
47
-
48
- class ImageFolder(data.Dataset):
49
-
50
- def __init__(self, root, transform=None, return_paths=False,
51
- loader=default_loader):
52
- imgs = make_dataset(root)
53
- if len(imgs) == 0:
54
- raise(RuntimeError("Found 0 images in: " + root + "\n"
55
- "Supported image extensions are: " +
56
- ",".join(IMG_EXTENSIONS)))
57
-
58
- self.root = root
59
- self.imgs = imgs
60
- self.transform = transform
61
- self.return_paths = return_paths
62
- self.loader = loader
63
-
64
- def __getitem__(self, index):
65
- path = self.imgs[index]
66
- img = self.loader(path)
67
- if self.transform is not None:
68
- img = self.transform(img)
69
- if self.return_paths:
70
- return img, path
71
- else:
72
- return img
73
-
74
- def __len__(self):
75
- return len(self.imgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/paired_1_to_n_dataset.py DELETED
@@ -1,99 +0,0 @@
1
- import os.path
2
- import random
3
- import torchvision.transforms as transforms
4
- import torch
5
- from data.base_dataset import BaseDataset
6
- from data.image_folder import make_dataset
7
- from PIL import Image
8
- from data.rotate_and_crop import rotate_and_crop
9
-
10
- class Paired1ToNDataset(BaseDataset):
11
- def initialize(self, opt):
12
- self.opt = opt
13
- self.root = opt.dataroot
14
- self.if_aug = opt.phase == 'train'
15
- self.img_dir = os.path.join(self.root, 'image')
16
- self.skt_dir = os.path.join(self.root, opt.render_dir, opt.aug_folder)
17
- list_file = os.path.join(self.root, 'list', opt.phase + '.txt')
18
- with open(list_file) as f:
19
- content = f.readlines()
20
- self.list = sorted([x.strip() for x in content])
21
-
22
- def __getitem__(self, index):
23
- N = self.opt.nGT
24
- filename = self.list[index]
25
-
26
- if_crop_or_rotate = self.if_aug and ((self.opt.crop or self.opt.rotate) and random.random() < 0.8)
27
- # even with augmentation on, 20% chance does nothing
28
- if_flip = self.if_aug and ((not self.opt.no_flip) and random.random() < 0.5)
29
-
30
- pathA = os.path.join(self.img_dir, filename + '.jpg')
31
- img = Image.open(pathA)
32
- fine_size = self.opt.fineSize
33
-
34
- if if_flip:
35
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
36
- if self.if_aug and self.opt.color_jitter:
37
- jitter_amount = 0.02
38
- img = transforms.ColorJitter(jitter_amount, jitter_amount, jitter_amount, jitter_amount)(img)
39
- if if_crop_or_rotate:
40
- load_size = self.opt.loadSize
41
- img = img.resize((load_size, load_size), Image.BICUBIC)
42
- if self.opt.rotate:
43
- rot_deg = 5*random.randint(-3, 3)
44
- img = rotate_and_crop(img, rot_deg, True)
45
- if self.opt.crop:
46
- img = transforms.ToTensor()(img)
47
- w_offset = random.randint(0, max(0, load_size - fine_size - 1))
48
- h_offset = random.randint(0, max(0, load_size - fine_size - 1))
49
- img = img[:, h_offset:h_offset + fine_size,
50
- w_offset:w_offset + fine_size]
51
- else:
52
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
53
- img = transforms.ToTensor()(img)
54
- else:
55
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
56
- img = transforms.ToTensor()(img)
57
-
58
- if self.opt.inverse_gamma:
59
- linear_mask = (img <= 0.04045).float()
60
- exponential_mask = (img > 0.04045).float()
61
- img = (img / 12.92 * linear_mask) + (((img + 0.055) / 1.055) ** 2.4) * exponential_mask
62
-
63
- A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
64
-
65
- B = N*[None]
66
- for i in range(N):
67
- pathB = os.path.join(self.skt_dir, '%s_%02d.png' % (filename, i+1))
68
- img = Image.open(pathB)
69
- if if_flip:
70
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
71
- if if_crop_or_rotate:
72
- load_size = self.opt.loadSize
73
- img = img.resize((load_size, load_size), Image.BICUBIC)
74
- if self.opt.rotate:
75
- img = rotate_and_crop(img, rot_deg, True)
76
- if self.opt.crop:
77
- img = transforms.ToTensor()(img)
78
- img = img[:, h_offset:h_offset + fine_size,
79
- w_offset:w_offset + fine_size]
80
- else:
81
- img.resize((fine_size, fine_size), Image.BICUBIC)
82
- img = transforms.ToTensor()(img)
83
- else:
84
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
85
- img = transforms.ToTensor()(img)
86
-
87
- B[i] = transforms.Normalize((0.5,), (0.5,))(img)
88
-
89
- B = torch.cat(B, 0)
90
-
91
- return {'A': A, 'B': B,
92
- 'A_paths': pathA, 'B_paths': pathB}
93
-
94
- def __len__(self):
95
- return len(self.list)
96
-
97
- def name(self):
98
- return 'Paired1ToNDataset'
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/rotate_and_crop.py DELETED
@@ -1,63 +0,0 @@
1
- import numpy as np
2
- from PIL import Image
3
-
4
- def perp(a):
5
- # https://stackoverflow.com/questions/3252194/numpy-and-line-intersections
6
- b = np.empty_like(a)
7
- b[0] = -a[1]
8
- b[1] = a[0]
9
- return b
10
-
11
- def seg_intersect(a1,a2, b1,b2):
12
- # https://stackoverflow.com/questions/3252194/numpy-and-line-intersections
13
- da = a2-a1
14
- db = b2-b1
15
- dp = a1-b1
16
- dap = perp(da)
17
- denom = np.dot(dap, db)
18
- num = np.dot(dap, dp)
19
- return (num / denom.astype(float))*db + b1
20
-
21
- def rotate_and_crop(img, deg, same_size=False, interp=Image.BICUBIC):
22
- # let the four corners of a rectangle to be ABCD, clockwise
23
- if deg == 0:
24
- return img
25
-
26
- w, h = img.size
27
-
28
- A = np.array([-w/2, h/2])
29
- B = np.array([w/2, h/2])
30
- C = np.array([w/2, -h/2])
31
- D = np.array([-w/2, -h/2])
32
-
33
- rad = np.radians(deg)
34
- c, s = np.cos(rad), np.sin(rad)
35
- R = np.array([[c, -s], [s, c]]).T
36
-
37
- Arot = np.dot(A, R)
38
- Brot = np.dot(B, R)
39
- if deg > 0:
40
- X = seg_intersect(A, C, Arot, Brot)
41
- offset = X - A
42
- offset[1] = -offset[1]
43
- else:
44
- X = seg_intersect(B, D, Arot, Brot)
45
- offset = B - X
46
-
47
- if same_size:
48
- wh_org = np.array([w, h])
49
- wh = np.ceil(np.divide(np.square(wh_org), wh_org - 2*offset)).astype(np.int32)
50
- offset = (wh - wh_org)/2
51
- img = img.resize(wh, interp)
52
- w = wh[0]
53
- h = wh[1]
54
- else:
55
- offset = np.ceil(offset)
56
- img = img.rotate(deg, interp)
57
- return img.crop(
58
- (offset[0],
59
- offset[1],
60
- w - offset[0],
61
- h - offset[1])
62
- )
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/data/test_dir_dataset.py DELETED
@@ -1,38 +0,0 @@
1
- import os.path
2
- from glob import glob
3
- import torchvision.transforms as transforms
4
- import torch
5
- from data.base_dataset import BaseDataset
6
- from data.image_folder import make_dataset, IMG_EXTENSIONS
7
- from PIL import Image
8
-
9
-
10
- class TestDirDataset(BaseDataset):
11
- def initialize(self, opt):
12
- self.opt = opt
13
- self.in_dir = opt.dataroot
14
- if opt.file_name:
15
- self.list = [os.path.join(self.in_dir, opt.file_name)]
16
- else:
17
- self.list = []
18
- for ext in IMG_EXTENSIONS:
19
- self.list.extend(glob(os.path.join(self.in_dir, '*' + ext)))
20
-
21
- def __getitem__(self, index):
22
- file_path = self.list[index]
23
- img = Image.open(file_path)
24
- w, h = img.size
25
- fine_size = self.opt.fineSize
26
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
27
- img = transforms.ToTensor()(img)
28
- A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
29
- B = A.clone()[0]
30
-
31
- return {'A': A, 'B': B, 'A_paths': file_path, 'B_paths': file_path, 'w': w, 'h': h}
32
-
33
- def __len__(self):
34
- return len(self.list)
35
-
36
- def name(self):
37
- return 'TestDirDataset'
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/examples/00000932.jpg DELETED
Binary file (40.9 kB)
 
src/photosketch/examples/bedroom.jpg DELETED
Binary file (22.7 kB)
 
src/photosketch/examples/input_image_vermeer.png DELETED

Git LFS Details

  • SHA256: 22a917b48d85467516de1cbb783a1ace299a6aa320d11ac6a0d4144ab0694978
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
src/photosketch/inference_utils.py DELETED
@@ -1,109 +0,0 @@
1
- import os
2
- import json
3
-
4
- from PIL import Image
5
- from .models.models import create_model
6
- from .data.data_loader import CreateDataLoader
7
- from .options.inference_options import InferenceOptions
8
- from torchvision import transforms
9
- import numpy as np
10
-
11
- class SketchGenerator:
12
-
13
- def __init__(self, config_path, photosketch_dir, results_dir, img_dir, name='pretrained', nThreads=1, batchSize=1, serial_batches=True, no_flip=True):
14
- self.config_path = config_path
15
- self.photosketch_dir = photosketch_dir
16
- self.results_dir = results_dir
17
- self.img_dir = img_dir
18
-
19
- self.opt = self._create_options(name=name, nThreads=nThreads, batchSize=batchSize, serial_batches=serial_batches, no_flip=no_flip)
20
- self.model = create_model(self.opt)
21
-
22
- def img_to_sketch(self, img_path):
23
- input_data = self.preprocess_input(img_path, self.opt)
24
- self.model.set_input(input_data)
25
- img_path = self.model.get_image_paths()
26
- self.model.test()
27
- image_numpy = self.model.fake_B.detach()[0][0].cpu().float().numpy()
28
- image_numpy = (image_numpy + 1) / 2.0 * 255.0
29
- image_pil = Image.fromarray(image_numpy.astype(np.uint8))
30
- image_pil = image_pil.resize((self.model.input_w[0], self.model.input_h[0]), Image.BICUBIC)
31
- return image_pil
32
-
33
- def _read_default_options(self):
34
- with open(self.config_path) as f:
35
- opt = json.load(f)
36
- return opt
37
-
38
- def _create_options(self, name='pretrained', nThreads=1, batchSize=1, serial_batches=True, no_flip=True):
39
- loaded_opt = self._read_default_options()
40
- loaded_opt['checkpoints_dir'] = self.photosketch_dir
41
- loaded_opt['results_dir'] = self.results_dir
42
- loaded_opt['name'] = name
43
- loaded_opt['dataroot'] = self.img_dir
44
- opt = InferenceOptions().parse(loaded_opt)
45
- opt.nThreads = nThreads
46
- opt.batchSize = batchSize
47
- opt.serial_batches = serial_batches
48
- opt.no_flip = no_flip
49
- return opt
50
-
51
- def preprocess_input(self, image_path):
52
- """
53
- Given an image path, preprocess the image to be used as input for the model
54
- """
55
- img = Image.open(image_path)
56
- w, h = img.size
57
- fine_size = self.opt.fineSize
58
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
59
- img = transforms.ToTensor()(img)
60
- A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img).unsqueeze(0)
61
- B = A.clone()[0]
62
- return {'A': A, 'B': B, 'A_paths': [image_path], 'B_paths': [image_path], 'w': [w], 'h': [h]}
63
-
64
- def preprocess_PIL(self, img):
65
- """
66
- Given a PIL image, preprocess the image to be used as input for the model
67
- """
68
- w, h = img.size
69
- fine_size = self.opt.fineSize
70
- img = img.resize((fine_size, fine_size), Image.BICUBIC)
71
- img = transforms.ToTensor()(img)
72
- A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img).unsqueeze(0)
73
- B = A.clone()[0]
74
- return {'A': A, 'B': B, 'A_paths': [''], 'B_paths': [""], 'w': [w], 'h': [h]}
75
-
76
- def get_dataloader(self):
77
- data_loader = CreateDataLoader(self.opt)
78
- dataset = data_loader.load_data()
79
- return dataset
80
-
81
-
82
- def path_to_sketch(self, img_path: str) -> Image:
83
- """
84
- Given a single path to an image, generate a sketch and return the PIL image
85
- """
86
- input_data = self.preprocess_input(img_path)
87
- self.model.set_input(input_data)
88
- img_path = self.model.get_image_paths()
89
- self.model.test()
90
- image_numpy = self.model.fake_B.detach()[0][0].cpu().float().numpy()
91
- image_numpy = (image_numpy + 1) / 2.0 * 255.0
92
- image_pil = Image.fromarray(image_numpy.astype(np.uint8))
93
- image_pil = image_pil.resize((self.model.input_w[0], self.model.input_h[0]), Image.BICUBIC)
94
- return image_pil
95
-
96
- def PIL_to_sketch(self, img: Image) -> Image:
97
- """
98
- Given a single PIL image, generate a sketch and return the PIL image
99
- """
100
- input_data = self.preprocess_PIL(img)
101
- self.model.set_input(input_data)
102
- img_path = self.model.get_image_paths()
103
- self.model.test()
104
- image_numpy = self.model.fake_B.detach()[0][0].cpu().float().numpy()
105
- image_numpy = (image_numpy + 1) / 2.0 * 255.0
106
- image_pil = Image.fromarray(image_numpy.astype(np.uint8))
107
- image_pil = image_pil.resize((self.model.input_w[0], self.model.input_h[0]), Image.BICUBIC)
108
- return image_pil
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/models/__init__.py DELETED
File without changes
src/photosketch/models/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (186 Bytes)
 
src/photosketch/models/__pycache__/base_model.cpython-312.pyc DELETED
Binary file (4.19 kB)
 
src/photosketch/models/__pycache__/models.cpython-312.pyc DELETED
Binary file (786 Bytes)
 
src/photosketch/models/__pycache__/networks.cpython-312.pyc DELETED
Binary file (26.7 kB)
 
src/photosketch/models/__pycache__/pix2pix_model.cpython-312.pyc DELETED
Binary file (12.8 kB)
 
src/photosketch/models/base_model.py DELETED
@@ -1,69 +0,0 @@
1
- import os
2
- import torch
3
-
4
-
5
- class BaseModel():
6
- def name(self):
7
- return 'BaseModel'
8
-
9
- def initialize(self, opt):
10
- self.opt = opt
11
- self.isTrain = opt.isTrain
12
- self.device = torch.device("cuda" if opt.use_cuda else "cpu")
13
- self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
14
-
15
- def set_input(self, input):
16
- self.input = input
17
-
18
- def forward(self):
19
- pass
20
-
21
- # used in test time, no backprop
22
- def test(self):
23
- pass
24
-
25
- def get_image_paths(self):
26
- pass
27
-
28
- def optimize_parameters(self):
29
- pass
30
-
31
- def get_current_visuals(self):
32
- return self.input
33
-
34
- def get_current_errors(self):
35
- return {}
36
-
37
- def save(self, label):
38
- pass
39
-
40
- # helper saving function that can be used by subclasses
41
- def save_network(self, network, network_label, epoch_label):
42
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
43
- save_path = os.path.join(self.save_dir, save_filename)
44
- torch.save(network.cpu().state_dict(), save_path)
45
- network = network.to(self.device)
46
-
47
- # helper loading function that can be used by subclasses
48
- def load_network(self, network, network_label, epoch_label):
49
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
50
- if self.opt.pretrain_path:
51
- save_path = os.path.join(self.opt.pretrain_path, save_filename)
52
- else:
53
- save_path = os.path.join(self.save_dir, save_filename)
54
- network.load_state_dict(torch.load(save_path))
55
-
56
- # update learning rate (called once every epoch)
57
- def update_learning_rate(self):
58
- for scheduler in self.schedulers:
59
- scheduler.step()
60
- lr = self.optimizers[0].param_groups[0]['lr']
61
- print('learning rate = %.7f' % lr)
62
-
63
- def set_requires_grad(self, nets, requires_grad=False):
64
- if not isinstance(nets, list):
65
- nets = [nets]
66
- for net in nets:
67
- if net is not None:
68
- for param in net.parameters():
69
- param.requires_grad = requires_grad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/models/models.py DELETED
@@ -1,12 +0,0 @@
1
-
2
- def create_model(opt):
3
- model = None
4
- print(opt.model)
5
- if opt.model == 'pix2pix':
6
- from .pix2pix_model import Pix2PixModel
7
- model = Pix2PixModel()
8
- else:
9
- raise ValueError("Model [%s] not recognized." % opt.model)
10
- model.initialize(opt)
11
- print("model [%s] was created" % (model.name()))
12
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/models/networks.py DELETED
@@ -1,557 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import init
4
- import functools
5
- from torch.autograd import Variable
6
- from torch.optim import lr_scheduler
7
- import numpy as np
8
- ###############################################################################
9
- # Functions
10
- ###############################################################################
11
-
12
-
13
- def weights_init_normal(m):
14
- classname = m.__class__.__name__
15
- # print(classname)
16
- if classname.find('Conv') != -1:
17
- init.normal_(m.weight.data, 0.0, 0.02)
18
- elif classname.find('Linear') != -1:
19
- init.normal_(m.weight.data, 0.0, 0.02)
20
- elif classname.find('BatchNorm2d') != -1:
21
- init.normal_(m.weight.data, 1.0, 0.02)
22
- init.constant_(m.bias.data, 0.0)
23
-
24
-
25
- def weights_init_xavier(m):
26
- classname = m.__class__.__name__
27
- # print(classname)
28
- if classname.find('Conv') != -1:
29
- init.xavier_normal_(m.weight.data, gain=0.02)
30
- elif classname.find('Linear') != -1:
31
- init.xavier_normal_(m.weight.data, gain=0.02)
32
- elif classname.find('BatchNorm2d') != -1:
33
- init.normal_(m.weight.data, 1.0, 0.02)
34
- init.constant_(m.bias.data, 0.0)
35
-
36
-
37
- def weights_init_kaiming(m):
38
- classname = m.__class__.__name__
39
- # print(classname)
40
- if classname.find('Conv') != -1:
41
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
42
- elif classname.find('Linear') != -1:
43
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
44
- elif classname.find('BatchNorm2d') != -1:
45
- init.normal_(m.weight.data, 1.0, 0.02)
46
- init.constant_(m.bias.data, 0.0)
47
-
48
-
49
- def weights_init_orthogonal(m):
50
- classname = m.__class__.__name__
51
- print(classname)
52
- if classname.find('Conv') != -1:
53
- init.orthogonal_(m.weight.data, gain=1)
54
- elif classname.find('Linear') != -1:
55
- init.orthogonal_(m.weight.data, gain=1)
56
- elif classname.find('BatchNorm2d') != -1:
57
- init.normal_(m.weight.data, 1.0, 0.02)
58
- init.constant_(m.bias.data, 0.0)
59
-
60
-
61
- def init_weights(net, init_type='normal'):
62
- print('initialization method [%s]' % init_type)
63
- if init_type == 'normal':
64
- net.apply(weights_init_normal)
65
- elif init_type == 'xavier':
66
- net.apply(weights_init_xavier)
67
- elif init_type == 'kaiming':
68
- net.apply(weights_init_kaiming)
69
- elif init_type == 'orthogonal':
70
- net.apply(weights_init_orthogonal)
71
- else:
72
- raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
73
-
74
-
75
- def get_norm_layer(norm_type='instance'):
76
- if norm_type == 'batch':
77
- norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
78
- elif norm_type == 'instance':
79
- norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
80
- elif norm_type == 'none':
81
- norm_layer = None
82
- else:
83
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
84
- return norm_layer
85
-
86
-
87
- def get_scheduler(optimizer, opt):
88
- if opt.lr_policy == 'lambda':
89
- def lambda_rule(epoch):
90
- lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
91
- return lr_l
92
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
93
- elif opt.lr_policy == 'step':
94
- scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
95
- elif opt.lr_policy == 'plateau':
96
- scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
97
- else:
98
- return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
99
- return scheduler
100
-
101
-
102
- def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal'):
103
- netG = None
104
- norm_layer = get_norm_layer(norm_type=norm)
105
-
106
- if which_model_netG == 'resnet_9blocks':
107
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
108
- elif which_model_netG == 'resnet_6blocks':
109
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
110
- elif which_model_netG == 'unet_128':
111
- netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
112
- elif which_model_netG == 'unet_256':
113
- netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
114
- else:
115
- raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
116
-
117
- init_weights(netG, init_type=init_type)
118
- return netG
119
-
120
-
121
- def define_D(input_nc, ndf, which_model_netD,
122
- n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal'):
123
- netD = None
124
- norm_layer = get_norm_layer(norm_type=norm)
125
-
126
- if which_model_netD == 'basic':
127
- netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
128
- elif which_model_netD == 'n_layers':
129
- netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
130
- elif which_model_netD == 'pixel':
131
- netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
132
- elif which_model_netD == 'global':
133
- netD = GlobalDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
134
- elif which_model_netD == 'global_np':
135
- netD = GlobalNPDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
136
- else:
137
- raise NotImplementedError('Discriminator model name [%s] is not recognized' %
138
- which_model_netD)
139
- init_weights(netD, init_type=init_type)
140
- return netD
141
-
142
-
143
- def print_network(net):
144
- num_params = 0
145
- for param in net.parameters():
146
- num_params += param.numel()
147
- print(net)
148
- print('Total number of parameters: %d' % num_params)
149
-
150
-
151
- ##############################################################################
152
- # Classes
153
- ##############################################################################
154
-
155
-
156
- # Defines the GAN loss which uses either LSGAN or the regular GAN.
157
- # When LSGAN is used, it is basically same as MSELoss,
158
- # but it abstracts away the need to create the target label tensor
159
- # that has the same size as the input
160
- class GANLoss(nn.Module):
161
- def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
162
- device='cpu'):
163
- super(GANLoss, self).__init__()
164
- self.device = device
165
- self.real_label = target_real_label
166
- self.fake_label = target_fake_label
167
- self.real_label_var = None
168
- self.fake_label_var = None
169
- if use_lsgan:
170
- self.loss = nn.MSELoss().to(device)
171
- else:
172
- self.loss = nn.BCELoss().to(device)
173
-
174
- def get_target_tensor(self, input, target_is_real):
175
- target_tensor = None
176
- if target_is_real:
177
- create_label = ((self.real_label_var is None) or
178
- (self.real_label_var.numel() != input.numel()))
179
- if create_label:
180
- self.real_label_var = torch.full(input.size(), self.real_label, requires_grad=False, device=self.device)
181
- target_tensor = self.real_label_var
182
- else:
183
- create_label = ((self.fake_label_var is None) or
184
- (self.fake_label_var.numel() != input.numel()))
185
- if create_label:
186
- self.fake_label_var = torch.full(input.size(), self.fake_label, requires_grad=False, device=self.device)
187
- target_tensor = self.fake_label_var
188
- return target_tensor
189
-
190
- def __call__(self, input, target_is_real):
191
- target_tensor = self.get_target_tensor(input, target_is_real)
192
- return self.loss(input, target_tensor)
193
-
194
-
195
- # Defines the generator that consists of Resnet blocks between a few
196
- # downsampling/upsampling operations.
197
- # Code and idea originally from Justin Johnson's architecture.
198
- # https://github.com/jcjohnson/fast-neural-style/
199
- class ResnetGenerator(nn.Module):
200
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
201
- assert(n_blocks >= 0)
202
- super(ResnetGenerator, self).__init__()
203
- self.input_nc = input_nc
204
- self.output_nc = output_nc
205
- self.ngf = ngf
206
- if type(norm_layer) == functools.partial:
207
- use_bias = norm_layer.func == nn.InstanceNorm2d
208
- else:
209
- use_bias = norm_layer == nn.InstanceNorm2d
210
-
211
- model = [nn.ReflectionPad2d(3),
212
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
213
- bias=use_bias),
214
- norm_layer(ngf),
215
- nn.ReLU(True)]
216
-
217
- n_downsampling = 2
218
- for i in range(n_downsampling):
219
- mult = 2**i
220
- model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
221
- stride=2, padding=1, bias=use_bias),
222
- norm_layer(ngf * mult * 2),
223
- nn.ReLU(True)]
224
-
225
- mult = 2**n_downsampling
226
- for i in range(n_blocks):
227
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
228
-
229
- for i in range(n_downsampling):
230
- mult = 2**(n_downsampling - i)
231
- model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
232
- kernel_size=3, stride=2,
233
- padding=1, output_padding=1,
234
- bias=use_bias),
235
- norm_layer(int(ngf * mult / 2)),
236
- nn.ReLU(True)]
237
- model += [nn.ReflectionPad2d(3)]
238
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
239
- model += [nn.Tanh()]
240
-
241
- self.model = nn.Sequential(*model)
242
-
243
- def forward(self, input):
244
- return self.model(input)
245
-
246
-
247
- # Define a resnet block
248
- class ResnetBlock(nn.Module):
249
- def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
250
- super(ResnetBlock, self).__init__()
251
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
252
-
253
- def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
254
- conv_block = []
255
- p = 0
256
- if padding_type == 'reflect':
257
- conv_block += [nn.ReflectionPad2d(1)]
258
- elif padding_type == 'replicate':
259
- conv_block += [nn.ReplicationPad2d(1)]
260
- elif padding_type == 'zero':
261
- p = 1
262
- else:
263
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
264
-
265
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
266
- norm_layer(dim),
267
- nn.ReLU(True)]
268
- if use_dropout:
269
- conv_block += [nn.Dropout(0.5)]
270
- else:
271
- conv_block += [nn.Dropout(0)]
272
-
273
- p = 0
274
- if padding_type == 'reflect':
275
- conv_block += [nn.ReflectionPad2d(1)]
276
- elif padding_type == 'replicate':
277
- conv_block += [nn.ReplicationPad2d(1)]
278
- elif padding_type == 'zero':
279
- p = 1
280
- else:
281
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
282
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
283
- norm_layer(dim)]
284
-
285
- return nn.Sequential(*conv_block)
286
-
287
- def forward(self, x):
288
- out = x + self.conv_block(x)
289
- return out
290
-
291
- # Defines the Unet generator.
292
- # |num_downs|: number of downsamplings in UNet. For example,
293
- # if |num_downs| == 7, image of size 128x128 will become of size 1x1
294
- # at the bottleneck
295
- class UnetGenerator(nn.Module):
296
- def __init__(self, input_nc, output_nc, num_downs, ngf=64,
297
- norm_layer=nn.BatchNorm2d, use_dropout=False):
298
- super(UnetGenerator, self).__init__()
299
-
300
- # construct unet structure
301
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
302
- for i in range(num_downs - 5):
303
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
304
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
305
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
306
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
307
- unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
308
-
309
- self.model = unet_block
310
-
311
- def forward(self, input):
312
- return self.model(input)
313
-
314
-
315
- # Defines the submodule with skip connection.
316
- # X -------------------identity---------------------- X
317
- # |-- downsampling -- |submodule| -- upsampling --|
318
- class UnetSkipConnectionBlock(nn.Module):
319
- def __init__(self, outer_nc, inner_nc, input_nc=None,
320
- submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
321
- super(UnetSkipConnectionBlock, self).__init__()
322
- self.outermost = outermost
323
- if type(norm_layer) == functools.partial:
324
- use_bias = norm_layer.func == nn.InstanceNorm2d
325
- else:
326
- use_bias = norm_layer == nn.InstanceNorm2d
327
- if input_nc is None:
328
- input_nc = outer_nc
329
- downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
330
- stride=2, padding=1, bias=use_bias)
331
- downrelu = nn.LeakyReLU(0.2, True)
332
- downnorm = norm_layer(inner_nc)
333
- uprelu = nn.ReLU(True)
334
- upnorm = norm_layer(outer_nc)
335
-
336
- if outermost:
337
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
338
- kernel_size=4, stride=2,
339
- padding=1)
340
- down = [downconv]
341
- up = [uprelu, upconv, nn.Tanh()]
342
- model = down + [submodule] + up
343
- elif innermost:
344
- upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
345
- kernel_size=4, stride=2,
346
- padding=1, bias=use_bias)
347
- down = [downrelu, downconv]
348
- up = [uprelu, upconv, upnorm]
349
- model = down + up
350
- else:
351
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
352
- kernel_size=4, stride=2,
353
- padding=1, bias=use_bias)
354
- down = [downrelu, downconv, downnorm]
355
- up = [uprelu, upconv, upnorm]
356
-
357
- if use_dropout:
358
- model = down + [submodule] + up + [nn.Dropout(0.5)]
359
- else:
360
- model = down + [submodule] + up
361
-
362
- self.model = nn.Sequential(*model)
363
-
364
- def forward(self, x):
365
- if self.outermost:
366
- return self.model(x)
367
- else:
368
- return torch.cat([x, self.model(x)], 1)
369
-
370
-
371
- # Defines the PatchGAN discriminator with the specified arguments.
372
- class NLayerDiscriminator(nn.Module):
373
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
374
- super(NLayerDiscriminator, self).__init__()
375
- if type(norm_layer) == functools.partial:
376
- use_bias = norm_layer.func == nn.InstanceNorm2d
377
- else:
378
- use_bias = norm_layer == nn.InstanceNorm2d
379
-
380
- # 256
381
- kw = 4
382
- padw = 1
383
- sequence = [
384
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
385
- nn.LeakyReLU(0.2, True)
386
- ]
387
-
388
- # 128
389
-
390
- nf_mult = 1
391
- nf_mult_prev = 1
392
- for n in range(1, n_layers):
393
- nf_mult_prev = nf_mult
394
- nf_mult = min(2**n, 8)
395
- sequence += [
396
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
397
- kernel_size=kw, stride=2, padding=padw, bias=use_bias),
398
- norm_layer(ndf * nf_mult),
399
- nn.LeakyReLU(0.2, True)
400
- ]
401
- # 64
402
- # 32
403
-
404
- nf_mult_prev = nf_mult
405
- nf_mult = min(2**n_layers, 8)
406
- sequence += [
407
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
408
- kernel_size=kw, stride=1, padding=padw, bias=use_bias),
409
- norm_layer(ndf * nf_mult),
410
- nn.LeakyReLU(0.2, True)
411
- ]
412
- # 31
413
-
414
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
415
- # 30
416
-
417
- if use_sigmoid:
418
- sequence += [nn.Sigmoid()]
419
-
420
- self.model = nn.Sequential(*sequence)
421
-
422
- def forward(self, input):
423
- return self.model(input)
424
-
425
- class GlobalDiscriminator(nn.Module):
426
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
427
- super(GlobalDiscriminator, self).__init__()
428
-
429
- if type(norm_layer) == functools.partial:
430
- use_bias = norm_layer.func == nn.InstanceNorm2d
431
- else:
432
- use_bias = norm_layer == nn.InstanceNorm2d
433
-
434
- # 256
435
- kw = 4
436
- padw = 1
437
- sequence = [
438
- nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
439
- nn.LeakyReLU(0.2, True)
440
- ]
441
-
442
- # 128
443
- nf_mult = 1
444
- nf_mult_prev = 1
445
- for n in range(1, n_layers):
446
- nf_mult_prev = nf_mult
447
- nf_mult = min(2**n, 8)
448
- sequence += [
449
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
450
- kernel_size=kw, stride=2, padding=padw, bias=use_bias),
451
- norm_layer(ndf * nf_mult),
452
- nn.LeakyReLU(0.2, True)
453
- ]
454
- # 64
455
- # 32
456
-
457
- nf_mult_prev = nf_mult
458
- nf_mult = min(2**n_layers, 8)
459
- sequence += [
460
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
461
- kernel_size=kw, stride=2, padding=padw, bias=use_bias),
462
- norm_layer(ndf * nf_mult),
463
- nn.LeakyReLU(0.2, True)
464
- ]
465
- # 16
466
-
467
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=2, padding=0)]
468
- sequence += [nn.Conv2d(1, 1, kernel_size=7, stride=1, padding=0)]
469
-
470
- if use_sigmoid:
471
- sequence += [nn.Sigmoid()]
472
-
473
- self.model = nn.Sequential(*sequence)
474
-
475
- def forward(self, input):
476
- return self.model(input)
477
-
478
- class GlobalNPDiscriminator(nn.Module):
479
- # no padding
480
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
481
- super(GlobalNPDiscriminator, self).__init__()
482
-
483
- if type(norm_layer) == functools.partial:
484
- use_bias = norm_layer.func == nn.InstanceNorm2d
485
- else:
486
- use_bias = norm_layer == nn.InstanceNorm2d
487
-
488
- # 256
489
- kw = [8, 3, 4]
490
- padw = 0
491
- sequence = [
492
- nn.Conv2d(input_nc, ndf, kernel_size=kw[0], stride=2, padding=padw),
493
- nn.LeakyReLU(0.2, True)
494
- ]
495
- # 125
496
- nf_mult = 1
497
- nf_mult_prev = 1
498
- for n in range(1, n_layers):
499
- nf_mult_prev = nf_mult
500
- nf_mult = min(2**n, 8)
501
- sequence += [
502
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
503
- kernel_size=kw[n], stride=2, padding=padw, bias=use_bias),
504
- norm_layer(ndf * nf_mult),
505
- nn.LeakyReLU(0.2, True)
506
- ]
507
- # 62
508
- # 30
509
-
510
- nf_mult_prev = nf_mult
511
- nf_mult = min(2**n_layers, 8)
512
- sequence += [
513
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
514
- kernel_size=4, stride=2, padding=padw, bias=use_bias),
515
- norm_layer(ndf * nf_mult),
516
- nn.LeakyReLU(0.2, True)
517
- ]
518
- # 14
519
-
520
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=2, padding=0)]
521
- # 6
522
- sequence += [nn.Conv2d(1, 1, kernel_size=6, stride=1, padding=0, bias=use_bias)]
523
- # 1
524
-
525
- if use_sigmoid:
526
- sequence += [nn.Sigmoid()]
527
-
528
- self.model = nn.Sequential(*sequence)
529
-
530
- def forward(self, input):
531
- return self.model(input)
532
-
533
- class PixelDiscriminator(nn.Module):
534
- def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
535
- super(PixelDiscriminator, self).__init__()
536
-
537
- if type(norm_layer) == functools.partial:
538
- use_bias = norm_layer.func == nn.InstanceNorm2d
539
- else:
540
- use_bias = norm_layer == nn.InstanceNorm2d
541
-
542
- self.net = [
543
- nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
544
- nn.LeakyReLU(0.2, True),
545
- nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
546
- norm_layer(ndf * 2),
547
- nn.LeakyReLU(0.2, True),
548
- nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
549
-
550
- if use_sigmoid:
551
- self.net.append(nn.Sigmoid())
552
-
553
- self.net = nn.Sequential(*self.net)
554
-
555
- def forward(self, input):
556
- return self.net(input)
557
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/models/pix2pix_model.py DELETED
@@ -1,174 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import os
4
- from collections import OrderedDict
5
- from torch.autograd import Variable
6
- from ..util import util
7
- from ..util.image_pool import ImagePool
8
- from .base_model import BaseModel
9
- from . import networks
10
- from PIL import Image
11
-
12
-
13
- class Pix2PixModel(BaseModel):
14
- def name(self):
15
- return 'Pix2PixModel'
16
-
17
- def initialize(self, opt):
18
- BaseModel.initialize(self, opt)
19
- self.isTrain = opt.isTrain
20
-
21
- # load/define networks
22
- self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
23
- opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type)
24
-
25
- if self.isTrain:
26
- use_sigmoid = opt.no_lsgan
27
- self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
28
- opt.which_model_netD,
29
- opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type)
30
-
31
- if not self.isTrain or opt.continue_train:
32
- self.load_network(self.netG, 'G', opt.which_epoch)
33
- if self.isTrain:
34
- self.load_network(self.netD, 'D', opt.which_epoch)
35
-
36
- self.netG = self.netG.to(self.device)
37
- if self.isTrain:
38
- self.netD = self.netD.to(self.device)
39
-
40
- if self.isTrain:
41
- self.fake_AB_pool = ImagePool(opt.pool_size)
42
- self.old_lr = opt.lr
43
- # define loss functions
44
- self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, device=self.device)
45
- self.criterionL1 = torch.nn.L1Loss().to(self.device)
46
-
47
- # initialize optimizers
48
- self.schedulers = []
49
- self.optimizers = []
50
- self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
51
- lr=opt.lr, betas=(opt.beta1, 0.999))
52
- self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
53
- lr=opt.lr, betas=(opt.beta1, 0.999))
54
- self.optimizers.append(self.optimizer_G)
55
- self.optimizers.append(self.optimizer_D)
56
- for optimizer in self.optimizers:
57
- self.schedulers.append(networks.get_scheduler(optimizer, opt))
58
-
59
- print('---------- Networks initialized -------------')
60
- networks.print_network(self.netG)
61
- if self.isTrain:
62
- networks.print_network(self.netD)
63
- print('-----------------------------------------------')
64
-
65
- def set_input(self, input):
66
- AtoB = self.opt.which_direction == 'AtoB'
67
- self.input_A = input['A' if AtoB else 'B'].to(self.device)
68
- self.input_B = input['B' if AtoB else 'A'].to(self.device)
69
- self.image_paths = input['A_paths' if AtoB else 'B_paths']
70
- if 'w' in input:
71
- self.input_w = input['w']
72
- if 'h' in input:
73
- self.input_h = input['h']
74
-
75
- def forward(self):
76
- self.real_A = self.input_A
77
- self.fake_B = self.netG(self.real_A)
78
- self.real_B = self.input_B
79
-
80
- # no backprop gradients
81
- def test(self):
82
- with torch.no_grad():
83
- self.forward()
84
-
85
- # get image paths
86
- def get_image_paths(self):
87
- return self.image_paths
88
-
89
- def backward_D(self):
90
- # Fake
91
- # stop backprop to the generator by detaching fake_B
92
- fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).detach())
93
- pred_fake = self.netD(fake_AB.detach())
94
- self.loss_D_fake = self.criterionGAN(pred_fake, False)
95
-
96
- # Real
97
- n = self.real_B.shape[1]
98
- loss_D_real_set = torch.empty(n, device=self.device)
99
- for i in range(n):
100
- sel_B = self.real_B[:, i, :, :].unsqueeze(1)
101
- real_AB = torch.cat((self.real_A, sel_B), 1)
102
- pred_real = self.netD(real_AB)
103
- loss_D_real_set[i] = self.criterionGAN(pred_real, True)
104
- self.loss_D_real = torch.mean(loss_D_real_set)
105
-
106
- # Combined loss
107
- self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 * self.opt.lambda_G
108
-
109
- self.loss_D.backward()
110
-
111
- def backward_G(self):
112
- # First, G(A) should fake the discriminator
113
- fake_AB = torch.cat((self.real_A, self.fake_B), 1)
114
- pred_fake = self.netD(fake_AB)
115
- self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.lambda_G
116
-
117
- # Second, G(A) = B
118
- n = self.real_B.shape[1]
119
- fake_B_expand = self.fake_B.expand(-1, n, -1, -1)
120
- L1 = torch.abs(fake_B_expand - self.real_B)
121
- L1 = L1.view(-1, n, self.real_B.shape[2]*self.real_B.shape[3])
122
- L1 = torch.mean(L1, 2)
123
- min_L1, min_idx = torch.min(L1, 1)
124
- self.loss_G_L1 = torch.mean(min_L1) * self.opt.lambda_A
125
- self.min_idx = min_idx
126
-
127
- self.loss_G = self.loss_G_GAN + self.loss_G_L1
128
-
129
- self.loss_G.backward()
130
-
131
- def optimize_parameters(self):
132
- self.forward()
133
- # update D
134
- self.set_requires_grad(self.netD, True)
135
- self.optimizer_D.zero_grad()
136
- self.backward_D()
137
- self.optimizer_D.step()
138
-
139
- # update G
140
- self.set_requires_grad(self.netD, False)
141
- self.optimizer_G.zero_grad()
142
- self.backward_G()
143
- self.optimizer_G.step()
144
-
145
- def get_current_errors(self):
146
- return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
147
- ('G_L1', self.loss_G_L1.item()),
148
- ('D_real', self.loss_D_real.item()),
149
- ('D_fake', self.loss_D_fake.item())
150
- ])
151
-
152
- def get_current_visuals(self):
153
- real_A = util.tensor2im(self.real_A.detach())
154
- fake_B = util.tensor2im(self.fake_B.detach())
155
- if self.isTrain:
156
- sel_B = self.real_B[:, self.min_idx[0], :, :]
157
- else:
158
- sel_B = self.real_B[:, 0, :, :]
159
- real_B = util.tensor2im(sel_B.unsqueeze(1).detach())
160
- return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
161
-
162
- def save(self, label):
163
- self.save_network(self.netG, 'G', label)
164
- self.save_network(self.netD, 'D', label)
165
-
166
- def write_image(self, out_dir):
167
- image_numpy = self.fake_B.detach()[0][0].cpu().float().numpy()
168
- image_numpy = (image_numpy + 1) / 2.0 * 255.0
169
- image_pil = Image.fromarray(image_numpy.astype(np.uint8))
170
- image_pil = image_pil.resize((self.input_w[0], self.input_h[0]), Image.BICUBIC)
171
- name, _ = os.path.splitext(os.path.basename(self.image_paths[0]))
172
- out_path = os.path.join(out_dir, name + self.opt.suffix + '.png')
173
- image_pil.save(out_path)
174
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/options/__init__.py DELETED
File without changes
src/photosketch/options/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (187 Bytes)
 
src/photosketch/options/__pycache__/base_options.cpython-312.pyc DELETED
Binary file (9.53 kB)
 
src/photosketch/options/__pycache__/inference_options.cpython-312.pyc DELETED
Binary file (4.04 kB)
 
src/photosketch/options/base_options.py DELETED
@@ -1,83 +0,0 @@
1
- import argparse
2
- import os
3
- from ..util import util
4
- import torch
5
-
6
-
7
- class BaseOptions():
8
- def __init__(self):
9
- self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10
- self.initialized = False
11
-
12
- def initialize(self):
13
- self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
14
- self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
15
- self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
16
- self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
17
- self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
18
- self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
19
- self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
20
- self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
21
- self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
22
- self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
23
- self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
24
- self.parser.add_argument('--no-cuda', action='store_true', default=False, help='disable CUDA training (please use CUDA_VISIBLE_DEVICES to select GPU)')
25
- self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
26
- self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
27
- self.parser.add_argument('--model', type=str, default='cycle_gan',
28
- help='chooses which model to use. cycle_gan, pix2pix, test')
29
- self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
30
- self.parser.add_argument('--nThreads', default=6, type=int, help='# threads for loading data')
31
- self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
32
- self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
33
- self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
34
- self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
35
- self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
36
- self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
37
- self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
38
- self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
39
- self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
40
- self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
41
- self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
42
- self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
43
- self.parser.add_argument('--render_dir', type=str, default='sketch-rendered')
44
- self.parser.add_argument('--aug_folder', type=str, default='width-5')
45
- self.parser.add_argument('--stroke_dir', type=str, default='')
46
- self.parser.add_argument('--crop', action='store_true')
47
- self.parser.add_argument('--rotate', action='store_true')
48
- self.parser.add_argument('--color_jitter', action='store_true')
49
- self.parser.add_argument('--stroke_no_couple', action='store_true', help='')
50
- self.parser.add_argument('--pretrain_path', type=str, default='')
51
- self.parser.add_argument('--nGT', type=int, default=5)
52
- self.parser.add_argument('--rot_int_max', type=int, default=3)
53
- self.parser.add_argument('--jitter_amount', type=float, default=0.02)
54
- self.parser.add_argument('--inverse_gamma', action='store_true')
55
- self.parser.add_argument('--img_mean', type=float, nargs='+')
56
- self.parser.add_argument('--img_std', type=float, nargs='+')
57
- self.parser.add_argument('--lst_file', type=str)
58
- self.initialized = True
59
-
60
- def parse(self):
61
- if not self.initialized:
62
- self.initialize()
63
- self.opt = self.parser.parse_args()
64
- self.opt.isTrain = self.isTrain # train or test
65
-
66
- self.opt.use_cuda = not self.opt.no_cuda and torch.cuda.is_available()
67
- args = vars(self.opt)
68
-
69
- print('------------ Options -------------')
70
- for k, v in sorted(args.items()):
71
- print('%s: %s' % (str(k), str(v)))
72
- print('-------------- End ----------------')
73
-
74
- # save to the disk
75
- expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
76
- util.mkdirs(expr_dir)
77
- file_name = os.path.join(expr_dir, 'opt.txt')
78
- with open(file_name, 'wt') as opt_file:
79
- opt_file.write('------------ Options -------------\n')
80
- for k, v in sorted(args.items()):
81
- opt_file.write('%s: %s\n' % (str(k), str(v)))
82
- opt_file.write('-------------- End ----------------\n')
83
- return self.opt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/options/inference_options.py DELETED
@@ -1,43 +0,0 @@
1
- from .base_options import BaseOptions
2
- from argparse import Namespace
3
- import os
4
- from ..util import util
5
- import torch
6
-
7
-
8
- class InferenceOptions(BaseOptions):
9
- def initialize(self):
10
- BaseOptions.initialize(self)
11
- self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
12
- self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
- self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
14
- self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15
- self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
16
- self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
17
- self.parser.add_argument('--file_name', type=str, default='')
18
- self.parser.add_argument('--suffix', type=str, default='')
19
- self.isTrain = False
20
-
21
- def parse(self, args_dict):
22
- self.opt = Namespace(**args_dict)
23
- self.opt.max_dataset_size = float(self.opt.max_dataset_size)
24
- self.isTrain = False
25
- self.opt.isTrain = self.isTrain
26
- self.opt.use_cuda = not self.opt.no_cuda and torch.cuda.is_available()
27
- args = vars(self.opt)
28
-
29
- print('------------ Options -------------')
30
- for k, v in sorted(args.items()):
31
- print('%s: %s' % (str(k), str(v)))
32
- print('-------------- End ----------------')
33
-
34
- # save to the disk
35
- expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
36
- util.mkdirs(expr_dir)
37
- file_name = os.path.join(expr_dir, 'inference_config.txt')
38
- with open(file_name, 'wt') as opt_file:
39
- opt_file.write('------------ Options -------------\n')
40
- for k, v in sorted(args.items()):
41
- opt_file.write('%s: %s\n' % (str(k), str(v)))
42
- opt_file.write('-------------- End ----------------\n')
43
- return self.opt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/options/opt.json DELETED
@@ -1,56 +0,0 @@
1
- {
2
- "aspect_ratio": 1.0,
3
- "aug_folder": "width-5",
4
- "batchSize": 1,
5
- "checkpoints_dir": "",
6
- "color_jitter": false,
7
- "crop": false,
8
- "dataroot": "examples/",
9
- "dataset_mode": "test_dir",
10
- "display_id": 1,
11
- "display_port": 8097,
12
- "display_server": "http://localhost",
13
- "display_winsize": 256,
14
- "file_name": "",
15
- "fineSize": 256,
16
- "how_many": 50,
17
- "img_mean": null,
18
- "img_std": null,
19
- "init_type": "normal",
20
- "input_nc": 3,
21
- "inverse_gamma": false,
22
- "isTrain": false,
23
- "jitter_amount": 0.02,
24
- "loadSize": 286,
25
- "lst_file": null,
26
- "max_dataset_size": "inf",
27
- "model": "pix2pix",
28
- "nGT": 5,
29
- "nThreads": 6,
30
- "n_layers_D": 3,
31
- "name": "",
32
- "ndf": 64,
33
- "ngf": 64,
34
- "no_cuda": false,
35
- "no_dropout": true,
36
- "no_flip": false,
37
- "norm": "batch",
38
- "ntest": "inf",
39
- "output_nc": 1,
40
- "phase": "test",
41
- "pretrain_path": "",
42
- "render_dir": "sketch-rendered",
43
- "resize_or_crop": "resize_and_crop",
44
- "results_dir": "",
45
- "rot_int_max": 3,
46
- "rotate": false,
47
- "serial_batches": false,
48
- "stroke_dir": "",
49
- "stroke_no_couple": false,
50
- "suffix": "",
51
- "use_cuda": true,
52
- "which_direction": "AtoB",
53
- "which_epoch": "latest",
54
- "which_model_netD": "basic",
55
- "which_model_netG": "resnet_9blocks"
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/options/test_options.py DELETED
@@ -1,15 +0,0 @@
1
- from .base_options import BaseOptions
2
-
3
-
4
- class TestOptions(BaseOptions):
5
- def initialize(self):
6
- BaseOptions.initialize(self)
7
- self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8
- self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9
- self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10
- self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11
- self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12
- self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
13
- self.parser.add_argument('--file_name', type=str, default='')
14
- self.parser.add_argument('--suffix', type=str, default='')
15
- self.isTrain = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/options/train_options.py DELETED
@@ -1,31 +0,0 @@
1
- from .base_options import BaseOptions
2
-
3
-
4
- class TrainOptions(BaseOptions):
5
- def initialize(self):
6
- BaseOptions.initialize(self)
7
- self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
8
- self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9
- self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
10
- self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11
- self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
12
- self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
13
- self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
14
- self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
15
- self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
16
- self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
17
- self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
18
- self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
19
- self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
20
- self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
21
- self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
22
- self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
23
- self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
24
- self.parser.add_argument('--lambda_stroke', type=float, default=100, help='weight for stroke generation')
25
- self.parser.add_argument('--lambda_G', type=float, default=1, help='weight for GAN')
26
- self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
27
- self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
28
- self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
29
- self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
30
- self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
31
- self.isTrain = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/pretrained/inference_config.txt DELETED
@@ -1,56 +0,0 @@
1
- ------------ Options -------------
2
- aspect_ratio: 1.0
3
- aug_folder: width-5
4
- batchSize: 1
5
- checkpoints_dir: src/photosketch
6
- color_jitter: False
7
- crop: False
8
- dataroot: src/photosketch/inputs
9
- dataset_mode: test_dir
10
- display_id: 1
11
- display_port: 8097
12
- display_server: http://localhost
13
- display_winsize: 256
14
- file_name:
15
- fineSize: 256
16
- how_many: 50
17
- img_mean: None
18
- img_std: None
19
- init_type: normal
20
- input_nc: 3
21
- inverse_gamma: False
22
- isTrain: False
23
- jitter_amount: 0.02
24
- loadSize: 286
25
- lst_file: None
26
- max_dataset_size: inf
27
- model: pix2pix
28
- nGT: 5
29
- nThreads: 6
30
- n_layers_D: 3
31
- name: pretrained
32
- ndf: 64
33
- ngf: 64
34
- no_cuda: False
35
- no_dropout: True
36
- no_flip: False
37
- norm: batch
38
- ntest: inf
39
- output_nc: 1
40
- phase: test
41
- pretrain_path:
42
- render_dir: sketch-rendered
43
- resize_or_crop: resize_and_crop
44
- results_dir: src/photosketch/outputs
45
- rot_int_max: 3
46
- rotate: False
47
- serial_batches: False
48
- stroke_dir:
49
- stroke_no_couple: False
50
- suffix:
51
- use_cuda: True
52
- which_direction: AtoB
53
- which_epoch: latest
54
- which_model_netD: basic
55
- which_model_netG: resnet_9blocks
56
- -------------- End ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/pretrained/latest_net_D.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b40d1a853a9b2a4b4f1b618f095728651dc8c119dfa093eb84ee7884dfc2d44
3
- size 10896061
 
 
 
 
src/photosketch/pretrained/latest_net_G.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:442a5cf113b4b544698c1ea71dff20890685b821d0ae853b4442f9ed2fb98f5b
3
- size 45566034
 
 
 
 
src/photosketch/scripts/test.cmd DELETED
@@ -1,19 +0,0 @@
1
- @set "dataDir="
2
-
3
- python test.py ^
4
- --name default ^
5
- --dataroot %dataDir%\ContourDrawing\ ^
6
- --phase val ^
7
- --how_many 100 ^
8
- --checkpoints_dir %dataDir%\Exp\PhotoSketch\Checkpoints\ ^
9
- --results_dir %dataDir%\Exp\PhotoSketch\Results\ ^
10
- --model pix2pix ^
11
- --which_direction AtoB ^
12
- --dataset_mode 1_to_n ^
13
- --norm batch ^
14
- --input_nc 3 ^
15
- --output_nc 1 ^
16
- --which_model_netG resnet_9blocks ^
17
- --which_model_netD global_np ^
18
- --aug_folder width-5 ^
19
- --no_dropout ^
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/scripts/test.sh DELETED
@@ -1,19 +0,0 @@
1
- dataDir=
2
-
3
- python test.py \
4
- --name default \
5
- --dataroot ${dataDir}/ContourDrawing/ \
6
- --phase val \
7
- --how_many 100 \
8
- --checkpoints_dir ${dataDir}/Exp/PhotoSketch/Checkpoints/ \
9
- --results_dir ${dataDir}/Exp/PhotoSketch/Results/ \
10
- --model pix2pix \
11
- --which_direction AtoB \
12
- --dataset_mode 1_to_n \
13
- --norm batch \
14
- --input_nc 3 \
15
- --output_nc 1 \
16
- --which_model_netG resnet_9blocks \
17
- --which_model_netD global_np \
18
- --aug_folder width-5 \
19
- --no_dropout \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/scripts/test_pretrained.cmd DELETED
@@ -1,15 +0,0 @@
1
- @set "dataDir="
2
-
3
- python test_pretrained.py ^
4
- --name pretrained ^
5
- --dataset_mode test_dir ^
6
- --dataroot examples\ ^
7
- --results_dir %dataDir%\Exp\PhotoSketch\Results\ ^
8
- --checkpoints_dir %dataDir%\Exp\PhotoSketch\Checkpoints\ ^
9
- --model pix2pix ^
10
- --which_direction AtoB ^
11
- --norm batch ^
12
- --input_nc 3 ^
13
- --output_nc 1 ^
14
- --which_model_netG resnet_9blocks ^
15
- --no_dropout ^
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/scripts/test_pretrained.sh DELETED
@@ -1,15 +0,0 @@
1
- dataDir=/home/fgirella/sketch2img/PhotoSketch
2
-
3
- python test_pretrained.py \
4
- --name pretrained \
5
- --dataset_mode test_dir \
6
- --dataroot examples/ \
7
- --results_dir ${dataDir}/Runs/Results/ \
8
- --checkpoints_dir ${dataDir}/ \
9
- --model pix2pix \
10
- --which_direction AtoB \
11
- --norm batch \
12
- --input_nc 3 \
13
- --output_nc 1 \
14
- --which_model_netG resnet_9blocks \
15
- --no_dropout \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/scripts/train.cmd DELETED
@@ -1,22 +0,0 @@
1
- @set "dataDir="
2
-
3
- python train.py ^
4
- --name default ^
5
- --dataroot %dataDir%\ContourDrawing\ ^
6
- --checkpoints_dir %dataDir%\Exp\PhotoSketch\Checkpoints\ ^
7
- --model pix2pix ^
8
- --which_direction AtoB ^
9
- --dataset_mode 1_to_n ^
10
- --no_lsgan ^
11
- --norm batch ^
12
- --pool_size 0 ^
13
- --output_nc 1 ^
14
- --which_model_netG resnet_9blocks ^
15
- --which_model_netD global_np ^
16
- --batchSize 2 ^
17
- --lambda_A 200 ^
18
- --lr 0.0002 ^
19
- --aug_folder width-5 ^
20
- --crop --rotate --color_jitter ^
21
- --niter 400 ^
22
- --niter_decay 400 ^
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/scripts/train.sh DELETED
@@ -1,22 +0,0 @@
1
- dataDir=
2
-
3
- python train.py \
4
- --name default \
5
- --dataroot ${dataDir}/ContourDrawing/ \
6
- --checkpoints_dir ${dataDir}/Exp/PhotoSketch/Checkpoints/ \
7
- --model pix2pix \
8
- --which_direction AtoB \
9
- --dataset_mode 1_to_n \
10
- --no_lsgan \
11
- --norm batch \
12
- --pool_size 0 \
13
- --output_nc 1 \
14
- --which_model_netG resnet_9blocks \
15
- --which_model_netD global_np \
16
- --batchSize 2 \
17
- --lambda_A 200 \
18
- --lr 0.0002 \
19
- --aug_folder width-5 \
20
- --crop --rotate --color_jitter \
21
- --niter 400 \
22
- --niter_decay 400 \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/test.py DELETED
@@ -1,37 +0,0 @@
1
- import time
2
- import os
3
- from options.test_options import TestOptions
4
- from data.data_loader import CreateDataLoader
5
- from models.models import create_model
6
- from util.visualizer import Visualizer
7
- from util import html
8
-
9
- def main():
10
- opt = TestOptions().parse()
11
- opt.nThreads = 1 # test code only supports nThreads = 1
12
- opt.batchSize = 1 # test code only supports batchSize = 1
13
- opt.serial_batches = True # no shuffle
14
- opt.no_flip = True # no flip
15
-
16
- data_loader = CreateDataLoader(opt)
17
- dataset = data_loader.load_data()
18
- model = create_model(opt)
19
- visualizer = Visualizer(opt)
20
- # create website
21
- web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
22
- webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
23
- # test
24
- for i, data in enumerate(dataset):
25
- if i >= opt.how_many:
26
- break
27
- model.set_input(data)
28
- model.test()
29
- visuals = model.get_current_visuals()
30
- img_path = model.get_image_paths()
31
- print('%04d: process image... %s' % (i, img_path))
32
- visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
33
-
34
- webpage.save()
35
-
36
- if __name__ == '__main__':
37
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/test_pretrained.py DELETED
@@ -1,29 +0,0 @@
1
- import os
2
- from options.test_options import TestOptions
3
- from data.data_loader import CreateDataLoader
4
- from models.models import create_model
5
-
6
- def main():
7
- opt = TestOptions().parse()
8
- opt.nThreads = 1 # test code only supports nThreads = 1
9
- opt.batchSize = 1 # test code only supports batchSize = 1
10
- opt.serial_batches = True # no shuffle
11
- opt.no_flip = True # no flip
12
-
13
- if not os.path.isdir(opt.results_dir):
14
- os.makedirs(opt.results_dir)
15
-
16
- data_loader = CreateDataLoader(opt)
17
- dataset = data_loader.load_data()
18
- model = create_model(opt)
19
-
20
- # test
21
- for i, data in enumerate(dataset):
22
- model.set_input(data)
23
- img_path = model.get_image_paths()
24
- print('Processing %04d (%s)' % (i+1, img_path[0]))
25
- model.test()
26
- model.write_image(opt.results_dir)
27
-
28
- if __name__ == '__main__':
29
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/train.py DELETED
@@ -1,61 +0,0 @@
1
- import time
2
- from options.train_options import TrainOptions
3
- from data.data_loader import CreateDataLoader
4
- from models.models import create_model
5
- from util.visualizer import Visualizer
6
-
7
- import torch, random
8
- torch.manual_seed(0)
9
- random.seed(0)
10
-
11
- def main():
12
- opt = TrainOptions().parse()
13
- data_loader = CreateDataLoader(opt)
14
- dataset = data_loader.load_data()
15
- dataset_size = len(data_loader)
16
- print('#training images = %d' % dataset_size)
17
-
18
- model = create_model(opt)
19
- visualizer = Visualizer(opt)
20
- total_steps = 0
21
-
22
- for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
23
- epoch_start_time = time.time()
24
- epoch_iter = 0
25
-
26
- for i, data in enumerate(dataset):
27
- iter_start_time = time.time()
28
- visualizer.reset()
29
- total_steps += opt.batchSize
30
- epoch_iter += opt.batchSize
31
- model.set_input(data)
32
- model.optimize_parameters()
33
-
34
- if total_steps % opt.display_freq == 0:
35
- save_result = total_steps % opt.update_html_freq == 0
36
- visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
37
-
38
- if total_steps % opt.print_freq == 0:
39
- errors = model.get_current_errors()
40
- t = (time.time() - iter_start_time) / opt.batchSize
41
- visualizer.print_current_errors(epoch, epoch_iter, errors, t)
42
- if opt.display_id > 0:
43
- visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
44
-
45
- if total_steps % opt.save_latest_freq == 0:
46
- print('saving the latest model (epoch %d, total_steps %d)' %
47
- (epoch, total_steps))
48
- model.save('latest')
49
-
50
- if epoch % opt.save_epoch_freq == 0:
51
- print('saving the model at the end of epoch %d, iters %d' %
52
- (epoch, total_steps))
53
- model.save('latest')
54
- model.save(epoch)
55
-
56
- print('End of epoch %d / %d \t Time Taken: %d sec' %
57
- (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
58
- model.update_learning_rate()
59
-
60
- if __name__ == '__main__':
61
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/photosketch/util/.DS_Store DELETED
Binary file (6.15 kB)
 
src/photosketch/util/__init__.py DELETED
File without changes