ali-ghamdan commited on
Commit
31a913a
Β·
1 Parent(s): 9be6adc
This view is limited to 50 files because it contains too many changes. Β  See raw diff
.github/CODEOWNERS ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # See: https://help.github.com/en/articles/about-code-owners
2
+ #
3
+ # Owners will be requested for review when someone opens a pull request.
4
+ * @jantic @alexandrevicenzi
.gitignore ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # DeOldify
107
+ data
108
+ *SymbolicLinks.sh
109
+ *.ipynb_checkpoints*
110
+ ColorizeTraining*[0-9]*.ipynb
111
+ *Colorizer[0-9]*.ipynb
112
+ lesson7-superres*.ipynb
113
+ test.py
114
+ result_images
115
+ *.prof
116
+ *.pth
117
+ video
118
+ /test_images
119
+ deoldify/.ipynb_checkpoints/*-checkpoint.py
120
+ tmp*
.pre-commit-config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/ambv/black
3
+ rev: stable
4
+ hooks:
5
+ - id: black
6
+ args: [-S]
7
+ language_version: python3.6
.pylintrc ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [MASTER]
2
+
3
+ # A comma-separated list of package or module names from where C extensions may
4
+ # be loaded. Extensions are loading into the active Python interpreter and may
5
+ # run arbitrary code.
6
+ extension-pkg-whitelist=
7
+
8
+ # Add files or directories to the blacklist. They should be base names, not
9
+ # paths.
10
+ ignore=CVS
11
+
12
+ # Add files or directories matching the regex patterns to the blacklist. The
13
+ # regex matches against base names, not paths.
14
+ ignore-patterns=
15
+
16
+ # Python code to execute, usually for sys.path manipulation such as
17
+ # pygtk.require().
18
+ #init-hook='import sys; sys.path.append("./venv/lib/python3.7/site-packages")'
19
+
20
+ # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
21
+ # number of processors available to use.
22
+ jobs=1
23
+
24
+ # Control the amount of potential inferred values when inferring a single
25
+ # object. This can help the performance when dealing with large functions or
26
+ # complex, nested conditions.
27
+ limit-inference-results=100
28
+
29
+ # List of plugins (as comma separated values of python modules names) to load,
30
+ # usually to register additional checkers.
31
+ load-plugins=
32
+
33
+ # Pickle collected data for later comparisons.
34
+ persistent=yes
35
+
36
+ # Specify a configuration file.
37
+ #rcfile=
38
+
39
+ # When enabled, pylint would attempt to guess common misconfiguration and emit
40
+ # user-friendly hints instead of false-positive error messages.
41
+ suggestion-mode=yes
42
+
43
+ # Allow loading of arbitrary C extensions. Extensions are imported into the
44
+ # active Python interpreter and may run arbitrary code.
45
+ unsafe-load-any-extension=no
46
+
47
+
48
+ [MESSAGES CONTROL]
49
+
50
+ # Only show warnings with the listed confidence levels. Leave empty to show
51
+ # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
52
+ confidence=
53
+
54
+ # Disable the message, report, category or checker with the given id(s). You
55
+ # can either give multiple identifiers separated by comma (,) or put this
56
+ # option multiple times (only on the command line, not in the configuration
57
+ # file where it should appear only once). You can also use "--disable=all" to
58
+ # disable everything first and then reenable specific checks. For example, if
59
+ # you want to run only the similarities checker, you can use "--disable=all
60
+ # --enable=similarities". If you want to run only the classes checker, but have
61
+ # no Warning level messages displayed, use "--disable=all --enable=classes
62
+ # --disable=W".
63
+ disable=print-statement,
64
+ parameter-unpacking,
65
+ unpacking-in-except,
66
+ old-raise-syntax,
67
+ backtick,
68
+ long-suffix,
69
+ old-ne-operator,
70
+ old-octal-literal,
71
+ import-star-module-level,
72
+ non-ascii-bytes-literal,
73
+ raw-checker-failed,
74
+ bad-inline-option,
75
+ locally-disabled,
76
+ locally-enabled,
77
+ file-ignored,
78
+ suppressed-message,
79
+ useless-suppression,
80
+ deprecated-pragma,
81
+ use-symbolic-message-instead,
82
+ apply-builtin,
83
+ basestring-builtin,
84
+ buffer-builtin,
85
+ cmp-builtin,
86
+ coerce-builtin,
87
+ execfile-builtin,
88
+ file-builtin,
89
+ long-builtin,
90
+ raw_input-builtin,
91
+ reduce-builtin,
92
+ standarderror-builtin,
93
+ unicode-builtin,
94
+ xrange-builtin,
95
+ coerce-method,
96
+ delslice-method,
97
+ getslice-method,
98
+ setslice-method,
99
+ no-absolute-import,
100
+ old-division,
101
+ dict-iter-method,
102
+ dict-view-method,
103
+ next-method-called,
104
+ metaclass-assignment,
105
+ indexing-exception,
106
+ raising-string,
107
+ reload-builtin,
108
+ oct-method,
109
+ hex-method,
110
+ nonzero-method,
111
+ cmp-method,
112
+ input-builtin,
113
+ round-builtin,
114
+ intern-builtin,
115
+ unichr-builtin,
116
+ map-builtin-not-iterating,
117
+ zip-builtin-not-iterating,
118
+ range-builtin-not-iterating,
119
+ filter-builtin-not-iterating,
120
+ using-cmp-argument,
121
+ eq-without-hash,
122
+ div-method,
123
+ idiv-method,
124
+ rdiv-method,
125
+ exception-message-attribute,
126
+ invalid-str-codec,
127
+ sys-max-int,
128
+ bad-python3-import,
129
+ deprecated-string-function,
130
+ deprecated-str-translate-call,
131
+ deprecated-itertools-function,
132
+ deprecated-types-field,
133
+ next-method-defined,
134
+ dict-items-not-iterating,
135
+ dict-keys-not-iterating,
136
+ dict-values-not-iterating,
137
+ deprecated-operator-function,
138
+ deprecated-urllib-function,
139
+ xreadlines-attribute,
140
+ deprecated-sys-function,
141
+ exception-escape,
142
+ comprehension-escape,
143
+ # Disabled due Black
144
+ bad-continuation,
145
+ bad-whitespace,
146
+ # We don't care about these
147
+ redundant-keyword-arg,
148
+
149
+ # Enable the message, report, category or checker with the given id(s). You can
150
+ # either give multiple identifier separated by comma (,) or put this option
151
+ # multiple time (only on the command line, not in the configuration file where
152
+ # it should appear only once). See also the "--disable" option for examples.
153
+ enable=c-extension-no-member
154
+
155
+
156
+ [REPORTS]
157
+
158
+ # Python expression which should return a note less than 10 (10 is the highest
159
+ # note). You have access to the variables errors warning, statement which
160
+ # respectively contain the number of errors / warnings messages and the total
161
+ # number of statements analyzed. This is used by the global evaluation report
162
+ # (RP0004).
163
+ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
164
+
165
+ # Template used to display messages. This is a python new-style format string
166
+ # used to format the message information. See doc for all details.
167
+ #msg-template=
168
+
169
+ # Set the output format. Available formats are text, parseable, colorized, json
170
+ # and msvs (visual studio). You can also give a reporter class, e.g.
171
+ # mypackage.mymodule.MyReporterClass.
172
+ output-format=text
173
+
174
+ # Tells whether to display a full report or only the messages.
175
+ reports=no
176
+
177
+ # Activate the evaluation score.
178
+ score=yes
179
+
180
+
181
+ [REFACTORING]
182
+
183
+ # Maximum number of nested blocks for function / method body
184
+ max-nested-blocks=5
185
+
186
+ # Complete name of functions that never returns. When checking for
187
+ # inconsistent-return-statements if a never returning function is called then
188
+ # it will be considered as an explicit return statement and no message will be
189
+ # printed.
190
+ never-returning-functions=sys.exit
191
+
192
+
193
+ [LOGGING]
194
+
195
+ # Logging modules to check that the string format arguments are in logging
196
+ # function parameter format.
197
+ logging-modules=logging
198
+
199
+
200
+ [SIMILARITIES]
201
+
202
+ # Ignore comments when computing similarities.
203
+ ignore-comments=yes
204
+
205
+ # Ignore docstrings when computing similarities.
206
+ ignore-docstrings=yes
207
+
208
+ # Ignore imports when computing similarities.
209
+ ignore-imports=no
210
+
211
+ # Minimum lines number of a similarity.
212
+ min-similarity-lines=4
213
+
214
+
215
+ [MISCELLANEOUS]
216
+
217
+ # List of note tags to take in consideration, separated by a comma.
218
+ notes=FIXME,
219
+ XXX,
220
+ TODO
221
+
222
+
223
+ [FORMAT]
224
+
225
+ # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
226
+ expected-line-ending-format=
227
+
228
+ # Regexp for a line that is allowed to be longer than the limit.
229
+ ignore-long-lines=^\s*(# )?<?https?://\S+>?$
230
+
231
+ # Number of spaces of indent required inside a hanging or continued line.
232
+ indent-after-paren=4
233
+
234
+ # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
235
+ # tab).
236
+ indent-string=' '
237
+
238
+ # Maximum number of characters on a single line.
239
+ max-line-length=100
240
+
241
+ # Maximum number of lines in a module.
242
+ max-module-lines=1000
243
+
244
+ # List of optional constructs for which whitespace checking is disabled. `dict-
245
+ # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
246
+ # `trailing-comma` allows a space between comma and closing bracket: (a, ).
247
+ # `empty-line` allows space-only lines.
248
+ no-space-check=trailing-comma,
249
+ dict-separator
250
+
251
+ # Allow the body of a class to be on the same line as the declaration if body
252
+ # contains single statement.
253
+ single-line-class-stmt=no
254
+
255
+ # Allow the body of an if to be on the same line as the test if there is no
256
+ # else.
257
+ single-line-if-stmt=no
258
+
259
+
260
+ [BASIC]
261
+
262
+ # Naming style matching correct argument names.
263
+ argument-naming-style=snake_case
264
+
265
+ # Regular expression matching correct argument names. Overrides argument-
266
+ # naming-style.
267
+ #argument-rgx=
268
+
269
+ # Naming style matching correct attribute names.
270
+ attr-naming-style=snake_case
271
+
272
+ # Regular expression matching correct attribute names. Overrides attr-naming-
273
+ # style.
274
+ #attr-rgx=
275
+
276
+ # Bad variable names which should always be refused, separated by a comma.
277
+ bad-names=foo,
278
+ bar,
279
+ baz,
280
+ toto,
281
+ tutu,
282
+ tata
283
+
284
+ # Naming style matching correct class attribute names.
285
+ class-attribute-naming-style=any
286
+
287
+ # Regular expression matching correct class attribute names. Overrides class-
288
+ # attribute-naming-style.
289
+ #class-attribute-rgx=
290
+
291
+ # Naming style matching correct class names.
292
+ class-naming-style=PascalCase
293
+
294
+ # Regular expression matching correct class names. Overrides class-naming-
295
+ # style.
296
+ #class-rgx=
297
+
298
+ # Naming style matching correct constant names.
299
+ const-naming-style=UPPER_CASE
300
+
301
+ # Regular expression matching correct constant names. Overrides const-naming-
302
+ # style.
303
+ #const-rgx=
304
+
305
+ # Minimum line length for functions/classes that require docstrings, shorter
306
+ # ones are exempt.
307
+ docstring-min-length=-1
308
+
309
+ # Naming style matching correct function names.
310
+ function-naming-style=snake_case
311
+
312
+ # Regular expression matching correct function names. Overrides function-
313
+ # naming-style.
314
+ #function-rgx=
315
+
316
+ # Good variable names which should always be accepted, separated by a comma.
317
+ good-names=f,
318
+ i,
319
+ j,
320
+ k,
321
+ s,
322
+ t,
323
+ ex,
324
+ Run,
325
+ _
326
+
327
+ # Include a hint for the correct naming format with invalid-name.
328
+ include-naming-hint=no
329
+
330
+ # Naming style matching correct inline iteration names.
331
+ inlinevar-naming-style=any
332
+
333
+ # Regular expression matching correct inline iteration names. Overrides
334
+ # inlinevar-naming-style.
335
+ #inlinevar-rgx=
336
+
337
+ # Naming style matching correct method names.
338
+ method-naming-style=snake_case
339
+
340
+ # Regular expression matching correct method names. Overrides method-naming-
341
+ # style.
342
+ #method-rgx=
343
+
344
+ # Naming style matching correct module names.
345
+ module-naming-style=snake_case
346
+
347
+ # Regular expression matching correct module names. Overrides module-naming-
348
+ # style.
349
+ #module-rgx=
350
+
351
+ # Colon-delimited sets of names that determine each other's naming style when
352
+ # the name regexes allow several styles.
353
+ name-group=
354
+
355
+ # Regular expression which should only match function or class names that do
356
+ # not require a docstring.
357
+ no-docstring-rgx=^_
358
+
359
+ # List of decorators that produce properties, such as abc.abstractproperty. Add
360
+ # to this list to register other decorators that produce valid properties.
361
+ # These decorators are taken in consideration only for invalid-name.
362
+ property-classes=abc.abstractproperty
363
+
364
+ # Naming style matching correct variable names.
365
+ variable-naming-style=snake_case
366
+
367
+ # Regular expression matching correct variable names. Overrides variable-
368
+ # naming-style.
369
+ variable-rgx=_?[a-z][A-Za-z0-9_]{0,30}$
370
+ argument-rgx=_?[a-z][A-Za-z0-9_]{0,30}$
371
+
372
+
373
+ [TYPECHECK]
374
+
375
+ # List of decorators that produce context managers, such as
376
+ # contextlib.contextmanager. Add to this list to register other decorators that
377
+ # produce valid context managers.
378
+ contextmanager-decorators=contextlib.contextmanager
379
+
380
+ # List of members which are set dynamically and missed by pylint inference
381
+ # system, and so shouldn't trigger E1101 when accessed. Python regular
382
+ # expressions are accepted.
383
+ generated-members=torch.mm,
384
+ torch.diag,
385
+ torch.symeig,
386
+ torch.sqrt,
387
+ torch.cat,
388
+ cv2.cvtColor,
389
+ cv2.COLOR_BGR2YUV,
390
+ cv2.COLOR_YUV2BGR,
391
+
392
+ # Tells whether missing members accessed in mixin class should be ignored. A
393
+ # mixin class is detected if its name ends with "mixin" (case insensitive).
394
+ ignore-mixin-members=yes
395
+
396
+ # Tells whether to warn about missing members when the owner of the attribute
397
+ # is inferred to be None.
398
+ ignore-none=yes
399
+
400
+ # This flag controls whether pylint should warn about no-member and similar
401
+ # checks whenever an opaque object is returned when inferring. The inference
402
+ # can return multiple potential results while evaluating a Python object, but
403
+ # some branches might not be evaluated, which results in partial inference. In
404
+ # that case, it might be useful to still emit no-member and other checks for
405
+ # the rest of the inferred objects.
406
+ ignore-on-opaque-inference=yes
407
+
408
+ # List of class names for which member attributes should not be checked (useful
409
+ # for classes with dynamically set attributes). This supports the use of
410
+ # qualified names.
411
+ ignored-classes=optparse.Values,thread._local,_thread._local
412
+
413
+ # List of module names for which member attributes should not be checked
414
+ # (useful for modules/projects where namespaces are manipulated during runtime
415
+ # and thus existing member attributes cannot be deduced by static analysis. It
416
+ # supports qualified module names, as well as Unix pattern matching.
417
+ ignored-modules=
418
+
419
+ # Show a hint with possible names when a member name was not found. The aspect
420
+ # of finding the hint is based on edit distance.
421
+ missing-member-hint=yes
422
+
423
+ # The minimum edit distance a name should have in order to be considered a
424
+ # similar match for a missing member name.
425
+ missing-member-hint-distance=1
426
+
427
+ # The total number of similar names that should be taken in consideration when
428
+ # showing a hint for a missing member.
429
+ missing-member-max-choices=1
430
+
431
+
432
+ [VARIABLES]
433
+
434
+ # List of additional names supposed to be defined in builtins. Remember that
435
+ # you should avoid to define new builtins when possible.
436
+ additional-builtins=
437
+
438
+ # Tells whether unused global variables should be treated as a violation.
439
+ allow-global-unused-variables=yes
440
+
441
+ # List of strings which can identify a callback function by name. A callback
442
+ # name must start or end with one of those strings.
443
+ callbacks=cb_,
444
+ _cb
445
+
446
+ # A regular expression matching the name of dummy variables (i.e. expected to
447
+ # not be used).
448
+ dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
449
+
450
+ # Argument names that match this expression will be ignored. Default to name
451
+ # with leading underscore.
452
+ ignored-argument-names=_.*|^ignored_|^unused_
453
+
454
+ # Tells whether we should check for unused import in __init__ files.
455
+ init-import=no
456
+
457
+ # List of qualified module names which can have objects that can redefine
458
+ # builtins.
459
+ redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
460
+
461
+
462
+ [SPELLING]
463
+
464
+ # Limits count of emitted suggestions for spelling mistakes.
465
+ max-spelling-suggestions=4
466
+
467
+ # Spelling dictionary name. Available dictionaries: en_IE (myspell), en_ZM
468
+ # (myspell), en_GB (myspell), en_HK (myspell), en_BZ (myspell), en_PH
469
+ # (myspell), en_ZA (myspell), en_MW (myspell), en_AU (myspell), en_CA
470
+ # (myspell), en_JM (myspell), en_GH (myspell), en_TT (myspell), en_SG
471
+ # (myspell), en_BW (myspell), en_US (myspell), en_NZ (myspell), en_AG
472
+ # (myspell), en_ZW (myspell), en_NA (myspell), en_IN (myspell), en_BS
473
+ # (myspell), en_DK (myspell), en_NG (myspell)..
474
+ spelling-dict=
475
+
476
+ # List of comma separated words that should not be checked.
477
+ spelling-ignore-words=
478
+
479
+ # A path to a file that contains private dictionary; one word per line.
480
+ spelling-private-dict-file=
481
+
482
+ # Tells whether to store unknown words to indicated private dictionary in
483
+ # --spelling-private-dict-file option instead of raising a message.
484
+ spelling-store-unknown-words=no
485
+
486
+
487
+ [IMPORTS]
488
+
489
+ # Allow wildcard imports from modules that define __all__.
490
+ allow-wildcard-with-all=no
491
+
492
+ # Analyse import fallback blocks. This can be used to support both Python 2 and
493
+ # 3 compatible code, which means that the block might have code that exists
494
+ # only in one or another interpreter, leading to false positives when analysed.
495
+ analyse-fallback-blocks=no
496
+
497
+ # Deprecated modules which should not be used, separated by a comma.
498
+ deprecated-modules=optparse,tkinter.tix
499
+
500
+ # Create a graph of external dependencies in the given file (report RP0402 must
501
+ # not be disabled).
502
+ ext-import-graph=
503
+
504
+ # Create a graph of every (i.e. internal and external) dependencies in the
505
+ # given file (report RP0402 must not be disabled).
506
+ import-graph=
507
+
508
+ # Create a graph of internal dependencies in the given file (report RP0402 must
509
+ # not be disabled).
510
+ int-import-graph=
511
+
512
+ # Force import order to recognize a module as part of the standard
513
+ # compatibility libraries.
514
+ known-standard-library=
515
+
516
+ # Force import order to recognize a module as part of a third party library.
517
+ known-third-party=enchant
518
+
519
+
520
+ [CLASSES]
521
+
522
+ # List of method names used to declare (i.e. assign) instance attributes.
523
+ defining-attr-methods=__init__,
524
+ __new__,
525
+ setUp
526
+
527
+ # List of member names, which should be excluded from the protected access
528
+ # warning.
529
+ exclude-protected=_asdict,
530
+ _fields,
531
+ _replace,
532
+ _source,
533
+ _make
534
+
535
+ # List of valid names for the first argument in a class method.
536
+ valid-classmethod-first-arg=cls
537
+
538
+ # List of valid names for the first argument in a metaclass class method.
539
+ valid-metaclass-classmethod-first-arg=cls
540
+
541
+
542
+ [DESIGN]
543
+
544
+ # Maximum number of arguments for function / method.
545
+ max-args=5
546
+
547
+ # Maximum number of attributes for a class (see R0902).
548
+ max-attributes=7
549
+
550
+ # Maximum number of boolean expressions in an if statement.
551
+ max-bool-expr=5
552
+
553
+ # Maximum number of branch for function / method body.
554
+ max-branches=12
555
+
556
+ # Maximum number of locals for function / method body.
557
+ max-locals=15
558
+
559
+ # Maximum number of parents for a class (see R0901).
560
+ max-parents=7
561
+
562
+ # Maximum number of public methods for a class (see R0904).
563
+ max-public-methods=20
564
+
565
+ # Maximum number of return / yield for function / method body.
566
+ max-returns=6
567
+
568
+ # Maximum number of statements in function / method body.
569
+ max-statements=50
570
+
571
+ # Minimum number of public methods for a class (see R0903).
572
+ min-public-methods=2
573
+
574
+
575
+ [EXCEPTIONS]
576
+
577
+ # Exceptions that will emit a warning when being caught. Defaults to
578
+ # "Exception".
579
+ overgeneral-exceptions=Exception
.travis.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ sudo: false
2
+ language: python
3
+ install: pip install tox
4
+ matrix:
5
+ include:
6
+ - python: "3.6"
7
+ env: TOX_ENV=static
8
+ - python: "3.6"
9
+ env: TOX_ENV=format
10
+ script: tox -e $TOX_ENV
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.formatting.provider": "yapf"
3
+ }
ColorFIDBenchmarkArtistic.ipynb ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Color FID Benchmark (HQ)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
18
+ "os.environ['OMP_NUM_THREADS']='1'"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "import statistics\n",
28
+ "from fastai import *\n",
29
+ "from deoldify.visualize import *\n",
30
+ "import cv2\n",
31
+ "from fid.fid_score import *\n",
32
+ "from fid.inception import *\n",
33
+ "import imageio\n",
34
+ "plt.style.use('dark_background')\n",
35
+ "torch.backends.cudnn.benchmark=True\n",
36
+ "import warnings\n",
37
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.nn.functional\")\n",
38
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message='.*?retrieve source code for container of type.*?')"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "## Setup"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "#NOTE: Data should come from here: 'https://datasets.figure-eight.com/figure_eight_datasets/open-images/test_challenge.zip'\n",
55
+ "#NOTE: Minimum recommmended number of samples is 10K. Source: https://github.com/bioinf-jku/TTUR\n",
56
+ "\n",
57
+ "path = Path('data/ColorBenchmark')\n",
58
+ "path_hr = path/'source'\n",
59
+ "path_lr = path/'bandw'\n",
60
+ "path_results = Path('./result_images/ColorBenchmarkFID/artistic')\n",
61
+ "path_rendered = path_results/'rendered'\n",
62
+ "\n",
63
+ "#path = Path('data/DeOldifyColor')\n",
64
+ "#path_hr = path\n",
65
+ "#path_lr = path/'bandw'\n",
66
+ "#path_results = Path('./result_images/ColorBenchmark/edge')\n",
67
+ "#path_rendered = path_results/'rendered'\n",
68
+ "\n",
69
+ "#num_images = 2048\n",
70
+ "#num_images = 15000\n",
71
+ "num_images = 50000\n",
72
+ "render_factor=35\n",
73
+ "fid_batch_size = 4\n",
74
+ "eval_size=299"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "def inception_model(dims:int):\n",
84
+ " block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\n",
85
+ " model = InceptionV3([block_idx])\n",
86
+ " model.cuda()\n",
87
+ " return model"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "def create_before_images(fn,i):\n",
97
+ " dest = path_lr/fn.relative_to(path_hr)\n",
98
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
99
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
100
+ " img.save(dest) "
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "def render_images(colorizer, source_dir:Path, filtered_dir:Path, target_dir:Path, render_factor:int, num_images:int)->[(Path, Path, Path)]:\n",
110
+ " results = []\n",
111
+ " bandw_list = ImageList.from_folder(path_lr)\n",
112
+ " bandw_list = bandw_list[:num_images]\n",
113
+ "\n",
114
+ " if len(bandw_list.items) == 0: return results\n",
115
+ "\n",
116
+ " results = []\n",
117
+ " img_iterator = progress_bar(bandw_list.items)\n",
118
+ "\n",
119
+ " for bandw_path in img_iterator:\n",
120
+ " target_path = target_dir/bandw_path.relative_to(source_dir)\n",
121
+ "\n",
122
+ " try:\n",
123
+ " result_image = colorizer.get_transformed_image(path=bandw_path, render_factor=render_factor)\n",
124
+ " result_path = Path(str(path_results) + '/' + bandw_path.parent.name + '/' + bandw_path.name)\n",
125
+ " if not result_path.parent.exists():\n",
126
+ " result_path.parent.mkdir(parents=True, exist_ok=True)\n",
127
+ " result_image.save(result_path)\n",
128
+ " results.append((result_path, bandw_path, target_path))\n",
129
+ " except Exception as err:\n",
130
+ " print('Failed to render image. Skipping. Details: {0}'.format(err))\n",
131
+ " \n",
132
+ " return results "
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "def calculate_fid_score(render_results, bs:int, eval_size:int):\n",
142
+ " dims = 2048\n",
143
+ " cuda = True\n",
144
+ " model = inception_model(dims=dims)\n",
145
+ " rendered_paths = []\n",
146
+ " target_paths = []\n",
147
+ " \n",
148
+ " for render_result in render_results:\n",
149
+ " rendered_path, _, target_path = render_result\n",
150
+ " rendered_paths.append(str(rendered_path))\n",
151
+ " target_paths.append(str(target_path))\n",
152
+ " \n",
153
+ " rendered_m, rendered_s = calculate_activation_statistics(files=rendered_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
154
+ " target_m, target_s = calculate_activation_statistics(files=target_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
155
+ " fid_score = calculate_frechet_distance(rendered_m, rendered_s, target_m, target_s)\n",
156
+ " del model\n",
157
+ " return fid_score"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "metadata": {},
163
+ "source": [
164
+ "## Create black and whites source images"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {},
170
+ "source": [
171
+ "Only runs if the directory isn't already created."
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "if not path_lr.exists():\n",
181
+ " il = ImageList.from_folder(path_hr)\n",
182
+ " parallel(create_before_images, il.items)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "path_results.parent.mkdir(parents=True, exist_ok=True)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "markdown",
196
+ "metadata": {},
197
+ "source": [
198
+ "### Rendering"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "colorizer = get_image_colorizer(artistic=True)"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "render_results = render_images(colorizer=colorizer, source_dir=path_lr, target_dir=path_hr, filtered_dir=path_results, render_factor=render_factor, num_images=num_images)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {},
222
+ "source": [
223
+ "### Colorizaton Scoring"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "fid_score = calculate_fid_score(render_results, bs=fid_batch_size, eval_size=eval_size)"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "print('FID Score: ' + str(fid_score))"
242
+ ]
243
+ }
244
+ ],
245
+ "metadata": {
246
+ "kernelspec": {
247
+ "display_name": "Python 3",
248
+ "language": "python",
249
+ "name": "python3"
250
+ },
251
+ "language_info": {
252
+ "codemirror_mode": {
253
+ "name": "ipython",
254
+ "version": 3
255
+ },
256
+ "file_extension": ".py",
257
+ "mimetype": "text/x-python",
258
+ "name": "python",
259
+ "nbconvert_exporter": "python",
260
+ "pygments_lexer": "ipython3",
261
+ "version": "3.7.0"
262
+ }
263
+ },
264
+ "nbformat": 4,
265
+ "nbformat_minor": 2
266
+ }
ColorizeTrainingArtistic.ipynb ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Artistic Model Training"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "#### NOTES: \n",
15
+ "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
16
+ "* This model prioritizes colorful renderings. It has higher variation in renderings at different resolutions compared to the \"stable\" model"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "#NOTE: This must be the first call in order to work properly!\n",
26
+ "from deoldify import device\n",
27
+ "from deoldify.device_id import DeviceId\n",
28
+ "#choices: CPU, GPU0...GPU7\n",
29
+ "device.set(device=DeviceId.GPU0)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "import fastai\n",
40
+ "from fastai import *\n",
41
+ "from fastai.vision import *\n",
42
+ "from fastai.callbacks.tensorboard import *\n",
43
+ "from fastai.vision.gan import *\n",
44
+ "from deoldify.generators import *\n",
45
+ "from deoldify.critics import *\n",
46
+ "from deoldify.dataset import *\n",
47
+ "from deoldify.loss import *\n",
48
+ "from deoldify.save import *\n",
49
+ "from PIL import Image, ImageDraw, ImageFont\n",
50
+ "from PIL import ImageFile"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Setup"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
67
+ "path_hr = path\n",
68
+ "path_lr = path/'bandw'\n",
69
+ "\n",
70
+ "proj_id = 'ArtisticModel'\n",
71
+ "\n",
72
+ "gen_name = proj_id + '_gen'\n",
73
+ "pre_gen_name = gen_name + '_0'\n",
74
+ "crit_name = proj_id + '_crit'\n",
75
+ "\n",
76
+ "name_gen = proj_id + '_image_gen'\n",
77
+ "path_gen = path/name_gen\n",
78
+ "\n",
79
+ "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
80
+ "\n",
81
+ "nf_factor = 1.5\n",
82
+ "pct_start = 1e-8"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "def get_data(bs:int, sz:int, keep_pct:float):\n",
92
+ " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
93
+ " random_seed=None, keep_pct=keep_pct)\n",
94
+ "\n",
95
+ "def get_crit_data(classes, bs, sz):\n",
96
+ " src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\n",
97
+ " ll = src.label_from_folder(classes=classes)\n",
98
+ " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
99
+ " .databunch(bs=bs).normalize(imagenet_stats))\n",
100
+ " return data\n",
101
+ "\n",
102
+ "def create_training_images(fn,i):\n",
103
+ " dest = path_lr/fn.relative_to(path_hr)\n",
104
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
105
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
106
+ " img.save(dest) \n",
107
+ " \n",
108
+ "def save_preds(dl):\n",
109
+ " i=0\n",
110
+ " names = dl.dataset.items\n",
111
+ " \n",
112
+ " for b in dl:\n",
113
+ " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
114
+ " for o in preds:\n",
115
+ " o.save(path_gen/names[i].name)\n",
116
+ " i += 1\n",
117
+ " \n",
118
+ "def save_gen_images():\n",
119
+ " if path_gen.exists(): shutil.rmtree(path_gen)\n",
120
+ " path_gen.mkdir(exist_ok=True)\n",
121
+ " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
122
+ " save_preds(data_gen.fix_dl)\n",
123
+ " PIL.Image.open(path_gen.ls()[0])"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Create black and white training images"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "Only runs if the directory isn't already created."
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "if not path_lr.exists():\n",
147
+ " il = ImageList.from_folder(path_hr)\n",
148
+ " parallel(create_training_images, il.items)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {},
154
+ "source": [
155
+ "## Pre-train generator"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "markdown",
160
+ "metadata": {},
161
+ "source": [
162
+ "#### NOTE\n",
163
+ "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "### 64px"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "bs=88\n",
180
+ "sz=64\n",
181
+ "keep_pct=1.0"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "learn_gen.save(pre_gen_name)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "learn_gen.unfreeze()"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "learn_gen.save(pre_gen_name)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {},
259
+ "source": [
260
+ "### 128px"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": [
269
+ "bs=22\n",
270
+ "sz=128\n",
271
+ "keep_pct=1.0"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "learn_gen.unfreeze()"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "learn_gen.save(pre_gen_name)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "metadata": {},
313
+ "source": [
314
+ "### 192px"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "bs=11\n",
324
+ "sz=192\n",
325
+ "keep_pct=0.50"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": null,
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "learn_gen.unfreeze()"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "metadata": {},
359
+ "outputs": [],
360
+ "source": [
361
+ "learn_gen.save(pre_gen_name)"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "metadata": {},
367
+ "source": [
368
+ "## Repeatable GAN Cycle"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {},
374
+ "source": [
375
+ "#### NOTE\n",
376
+ "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. "
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "old_checkpoint_num = 0\n",
386
+ "checkpoint_num = old_checkpoint_num + 1\n",
387
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
388
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
389
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
390
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "### Save Generated Images"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "bs=8\n",
407
+ "sz=192"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "save_gen_images()"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "metadata": {},
431
+ "source": [
432
+ "### Pretrain Critic"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {},
438
+ "source": [
439
+ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "if old_checkpoint_num == 0:\n",
449
+ " bs=64\n",
450
+ " sz=128\n",
451
+ " learn_gen=None\n",
452
+ " gc.collect()\n",
453
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
454
+ " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
455
+ " learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
456
+ " learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
457
+ " learn_critic.fit_one_cycle(6, 1e-3)\n",
458
+ " learn_critic.save(crit_old_checkpoint_name)"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "bs=16\n",
468
+ "sz=192"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": [
486
+ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": null,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "learn_critic.fit_one_cycle(4, 1e-4)"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "learn_critic.save(crit_new_checkpoint_name)"
523
+ ]
524
+ },
525
+ {
526
+ "cell_type": "markdown",
527
+ "metadata": {},
528
+ "source": [
529
+ "### GAN"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": null,
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": [
538
+ "learn_crit=None\n",
539
+ "learn_gen=None\n",
540
+ "gc.collect()"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": null,
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "lr=1e-5\n",
550
+ "sz=192\n",
551
+ "bs=9"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": null,
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
588
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
589
+ " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
590
+ "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
591
+ "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
592
+ "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "markdown",
597
+ "metadata": {},
598
+ "source": [
599
+ "#### Instructions: \n",
600
+ "Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct."
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "metadata": {},
607
+ "outputs": [],
608
+ "source": [
609
+ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
610
+ "learn_gen.freeze_to(-1)\n",
611
+ "learn.fit(1,lr)"
612
+ ]
613
+ }
614
+ ],
615
+ "metadata": {
616
+ "kernelspec": {
617
+ "display_name": "Python 3",
618
+ "language": "python",
619
+ "name": "python3"
620
+ },
621
+ "language_info": {
622
+ "codemirror_mode": {
623
+ "name": "ipython",
624
+ "version": 3
625
+ },
626
+ "file_extension": ".py",
627
+ "mimetype": "text/x-python",
628
+ "name": "python",
629
+ "nbconvert_exporter": "python",
630
+ "pygments_lexer": "ipython3",
631
+ "version": "3.7.0"
632
+ }
633
+ },
634
+ "nbformat": 4,
635
+ "nbformat_minor": 4
636
+ }
ColorizeTrainingStable.ipynb ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Stable Model Training"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "#### NOTES: \n",
15
+ "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
16
+ "* This model prioritizes stable and reliable renderings. It does particularly well on portraits and landscapes. It's not as colorful as the artistic model."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "#NOTE: This must be the first call in order to work properly!\n",
26
+ "from deoldify import device\n",
27
+ "from deoldify.device_id import DeviceId\n",
28
+ "#choices: CPU, GPU0...GPU7\n",
29
+ "device.set(device=DeviceId.GPU0)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "import fastai\n",
40
+ "from fastai import *\n",
41
+ "from fastai.vision import *\n",
42
+ "from fastai.callbacks.tensorboard import *\n",
43
+ "from fastai.vision.gan import *\n",
44
+ "from deoldify.generators import *\n",
45
+ "from deoldify.critics import *\n",
46
+ "from deoldify.dataset import *\n",
47
+ "from deoldify.loss import *\n",
48
+ "from deoldify.save import *\n",
49
+ "from PIL import Image, ImageDraw, ImageFont\n",
50
+ "from PIL import ImageFile"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Setup"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
67
+ "path_hr = path\n",
68
+ "path_lr = path/'bandw'\n",
69
+ "\n",
70
+ "proj_id = 'StableModel'\n",
71
+ "\n",
72
+ "gen_name = proj_id + '_gen'\n",
73
+ "pre_gen_name = gen_name + '_0'\n",
74
+ "crit_name = proj_id + '_crit'\n",
75
+ "\n",
76
+ "name_gen = proj_id + '_image_gen'\n",
77
+ "path_gen = path/name_gen\n",
78
+ "\n",
79
+ "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
80
+ "\n",
81
+ "nf_factor = 2\n",
82
+ "pct_start = 1e-8"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "def get_data(bs:int, sz:int, keep_pct:float):\n",
92
+ " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
93
+ " random_seed=None, keep_pct=keep_pct)\n",
94
+ "\n",
95
+ "def get_crit_data(classes, bs, sz):\n",
96
+ " src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\n",
97
+ " ll = src.label_from_folder(classes=classes)\n",
98
+ " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
99
+ " .databunch(bs=bs).normalize(imagenet_stats))\n",
100
+ " return data\n",
101
+ "\n",
102
+ "def create_training_images(fn,i):\n",
103
+ " dest = path_lr/fn.relative_to(path_hr)\n",
104
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
105
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
106
+ " img.save(dest) \n",
107
+ " \n",
108
+ "def save_preds(dl):\n",
109
+ " i=0\n",
110
+ " names = dl.dataset.items\n",
111
+ " \n",
112
+ " for b in dl:\n",
113
+ " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
114
+ " for o in preds:\n",
115
+ " o.save(path_gen/names[i].name)\n",
116
+ " i += 1\n",
117
+ " \n",
118
+ "def save_gen_images():\n",
119
+ " if path_gen.exists(): shutil.rmtree(path_gen)\n",
120
+ " path_gen.mkdir(exist_ok=True)\n",
121
+ " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
122
+ " save_preds(data_gen.fix_dl)\n",
123
+ " PIL.Image.open(path_gen.ls()[0])"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Create black and white training images"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "Only runs if the directory isn't already created."
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "if not path_lr.exists():\n",
147
+ " il = ImageList.from_folder(path_hr)\n",
148
+ " parallel(create_training_images, il.items)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {},
154
+ "source": [
155
+ "## Pre-train generator"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "markdown",
160
+ "metadata": {},
161
+ "source": [
162
+ "#### NOTE\n",
163
+ "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "### 64px"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "bs=88\n",
180
+ "sz=64\n",
181
+ "keep_pct=1.0"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "learn_gen.save(pre_gen_name)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "learn_gen.unfreeze()"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "learn_gen.save(pre_gen_name)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {},
259
+ "source": [
260
+ "### 128px"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": [
269
+ "bs=20\n",
270
+ "sz=128\n",
271
+ "keep_pct=1.0"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "learn_gen.unfreeze()"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "learn_gen.save(pre_gen_name)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "metadata": {},
313
+ "source": [
314
+ "### 192px"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "bs=8\n",
324
+ "sz=192\n",
325
+ "keep_pct=0.50"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": null,
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "learn_gen.unfreeze()"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "metadata": {},
359
+ "outputs": [],
360
+ "source": [
361
+ "learn_gen.save(pre_gen_name)"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "metadata": {},
367
+ "source": [
368
+ "## Repeatable GAN Cycle"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {},
374
+ "source": [
375
+ "#### NOTE\n",
376
+ "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. "
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "old_checkpoint_num = 0\n",
386
+ "checkpoint_num = old_checkpoint_num + 1\n",
387
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
388
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
389
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
390
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "### Save Generated Images"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "bs=8\n",
407
+ "sz=192"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "save_gen_images()"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "metadata": {},
431
+ "source": [
432
+ "### Pretrain Critic"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {},
438
+ "source": [
439
+ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "if old_checkpoint_num == 0:\n",
449
+ " bs=64\n",
450
+ " sz=128\n",
451
+ " learn_gen=None\n",
452
+ " gc.collect()\n",
453
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
454
+ " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
455
+ " learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
456
+ " learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
457
+ " learn_critic.fit_one_cycle(6, 1e-3)\n",
458
+ " learn_critic.save(crit_old_checkpoint_name)"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "bs=16\n",
468
+ "sz=192"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": [
486
+ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": null,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "learn_critic.fit_one_cycle(4, 1e-4)"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "learn_critic.save(crit_new_checkpoint_name)"
523
+ ]
524
+ },
525
+ {
526
+ "cell_type": "markdown",
527
+ "metadata": {},
528
+ "source": [
529
+ "### GAN"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": null,
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": [
538
+ "learn_crit=None\n",
539
+ "learn_gen=None\n",
540
+ "gc.collect()"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": null,
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "lr=2e-5\n",
550
+ "sz=192\n",
551
+ "bs=5"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": null,
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
588
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
589
+ " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
590
+ "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
591
+ "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
592
+ "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "markdown",
597
+ "metadata": {},
598
+ "source": [
599
+ "#### Instructions: \n",
600
+ "Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct."
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "metadata": {},
607
+ "outputs": [],
608
+ "source": [
609
+ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
610
+ "learn_gen.freeze_to(-1)\n",
611
+ "learn.fit(1,lr)"
612
+ ]
613
+ }
614
+ ],
615
+ "metadata": {
616
+ "kernelspec": {
617
+ "display_name": "Python 3",
618
+ "language": "python",
619
+ "name": "python3"
620
+ },
621
+ "language_info": {
622
+ "codemirror_mode": {
623
+ "name": "ipython",
624
+ "version": 3
625
+ },
626
+ "file_extension": ".py",
627
+ "mimetype": "text/x-python",
628
+ "name": "python",
629
+ "nbconvert_exporter": "python",
630
+ "pygments_lexer": "ipython3",
631
+ "version": "3.7.6"
632
+ }
633
+ },
634
+ "nbformat": 4,
635
+ "nbformat_minor": 4
636
+ }
ColorizeTrainingStableLargeBatch.ipynb ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Stable Model Training (Large Batch/Limited GPU Memory Support)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## IMPORTANT: Training has -not- been verified by myself for this notebook ~jantic"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {},
20
+ "source": [
21
+ "#### NOTES: \n",
22
+ "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
23
+ "* This model prioritizes stable and reliable renderings. It does particularly well on portraits and landscapes. It's not as colorful as the artistic model."
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "import os\n",
33
+ "os.environ['CUDA_VISIBLE_DEVICES']='0' "
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "import fastai\n",
43
+ "from fastai import *\n",
44
+ "from fastai.vision import *\n",
45
+ "from fastai.callbacks.tensorboard import *\n",
46
+ "from fastai.vision.gan import *\n",
47
+ "from deoldify.generators import *\n",
48
+ "from deoldify.critics import *\n",
49
+ "from deoldify.dataset import *\n",
50
+ "from deoldify.loss import *\n",
51
+ "from deoldify.save import *\n",
52
+ "from PIL import Image, ImageDraw, ImageFont\n",
53
+ "from PIL import ImageFile"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "metadata": {},
59
+ "source": [
60
+ "## Setup"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "metadata": {},
66
+ "source": [
67
+ "### Activate Large Model Support for PyTorch\n",
68
+ "This will allow us to fit the model within a GPU with smaller memory capacity (e.g. GTX 1070 8Gb)."
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {},
74
+ "source": [
75
+ "Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning Community Edition (WML-CE) PyTorch V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with β€œout-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.\n",
76
+ "\n",
77
+ "Requires the use of IBM WML-CE (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html)\n",
78
+ "\n",
79
+ "Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/navigation/wmlce_getstarted_pytorch.html"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "import shutil"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# Set limit of GPU used before swapping to tensors to host memory\n",
98
+ "max_gpu_mem = 7\n",
99
+ "\n",
100
+ "def gb_to_bytes(gb):\n",
101
+ " return gb*1024*1024*1024\n",
102
+ "\n",
103
+ "# Enable PyTorch LMS\n",
104
+ "torch.cuda.set.enabled_lms(True)\n",
105
+ "# Set LMS limit\n",
106
+ "torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_mem))"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# Check LMS is enabled\n",
116
+ "torch.cuda.get_enabled_lms()"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "# Check LMS Limit has been set\n",
126
+ "torch.cuda.get_limit_lms()"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ " "
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# Path to Training Data\n",
143
+ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
144
+ "path_hr = path\n",
145
+ "\n",
146
+ "# Path to Black and White images\n",
147
+ "path_bandw = Path('/training/DeOldify')\n",
148
+ "path_lr = path_bandw/'bandw'\n",
149
+ "\n",
150
+ "# Name of Model\n",
151
+ "proj_id = 'StableModel'\n",
152
+ "\n",
153
+ "# Name of Generator\n",
154
+ "gen_name = proj_id + '_gen'\n",
155
+ "pre_gen_name = gen_name + '_0'\n",
156
+ "\n",
157
+ "# Name of Critic\n",
158
+ "crit_name = proj_id + '_crit'\n",
159
+ "\n",
160
+ "# Name of Generated Images folder, located within the Black and White folder\n",
161
+ "name_gen = proj_id + '_image_gen'\n",
162
+ "path_gen = path/name_gen\n",
163
+ "\n",
164
+ "# Path to tensorboard data\n",
165
+ "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
166
+ "\n",
167
+ "nf_factor = 2\n",
168
+ "pct_start = 1e-8\n",
169
+ "\n",
170
+ "# Number of workers for DataLoader\n",
171
+ "num_works = 2"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "def get_data(bs:int, sz:int, keep_pct:float):\n",
181
+ " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
182
+ " random_seed=None, keep_pct=keep_pct, num_workers=num_works)\n",
183
+ "\n",
184
+ "def get_crit_data(classes, bs, sz):\n",
185
+ " src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\n",
186
+ " ll = src.label_from_folder(classes=classes)\n",
187
+ " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
188
+ " .databunch(bs=bs).normalize(imagenet_stats))\n",
189
+ " return data\n",
190
+ "\n",
191
+ "def create_training_images(fn,i):\n",
192
+ " dest = path_lr/fn.relative_to(path_hr)\n",
193
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
194
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
195
+ " img.save(dest) \n",
196
+ " \n",
197
+ "def save_preds(dl):\n",
198
+ " i=0\n",
199
+ " names = dl.dataset.items\n",
200
+ " \n",
201
+ " for b in dl:\n",
202
+ " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
203
+ " for o in preds:\n",
204
+ " o.save(path_gen/names[i].name)\n",
205
+ " i += 1\n",
206
+ " \n",
207
+ "def save_gen_images():\n",
208
+ " if path_gen.exists(): shutil.rmtree(path_gen)\n",
209
+ " path_gen.mkdir(exist_ok=True)\n",
210
+ " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
211
+ " save_preds(data_gen.fix_dl)\n",
212
+ " PIL.Image.open(path_gen.ls()[0])"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "metadata": {},
218
+ "source": [
219
+ "## Create black and white training images"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {},
225
+ "source": [
226
+ "Only runs if the directory isn't already created."
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "if not path_lr.exists():\n",
236
+ " il = ImageList.from_folder(path_hr)\n",
237
+ " parallel(create_training_images, il.items)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "## Pre-train generator"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "#### NOTE\n",
252
+ "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "metadata": {},
258
+ "source": [
259
+ "### 64px"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "bs=88 # This can be increased if using PyTorch LMS, training could be slower.\n",
269
+ "sz=64\n",
270
+ "keep_pct=1.0"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "metadata": {},
286
+ "outputs": [],
287
+ "source": [
288
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "learn_gen.save(pre_gen_name)"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "learn_gen.unfreeze()"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "metadata": {},
331
+ "outputs": [],
332
+ "source": [
333
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "learn_gen.save(pre_gen_name)"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {},
348
+ "source": [
349
+ "### 128px"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "bs=40 # This can be increased if using PyTorch LMS, training could be slower.\n",
359
+ "sz=128\n",
360
+ "keep_pct=1.0"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": [
378
+ "learn_gen.unfreeze()"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "learn_gen.save(pre_gen_name)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "markdown",
401
+ "metadata": {},
402
+ "source": [
403
+ "### 192px"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "bs=16 # This can be increased if using PyTorch LMS, training could be slower.\n",
413
+ "sz=192\n",
414
+ "keep_pct=0.50"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "learn_gen.unfreeze()"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {},
439
+ "outputs": [],
440
+ "source": [
441
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {},
448
+ "outputs": [],
449
+ "source": [
450
+ "learn_gen.save(pre_gen_name)"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "metadata": {},
456
+ "source": [
457
+ "### 256px"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "bs=8 # This can be increased if using PyTorch LMS, training could be slower.\n",
467
+ "sz=256\n",
468
+ "keep_pct=0.50"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": [
486
+ "learn_gen.unfreeze()"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": null,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "learn_gen.save(pre_gen_name)"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "metadata": {},
510
+ "source": [
511
+ "## Repeatable GAN Cycle"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "markdown",
516
+ "metadata": {},
517
+ "source": [
518
+ "#### NOTE\n",
519
+ "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. "
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": null,
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": [
528
+ "old_checkpoint_num = 0\n",
529
+ "checkpoint_num = old_checkpoint_num + 1\n",
530
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
531
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
532
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
533
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "markdown",
538
+ "metadata": {},
539
+ "source": [
540
+ "### Save Generated Images"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": null,
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "bs=8\n",
550
+ "sz=256"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": null,
556
+ "metadata": {},
557
+ "outputs": [],
558
+ "source": [
559
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": null,
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": [
568
+ "save_gen_images()"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "markdown",
573
+ "metadata": {},
574
+ "source": [
575
+ "### Pretrain Critic"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "markdown",
580
+ "metadata": {},
581
+ "source": [
582
+ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": null,
588
+ "metadata": {},
589
+ "outputs": [],
590
+ "source": [
591
+ "if old_checkpoint_num == 0:\n",
592
+ " bs=64\n",
593
+ " sz=128\n",
594
+ " learn_gen=None\n",
595
+ " gc.collect()\n",
596
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
597
+ " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
598
+ " learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
599
+ " learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
600
+ " learn_critic.fit_one_cycle(6, 1e-3)\n",
601
+ " learn_critic.save(crit_old_checkpoint_name)"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": null,
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": [
610
+ "bs=8\n",
611
+ "sz=256"
612
+ ]
613
+ },
614
+ {
615
+ "cell_type": "code",
616
+ "execution_count": null,
617
+ "metadata": {},
618
+ "outputs": [],
619
+ "source": [
620
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "code",
625
+ "execution_count": null,
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "code",
634
+ "execution_count": null,
635
+ "metadata": {},
636
+ "outputs": [],
637
+ "source": [
638
+ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
639
+ ]
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": null,
644
+ "metadata": {},
645
+ "outputs": [],
646
+ "source": [
647
+ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": null,
653
+ "metadata": {},
654
+ "outputs": [],
655
+ "source": [
656
+ "learn_critic.fit_one_cycle(4, 1e-4)"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "code",
661
+ "execution_count": null,
662
+ "metadata": {},
663
+ "outputs": [],
664
+ "source": [
665
+ "learn_critic.save(crit_new_checkpoint_name)"
666
+ ]
667
+ },
668
+ {
669
+ "cell_type": "markdown",
670
+ "metadata": {},
671
+ "source": [
672
+ "### GAN"
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "execution_count": null,
678
+ "metadata": {},
679
+ "outputs": [],
680
+ "source": [
681
+ "learn_crit=None\n",
682
+ "learn_gen=None\n",
683
+ "gc.collect()"
684
+ ]
685
+ },
686
+ {
687
+ "cell_type": "code",
688
+ "execution_count": null,
689
+ "metadata": {},
690
+ "outputs": [],
691
+ "source": [
692
+ "lr=2e-5\n",
693
+ "sz=256\n",
694
+ "bs=5"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": null,
700
+ "metadata": {},
701
+ "outputs": [],
702
+ "source": [
703
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
704
+ ]
705
+ },
706
+ {
707
+ "cell_type": "code",
708
+ "execution_count": null,
709
+ "metadata": {},
710
+ "outputs": [],
711
+ "source": [
712
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
713
+ ]
714
+ },
715
+ {
716
+ "cell_type": "code",
717
+ "execution_count": null,
718
+ "metadata": {},
719
+ "outputs": [],
720
+ "source": [
721
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "code",
726
+ "execution_count": null,
727
+ "metadata": {},
728
+ "outputs": [],
729
+ "source": [
730
+ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
731
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
732
+ " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
733
+ "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
734
+ "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
735
+ "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
736
+ ]
737
+ },
738
+ {
739
+ "cell_type": "markdown",
740
+ "metadata": {},
741
+ "source": [
742
+ "#### Instructions: \n",
743
+ "Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct."
744
+ ]
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "execution_count": null,
749
+ "metadata": {},
750
+ "outputs": [],
751
+ "source": [
752
+ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
753
+ "learn_gen.freeze_to(-1)\n",
754
+ "learn.fit(1,lr)"
755
+ ]
756
+ }
757
+ ],
758
+ "metadata": {
759
+ "kernelspec": {
760
+ "display_name": "Python 3",
761
+ "language": "python",
762
+ "name": "python3"
763
+ },
764
+ "language_info": {
765
+ "codemirror_mode": {
766
+ "name": "ipython",
767
+ "version": 3
768
+ },
769
+ "file_extension": ".py",
770
+ "mimetype": "text/x-python",
771
+ "name": "python",
772
+ "nbconvert_exporter": "python",
773
+ "pygments_lexer": "ipython3",
774
+ "version": "3.7.0"
775
+ }
776
+ },
777
+ "nbformat": 4,
778
+ "nbformat_minor": 2
779
+ }
ColorizeTrainingVideo.ipynb ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Video Model Training"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "#### NOTES: \n",
15
+ "* It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.\n",
16
+ "* This is \"NoGAN\" based training, described in the DeOldify readme."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "#NOTE: This must be the first call in order to work properly!\n",
26
+ "from deoldify import device\n",
27
+ "from deoldify.device_id import DeviceId\n",
28
+ "#choices: CPU, GPU0...GPU7\n",
29
+ "device.set(device=DeviceId.GPU0)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "import fastai\n",
40
+ "from fastai import *\n",
41
+ "from fastai.vision import *\n",
42
+ "from fastai.callbacks.tensorboard import *\n",
43
+ "from fastai.vision.gan import *\n",
44
+ "from deoldify.generators import *\n",
45
+ "from deoldify.critics import *\n",
46
+ "from deoldify.dataset import *\n",
47
+ "from deoldify.loss import *\n",
48
+ "from deoldify.save import *\n",
49
+ "from deoldify.augs import noisify \n",
50
+ "from PIL import Image, ImageDraw, ImageFont\n",
51
+ "from PIL import ImageFile"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "metadata": {},
57
+ "source": [
58
+ "## Setup"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
68
+ "path_hr = path\n",
69
+ "path_lr = path/'bandw'\n",
70
+ "\n",
71
+ "proj_id = 'VideoModel'\n",
72
+ "gen_name = proj_id + '_gen'\n",
73
+ "pre_gen_name = gen_name + '_0'\n",
74
+ "crit_name = proj_id + '_crit'\n",
75
+ "\n",
76
+ "name_gen = proj_id + '_image_gen'\n",
77
+ "path_gen = path/name_gen\n",
78
+ "\n",
79
+ "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
80
+ "\n",
81
+ "nf_factor = 2\n",
82
+ "xtra_tfms=[noisify(p=0.8)]\n",
83
+ "pct_start = 1e-8"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "def get_data(bs:int, sz:int, keep_pct:float):\n",
93
+ " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
94
+ " random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)\n",
95
+ "\n",
96
+ "def get_crit_data(classes, bs, sz):\n",
97
+ " src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\n",
98
+ " ll = src.label_from_folder(classes=classes)\n",
99
+ " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
100
+ " .databunch(bs=bs).normalize(imagenet_stats))\n",
101
+ " return data\n",
102
+ "\n",
103
+ "def create_training_images(fn,i):\n",
104
+ " dest = path_lr/fn.relative_to(path_hr)\n",
105
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
106
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
107
+ " img.save(dest) \n",
108
+ " \n",
109
+ "def save_preds(dl):\n",
110
+ " i=0\n",
111
+ " names = dl.dataset.items\n",
112
+ " \n",
113
+ " for b in dl:\n",
114
+ " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
115
+ " for o in preds:\n",
116
+ " o.save(path_gen/names[i].name)\n",
117
+ " i += 1\n",
118
+ " \n",
119
+ "def save_gen_images():\n",
120
+ " if path_gen.exists(): shutil.rmtree(path_gen)\n",
121
+ " path_gen.mkdir(exist_ok=True)\n",
122
+ " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
123
+ " save_preds(data_gen.fix_dl)\n",
124
+ " PIL.Image.open(path_gen.ls()[0])"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "markdown",
129
+ "metadata": {},
130
+ "source": [
131
+ "## Create black and white training images"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "metadata": {},
137
+ "source": [
138
+ "Only runs if the directory isn't already created."
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "if not path_lr.exists():\n",
148
+ " il = ImageList.from_folder(path_hr)\n",
149
+ " parallel(create_training_images, il.items)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "## Finetune Generator With Noise Augmented Images."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "metadata": {},
162
+ "source": [
163
+ "##### This helps the generator better deal with noisy/grainy video (which is pretty normal)."
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "bs=8\n",
173
+ "sz=192\n",
174
+ "keep_pct=0.25"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "learn_gen = learn_gen.load(pre_gen_name, with_opt=False)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "learn_gen.unfreeze()"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "learn_gen.save(pre_gen_name)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "## Repeatable GAN Cycle"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "#### NOTE\n",
252
+ "Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video). "
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "old_checkpoint_num = 0\n",
262
+ "checkpoint_num = old_checkpoint_num + 1\n",
263
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
264
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
265
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
266
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "metadata": {},
272
+ "source": [
273
+ "### Save Generated Images"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "bs=8\n",
283
+ "sz=192"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "save_gen_images()"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "markdown",
306
+ "metadata": {},
307
+ "source": [
308
+ "### Pretrain Critic"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": null,
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": [
317
+ "bs=16\n",
318
+ "sz=192"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "learn_gen=None\n",
328
+ "gc.collect()"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "learn_critic.fit_one_cycle(4, 1e-4)"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": [
382
+ "learn_critic.save(crit_new_checkpoint_name)"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "markdown",
387
+ "metadata": {},
388
+ "source": [
389
+ "### GAN"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "learn_crit=None\n",
399
+ "learn_gen=None\n",
400
+ "gc.collect()"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "lr=5e-6\n",
410
+ "sz=192\n",
411
+ "bs=5"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {},
436
+ "outputs": [],
437
+ "source": [
438
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
448
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
449
+ " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
450
+ "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
451
+ "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))\n",
452
+ "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "markdown",
457
+ "metadata": {},
458
+ "source": [
459
+ "#### Instructions: \n",
460
+ "Find the checkpoint just before where glitches start to be introduced. So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6."
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": null,
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
470
+ "learn_gen.freeze_to(-1)\n",
471
+ "learn.fit(1,lr)"
472
+ ]
473
+ }
474
+ ],
475
+ "metadata": {
476
+ "kernelspec": {
477
+ "display_name": "Python 3",
478
+ "language": "python",
479
+ "name": "python3"
480
+ },
481
+ "language_info": {
482
+ "codemirror_mode": {
483
+ "name": "ipython",
484
+ "version": 3
485
+ },
486
+ "file_extension": ".py",
487
+ "mimetype": "text/x-python",
488
+ "name": "python",
489
+ "nbconvert_exporter": "python",
490
+ "pygments_lexer": "ipython3",
491
+ "version": "3.7.6"
492
+ }
493
+ },
494
+ "nbformat": 4,
495
+ "nbformat_minor": 4
496
+ }
ColorizeTrainingWandb.ipynb ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Stable Model Training with monitoring through Weights & Biases"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "#### NOTES: \n",
15
+ "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
16
+ "* This model prioritizes stable and reliable renderings. It does particularly well on portraits and landscapes. It's not as colorful as the artistic model.\n",
17
+ "* Training with this notebook has been logged and monitored through [Weights & Biases](https://www.wandb.com/). Refer to [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).\n",
18
+ "* It is **highly** recommended to use a 11 Go GPU to run this notebook. Anything lower will require to reduce the batch size (leading to moro instability) or use of \"Large Model Support\" from IBM WML-CE (not so easy to setup). An alternative is to rent ressources online."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# Install W&B Callback\n",
28
+ "#!pip install wandb"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "#NOTE: This must be the first call in order to work properly!\n",
38
+ "from deoldify import device\n",
39
+ "from deoldify.device_id import DeviceId\n",
40
+ "#choices: CPU, GPU0...GPU7\n",
41
+ "device.set(device=DeviceId.GPU0)"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "import os\n",
51
+ "import fastai\n",
52
+ "from fastai import *\n",
53
+ "from fastai.vision import *\n",
54
+ "from fastai.vision.gan import *\n",
55
+ "from deoldify.generators import *\n",
56
+ "from deoldify.critics import *\n",
57
+ "from deoldify.dataset import *\n",
58
+ "from deoldify.loss import *\n",
59
+ "from deoldify.save import *\n",
60
+ "from PIL import Image, ImageDraw, ImageFont\n",
61
+ "from PIL import ImageFile\n",
62
+ "from torch.utils.data.sampler import RandomSampler, SequentialSampler\n",
63
+ "from tqdm import tqdm\n",
64
+ "import wandb\n",
65
+ "from wandb.fastai import WandbCallback"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": [
72
+ "## Setup"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# Set up W&B: checks user can connect to W&B servers\n",
82
+ "# Note: set up API key the first time\n",
83
+ "wandb.login()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "# Dataset can be downloaded from https://www.kaggle.com/c/imagenet-object-localization-challenge/data\n",
93
+ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
94
+ "path_hr = path\n",
95
+ "path_lr = path/'bandw'\n",
96
+ "\n",
97
+ "proj_id = 'StableModel'\n",
98
+ "\n",
99
+ "gen_name = proj_id + '_gen'\n",
100
+ "pre_gen_name = gen_name + '_0'\n",
101
+ "crit_name = proj_id + '_crit'\n",
102
+ "\n",
103
+ "name_gen = proj_id + '_image_gen'\n",
104
+ "path_gen = path/name_gen\n",
105
+ "\n",
106
+ "nf_factor = 2\n",
107
+ "pct_start = 1e-8"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {},
113
+ "source": [
114
+ "## Iterating through the dataset\n",
115
+ "\n",
116
+ "The dataset is very large and it would take a long time to iterate through all the samples at each epoch.\n",
117
+ "\n",
118
+ "We use custom samplers in order to limit epochs to subsets of data while still iterating slowly through the entire dataset (epoch after epoch). This let us run the validation loop more often where we log metrics as well as prediction samples on validation data."
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "# Reduce quantity of samples per training epoch\n",
128
+ "# Adapted from https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10\n",
129
+ "\n",
130
+ "@classmethod\n",
131
+ "def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,\n",
132
+ " val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,\n",
133
+ " device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':\n",
134
+ " \"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`\"\n",
135
+ " datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
136
+ " val_bs = ifnone(val_bs, bs)\n",
137
+ " if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]\n",
138
+ " dls = [DataLoader(d, b, sampler=sa(d), drop_last=sh, num_workers=num_workers, **dl_kwargs) for d,b,sh,sa in\n",
139
+ " zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False), sampler) if d is not None]\n",
140
+ " return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n",
141
+ "\n",
142
+ "ImageDataBunch.create = create\n",
143
+ "ImageImageList._bunch = ImageDataBunch\n",
144
+ "\n",
145
+ "class FixedLenRandomSampler(RandomSampler):\n",
146
+ " def __init__(self, data_source, epoch_size):\n",
147
+ " super().__init__(data_source)\n",
148
+ " self.epoch_size = epoch_size\n",
149
+ " self.not_sampled = np.array([True]*len(data_source))\n",
150
+ " \n",
151
+ " @property\n",
152
+ " def reset_state(self): self.not_sampled[:] = True\n",
153
+ " \n",
154
+ " def __iter__(self):\n",
155
+ " ns = sum(self.not_sampled)\n",
156
+ " idx_last = []\n",
157
+ " if ns >= len(self):\n",
158
+ " idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()\n",
159
+ " if ns == len(self): self.reset_state\n",
160
+ " else:\n",
161
+ " idx_last = np.where(self.not_sampled)[0].tolist()\n",
162
+ " self.reset_state\n",
163
+ " idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()\n",
164
+ " self.not_sampled[idx] = False\n",
165
+ " idx = [*idx_last, *idx]\n",
166
+ " return iter(idx)\n",
167
+ " \n",
168
+ " def __len__(self):\n",
169
+ " return self.epoch_size"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "def get_data(bs:int, sz:int, keep_pct=1.0, random_seed=None, valid_pct=0.2, epoch_size=1000):\n",
179
+ " \n",
180
+ " # Create samplers\n",
181
+ " train_sampler = partial(FixedLenRandomSampler, epoch_size=epoch_size)\n",
182
+ " samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]\n",
183
+ "\n",
184
+ " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, random_seed=random_seed,\n",
185
+ " keep_pct=keep_pct, samplers=samplers, valid_pct=valid_pct)\n",
186
+ "\n",
187
+ "# Function modified to allow use of custom samplers\n",
188
+ "def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None,\n",
189
+ " keep_pct:float=1.0, num_workers:int=8, samplers=None, valid_pct=0.2, xtra_tfms=[])->ImageDataBunch:\n",
190
+ " src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB')\n",
191
+ " .use_partial_data(sample_pct=keep_pct, seed=random_seed)\n",
192
+ " .split_by_rand_pct(valid_pct, seed=random_seed))\n",
193
+ " data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))\n",
194
+ " .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)\n",
195
+ " .databunch(bs=bs, num_workers=num_workers, sampler=samplers, no_check=True)\n",
196
+ " .normalize(imagenet_stats, do_y=True))\n",
197
+ " data.c = 3\n",
198
+ " return data\n",
199
+ "\n",
200
+ "# Function to limit amount of data in critic\n",
201
+ "def filter_data(pct=1.0):\n",
202
+ " def _f(fname):\n",
203
+ " if 'test' in str(fname):\n",
204
+ " if np.random.random_sample() > pct:\n",
205
+ " return False\n",
206
+ " return True\n",
207
+ " return _f\n",
208
+ "\n",
209
+ "def get_crit_data(classes, bs, sz, pct=1.0):\n",
210
+ " src = ImageList.from_folder(path, include=classes, recurse=True).filter_by_func(filter_data(pct)).split_by_rand_pct(0.1)\n",
211
+ " ll = src.label_from_folder(classes=classes)\n",
212
+ " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
213
+ " .databunch(bs=bs).normalize(imagenet_stats))\n",
214
+ " return data\n",
215
+ "\n",
216
+ "def create_training_images(fn,i):\n",
217
+ " dest = path_lr/fn.relative_to(path_hr)\n",
218
+ " dest.parent.mkdir(parents=True, exist_ok=True)\n",
219
+ " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
220
+ " img.save(dest) \n",
221
+ " \n",
222
+ "def save_preds(dl):\n",
223
+ " i=0\n",
224
+ " names = dl.dataset.items \n",
225
+ " for b in tqdm(dl):\n",
226
+ " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
227
+ " for o in preds:\n",
228
+ " o.save(path_gen/names[i].name)\n",
229
+ " i += 1\n",
230
+ " \n",
231
+ "def save_gen_images(keep_pct):\n",
232
+ " if path_gen.exists(): shutil.rmtree(path_gen)\n",
233
+ " path_gen.mkdir(exist_ok=True)\n",
234
+ " data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\n",
235
+ " save_preds(data_gen.fix_dl)"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "metadata": {},
241
+ "source": [
242
+ "## Create black and white training images"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {},
248
+ "source": [
249
+ "Only runs if the directory isn't already created."
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "if not path_lr.exists():\n",
259
+ " il = ImageList.from_folder(path_hr)\n",
260
+ " parallel(create_training_images, il.items)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": [
269
+ "# Number of black & white images\n",
270
+ "data_size = len(list(path_lr.rglob('*.*')))\n",
271
+ "print('Number of black & white images:', data_size)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {},
277
+ "source": [
278
+ "## Pre-train generator"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "metadata": {},
284
+ "source": [
285
+ "#### NOTE\n",
286
+ "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {},
292
+ "source": [
293
+ "### 64px"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "# Init logging of a new run\n",
303
+ "wandb.init(tags=['Pre-train Gen']) # tags are optional"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": null,
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "bs=88\n",
313
+ "sz=64\n",
314
+ "\n",
315
+ "# Define target number of training/validation samples as well as number of epochs\n",
316
+ "epoch_train_size = 100 * bs\n",
317
+ "epoch_valid_size = 10 * bs\n",
318
+ "valid_pct = epoch_valid_size / data_size\n",
319
+ "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n",
320
+ "\n",
321
+ "# Log hyper parameters\n",
322
+ "wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz,\n",
323
+ " \"Step 1 - epoch size\": epoch_train_size, \"Step 1 - number epochs\": number_epochs})"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": null,
338
+ "metadata": {},
339
+ "outputs": [],
340
+ "source": [
341
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "learn_gen.callback_fns.append(partial(WandbCallback,\n",
351
+ " input_type='images')) # log prediction samples"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "learn_gen.fit_one_cycle(number_epochs, pct_start=0.8, max_lr=slice(1e-3))"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "learn_gen.save(pre_gen_name)"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": [
378
+ "learn_gen.unfreeze()"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "learn_gen.save(pre_gen_name)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "markdown",
401
+ "metadata": {},
402
+ "source": [
403
+ "### 128px"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "bs=20\n",
413
+ "sz=128\n",
414
+ "\n",
415
+ "# Define target number of training/validation samples as well as number of epochs\n",
416
+ "epoch_train_size = 100 * bs\n",
417
+ "epoch_valid_size = 10 * bs\n",
418
+ "valid_pct = epoch_valid_size / data_size\n",
419
+ "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n",
420
+ "\n",
421
+ "# Log hyper parameters\n",
422
+ "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz,\n",
423
+ " \"Step 2 - epoch size\": epoch_train_size, \"Step 2 - number epochs\": number_epochs})"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {},
439
+ "outputs": [],
440
+ "source": [
441
+ "learn_gen.unfreeze()"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {},
448
+ "outputs": [],
449
+ "source": [
450
+ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "learn_gen.save(pre_gen_name)"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "markdown",
464
+ "metadata": {},
465
+ "source": [
466
+ "### 192px"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": null,
472
+ "metadata": {},
473
+ "outputs": [],
474
+ "source": [
475
+ "bs=8\n",
476
+ "sz=192\n",
477
+ "\n",
478
+ "# Define target number of training/validation samples as well as number of epochs\n",
479
+ "epoch_train_size = 100 * bs\n",
480
+ "epoch_valid_size = 10 * bs\n",
481
+ "valid_pct = epoch_valid_size / data_size\n",
482
+ "number_epochs = (data_size - epoch_valid_size) // epoch_train_size // 2 # Training is long - we use half of data\n",
483
+ "\n",
484
+ "# Log hyper parameters\n",
485
+ "wandb.config.update({\"Step 3 - batch size\": bs, \"Step 3 - image size\": sz,\n",
486
+ " \"Step 3 - epoch size\": epoch_train_size, \"Step 3 - number epochs\": number_epochs})"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": null,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "learn_gen.unfreeze()"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "learn_gen.save(pre_gen_name)"
523
+ ]
524
+ },
525
+ {
526
+ "cell_type": "code",
527
+ "execution_count": null,
528
+ "metadata": {},
529
+ "outputs": [],
530
+ "source": [
531
+ "# End logging of current session run\n",
532
+ "# Note: this is optional and would be automatically triggered when stopping the kernel\n",
533
+ "wandb.join()"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "markdown",
538
+ "metadata": {},
539
+ "source": [
540
+ "## Repeatable GAN Cycle"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "markdown",
545
+ "metadata": {},
546
+ "source": [
547
+ "#### NOTE\n",
548
+ "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. "
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": [
557
+ "old_checkpoint_num = 0\n",
558
+ "checkpoint_num = old_checkpoint_num + 1\n",
559
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
560
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
561
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
562
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "### Save Generated Images"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "bs=8\n",
579
+ "sz=192\n",
580
+ "\n",
581
+ "# Define target number of training/validation samples as well as number of epochs\n",
582
+ "epoch_train_size = 100 * bs\n",
583
+ "epoch_valid_size = 10 * bs\n",
584
+ "valid_pct = epoch_valid_size / data_size\n",
585
+ "number_epochs = (data_size - epoch_valid_size) // epoch_train_size"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": null,
591
+ "metadata": {},
592
+ "outputs": [],
593
+ "source": [
594
+ "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": null,
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "metadata": {},
610
+ "outputs": [],
611
+ "source": [
612
+ "save_gen_images(0.1)"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "markdown",
617
+ "metadata": {},
618
+ "source": [
619
+ "### Pretrain Critic"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "markdown",
624
+ "metadata": {},
625
+ "source": [
626
+ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": null,
632
+ "metadata": {},
633
+ "outputs": [],
634
+ "source": [
635
+ "if old_checkpoint_num == 0:\n",
636
+ " \n",
637
+ " # Init logging of a new run\n",
638
+ " wandb.init(tags=['Pre-train Crit']) # tags are optional\n",
639
+ " \n",
640
+ " bs=64\n",
641
+ " sz=128\n",
642
+ " learn_gen=None\n",
643
+ " \n",
644
+ " # Log hyper parameters\n",
645
+ " wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz})\n",
646
+ "\n",
647
+ " gc.collect() \n",
648
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
649
+ " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
650
+ " learn_crit = colorize_crit_learner(data=data_crit, nf=256)\n",
651
+ " learn_crit.callback_fns.append(partial(WandbCallback)) # log prediction samples\n",
652
+ " learn_crit.fit_one_cycle(6, 1e-3)\n",
653
+ " learn_crit.save(crit_old_checkpoint_name)"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": null,
659
+ "metadata": {},
660
+ "outputs": [],
661
+ "source": [
662
+ "bs=16\n",
663
+ "sz=192\n",
664
+ "\n",
665
+ "# Log hyper parameters\n",
666
+ "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz})"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "metadata": {},
673
+ "outputs": [],
674
+ "source": [
675
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "metadata": {},
682
+ "outputs": [],
683
+ "source": [
684
+ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": null,
690
+ "metadata": {},
691
+ "outputs": [],
692
+ "source": [
693
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "execution_count": null,
699
+ "metadata": {},
700
+ "outputs": [],
701
+ "source": [
702
+ "learn_crit.fit_one_cycle(4, 1e-4)"
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "code",
707
+ "execution_count": null,
708
+ "metadata": {},
709
+ "outputs": [],
710
+ "source": [
711
+ "learn_crit.save(crit_new_checkpoint_name)"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "markdown",
716
+ "metadata": {},
717
+ "source": [
718
+ "### GAN"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": null,
724
+ "metadata": {},
725
+ "outputs": [],
726
+ "source": [
727
+ "# free up memory\n",
728
+ "learn_crit=None\n",
729
+ "learn_gen=None\n",
730
+ "learn=None\n",
731
+ "gc.collect()"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "code",
736
+ "execution_count": null,
737
+ "metadata": {},
738
+ "outputs": [],
739
+ "source": [
740
+ "# Set old_checkpoint_num to last iteration\n",
741
+ "old_checkpoint_num = 0\n",
742
+ "save_checkpoints = False\n",
743
+ "batch_per_epoch = 200\n",
744
+ "\n",
745
+ "checkpoint_num = old_checkpoint_num + 1\n",
746
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
747
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
748
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
749
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num) \n",
750
+ "\n",
751
+ "if False: # need only to do it once\n",
752
+ " \n",
753
+ " # Generate data\n",
754
+ " print('Generating data…')\n",
755
+ " bs=8\n",
756
+ " sz=192\n",
757
+ " epoch_train_size = batch_per_epoch * bs\n",
758
+ " epoch_valid_size = batch_per_epoch * bs // 10\n",
759
+ " valid_pct = epoch_valid_size / data_size\n",
760
+ " data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
761
+ " learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
762
+ " save_gen_images(0.02)\n",
763
+ "\n",
764
+ " # Pre-train critic\n",
765
+ " print('Pre-training critic…')\n",
766
+ " bs=16\n",
767
+ " sz=192\n",
768
+ "\n",
769
+ " len_test = len(list((path / 'test').rglob('*.*')))\n",
770
+ " len_gen = len(list((path / name_gen).rglob('*.*')))\n",
771
+ " keep_test_pct = len_gen / len_test * 2\n",
772
+ "\n",
773
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
774
+ " learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\n",
775
+ " learn_crit.fit_one_cycle(1, 1e-4)\n",
776
+ " learn_crit.save(crit_new_checkpoint_name)\n",
777
+ "\n",
778
+ "# Creating GAN\n",
779
+ "print('Creating GAN…')\n",
780
+ "sz=192\n",
781
+ "bs=8\n",
782
+ "lr_GAN=2e-5\n",
783
+ "epoch_train_size = batch_per_epoch * bs\n",
784
+ "epoch_valid_size = batch_per_epoch * bs // 10\n",
785
+ "valid_pct = epoch_valid_size / data_size\n",
786
+ "len_test = len(list((path / 'test').rglob('*.*')))\n",
787
+ "len_gen = len(list((path / name_gen).rglob('*.*')))\n",
788
+ "keep_test_pct = len_gen / len_test * 2\n",
789
+ "\n",
790
+ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
791
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\n",
792
+ "data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
793
+ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
794
+ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
795
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
796
+ " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
797
+ "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
798
+ "learn.callback_fns.append(partial(WandbCallback, input_type='images', seed=None, save_model=False))\n",
799
+ "learn.data = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
800
+ "\n",
801
+ "# Start logging to W&B\n",
802
+ "wandb.init(tags=['GAN'])\n",
803
+ "wandb.config.update({\"learning rate\": lr_GAN}) \n",
804
+ "\n",
805
+ "# Run the loop until satisfied with the results\n",
806
+ "while True:\n",
807
+ "\n",
808
+ " # Current loop\n",
809
+ " checkpoint_num = old_checkpoint_num + 1\n",
810
+ " gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
811
+ " gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
812
+ " crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
813
+ " crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num) \n",
814
+ " \n",
815
+ " \n",
816
+ " # GAN for 10 epochs between each checkpoint\n",
817
+ " try:\n",
818
+ " learn.fit(1, lr_GAN)\n",
819
+ " except:\n",
820
+ " # Sometimes we get an error for some unknown reason during callbacks\n",
821
+ " learn.callback_fns[-1](learn).on_epoch_end(old_checkpoint_num, None, [])\n",
822
+ " \n",
823
+ " if save_checkpoints:\n",
824
+ " learn_crit.save(crit_new_checkpoint_name)\n",
825
+ " learn_gen.save(gen_new_checkpoint_name)\n",
826
+ " \n",
827
+ " old_checkpoint_num += 1"
828
+ ]
829
+ },
830
+ {
831
+ "cell_type": "code",
832
+ "execution_count": null,
833
+ "metadata": {},
834
+ "outputs": [],
835
+ "source": [
836
+ "# End logging of current session run\n",
837
+ "# Note: this is optional and would be automatically triggered when stopping the kernel\n",
838
+ "wandb.join()"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "code",
843
+ "execution_count": null,
844
+ "metadata": {},
845
+ "outputs": [],
846
+ "source": []
847
+ }
848
+ ],
849
+ "metadata": {
850
+ "kernelspec": {
851
+ "display_name": "Python 3",
852
+ "language": "python",
853
+ "name": "python3"
854
+ },
855
+ "language_info": {
856
+ "codemirror_mode": {
857
+ "name": "ipython",
858
+ "version": 3
859
+ },
860
+ "file_extension": ".py",
861
+ "mimetype": "text/x-python",
862
+ "name": "python",
863
+ "nbconvert_exporter": "python",
864
+ "pygments_lexer": "ipython3",
865
+ "version": "3.7.6"
866
+ }
867
+ },
868
+ "nbformat": 4,
869
+ "nbformat_minor": 4
870
+ }
ImageColorizer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ImageColorizerArtisticTests.ipynb ADDED
@@ -0,0 +1,3319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "#NOTE: This must be the first call in order to work properly!\n",
10
+ "from deoldify import device\n",
11
+ "from deoldify.device_id import DeviceId\n",
12
+ "#choices: CPU, GPU0...GPU7\n",
13
+ "device.set(device=DeviceId.GPU0)"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from deoldify.visualize import *\n",
23
+ "plt.style.use('dark_background')\n",
24
+ "import warnings\n",
25
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*?Your .*? set is empty.*?\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "#Adjust render_factor (int) if image doesn't look quite right (max 64 on 11GB GPU). The default here works for most photos. \n",
35
+ "#It literally just is a number multiplied by 16 to get the square render resolution. \n",
36
+ "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
37
+ "#Example: render_factor=21 => color is rendered at 16x21 = 336x336 px. \n",
38
+ "render_factor=35"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "vis = get_image_colorizer(render_factor=render_factor, artistic=True)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "vis.plot_transformed_image(\"test_images/poolparty.jpg\", render_factor=38, compare=True)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "vis.plot_transformed_image(\"test_images/1852GatekeepersWindsor.jpg\", render_factor=45, compare=True)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "vis.plot_transformed_image(\"test_images/Chief.jpg\", render_factor=14, compare=True)"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "vis.plot_transformed_image(\"test_images/1850SchoolForGirls.jpg\", render_factor=46, compare=True)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "vis.plot_transformed_image(\"test_images/AtlanticCityBeach1905.jpg\", render_factor=30, compare=True)"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "vis.plot_transformed_image(\"test_images/CottonMillWorkers1913.jpg\", render_factor=45, compare=True)"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "vis.plot_transformed_image(\"test_images/BrooklynNavyYardHospital.jpg\", compare=True)"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "vis.plot_transformed_image(\"test_images/FinnishPeasant1867.jpg\", render_factor=30, compare=True)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "vis.plot_transformed_image(\"test_images/AtlanticCity1905.png\", render_factor=25, compare=True)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=21, compare=True)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "vis.plot_transformed_image(\"test_images/Drive1905.jpg\", compare=True)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "vis.plot_transformed_image(\"test_images/IronLung.png\", render_factor=21, compare=True)"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "vis.plot_transformed_image(\"test_images/FamilyWithDog.jpg\", render_factor=21, compare=True)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "vis.plot_transformed_image(\"test_images/DayAtSeaBelgium.jpg\", render_factor=30, compare=True)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\", render_factor=29, compare=True)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "vis.plot_transformed_image(\"test_images/OldWomanSweden1904.jpg\", render_factor=36, compare=True)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "vis.plot_transformed_image(\"test_images/WomenTapingPlanes.jpg\", render_factor=32, compare=True)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "vis.plot_transformed_image(\"test_images/overmiller.jpg\", render_factor=13, compare=True)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "vis.plot_transformed_image(\"test_images/BritishDispatchRider.jpg\", render_factor=19, compare=True)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "vis.plot_transformed_image(\"test_images/MuseauNacionalDosCoches.jpg\", render_factor=17, compare=True)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "vis.plot_transformed_image(\"test_images/abe.jpg\", render_factor=15, compare=True)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "vis.plot_transformed_image(\"test_images/RossCorbettHouseCork.jpg\", render_factor=30, compare=True)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "vis.plot_transformed_image(\"test_images/HPLabelleOfficeMontreal.jpg\", render_factor=40, compare=True)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\", render_factor=29, compare=True)"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "vis.plot_transformed_image(\"test_images/airmen1943.jpg\", render_factor=25, compare=True)"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "vis.plot_transformed_image(\"test_images/20sWoman.jpg\", render_factor=22, compare=True)"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "vis.plot_transformed_image(\"test_images/egypt-1.jpg\", render_factor=15, compare=True)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "vis.plot_transformed_image(\"test_images/Rutherford_Hayes.jpg\", render_factor=15, compare=True)"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "vis.plot_transformed_image(\"test_images/einstein_portrait.jpg\", render_factor=15, compare=True)"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": null,
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": [
317
+ "vis.plot_transformed_image(\"test_images/pinkerton.jpg\", render_factor=13, compare=True)"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "vis.plot_transformed_image(\"test_images/WaltWhitman.jpg\", render_factor=12, compare=True)"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "vis.plot_transformed_image(\"test_images/dorothea-lange.jpg\", render_factor=25, compare=True)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "vis.plot_transformed_image(\"test_images/Hemmingway2.jpg\", render_factor=15, compare=True)"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "vis.plot_transformed_image(\"test_images/hemmingway.jpg\", render_factor=9, compare=True)"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "vis.plot_transformed_image(\"test_images/smoking_kid.jpg\", render_factor=30, compare=True)"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\", render_factor=45, compare=True)"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "vis.plot_transformed_image(\"test_images/dustbowl_2.jpg\", render_factor=16, compare=True)"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "vis.plot_transformed_image(\"test_images/camera_man.jpg\", render_factor=23, compare=True)"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "vis.plot_transformed_image(\"test_images/migrant_mother.jpg\", render_factor=35, compare=True)"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "vis.plot_transformed_image(\"test_images/marktwain.jpg\", render_factor=10, compare=True)"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "vis.plot_transformed_image(\"test_images/HelenKeller.jpg\", render_factor=45, compare=True)"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "vis.plot_transformed_image(\"test_images/Evelyn_Nesbit.jpg\", render_factor=21, compare=True)"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "vis.plot_transformed_image(\"test_images/Eddie-Adams.jpg\", render_factor=22, compare=True)"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "vis.plot_transformed_image(\"test_images/soldier_kids.jpg\", render_factor=18, compare=True)"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": null,
449
+ "metadata": {},
450
+ "outputs": [],
451
+ "source": [
452
+ "vis.plot_transformed_image(\"test_images/AnselAdamsYosemite.jpg\", compare=True)"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": null,
458
+ "metadata": {},
459
+ "outputs": [],
460
+ "source": [
461
+ "vis.plot_transformed_image(\"test_images/unnamed.jpg\", render_factor=40, compare=True)"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": null,
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "vis.plot_transformed_image(\"test_images/workers_canyon.jpg\", render_factor=48, compare=True)"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "vis.plot_transformed_image(\"test_images/CottonMill.jpg\", render_factor=16, compare=True)"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "vis.plot_transformed_image(\"test_images/JudyGarland.jpeg\", render_factor=25, compare=True)"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "vis.plot_transformed_image(\"test_images/kids_pit.jpg\", render_factor=35, compare=True)"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": null,
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "vis.plot_transformed_image(\"test_images/last_samurai.jpg\", render_factor=15, compare=True)"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "execution_count": null,
512
+ "metadata": {},
513
+ "outputs": [],
514
+ "source": [
515
+ "vis.plot_transformed_image(\"test_images/AnselAdamsWhiteChurch.jpg\", render_factor=21, compare=True)"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "metadata": {},
522
+ "outputs": [],
523
+ "source": [
524
+ "vis.plot_transformed_image(\"test_images/opium.jpg\", render_factor=30, compare=True)"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "metadata": {},
531
+ "outputs": [],
532
+ "source": [
533
+ "vis.plot_transformed_image(\"test_images/dorothea_lange_2.jpg\", render_factor=22, compare=True)"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": null,
539
+ "metadata": {},
540
+ "outputs": [],
541
+ "source": [
542
+ "vis.plot_transformed_image(\"test_images/rgs.jpg\", render_factor=46, compare=True)"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "vis.plot_transformed_image(\"test_images/wh-auden.jpg\", render_factor=24, compare=True)"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "vis.plot_transformed_image(\"test_images/w-b-yeats.jpg\", render_factor=16, compare=True)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "vis.plot_transformed_image(\"test_images/marilyn_portrait.jpg\", render_factor=30, compare=True)"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "vis.plot_transformed_image(\"test_images/wilson-slaverevivalmeeting.jpg\", render_factor=38, compare=True)"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": null,
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "vis.plot_transformed_image(\"test_images/ww1_trench.jpg\", render_factor=18, compare=True)"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "metadata": {},
594
+ "outputs": [],
595
+ "source": [
596
+ "vis.plot_transformed_image(\"test_images/women-bikers.png\", render_factor=47, compare=True)"
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "metadata": {},
603
+ "outputs": [],
604
+ "source": [
605
+ "vis.plot_transformed_image(\"test_images/Unidentified1855.jpg\", render_factor=32, compare=True)"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": null,
611
+ "metadata": {},
612
+ "outputs": [],
613
+ "source": [
614
+ "vis.plot_transformed_image(\"test_images/skycrapper_lunch.jpg\", render_factor=32, compare=True)"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": null,
620
+ "metadata": {},
621
+ "outputs": [],
622
+ "source": [
623
+ "vis.plot_transformed_image(\"test_images/sioux.jpg\", render_factor=35, compare=True)"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": null,
629
+ "metadata": {},
630
+ "outputs": [],
631
+ "source": [
632
+ "vis.plot_transformed_image(\"test_images/school_kids.jpg\", render_factor=26, compare=True)"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "code",
637
+ "execution_count": null,
638
+ "metadata": {},
639
+ "outputs": [],
640
+ "source": [
641
+ "vis.plot_transformed_image(\"test_images/royal_family.jpg\", render_factor=33, compare=True)"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "vis.plot_transformed_image(\"test_images/redwood_lumberjacks.jpg\", render_factor=47, compare=True)"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": null,
656
+ "metadata": {},
657
+ "outputs": [],
658
+ "source": [
659
+ "vis.plot_transformed_image(\"test_images/poverty.jpg\", render_factor=26, compare=True)"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "execution_count": null,
665
+ "metadata": {},
666
+ "outputs": [],
667
+ "source": [
668
+ "vis.plot_transformed_image(\"test_images/paperboy.jpg\", render_factor=40, compare=True)"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": null,
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": [
677
+ "vis.plot_transformed_image(\"test_images/NativeAmericans.jpg\", render_factor=22, compare=True)"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": null,
683
+ "metadata": {},
684
+ "outputs": [],
685
+ "source": [
686
+ "vis.plot_transformed_image(\"test_images/helmut_newton-.jpg\", render_factor=43, compare=True)"
687
+ ]
688
+ },
689
+ {
690
+ "cell_type": "code",
691
+ "execution_count": null,
692
+ "metadata": {},
693
+ "outputs": [],
694
+ "source": [
695
+ "vis.plot_transformed_image(\"test_images/Greece1911.jpg\", render_factor=26, compare=True)"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "execution_count": null,
701
+ "metadata": {},
702
+ "outputs": [],
703
+ "source": [
704
+ "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\", render_factor=35, compare=True)"
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": null,
710
+ "metadata": {},
711
+ "outputs": [],
712
+ "source": [
713
+ "vis.plot_transformed_image(\"test_images/EgyptColosus.jpg\", render_factor=35, compare=True)"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": null,
719
+ "metadata": {},
720
+ "outputs": [],
721
+ "source": [
722
+ "vis.plot_transformed_image(\"test_images/egypt-2.jpg\", render_factor=22, compare=True)"
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "code",
727
+ "execution_count": null,
728
+ "metadata": {},
729
+ "outputs": [],
730
+ "source": [
731
+ "vis.plot_transformed_image(\"test_images/dustbowl_sd.jpg\", render_factor=12, compare=True)"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "code",
736
+ "execution_count": null,
737
+ "metadata": {},
738
+ "outputs": [],
739
+ "source": [
740
+ "vis.plot_transformed_image(\"test_images/dustbowl_people.jpg\", render_factor=24, compare=True)"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": null,
746
+ "metadata": {},
747
+ "outputs": [],
748
+ "source": [
749
+ "vis.plot_transformed_image(\"test_images/dustbowl_5.jpg\", render_factor=18, compare=True)"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": null,
755
+ "metadata": {},
756
+ "outputs": [],
757
+ "source": [
758
+ "vis.plot_transformed_image(\"test_images/dustbowl_1.jpg\", render_factor=15, compare=True)"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": null,
764
+ "metadata": {},
765
+ "outputs": [],
766
+ "source": [
767
+ "vis.plot_transformed_image(\"test_images/DriveThroughGiantTree.jpg\", render_factor=39, compare=True)"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": null,
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "vis.plot_transformed_image(\"test_images/covered-wagons-traveling.jpg\", render_factor=18, compare=True)"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "metadata": {},
783
+ "outputs": [],
784
+ "source": [
785
+ "vis.plot_transformed_image(\"test_images/civil-war_2.jpg\", render_factor=12, compare=True)"
786
+ ]
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": null,
791
+ "metadata": {},
792
+ "outputs": [],
793
+ "source": [
794
+ "vis.plot_transformed_image(\"test_images/civil_war_4.jpg\", render_factor=15, compare=True)"
795
+ ]
796
+ },
797
+ {
798
+ "cell_type": "code",
799
+ "execution_count": null,
800
+ "metadata": {},
801
+ "outputs": [],
802
+ "source": [
803
+ "vis.plot_transformed_image(\"test_images/civil_war_3.jpg\", render_factor=46, compare=True)"
804
+ ]
805
+ },
806
+ {
807
+ "cell_type": "code",
808
+ "execution_count": null,
809
+ "metadata": {},
810
+ "outputs": [],
811
+ "source": [
812
+ "vis.plot_transformed_image(\"test_images/civil_war.jpg\", render_factor=45, compare=True)"
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "code",
817
+ "execution_count": null,
818
+ "metadata": {},
819
+ "outputs": [],
820
+ "source": [
821
+ "vis.plot_transformed_image(\"test_images/BritishSlum.jpg\", render_factor=45, compare=True)"
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "execution_count": null,
827
+ "metadata": {},
828
+ "outputs": [],
829
+ "source": [
830
+ "vis.plot_transformed_image(\"test_images/bicycles.jpg\", render_factor=33, compare=True)"
831
+ ]
832
+ },
833
+ {
834
+ "cell_type": "code",
835
+ "execution_count": null,
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "vis.plot_transformed_image(\"test_images/brooklyn_girls_1940s.jpg\", render_factor=35, compare=True)"
840
+ ]
841
+ },
842
+ {
843
+ "cell_type": "code",
844
+ "execution_count": null,
845
+ "metadata": {},
846
+ "outputs": [],
847
+ "source": [
848
+ "vis.plot_transformed_image(\"test_images/40sCouple.jpg\", render_factor=20, compare=True)"
849
+ ]
850
+ },
851
+ {
852
+ "cell_type": "code",
853
+ "execution_count": null,
854
+ "metadata": {},
855
+ "outputs": [],
856
+ "source": [
857
+ "vis.plot_transformed_image(\"test_images/1946Wedding.jpg\", render_factor=30, compare=True)"
858
+ ]
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "metadata": {},
864
+ "outputs": [],
865
+ "source": [
866
+ "vis.plot_transformed_image(\"test_images/Dolores1920s.jpg\", render_factor=35, compare=True)"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": null,
872
+ "metadata": {},
873
+ "outputs": [],
874
+ "source": [
875
+ "vis.plot_transformed_image(\"test_images/TitanicGym.jpg\", render_factor=31, compare=True)"
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": null,
881
+ "metadata": {},
882
+ "outputs": [],
883
+ "source": [
884
+ "vis.plot_transformed_image(\"test_images/FrenchVillage1950s.jpg\", render_factor=38, compare=True)"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": null,
890
+ "metadata": {},
891
+ "outputs": [],
892
+ "source": [
893
+ "vis.plot_transformed_image(\"test_images/ClassDivide1930sBrittain.jpg\", render_factor=30, compare=True)"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "metadata": {},
900
+ "outputs": [],
901
+ "source": [
902
+ "vis.plot_transformed_image(\"test_images/1870sSphinx.jpg\", render_factor=15, compare=True)"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "execution_count": null,
908
+ "metadata": {},
909
+ "outputs": [],
910
+ "source": [
911
+ "vis.plot_transformed_image(\"test_images/1890Surfer.png\", render_factor=30, compare=True)"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": null,
917
+ "metadata": {},
918
+ "outputs": [],
919
+ "source": []
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": null,
924
+ "metadata": {},
925
+ "outputs": [],
926
+ "source": []
927
+ },
928
+ {
929
+ "cell_type": "code",
930
+ "execution_count": null,
931
+ "metadata": {},
932
+ "outputs": [],
933
+ "source": [
934
+ "vis.plot_transformed_image(\"test_images/TV1930s.jpg\", render_factor=30, compare=True)"
935
+ ]
936
+ },
937
+ {
938
+ "cell_type": "code",
939
+ "execution_count": null,
940
+ "metadata": {},
941
+ "outputs": [],
942
+ "source": [
943
+ "vis.plot_transformed_image(\"test_images/1864UnionSoldier.jpg\", render_factor=13, compare=True)"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": null,
949
+ "metadata": {},
950
+ "outputs": [],
951
+ "source": [
952
+ "vis.plot_transformed_image(\"test_images/1890sMedStudents.jpg\", render_factor=23, compare=True)"
953
+ ]
954
+ },
955
+ {
956
+ "cell_type": "code",
957
+ "execution_count": null,
958
+ "metadata": {},
959
+ "outputs": [],
960
+ "source": [
961
+ "vis.plot_transformed_image(\"test_images/BellyLaughWWI.jpg\", render_factor=13, compare=True)"
962
+ ]
963
+ },
964
+ {
965
+ "cell_type": "code",
966
+ "execution_count": null,
967
+ "metadata": {},
968
+ "outputs": [],
969
+ "source": [
970
+ "vis.plot_transformed_image(\"test_images/PiggyBackRide.jpg\", render_factor=20, compare=True)"
971
+ ]
972
+ },
973
+ {
974
+ "cell_type": "code",
975
+ "execution_count": null,
976
+ "metadata": {},
977
+ "outputs": [],
978
+ "source": [
979
+ "vis.plot_transformed_image(\"test_images/HealingTree.jpg\", render_factor=13, compare=True)"
980
+ ]
981
+ },
982
+ {
983
+ "cell_type": "code",
984
+ "execution_count": null,
985
+ "metadata": {},
986
+ "outputs": [],
987
+ "source": [
988
+ "vis.plot_transformed_image(\"test_images/ManPile.jpg\", render_factor=30, compare=True)"
989
+ ]
990
+ },
991
+ {
992
+ "cell_type": "code",
993
+ "execution_count": null,
994
+ "metadata": {},
995
+ "outputs": [],
996
+ "source": [
997
+ "vis.plot_transformed_image(\"test_images/1910Bike.jpg\", render_factor=20, compare=True)"
998
+ ]
999
+ },
1000
+ {
1001
+ "cell_type": "code",
1002
+ "execution_count": null,
1003
+ "metadata": {},
1004
+ "outputs": [],
1005
+ "source": [
1006
+ "vis.plot_transformed_image(\"test_images/FreeportIL.jpg\", render_factor=36, compare=True)"
1007
+ ]
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "execution_count": null,
1012
+ "metadata": {},
1013
+ "outputs": [],
1014
+ "source": [
1015
+ "vis.plot_transformed_image(\"test_images/DutchBabyCoupleEllis.jpg\", render_factor=25, compare=True)"
1016
+ ]
1017
+ },
1018
+ {
1019
+ "cell_type": "code",
1020
+ "execution_count": null,
1021
+ "metadata": {},
1022
+ "outputs": [],
1023
+ "source": [
1024
+ "vis.plot_transformed_image(\"test_images/InuitWoman1903.png\", render_factor=33, compare=True)"
1025
+ ]
1026
+ },
1027
+ {
1028
+ "cell_type": "code",
1029
+ "execution_count": null,
1030
+ "metadata": {},
1031
+ "outputs": [],
1032
+ "source": [
1033
+ "vis.plot_transformed_image(\"test_images/1920sDancing.jpg\", render_factor=16, compare=True)"
1034
+ ]
1035
+ },
1036
+ {
1037
+ "cell_type": "code",
1038
+ "execution_count": null,
1039
+ "metadata": {},
1040
+ "outputs": [],
1041
+ "source": [
1042
+ "vis.plot_transformed_image(\"test_images/AirmanDad.jpg\", render_factor=16, compare=True)"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": null,
1048
+ "metadata": {},
1049
+ "outputs": [],
1050
+ "source": [
1051
+ "vis.plot_transformed_image(\"test_images/1910Racket.png\", render_factor=34, compare=True)"
1052
+ ]
1053
+ },
1054
+ {
1055
+ "cell_type": "code",
1056
+ "execution_count": null,
1057
+ "metadata": {},
1058
+ "outputs": [],
1059
+ "source": [
1060
+ "vis.plot_transformed_image(\"test_images/1880Paris.jpg\", render_factor=30, compare=True)"
1061
+ ]
1062
+ },
1063
+ {
1064
+ "cell_type": "code",
1065
+ "execution_count": null,
1066
+ "metadata": {},
1067
+ "outputs": [],
1068
+ "source": [
1069
+ "vis.plot_transformed_image(\"test_images/Deadwood1860s.jpg\", render_factor=38, compare=True)"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "code",
1074
+ "execution_count": null,
1075
+ "metadata": {},
1076
+ "outputs": [],
1077
+ "source": [
1078
+ "vis.plot_transformed_image(\"test_images/1860sSamauris.jpg\", render_factor=34, compare=True)"
1079
+ ]
1080
+ },
1081
+ {
1082
+ "cell_type": "code",
1083
+ "execution_count": null,
1084
+ "metadata": {},
1085
+ "outputs": [],
1086
+ "source": [
1087
+ "vis.plot_transformed_image(\"test_images/LondonUnderground1860.jpg\", render_factor=40, compare=True)"
1088
+ ]
1089
+ },
1090
+ {
1091
+ "cell_type": "code",
1092
+ "execution_count": null,
1093
+ "metadata": {},
1094
+ "outputs": [],
1095
+ "source": [
1096
+ "vis.plot_transformed_image(\"test_images/Mid1800sSisters.jpg\", render_factor=22, compare=True)"
1097
+ ]
1098
+ },
1099
+ {
1100
+ "cell_type": "code",
1101
+ "execution_count": null,
1102
+ "metadata": {},
1103
+ "outputs": [],
1104
+ "source": [
1105
+ "vis.plot_transformed_image(\"test_images/1860Girls.jpg\", render_factor=45, compare=True)"
1106
+ ]
1107
+ },
1108
+ {
1109
+ "cell_type": "code",
1110
+ "execution_count": null,
1111
+ "metadata": {},
1112
+ "outputs": [],
1113
+ "source": [
1114
+ "vis.plot_transformed_image(\"test_images/SanFran1851.jpg\", render_factor=22, compare=True)"
1115
+ ]
1116
+ },
1117
+ {
1118
+ "cell_type": "code",
1119
+ "execution_count": null,
1120
+ "metadata": {},
1121
+ "outputs": [],
1122
+ "source": [
1123
+ "vis.plot_transformed_image(\"test_images/Kabuki1870s.png\", render_factor=25, compare=True)"
1124
+ ]
1125
+ },
1126
+ {
1127
+ "cell_type": "code",
1128
+ "execution_count": null,
1129
+ "metadata": {},
1130
+ "outputs": [],
1131
+ "source": [
1132
+ "vis.plot_transformed_image(\"test_images/Mormons1870s.jpg\", render_factor=47, compare=True)"
1133
+ ]
1134
+ },
1135
+ {
1136
+ "cell_type": "code",
1137
+ "execution_count": null,
1138
+ "metadata": {},
1139
+ "outputs": [],
1140
+ "source": [
1141
+ "vis.plot_transformed_image(\"test_images/EgyptianWomenLate1800s.jpg\", render_factor=7, compare=True)"
1142
+ ]
1143
+ },
1144
+ {
1145
+ "cell_type": "code",
1146
+ "execution_count": null,
1147
+ "metadata": {},
1148
+ "outputs": [],
1149
+ "source": [
1150
+ "vis.plot_transformed_image(\"test_images/PicadillyLate1800s.jpg\", render_factor=46, compare=True)"
1151
+ ]
1152
+ },
1153
+ {
1154
+ "cell_type": "code",
1155
+ "execution_count": null,
1156
+ "metadata": {},
1157
+ "outputs": [],
1158
+ "source": [
1159
+ "vis.plot_transformed_image(\"test_images/SutroBaths1880s.jpg\", render_factor=18, compare=True)"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "code",
1164
+ "execution_count": null,
1165
+ "metadata": {},
1166
+ "outputs": [],
1167
+ "source": [
1168
+ "vis.plot_transformed_image(\"test_images/1880sBrooklynBridge.jpg\", render_factor=18, compare=True)"
1169
+ ]
1170
+ },
1171
+ {
1172
+ "cell_type": "code",
1173
+ "execution_count": null,
1174
+ "metadata": {},
1175
+ "outputs": [],
1176
+ "source": [
1177
+ "vis.plot_transformed_image(\"test_images/ChinaOpiumc1880.jpg\", render_factor=43, compare=True)"
1178
+ ]
1179
+ },
1180
+ {
1181
+ "cell_type": "code",
1182
+ "execution_count": null,
1183
+ "metadata": {},
1184
+ "outputs": [],
1185
+ "source": [
1186
+ "vis.plot_transformed_image(\"test_images/Locomotive1880s.jpg\", render_factor=10, compare=True)"
1187
+ ]
1188
+ },
1189
+ {
1190
+ "cell_type": "code",
1191
+ "execution_count": null,
1192
+ "metadata": {},
1193
+ "outputs": [],
1194
+ "source": [
1195
+ "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\", render_factor=19, compare=True)"
1196
+ ]
1197
+ },
1198
+ {
1199
+ "cell_type": "code",
1200
+ "execution_count": null,
1201
+ "metadata": {},
1202
+ "outputs": [],
1203
+ "source": [
1204
+ "vis.plot_transformed_image(\"test_images/VictorianDragQueen1880s.png\", render_factor=13, compare=True)"
1205
+ ]
1206
+ },
1207
+ {
1208
+ "cell_type": "code",
1209
+ "execution_count": null,
1210
+ "metadata": {},
1211
+ "outputs": [],
1212
+ "source": [
1213
+ "vis.plot_transformed_image(\"test_images/Sami1880s.jpg\", render_factor=39, compare=True)"
1214
+ ]
1215
+ },
1216
+ {
1217
+ "cell_type": "code",
1218
+ "execution_count": null,
1219
+ "metadata": {},
1220
+ "outputs": [],
1221
+ "source": [
1222
+ "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\", render_factor=32, compare=True)"
1223
+ ]
1224
+ },
1225
+ {
1226
+ "cell_type": "code",
1227
+ "execution_count": null,
1228
+ "metadata": {},
1229
+ "outputs": [],
1230
+ "source": [
1231
+ "vis.plot_transformed_image(\"test_images/Rottindean1890s.png\", render_factor=22, compare=True)"
1232
+ ]
1233
+ },
1234
+ {
1235
+ "cell_type": "code",
1236
+ "execution_count": null,
1237
+ "metadata": {},
1238
+ "outputs": [],
1239
+ "source": [
1240
+ "vis.plot_transformed_image(\"test_images/1890sPingPong.jpg\", render_factor=15, compare=True)"
1241
+ ]
1242
+ },
1243
+ {
1244
+ "cell_type": "code",
1245
+ "execution_count": null,
1246
+ "metadata": {},
1247
+ "outputs": [],
1248
+ "source": [
1249
+ "vis.plot_transformed_image(\"test_images/London1937.png\", render_factor=36, compare=True)"
1250
+ ]
1251
+ },
1252
+ {
1253
+ "cell_type": "code",
1254
+ "execution_count": null,
1255
+ "metadata": {},
1256
+ "outputs": [],
1257
+ "source": [
1258
+ "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\", render_factor=27, compare=True)"
1259
+ ]
1260
+ },
1261
+ {
1262
+ "cell_type": "code",
1263
+ "execution_count": null,
1264
+ "metadata": {},
1265
+ "outputs": [],
1266
+ "source": [
1267
+ "vis.plot_transformed_image(\"test_images/OregonTrail1870s.jpg\", render_factor=25, compare=True)"
1268
+ ]
1269
+ },
1270
+ {
1271
+ "cell_type": "code",
1272
+ "execution_count": null,
1273
+ "metadata": {},
1274
+ "outputs": [],
1275
+ "source": [
1276
+ "vis.plot_transformed_image(\"test_images/EasterNyc1911.jpg\", render_factor=20, compare=True)"
1277
+ ]
1278
+ },
1279
+ {
1280
+ "cell_type": "code",
1281
+ "execution_count": null,
1282
+ "metadata": {},
1283
+ "outputs": [],
1284
+ "source": [
1285
+ "vis.plot_transformed_image(\"test_images/1899NycBlizzard.jpg\", render_factor=20, compare=True)"
1286
+ ]
1287
+ },
1288
+ {
1289
+ "cell_type": "code",
1290
+ "execution_count": null,
1291
+ "metadata": {},
1292
+ "outputs": [],
1293
+ "source": [
1294
+ "vis.plot_transformed_image(\"test_images/Edinburgh1920s.jpg\", render_factor=21, compare=True)"
1295
+ ]
1296
+ },
1297
+ {
1298
+ "cell_type": "code",
1299
+ "execution_count": null,
1300
+ "metadata": {},
1301
+ "outputs": [],
1302
+ "source": [
1303
+ "vis.plot_transformed_image(\"test_images/1890sShoeShopOhio.jpg\", render_factor=46, compare=True)"
1304
+ ]
1305
+ },
1306
+ {
1307
+ "cell_type": "code",
1308
+ "execution_count": null,
1309
+ "metadata": {},
1310
+ "outputs": [],
1311
+ "source": [
1312
+ "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\", render_factor=40, compare=True)"
1313
+ ]
1314
+ },
1315
+ {
1316
+ "cell_type": "code",
1317
+ "execution_count": null,
1318
+ "metadata": {},
1319
+ "outputs": [],
1320
+ "source": [
1321
+ "vis.plot_transformed_image(\"test_images/1938Reading.jpg\", render_factor=27, compare=True)"
1322
+ ]
1323
+ },
1324
+ {
1325
+ "cell_type": "code",
1326
+ "execution_count": null,
1327
+ "metadata": {},
1328
+ "outputs": [],
1329
+ "source": [
1330
+ "vis.plot_transformed_image(\"test_images/1850Geography.jpg\", render_factor=22, compare=True)"
1331
+ ]
1332
+ },
1333
+ {
1334
+ "cell_type": "code",
1335
+ "execution_count": null,
1336
+ "metadata": {},
1337
+ "outputs": [],
1338
+ "source": [
1339
+ "vis.plot_transformed_image(\"test_images/1901Electrophone.jpg\", render_factor=7, compare=True)"
1340
+ ]
1341
+ },
1342
+ {
1343
+ "cell_type": "code",
1344
+ "execution_count": null,
1345
+ "metadata": {},
1346
+ "outputs": [],
1347
+ "source": [
1348
+ "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\", render_factor=35, compare=True)"
1349
+ ]
1350
+ },
1351
+ {
1352
+ "cell_type": "code",
1353
+ "execution_count": null,
1354
+ "metadata": {},
1355
+ "outputs": [],
1356
+ "source": [
1357
+ "vis.plot_transformed_image(\"test_images/MaioreWoman1895NZ.jpg\", render_factor=43, compare=True)"
1358
+ ]
1359
+ },
1360
+ {
1361
+ "cell_type": "code",
1362
+ "execution_count": null,
1363
+ "metadata": {},
1364
+ "outputs": [],
1365
+ "source": [
1366
+ "vis.plot_transformed_image(\"test_images/WestVirginiaHouse.jpg\", render_factor=30, compare=True)"
1367
+ ]
1368
+ },
1369
+ {
1370
+ "cell_type": "code",
1371
+ "execution_count": null,
1372
+ "metadata": {},
1373
+ "outputs": [],
1374
+ "source": [
1375
+ "vis.plot_transformed_image(\"test_images/1920sGuadalope.jpg\", render_factor=33, compare=True)"
1376
+ ]
1377
+ },
1378
+ {
1379
+ "cell_type": "code",
1380
+ "execution_count": null,
1381
+ "metadata": {},
1382
+ "outputs": [],
1383
+ "source": [
1384
+ "vis.plot_transformed_image(\"test_images/1909Chicago.jpg\", render_factor=14, compare=True)"
1385
+ ]
1386
+ },
1387
+ {
1388
+ "cell_type": "code",
1389
+ "execution_count": null,
1390
+ "metadata": {},
1391
+ "outputs": [],
1392
+ "source": [
1393
+ "vis.plot_transformed_image(\"test_images/1920sFarmKid.jpg\", render_factor=12, compare=True)"
1394
+ ]
1395
+ },
1396
+ {
1397
+ "cell_type": "code",
1398
+ "execution_count": null,
1399
+ "metadata": {},
1400
+ "outputs": [],
1401
+ "source": [
1402
+ "vis.plot_transformed_image(\"test_images/ParisLate1800s.jpg\", render_factor=18, compare=True)"
1403
+ ]
1404
+ },
1405
+ {
1406
+ "cell_type": "code",
1407
+ "execution_count": null,
1408
+ "metadata": {},
1409
+ "outputs": [],
1410
+ "source": [
1411
+ "vis.plot_transformed_image(\"test_images/1900sDaytonaBeach.png\", render_factor=24, compare=True)"
1412
+ ]
1413
+ },
1414
+ {
1415
+ "cell_type": "code",
1416
+ "execution_count": null,
1417
+ "metadata": {},
1418
+ "outputs": [],
1419
+ "source": [
1420
+ "vis.plot_transformed_image(\"test_images/1930sGeorgia.jpg\", render_factor=17, compare=True)"
1421
+ ]
1422
+ },
1423
+ {
1424
+ "cell_type": "code",
1425
+ "execution_count": null,
1426
+ "metadata": {},
1427
+ "outputs": [],
1428
+ "source": [
1429
+ "vis.plot_transformed_image(\"test_images/NorwegianBride1920s.jpg\", render_factor=40, compare=True)"
1430
+ ]
1431
+ },
1432
+ {
1433
+ "cell_type": "code",
1434
+ "execution_count": null,
1435
+ "metadata": {},
1436
+ "outputs": [],
1437
+ "source": [
1438
+ "vis.plot_transformed_image(\"test_images/Depression.jpg\", render_factor=15, compare=True)"
1439
+ ]
1440
+ },
1441
+ {
1442
+ "cell_type": "code",
1443
+ "execution_count": null,
1444
+ "metadata": {},
1445
+ "outputs": [],
1446
+ "source": [
1447
+ "vis.plot_transformed_image(\"test_images/1888Slum.jpg\", render_factor=32, compare=True)"
1448
+ ]
1449
+ },
1450
+ {
1451
+ "cell_type": "code",
1452
+ "execution_count": null,
1453
+ "metadata": {},
1454
+ "outputs": [],
1455
+ "source": [
1456
+ "vis.plot_transformed_image(\"test_images/LivingRoom1920Sweden.jpg\", render_factor=46, compare=True)"
1457
+ ]
1458
+ },
1459
+ {
1460
+ "cell_type": "code",
1461
+ "execution_count": null,
1462
+ "metadata": {},
1463
+ "outputs": [],
1464
+ "source": [
1465
+ "vis.plot_transformed_image(\"test_images/1896NewsBoyGirl.jpg\", render_factor=21, compare=True)"
1466
+ ]
1467
+ },
1468
+ {
1469
+ "cell_type": "code",
1470
+ "execution_count": null,
1471
+ "metadata": {},
1472
+ "outputs": [],
1473
+ "source": [
1474
+ "vis.plot_transformed_image(\"test_images/PetDucks1927.jpg\", compare=True)"
1475
+ ]
1476
+ },
1477
+ {
1478
+ "cell_type": "code",
1479
+ "execution_count": null,
1480
+ "metadata": {},
1481
+ "outputs": [],
1482
+ "source": [
1483
+ "vis.plot_transformed_image(\"test_images/1899SodaFountain.jpg\", render_factor=46, compare=True)"
1484
+ ]
1485
+ },
1486
+ {
1487
+ "cell_type": "code",
1488
+ "execution_count": null,
1489
+ "metadata": {},
1490
+ "outputs": [],
1491
+ "source": [
1492
+ "vis.plot_transformed_image(\"test_images/TimesSquare1955.jpg\", render_factor=42, compare=True)"
1493
+ ]
1494
+ },
1495
+ {
1496
+ "cell_type": "code",
1497
+ "execution_count": null,
1498
+ "metadata": {},
1499
+ "outputs": [],
1500
+ "source": [
1501
+ "vis.plot_transformed_image(\"test_images/PuppyGify.jpg\", render_factor=22, compare=True)"
1502
+ ]
1503
+ },
1504
+ {
1505
+ "cell_type": "code",
1506
+ "execution_count": null,
1507
+ "metadata": {},
1508
+ "outputs": [],
1509
+ "source": [
1510
+ "vis.plot_transformed_image(\"test_images/1890CliffHouseSF.jpg\", render_factor=30, compare=True)"
1511
+ ]
1512
+ },
1513
+ {
1514
+ "cell_type": "code",
1515
+ "execution_count": null,
1516
+ "metadata": {},
1517
+ "outputs": [],
1518
+ "source": [
1519
+ "vis.plot_transformed_image(\"test_images/1908FamilyPhoto.jpg\", render_factor=35, compare=True)"
1520
+ ]
1521
+ },
1522
+ {
1523
+ "cell_type": "code",
1524
+ "execution_count": null,
1525
+ "metadata": {},
1526
+ "outputs": [],
1527
+ "source": [
1528
+ "vis.plot_transformed_image(\"test_images/1900sSaloon.jpg\", render_factor=30, compare=True)"
1529
+ ]
1530
+ },
1531
+ {
1532
+ "cell_type": "code",
1533
+ "execution_count": null,
1534
+ "metadata": {},
1535
+ "outputs": [],
1536
+ "source": [
1537
+ "vis.plot_transformed_image(\"test_images/1890BostonHospital.jpg\", render_factor=19, compare=True)"
1538
+ ]
1539
+ },
1540
+ {
1541
+ "cell_type": "code",
1542
+ "execution_count": null,
1543
+ "metadata": {},
1544
+ "outputs": [],
1545
+ "source": [
1546
+ "vis.plot_transformed_image(\"test_images/1870Girl.jpg\", render_factor=9, compare=True)"
1547
+ ]
1548
+ },
1549
+ {
1550
+ "cell_type": "code",
1551
+ "execution_count": null,
1552
+ "metadata": {},
1553
+ "outputs": [],
1554
+ "source": [
1555
+ "vis.plot_transformed_image(\"test_images/AustriaHungaryWomen1890s.jpg\", render_factor=15, compare=True)"
1556
+ ]
1557
+ },
1558
+ {
1559
+ "cell_type": "code",
1560
+ "execution_count": null,
1561
+ "metadata": {},
1562
+ "outputs": [],
1563
+ "source": [
1564
+ "vis.plot_transformed_image(\"test_images/Shack.jpg\",render_factor=43, compare=True)"
1565
+ ]
1566
+ },
1567
+ {
1568
+ "cell_type": "code",
1569
+ "execution_count": null,
1570
+ "metadata": {},
1571
+ "outputs": [],
1572
+ "source": [
1573
+ "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\", render_factor=15, compare=True)"
1574
+ ]
1575
+ },
1576
+ {
1577
+ "cell_type": "code",
1578
+ "execution_count": null,
1579
+ "metadata": {},
1580
+ "outputs": [],
1581
+ "source": [
1582
+ "vis.plot_transformed_image(\"test_images/1948CarsGrandma.jpg\", render_factor=14, compare=True)"
1583
+ ]
1584
+ },
1585
+ {
1586
+ "cell_type": "code",
1587
+ "execution_count": null,
1588
+ "metadata": {},
1589
+ "outputs": [],
1590
+ "source": [
1591
+ "vis.plot_transformed_image(\"test_images/PlanesManhattan1931.jpg\", render_factor=11, compare=True)"
1592
+ ]
1593
+ },
1594
+ {
1595
+ "cell_type": "code",
1596
+ "execution_count": null,
1597
+ "metadata": {},
1598
+ "outputs": [],
1599
+ "source": [
1600
+ "vis.plot_transformed_image(\"test_images/WorriedKid1940sNyc.jpg\", render_factor=25, compare=True)"
1601
+ ]
1602
+ },
1603
+ {
1604
+ "cell_type": "code",
1605
+ "execution_count": null,
1606
+ "metadata": {},
1607
+ "outputs": [],
1608
+ "source": [
1609
+ "vis.plot_transformed_image(\"test_images/1920sFamilyPhoto.jpg\", render_factor=13, compare=True)"
1610
+ ]
1611
+ },
1612
+ {
1613
+ "cell_type": "code",
1614
+ "execution_count": null,
1615
+ "metadata": {},
1616
+ "outputs": [],
1617
+ "source": [
1618
+ "vis.plot_transformed_image(\"test_images/CatWash1931.jpg\", render_factor=34, compare=True)"
1619
+ ]
1620
+ },
1621
+ {
1622
+ "cell_type": "code",
1623
+ "execution_count": null,
1624
+ "metadata": {},
1625
+ "outputs": [],
1626
+ "source": [
1627
+ "vis.plot_transformed_image(\"test_images/1940sBeerRiver.jpg\", render_factor=46, compare=True)"
1628
+ ]
1629
+ },
1630
+ {
1631
+ "cell_type": "code",
1632
+ "execution_count": null,
1633
+ "metadata": {},
1634
+ "outputs": [],
1635
+ "source": [
1636
+ "vis.plot_transformed_image(\"test_images/VictorianLivingRoom.jpg\", render_factor=47, compare=True)"
1637
+ ]
1638
+ },
1639
+ {
1640
+ "cell_type": "code",
1641
+ "execution_count": null,
1642
+ "metadata": {},
1643
+ "outputs": [],
1644
+ "source": [
1645
+ "vis.plot_transformed_image(\"test_images/1897BlindmansBluff.jpg\", render_factor=23, compare=True)"
1646
+ ]
1647
+ },
1648
+ {
1649
+ "cell_type": "code",
1650
+ "execution_count": null,
1651
+ "metadata": {},
1652
+ "outputs": [],
1653
+ "source": [
1654
+ "vis.plot_transformed_image(\"test_images/1874Mexico.png\", render_factor=25, compare=True)"
1655
+ ]
1656
+ },
1657
+ {
1658
+ "cell_type": "code",
1659
+ "execution_count": null,
1660
+ "metadata": {},
1661
+ "outputs": [],
1662
+ "source": [
1663
+ "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\", render_factor=45, compare=True)"
1664
+ ]
1665
+ },
1666
+ {
1667
+ "cell_type": "code",
1668
+ "execution_count": null,
1669
+ "metadata": {},
1670
+ "outputs": [],
1671
+ "source": [
1672
+ "vis.plot_transformed_image(\"test_images/1867MusicianConstantinople.jpg\", render_factor=11, compare=True)"
1673
+ ]
1674
+ },
1675
+ {
1676
+ "cell_type": "code",
1677
+ "execution_count": null,
1678
+ "metadata": {},
1679
+ "outputs": [],
1680
+ "source": [
1681
+ "vis.plot_transformed_image(\"test_images/1925Girl.jpg\", render_factor=20, compare=True)"
1682
+ ]
1683
+ },
1684
+ {
1685
+ "cell_type": "code",
1686
+ "execution_count": null,
1687
+ "metadata": {},
1688
+ "outputs": [],
1689
+ "source": [
1690
+ "vis.plot_transformed_image(\"test_images/1907Cowboys.jpg\", render_factor=22, compare=True)"
1691
+ ]
1692
+ },
1693
+ {
1694
+ "cell_type": "code",
1695
+ "execution_count": null,
1696
+ "metadata": {},
1697
+ "outputs": [],
1698
+ "source": [
1699
+ "vis.plot_transformed_image(\"test_images/WWIIPeeps.jpg\", render_factor=26, compare=True)"
1700
+ ]
1701
+ },
1702
+ {
1703
+ "cell_type": "code",
1704
+ "execution_count": null,
1705
+ "metadata": {},
1706
+ "outputs": [],
1707
+ "source": [
1708
+ "vis.plot_transformed_image(\"test_images/BabyBigBoots.jpg\", render_factor=17, compare=True)"
1709
+ ]
1710
+ },
1711
+ {
1712
+ "cell_type": "code",
1713
+ "execution_count": null,
1714
+ "metadata": {},
1715
+ "outputs": [],
1716
+ "source": [
1717
+ "vis.plot_transformed_image(\"test_images/1895BikeMaidens.jpg\", render_factor=8, compare=True)"
1718
+ ]
1719
+ },
1720
+ {
1721
+ "cell_type": "code",
1722
+ "execution_count": null,
1723
+ "metadata": {},
1724
+ "outputs": [],
1725
+ "source": [
1726
+ "vis.plot_transformed_image(\"test_images/IrishLate1800s.jpg\", render_factor=13, compare=True)"
1727
+ ]
1728
+ },
1729
+ {
1730
+ "cell_type": "code",
1731
+ "execution_count": null,
1732
+ "metadata": {},
1733
+ "outputs": [],
1734
+ "source": [
1735
+ "vis.plot_transformed_image(\"test_images/LibraryOfCongress1910.jpg\", render_factor=33, compare=True)"
1736
+ ]
1737
+ },
1738
+ {
1739
+ "cell_type": "code",
1740
+ "execution_count": null,
1741
+ "metadata": {},
1742
+ "outputs": [],
1743
+ "source": [
1744
+ "vis.plot_transformed_image(\"test_images/1875Olds.jpg\", render_factor=15, compare=True)"
1745
+ ]
1746
+ },
1747
+ {
1748
+ "cell_type": "code",
1749
+ "execution_count": null,
1750
+ "metadata": {},
1751
+ "outputs": [],
1752
+ "source": [
1753
+ "vis.plot_transformed_image(\"test_images/SenecaNative1908.jpg\", render_factor=22, compare=True)"
1754
+ ]
1755
+ },
1756
+ {
1757
+ "cell_type": "code",
1758
+ "execution_count": null,
1759
+ "metadata": {},
1760
+ "outputs": [],
1761
+ "source": [
1762
+ "vis.plot_transformed_image(\"test_images/WWIHospital.jpg\", render_factor=40, compare=True)"
1763
+ ]
1764
+ },
1765
+ {
1766
+ "cell_type": "code",
1767
+ "execution_count": null,
1768
+ "metadata": {},
1769
+ "outputs": [],
1770
+ "source": [
1771
+ "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\", render_factor=45, compare=True)"
1772
+ ]
1773
+ },
1774
+ {
1775
+ "cell_type": "code",
1776
+ "execution_count": null,
1777
+ "metadata": {},
1778
+ "outputs": [],
1779
+ "source": [
1780
+ "vis.plot_transformed_image(\"test_images/GreekImmigrants1905.jpg\", render_factor=25, compare=True)"
1781
+ ]
1782
+ },
1783
+ {
1784
+ "cell_type": "code",
1785
+ "execution_count": null,
1786
+ "metadata": {},
1787
+ "outputs": [],
1788
+ "source": [
1789
+ "vis.plot_transformed_image(\"test_images/FatMensShop.jpg\", render_factor=24, compare=True)"
1790
+ ]
1791
+ },
1792
+ {
1793
+ "cell_type": "code",
1794
+ "execution_count": null,
1795
+ "metadata": {},
1796
+ "outputs": [],
1797
+ "source": [
1798
+ "vis.plot_transformed_image(\"test_images/KidCage1930s.png\", compare=True)"
1799
+ ]
1800
+ },
1801
+ {
1802
+ "cell_type": "code",
1803
+ "execution_count": null,
1804
+ "metadata": {},
1805
+ "outputs": [],
1806
+ "source": [
1807
+ "vis.plot_transformed_image(\"test_images/FarmWomen1895.jpg\", compare=True)"
1808
+ ]
1809
+ },
1810
+ {
1811
+ "cell_type": "code",
1812
+ "execution_count": null,
1813
+ "metadata": {},
1814
+ "outputs": [],
1815
+ "source": [
1816
+ "vis.plot_transformed_image(\"test_images/NewZealand1860s.jpg\", compare=True)"
1817
+ ]
1818
+ },
1819
+ {
1820
+ "cell_type": "code",
1821
+ "execution_count": null,
1822
+ "metadata": {},
1823
+ "outputs": [],
1824
+ "source": [
1825
+ "vis.plot_transformed_image(\"test_images/JerseyShore1905.jpg\", render_factor=43, compare=True)"
1826
+ ]
1827
+ },
1828
+ {
1829
+ "cell_type": "code",
1830
+ "execution_count": null,
1831
+ "metadata": {},
1832
+ "outputs": [],
1833
+ "source": [
1834
+ "vis.plot_transformed_image(\"test_images/LondonKidsEarly1900s.jpg\", compare=True)"
1835
+ ]
1836
+ },
1837
+ {
1838
+ "cell_type": "code",
1839
+ "execution_count": null,
1840
+ "metadata": {},
1841
+ "outputs": [],
1842
+ "source": [
1843
+ "vis.plot_transformed_image(\"test_images/NYStreetClean1906.jpg\", compare=True)"
1844
+ ]
1845
+ },
1846
+ {
1847
+ "cell_type": "code",
1848
+ "execution_count": null,
1849
+ "metadata": {},
1850
+ "outputs": [],
1851
+ "source": [
1852
+ "vis.plot_transformed_image(\"test_images/Boston1937.jpg\", compare=True)"
1853
+ ]
1854
+ },
1855
+ {
1856
+ "cell_type": "code",
1857
+ "execution_count": null,
1858
+ "metadata": {},
1859
+ "outputs": [],
1860
+ "source": [
1861
+ "vis.plot_transformed_image(\"test_images/Cork1905.jpg\", render_factor=37, compare=True)"
1862
+ ]
1863
+ },
1864
+ {
1865
+ "cell_type": "code",
1866
+ "execution_count": null,
1867
+ "metadata": {},
1868
+ "outputs": [],
1869
+ "source": [
1870
+ "vis.plot_transformed_image(\"test_images/BoxedBedEarly1900s.jpg\", compare=True)"
1871
+ ]
1872
+ },
1873
+ {
1874
+ "cell_type": "code",
1875
+ "execution_count": null,
1876
+ "metadata": {},
1877
+ "outputs": [],
1878
+ "source": [
1879
+ "vis.plot_transformed_image(\"test_images/ZoologischerGarten1898.jpg\", compare=True)"
1880
+ ]
1881
+ },
1882
+ {
1883
+ "cell_type": "code",
1884
+ "execution_count": null,
1885
+ "metadata": {},
1886
+ "outputs": [],
1887
+ "source": [
1888
+ "vis.plot_transformed_image(\"test_images/EmpireState1930.jpg\", compare=True)"
1889
+ ]
1890
+ },
1891
+ {
1892
+ "cell_type": "code",
1893
+ "execution_count": null,
1894
+ "metadata": {},
1895
+ "outputs": [],
1896
+ "source": [
1897
+ "vis.plot_transformed_image(\"test_images/Agamemnon1919.jpg\", render_factor=40, compare=True)"
1898
+ ]
1899
+ },
1900
+ {
1901
+ "cell_type": "code",
1902
+ "execution_count": null,
1903
+ "metadata": {},
1904
+ "outputs": [],
1905
+ "source": [
1906
+ "vis.plot_transformed_image(\"test_images/AppalachianLoggers1901.jpg\", compare=True)"
1907
+ ]
1908
+ },
1909
+ {
1910
+ "cell_type": "code",
1911
+ "execution_count": null,
1912
+ "metadata": {},
1913
+ "outputs": [],
1914
+ "source": [
1915
+ "vis.plot_transformed_image(\"test_images/WWISikhs.jpg\", compare=True)"
1916
+ ]
1917
+ },
1918
+ {
1919
+ "cell_type": "code",
1920
+ "execution_count": null,
1921
+ "metadata": {},
1922
+ "outputs": [],
1923
+ "source": [
1924
+ "vis.plot_transformed_image(\"test_images/MementoMori1865.jpg\", compare=True)"
1925
+ ]
1926
+ },
1927
+ {
1928
+ "cell_type": "code",
1929
+ "execution_count": null,
1930
+ "metadata": {},
1931
+ "outputs": [],
1932
+ "source": [
1933
+ "vis.plot_transformed_image(\"test_images/RepBrennanRadio1922.jpg\", render_factor=43, compare=True)"
1934
+ ]
1935
+ },
1936
+ {
1937
+ "cell_type": "code",
1938
+ "execution_count": null,
1939
+ "metadata": {},
1940
+ "outputs": [],
1941
+ "source": [
1942
+ "vis.plot_transformed_image(\"test_images/Late1800sNative.jpg\", render_factor=20, compare=True)"
1943
+ ]
1944
+ },
1945
+ {
1946
+ "cell_type": "code",
1947
+ "execution_count": null,
1948
+ "metadata": {},
1949
+ "outputs": [],
1950
+ "source": [
1951
+ "vis.plot_transformed_image(\"test_images/GasPrices1939.jpg\", render_factor=30, compare=True)"
1952
+ ]
1953
+ },
1954
+ {
1955
+ "cell_type": "code",
1956
+ "execution_count": null,
1957
+ "metadata": {},
1958
+ "outputs": [],
1959
+ "source": [
1960
+ "vis.plot_transformed_image(\"test_images/1933RockefellerCenter.jpg\", compare=True)"
1961
+ ]
1962
+ },
1963
+ {
1964
+ "cell_type": "code",
1965
+ "execution_count": null,
1966
+ "metadata": {},
1967
+ "outputs": [],
1968
+ "source": [
1969
+ "vis.plot_transformed_image(\"test_images/Scotland1919.jpg\", compare=True)"
1970
+ ]
1971
+ },
1972
+ {
1973
+ "cell_type": "code",
1974
+ "execution_count": null,
1975
+ "metadata": {},
1976
+ "outputs": [],
1977
+ "source": [
1978
+ "vis.plot_transformed_image(\"test_images/1920CobblersShopLondon.jpg\", compare=True)"
1979
+ ]
1980
+ },
1981
+ {
1982
+ "cell_type": "code",
1983
+ "execution_count": null,
1984
+ "metadata": {},
1985
+ "outputs": [],
1986
+ "source": [
1987
+ "vis.plot_transformed_image(\"test_images/1909ParisFirstFemaleTaxisDriver.jpg\", compare=True)"
1988
+ ]
1989
+ },
1990
+ {
1991
+ "cell_type": "code",
1992
+ "execution_count": null,
1993
+ "metadata": {},
1994
+ "outputs": [],
1995
+ "source": [
1996
+ "vis.plot_transformed_image(\"test_images/HoovervilleSeattle1932.jpg\", compare=True)"
1997
+ ]
1998
+ },
1999
+ {
2000
+ "cell_type": "code",
2001
+ "execution_count": null,
2002
+ "metadata": {},
2003
+ "outputs": [],
2004
+ "source": [
2005
+ "vis.plot_transformed_image(\"test_images/ElephantLondon1934.png\", compare=True)"
2006
+ ]
2007
+ },
2008
+ {
2009
+ "cell_type": "code",
2010
+ "execution_count": null,
2011
+ "metadata": {},
2012
+ "outputs": [],
2013
+ "source": [
2014
+ "vis.plot_transformed_image(\"test_images/Jane_Addams.jpg\", compare=True)"
2015
+ ]
2016
+ },
2017
+ {
2018
+ "cell_type": "code",
2019
+ "execution_count": null,
2020
+ "metadata": {},
2021
+ "outputs": [],
2022
+ "source": [
2023
+ "vis.plot_transformed_image(\"test_images/AnselAdamsAdobe.jpg\", compare=True)"
2024
+ ]
2025
+ },
2026
+ {
2027
+ "cell_type": "code",
2028
+ "execution_count": null,
2029
+ "metadata": {},
2030
+ "outputs": [],
2031
+ "source": [
2032
+ "vis.plot_transformed_image(\"test_images/CricketLondon1930.jpg\", render_factor=45, compare=True)"
2033
+ ]
2034
+ },
2035
+ {
2036
+ "cell_type": "code",
2037
+ "execution_count": null,
2038
+ "metadata": {},
2039
+ "outputs": [],
2040
+ "source": [
2041
+ "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\", render_factor=32, compare=True)"
2042
+ ]
2043
+ },
2044
+ {
2045
+ "cell_type": "code",
2046
+ "execution_count": null,
2047
+ "metadata": {},
2048
+ "outputs": [],
2049
+ "source": [
2050
+ "vis.plot_transformed_image(\"test_images/AnselAdamsChurch.jpg\", compare=True)"
2051
+ ]
2052
+ },
2053
+ {
2054
+ "cell_type": "code",
2055
+ "execution_count": null,
2056
+ "metadata": {},
2057
+ "outputs": [],
2058
+ "source": [
2059
+ "vis.plot_transformed_image(\"test_images/BreadDelivery1920sIreland.jpg\", render_factor=20, compare=True)"
2060
+ ]
2061
+ },
2062
+ {
2063
+ "cell_type": "code",
2064
+ "execution_count": null,
2065
+ "metadata": {},
2066
+ "outputs": [],
2067
+ "source": [
2068
+ "vis.plot_transformed_image(\"test_images/BritishTeaBombay1890s.png\", render_factor=30, compare=True)"
2069
+ ]
2070
+ },
2071
+ {
2072
+ "cell_type": "code",
2073
+ "execution_count": null,
2074
+ "metadata": {},
2075
+ "outputs": [],
2076
+ "source": [
2077
+ "vis.plot_transformed_image(\"test_images/CafeParis1928.jpg\", render_factor=45, compare=True)"
2078
+ ]
2079
+ },
2080
+ {
2081
+ "cell_type": "code",
2082
+ "execution_count": null,
2083
+ "metadata": {},
2084
+ "outputs": [],
2085
+ "source": [
2086
+ "vis.plot_transformed_image(\"test_images/BigManTavern1908NYC.jpg\", compare=True)"
2087
+ ]
2088
+ },
2089
+ {
2090
+ "cell_type": "code",
2091
+ "execution_count": null,
2092
+ "metadata": {},
2093
+ "outputs": [],
2094
+ "source": [
2095
+ "vis.plot_transformed_image(\"test_images/Cars1890sIreland.jpg\", compare=True)"
2096
+ ]
2097
+ },
2098
+ {
2099
+ "cell_type": "code",
2100
+ "execution_count": null,
2101
+ "metadata": {},
2102
+ "outputs": [],
2103
+ "source": [
2104
+ "vis.plot_transformed_image(\"test_images/GalwayIreland1902.jpg\", render_factor=47, compare=True)"
2105
+ ]
2106
+ },
2107
+ {
2108
+ "cell_type": "code",
2109
+ "execution_count": null,
2110
+ "metadata": {},
2111
+ "outputs": [],
2112
+ "source": [
2113
+ "vis.plot_transformed_image(\"test_images/HomeIreland1924.jpg\", render_factor=40, compare=True)"
2114
+ ]
2115
+ },
2116
+ {
2117
+ "cell_type": "code",
2118
+ "execution_count": null,
2119
+ "metadata": {},
2120
+ "outputs": [],
2121
+ "source": [
2122
+ "vis.plot_transformed_image(\"test_images/HydeParkLondon1920s.jpg\", render_factor=30, compare=True)"
2123
+ ]
2124
+ },
2125
+ {
2126
+ "cell_type": "code",
2127
+ "execution_count": null,
2128
+ "metadata": {},
2129
+ "outputs": [],
2130
+ "source": [
2131
+ "vis.plot_transformed_image(\"test_images/1929LondonOverFleetSt.jpg\", render_factor=25, compare=True)"
2132
+ ]
2133
+ },
2134
+ {
2135
+ "cell_type": "code",
2136
+ "execution_count": null,
2137
+ "metadata": {},
2138
+ "outputs": [],
2139
+ "source": [
2140
+ "vis.plot_transformed_image(\"test_images/AccordianKid1900Paris.jpg\", compare=True)"
2141
+ ]
2142
+ },
2143
+ {
2144
+ "cell_type": "code",
2145
+ "execution_count": null,
2146
+ "metadata": {},
2147
+ "outputs": [],
2148
+ "source": [
2149
+ "vis.plot_transformed_image(\"test_images/AnselAdamsBuildings.jpg\", render_factor=45, compare=True)"
2150
+ ]
2151
+ },
2152
+ {
2153
+ "cell_type": "code",
2154
+ "execution_count": null,
2155
+ "metadata": {},
2156
+ "outputs": [],
2157
+ "source": [
2158
+ "vis.plot_transformed_image(\"test_images/AthleticClubParis1913.jpg\", render_factor=42, compare=True)"
2159
+ ]
2160
+ },
2161
+ {
2162
+ "cell_type": "code",
2163
+ "execution_count": null,
2164
+ "metadata": {},
2165
+ "outputs": [],
2166
+ "source": [
2167
+ "vis.plot_transformed_image(\"test_images/BombedLibraryLondon1940.jpg\", compare=True)"
2168
+ ]
2169
+ },
2170
+ {
2171
+ "cell_type": "code",
2172
+ "execution_count": null,
2173
+ "metadata": {},
2174
+ "outputs": [],
2175
+ "source": [
2176
+ "vis.plot_transformed_image(\"test_images/Boston1937.jpg\", render_factor=30, compare=True)"
2177
+ ]
2178
+ },
2179
+ {
2180
+ "cell_type": "code",
2181
+ "execution_count": null,
2182
+ "metadata": {},
2183
+ "outputs": [],
2184
+ "source": [
2185
+ "vis.plot_transformed_image(\"test_images/BoulevardDuTemple1838.jpg\", render_factor=25, compare=True)"
2186
+ ]
2187
+ },
2188
+ {
2189
+ "cell_type": "code",
2190
+ "execution_count": null,
2191
+ "metadata": {},
2192
+ "outputs": [],
2193
+ "source": [
2194
+ "vis.plot_transformed_image(\"test_images/BumperCarsParis1930.jpg\", render_factor=25, compare=True)"
2195
+ ]
2196
+ },
2197
+ {
2198
+ "cell_type": "code",
2199
+ "execution_count": null,
2200
+ "metadata": {},
2201
+ "outputs": [],
2202
+ "source": [
2203
+ "vis.plot_transformed_image(\"test_images/CafeTerrace1925Paris.jpg\", render_factor=35, compare=True)"
2204
+ ]
2205
+ },
2206
+ {
2207
+ "cell_type": "code",
2208
+ "execution_count": null,
2209
+ "metadata": {},
2210
+ "outputs": [],
2211
+ "source": [
2212
+ "vis.plot_transformed_image(\"test_images/CoalDeliveryParis1915.jpg\", render_factor=37, compare=True)"
2213
+ ]
2214
+ },
2215
+ {
2216
+ "cell_type": "code",
2217
+ "execution_count": null,
2218
+ "metadata": {},
2219
+ "outputs": [],
2220
+ "source": [
2221
+ "vis.plot_transformed_image(\"test_images/CorkKids1910.jpg\", render_factor=32, compare=True)"
2222
+ ]
2223
+ },
2224
+ {
2225
+ "cell_type": "code",
2226
+ "execution_count": null,
2227
+ "metadata": {},
2228
+ "outputs": [],
2229
+ "source": [
2230
+ "vis.plot_transformed_image(\"test_images/DeepSeaDiver1915.png\", render_factor=16, compare=True)"
2231
+ ]
2232
+ },
2233
+ {
2234
+ "cell_type": "code",
2235
+ "execution_count": null,
2236
+ "metadata": {},
2237
+ "outputs": [],
2238
+ "source": [
2239
+ "vis.plot_transformed_image(\"test_images/EastEndLondonStreetKids1901.jpg\", compare=True)"
2240
+ ]
2241
+ },
2242
+ {
2243
+ "cell_type": "code",
2244
+ "execution_count": null,
2245
+ "metadata": {},
2246
+ "outputs": [],
2247
+ "source": [
2248
+ "vis.plot_transformed_image(\"test_images/FreightTrainTeens1934.jpg\", compare=True)"
2249
+ ]
2250
+ },
2251
+ {
2252
+ "cell_type": "code",
2253
+ "execution_count": null,
2254
+ "metadata": {},
2255
+ "outputs": [],
2256
+ "source": [
2257
+ "vis.plot_transformed_image(\"test_images/HarrodsLondon1920.jpg\", render_factor=45, compare=True)"
2258
+ ]
2259
+ },
2260
+ {
2261
+ "cell_type": "code",
2262
+ "execution_count": null,
2263
+ "metadata": {},
2264
+ "outputs": [],
2265
+ "source": [
2266
+ "vis.plot_transformed_image(\"test_images/HerbSeller1899Paris.jpg\", render_factor=17, compare=True)"
2267
+ ]
2268
+ },
2269
+ {
2270
+ "cell_type": "code",
2271
+ "execution_count": null,
2272
+ "metadata": {},
2273
+ "outputs": [],
2274
+ "source": [
2275
+ "vis.plot_transformed_image(\"test_images/CalcuttaPoliceman1920.jpg\", render_factor=20, compare=True)"
2276
+ ]
2277
+ },
2278
+ {
2279
+ "cell_type": "code",
2280
+ "execution_count": null,
2281
+ "metadata": {},
2282
+ "outputs": [],
2283
+ "source": [
2284
+ "vis.plot_transformed_image(\"test_images/ElectricScooter1915.jpeg\", render_factor=20, compare=True)"
2285
+ ]
2286
+ },
2287
+ {
2288
+ "cell_type": "code",
2289
+ "execution_count": null,
2290
+ "metadata": {},
2291
+ "outputs": [],
2292
+ "source": [
2293
+ "vis.plot_transformed_image(\"test_images/GreatGrandparentsIrelandEarly1900s.jpg\", compare=True)"
2294
+ ]
2295
+ },
2296
+ {
2297
+ "cell_type": "code",
2298
+ "execution_count": null,
2299
+ "metadata": {},
2300
+ "outputs": [],
2301
+ "source": [
2302
+ "vis.plot_transformed_image(\"test_images/HalloweenEarly1900s.jpg\", render_factor=11, compare=True)"
2303
+ ]
2304
+ },
2305
+ {
2306
+ "cell_type": "code",
2307
+ "execution_count": null,
2308
+ "metadata": {},
2309
+ "outputs": [],
2310
+ "source": [
2311
+ "vis.plot_transformed_image(\"test_images/IceManLondon1919.jpg\", compare=True)"
2312
+ ]
2313
+ },
2314
+ {
2315
+ "cell_type": "code",
2316
+ "execution_count": null,
2317
+ "metadata": {},
2318
+ "outputs": [],
2319
+ "source": [
2320
+ "vis.plot_transformed_image(\"test_images/LeBonMarcheParis1875.jpg\", compare=True)"
2321
+ ]
2322
+ },
2323
+ {
2324
+ "cell_type": "code",
2325
+ "execution_count": null,
2326
+ "metadata": {},
2327
+ "outputs": [],
2328
+ "source": [
2329
+ "vis.plot_transformed_image(\"test_images/LittleAirplane1934.jpg\", render_factor=47, compare=True)"
2330
+ ]
2331
+ },
2332
+ {
2333
+ "cell_type": "code",
2334
+ "execution_count": null,
2335
+ "metadata": {},
2336
+ "outputs": [],
2337
+ "source": [
2338
+ "vis.plot_transformed_image(\"test_images/RoyalUniversityMedStudent1900Ireland.jpg\", render_factor=24, compare=True)"
2339
+ ]
2340
+ },
2341
+ {
2342
+ "cell_type": "code",
2343
+ "execution_count": null,
2344
+ "metadata": {},
2345
+ "outputs": [],
2346
+ "source": [
2347
+ "vis.plot_transformed_image(\"test_images/LewisTomalinLondon1895.png\", render_factor=35, compare=True)"
2348
+ ]
2349
+ },
2350
+ {
2351
+ "cell_type": "code",
2352
+ "execution_count": null,
2353
+ "metadata": {},
2354
+ "outputs": [],
2355
+ "source": [
2356
+ "vis.plot_transformed_image(\"test_images/SunHelmetsLondon1933.jpg\", render_factor=40, compare=True)"
2357
+ ]
2358
+ },
2359
+ {
2360
+ "cell_type": "code",
2361
+ "execution_count": null,
2362
+ "metadata": {},
2363
+ "outputs": [],
2364
+ "source": [
2365
+ "vis.plot_transformed_image(\"test_images/Killarney1910.jpg\", render_factor=45, compare=True)"
2366
+ ]
2367
+ },
2368
+ {
2369
+ "cell_type": "code",
2370
+ "execution_count": null,
2371
+ "metadata": {},
2372
+ "outputs": [],
2373
+ "source": [
2374
+ "vis.plot_transformed_image(\"test_images/LondonSheep1920s.png\", compare=True)"
2375
+ ]
2376
+ },
2377
+ {
2378
+ "cell_type": "code",
2379
+ "execution_count": null,
2380
+ "metadata": {},
2381
+ "outputs": [],
2382
+ "source": [
2383
+ "vis.plot_transformed_image(\"test_images/PostOfficeVermont1914.png\", compare=True)"
2384
+ ]
2385
+ },
2386
+ {
2387
+ "cell_type": "code",
2388
+ "execution_count": null,
2389
+ "metadata": {},
2390
+ "outputs": [],
2391
+ "source": [
2392
+ "vis.plot_transformed_image(\"test_images/ServantsBessboroughHouse1908Ireland.jpg\", compare=True)"
2393
+ ]
2394
+ },
2395
+ {
2396
+ "cell_type": "code",
2397
+ "execution_count": null,
2398
+ "metadata": {},
2399
+ "outputs": [],
2400
+ "source": [
2401
+ "vis.plot_transformed_image(\"test_images/WaterfordIreland1909.jpg\", render_factor=47, compare=True)"
2402
+ ]
2403
+ },
2404
+ {
2405
+ "cell_type": "code",
2406
+ "execution_count": null,
2407
+ "metadata": {},
2408
+ "outputs": [],
2409
+ "source": [
2410
+ "vis.plot_transformed_image(\"test_images/Lisbon1919.jpg\", compare=True)"
2411
+ ]
2412
+ },
2413
+ {
2414
+ "cell_type": "code",
2415
+ "execution_count": null,
2416
+ "metadata": {},
2417
+ "outputs": [],
2418
+ "source": [
2419
+ "vis.plot_transformed_image(\"test_images/London1918WartimeClothesManufacture.jpg\", render_factor=45, compare=True)"
2420
+ ]
2421
+ },
2422
+ {
2423
+ "cell_type": "code",
2424
+ "execution_count": null,
2425
+ "metadata": {},
2426
+ "outputs": [],
2427
+ "source": [
2428
+ "vis.plot_transformed_image(\"test_images/LondonHeatWave1935.png\", compare=True)"
2429
+ ]
2430
+ },
2431
+ {
2432
+ "cell_type": "code",
2433
+ "execution_count": null,
2434
+ "metadata": {},
2435
+ "outputs": [],
2436
+ "source": [
2437
+ "vis.plot_transformed_image(\"test_images/LondonsSmallestShop1900.jpg\", compare=True)"
2438
+ ]
2439
+ },
2440
+ {
2441
+ "cell_type": "code",
2442
+ "execution_count": null,
2443
+ "metadata": {},
2444
+ "outputs": [],
2445
+ "source": [
2446
+ "vis.plot_transformed_image(\"test_images/MetropolitanDistrictRailway1869London.jpg\", compare=True)"
2447
+ ]
2448
+ },
2449
+ {
2450
+ "cell_type": "code",
2451
+ "execution_count": null,
2452
+ "metadata": {},
2453
+ "outputs": [],
2454
+ "source": [
2455
+ "vis.plot_transformed_image(\"test_images/NativeWoman1926.jpg\", render_factor=43, compare=True)"
2456
+ ]
2457
+ },
2458
+ {
2459
+ "cell_type": "code",
2460
+ "execution_count": null,
2461
+ "metadata": {},
2462
+ "outputs": [],
2463
+ "source": [
2464
+ "vis.plot_transformed_image(\"test_images/PaddysMarketCork1900s.jpg\", compare=True)"
2465
+ ]
2466
+ },
2467
+ {
2468
+ "cell_type": "code",
2469
+ "execution_count": null,
2470
+ "metadata": {},
2471
+ "outputs": [],
2472
+ "source": [
2473
+ "vis.plot_transformed_image(\"test_images/PaddysMarketCork1900s.jpg\", render_factor=i, compare=True)"
2474
+ ]
2475
+ },
2476
+ {
2477
+ "cell_type": "code",
2478
+ "execution_count": null,
2479
+ "metadata": {},
2480
+ "outputs": [],
2481
+ "source": [
2482
+ "vis.plot_transformed_image(\"test_images/Paris1920Cart.jpg\", compare=True)"
2483
+ ]
2484
+ },
2485
+ {
2486
+ "cell_type": "code",
2487
+ "execution_count": null,
2488
+ "metadata": {},
2489
+ "outputs": [],
2490
+ "source": [
2491
+ "vis.plot_transformed_image(\"test_images/ParisLadies1910.jpg\", render_factor=38, compare=True)"
2492
+ ]
2493
+ },
2494
+ {
2495
+ "cell_type": "code",
2496
+ "execution_count": null,
2497
+ "metadata": {},
2498
+ "outputs": [],
2499
+ "source": [
2500
+ "vis.plot_transformed_image(\"test_images/ParisLadies1930s.jpg\", render_factor=18, compare=True)"
2501
+ ]
2502
+ },
2503
+ {
2504
+ "cell_type": "code",
2505
+ "execution_count": null,
2506
+ "metadata": {},
2507
+ "outputs": [],
2508
+ "source": [
2509
+ "vis.plot_transformed_image(\"test_images/Sphinx.jpeg\") "
2510
+ ]
2511
+ },
2512
+ {
2513
+ "cell_type": "code",
2514
+ "execution_count": null,
2515
+ "metadata": {},
2516
+ "outputs": [],
2517
+ "source": [
2518
+ "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\", render_factor=45, compare=True)"
2519
+ ]
2520
+ },
2521
+ {
2522
+ "cell_type": "code",
2523
+ "execution_count": null,
2524
+ "metadata": {},
2525
+ "outputs": [],
2526
+ "source": [
2527
+ "vis.plot_transformed_image(\"test_images/WorldsFair1900Paris.jpg\", compare=True)"
2528
+ ]
2529
+ },
2530
+ {
2531
+ "cell_type": "code",
2532
+ "execution_count": null,
2533
+ "metadata": {},
2534
+ "outputs": [],
2535
+ "source": [
2536
+ "vis.plot_transformed_image(\"test_images/London1850Coach.jpg\", render_factor=25, compare=True)"
2537
+ ]
2538
+ },
2539
+ {
2540
+ "cell_type": "code",
2541
+ "execution_count": null,
2542
+ "metadata": {},
2543
+ "outputs": [],
2544
+ "source": [
2545
+ "vis.plot_transformed_image(\"test_images/London1900EastEndBlacksmith.jpg\", compare=True)"
2546
+ ]
2547
+ },
2548
+ {
2549
+ "cell_type": "code",
2550
+ "execution_count": null,
2551
+ "metadata": {},
2552
+ "outputs": [],
2553
+ "source": [
2554
+ "vis.plot_transformed_image(\"test_images/London1930sCheetah.jpg\", render_factor=42, compare=True)"
2555
+ ]
2556
+ },
2557
+ {
2558
+ "cell_type": "code",
2559
+ "execution_count": null,
2560
+ "metadata": {},
2561
+ "outputs": [],
2562
+ "source": [
2563
+ "vis.plot_transformed_image(\"test_images/LondonFireBrigadeMember1926.jpg\", compare=True)"
2564
+ ]
2565
+ },
2566
+ {
2567
+ "cell_type": "code",
2568
+ "execution_count": null,
2569
+ "metadata": {},
2570
+ "outputs": [],
2571
+ "source": [
2572
+ "vis.plot_transformed_image(\"test_images/LondonGarbageTruck1910.jpg\", compare=True)"
2573
+ ]
2574
+ },
2575
+ {
2576
+ "cell_type": "code",
2577
+ "execution_count": null,
2578
+ "metadata": {},
2579
+ "outputs": [],
2580
+ "source": [
2581
+ "vis.plot_transformed_image(\"test_images/LondonRailwayWork1931.jpg\", render_factor=45, compare=True)"
2582
+ ]
2583
+ },
2584
+ {
2585
+ "cell_type": "code",
2586
+ "execution_count": null,
2587
+ "metadata": {},
2588
+ "outputs": [],
2589
+ "source": [
2590
+ "vis.plot_transformed_image(\"test_images/LondonStreets1900.jpg\", compare=True)"
2591
+ ]
2592
+ },
2593
+ {
2594
+ "cell_type": "code",
2595
+ "execution_count": null,
2596
+ "metadata": {},
2597
+ "outputs": [],
2598
+ "source": [
2599
+ "vis.plot_transformed_image(\"test_images/MuffinManlLondon1910.jpg\", render_factor=40, compare=True)"
2600
+ ]
2601
+ },
2602
+ {
2603
+ "cell_type": "code",
2604
+ "execution_count": null,
2605
+ "metadata": {},
2606
+ "outputs": [],
2607
+ "source": [
2608
+ "vis.plot_transformed_image(\"test_images/NativeCouple1912.jpg\", render_factor=21, compare=True)"
2609
+ ]
2610
+ },
2611
+ {
2612
+ "cell_type": "code",
2613
+ "execution_count": null,
2614
+ "metadata": {},
2615
+ "outputs": [],
2616
+ "source": [
2617
+ "vis.plot_transformed_image(\"test_images/NewspaperCivilWar1863.jpg\", compare=True)"
2618
+ ]
2619
+ },
2620
+ {
2621
+ "cell_type": "code",
2622
+ "execution_count": null,
2623
+ "metadata": {},
2624
+ "outputs": [],
2625
+ "source": [
2626
+ "vis.plot_transformed_image(\"test_images/PaddingtonStationLondon1907.jpg\", render_factor=45, compare=True)"
2627
+ ]
2628
+ },
2629
+ {
2630
+ "cell_type": "code",
2631
+ "execution_count": null,
2632
+ "metadata": {},
2633
+ "outputs": [],
2634
+ "source": [
2635
+ "vis.plot_transformed_image(\"test_images/Paris1899StreetDig.jpg\", compare=True)"
2636
+ ]
2637
+ },
2638
+ {
2639
+ "cell_type": "code",
2640
+ "execution_count": null,
2641
+ "metadata": {},
2642
+ "outputs": [],
2643
+ "source": [
2644
+ "vis.plot_transformed_image(\"test_images/Paris1926.jpg\", compare=True)"
2645
+ ]
2646
+ },
2647
+ {
2648
+ "cell_type": "code",
2649
+ "execution_count": null,
2650
+ "metadata": {},
2651
+ "outputs": [],
2652
+ "source": [
2653
+ "vis.plot_transformed_image(\"test_images/ParisWomenFurs1920s.jpg\", render_factor=15, compare=True)"
2654
+ ]
2655
+ },
2656
+ {
2657
+ "cell_type": "code",
2658
+ "execution_count": null,
2659
+ "metadata": {},
2660
+ "outputs": [],
2661
+ "source": [
2662
+ "vis.plot_transformed_image(\"test_images/PeddlerParis1899.jpg\", render_factor=35, compare=True)"
2663
+ ]
2664
+ },
2665
+ {
2666
+ "cell_type": "code",
2667
+ "execution_count": null,
2668
+ "metadata": {},
2669
+ "outputs": [],
2670
+ "source": [
2671
+ "vis.plot_transformed_image(\"test_images/SchoolKidsConnemaraIreland1901.jpg\", render_factor=18, compare=True)"
2672
+ ]
2673
+ },
2674
+ {
2675
+ "cell_type": "code",
2676
+ "execution_count": null,
2677
+ "metadata": {},
2678
+ "outputs": [],
2679
+ "source": [
2680
+ "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\", render_factor=44, compare=True)"
2681
+ ]
2682
+ },
2683
+ {
2684
+ "cell_type": "code",
2685
+ "execution_count": null,
2686
+ "metadata": {},
2687
+ "outputs": [],
2688
+ "source": [
2689
+ "vis.plot_transformed_image(\"test_images/SoapBoxRacerParis1920s.jpg\", compare=True)"
2690
+ ]
2691
+ },
2692
+ {
2693
+ "cell_type": "code",
2694
+ "execution_count": null,
2695
+ "metadata": {},
2696
+ "outputs": [],
2697
+ "source": [
2698
+ "vis.plot_transformed_image(\"test_images/SoccerMotorcycles1923London.jpg\", compare=True)"
2699
+ ]
2700
+ },
2701
+ {
2702
+ "cell_type": "code",
2703
+ "execution_count": null,
2704
+ "metadata": {},
2705
+ "outputs": [],
2706
+ "source": [
2707
+ "vis.plot_transformed_image(\"test_images/WalkingLibraryLondon1930.jpg\", compare=True)"
2708
+ ]
2709
+ },
2710
+ {
2711
+ "cell_type": "code",
2712
+ "execution_count": null,
2713
+ "metadata": {},
2714
+ "outputs": [],
2715
+ "source": [
2716
+ "vis.plot_transformed_image(\"test_images/LondonStreetDoctor1877.png\", render_factor=19, compare=True)"
2717
+ ]
2718
+ },
2719
+ {
2720
+ "cell_type": "code",
2721
+ "execution_count": null,
2722
+ "metadata": {},
2723
+ "outputs": [],
2724
+ "source": [
2725
+ "vis.plot_transformed_image(\"test_images/jacksonville.jpg\", compare=True)"
2726
+ ]
2727
+ },
2728
+ {
2729
+ "cell_type": "code",
2730
+ "execution_count": null,
2731
+ "metadata": {},
2732
+ "outputs": [],
2733
+ "source": [
2734
+ "vis.plot_transformed_image(\"test_images/ZebraCarriageLondon1900.jpg\", compare=True)"
2735
+ ]
2736
+ },
2737
+ {
2738
+ "cell_type": "code",
2739
+ "execution_count": null,
2740
+ "metadata": {},
2741
+ "outputs": [],
2742
+ "source": [
2743
+ "vis.plot_transformed_image(\"test_images/StreetGramaphonePlayerLondon1920s.png\", compare=True)"
2744
+ ]
2745
+ },
2746
+ {
2747
+ "cell_type": "code",
2748
+ "execution_count": null,
2749
+ "metadata": {},
2750
+ "outputs": [],
2751
+ "source": [
2752
+ "vis.plot_transformed_image(\"test_images/YaleBranchBarnardsExpress.jpg\", compare=True)"
2753
+ ]
2754
+ },
2755
+ {
2756
+ "cell_type": "code",
2757
+ "execution_count": null,
2758
+ "metadata": {},
2759
+ "outputs": [],
2760
+ "source": [
2761
+ "vis.plot_transformed_image(\"test_images/SynagogueInterior.PNG\", compare=True)"
2762
+ ]
2763
+ },
2764
+ {
2765
+ "cell_type": "code",
2766
+ "execution_count": null,
2767
+ "metadata": {},
2768
+ "outputs": [],
2769
+ "source": [
2770
+ "vis.plot_transformed_image(\"test_images/ArmisticeDay1918.jpg\", compare=True)"
2771
+ ]
2772
+ },
2773
+ {
2774
+ "cell_type": "code",
2775
+ "execution_count": null,
2776
+ "metadata": {},
2777
+ "outputs": [],
2778
+ "source": [
2779
+ "vis.plot_transformed_image(\"test_images/FlyingMachinesParis1909.jpg\", render_factor=25, compare=True)"
2780
+ ]
2781
+ },
2782
+ {
2783
+ "cell_type": "code",
2784
+ "execution_count": null,
2785
+ "metadata": {},
2786
+ "outputs": [],
2787
+ "source": [
2788
+ "vis.plot_transformed_image(\"test_images/GreatAunt1920.jpg\", compare=True)"
2789
+ ]
2790
+ },
2791
+ {
2792
+ "cell_type": "code",
2793
+ "execution_count": null,
2794
+ "metadata": {},
2795
+ "outputs": [],
2796
+ "source": [
2797
+ "vis.plot_transformed_image(\"test_images/NewBrunswick1915.jpg\", compare=True)"
2798
+ ]
2799
+ },
2800
+ {
2801
+ "cell_type": "code",
2802
+ "execution_count": null,
2803
+ "metadata": {},
2804
+ "outputs": [],
2805
+ "source": [
2806
+ "vis.plot_transformed_image(\"test_images/ShoeMakerLate1800s.jpg\", compare=True)"
2807
+ ]
2808
+ },
2809
+ {
2810
+ "cell_type": "code",
2811
+ "execution_count": null,
2812
+ "metadata": {},
2813
+ "outputs": [],
2814
+ "source": [
2815
+ "vis.plot_transformed_image(\"test_images/SpottedBull1908.jpg\", compare=True)"
2816
+ ]
2817
+ },
2818
+ {
2819
+ "cell_type": "code",
2820
+ "execution_count": null,
2821
+ "metadata": {},
2822
+ "outputs": [],
2823
+ "source": [
2824
+ "vis.plot_transformed_image(\"test_images/TouristsGermany1904.jpg\", compare=True)"
2825
+ ]
2826
+ },
2827
+ {
2828
+ "cell_type": "code",
2829
+ "execution_count": null,
2830
+ "metadata": {},
2831
+ "outputs": [],
2832
+ "source": [
2833
+ "vis.plot_transformed_image(\"test_images/TunisianStudents1914.jpg\", compare=True)"
2834
+ ]
2835
+ },
2836
+ {
2837
+ "cell_type": "code",
2838
+ "execution_count": null,
2839
+ "metadata": {},
2840
+ "outputs": [],
2841
+ "source": [
2842
+ "vis.plot_transformed_image(\"test_images/Yorktown1862.jpg\", compare=True)"
2843
+ ]
2844
+ },
2845
+ {
2846
+ "cell_type": "code",
2847
+ "execution_count": null,
2848
+ "metadata": {},
2849
+ "outputs": [],
2850
+ "source": [
2851
+ "vis.plot_transformed_image(\"test_images/LondonFashion1911.png\", compare=True)"
2852
+ ]
2853
+ },
2854
+ {
2855
+ "cell_type": "code",
2856
+ "execution_count": null,
2857
+ "metadata": {},
2858
+ "outputs": [],
2859
+ "source": [
2860
+ "vis.plot_transformed_image(\"test_images/1939GypsyKids.jpg\", compare=True)"
2861
+ ]
2862
+ },
2863
+ {
2864
+ "cell_type": "code",
2865
+ "execution_count": null,
2866
+ "metadata": {},
2867
+ "outputs": [],
2868
+ "source": [
2869
+ "vis.plot_transformed_image(\"test_images/1936OpiumShanghai.jpg\", compare=True)"
2870
+ ]
2871
+ },
2872
+ {
2873
+ "cell_type": "code",
2874
+ "execution_count": null,
2875
+ "metadata": {},
2876
+ "outputs": [],
2877
+ "source": [
2878
+ "vis.plot_transformed_image(\"test_images/1923HollandTunnel.jpg\", compare=True)"
2879
+ ]
2880
+ },
2881
+ {
2882
+ "cell_type": "code",
2883
+ "execution_count": null,
2884
+ "metadata": {},
2885
+ "outputs": [],
2886
+ "source": [
2887
+ "vis.plot_transformed_image(\"test_images/1939YakimaWAGirl.jpg\", compare=True)"
2888
+ ]
2889
+ },
2890
+ {
2891
+ "cell_type": "code",
2892
+ "execution_count": null,
2893
+ "metadata": {},
2894
+ "outputs": [],
2895
+ "source": [
2896
+ "vis.plot_transformed_image(\"test_images/GoldenGateConstruction.jpg\", render_factor=45, compare=True)"
2897
+ ]
2898
+ },
2899
+ {
2900
+ "cell_type": "code",
2901
+ "execution_count": null,
2902
+ "metadata": {},
2903
+ "outputs": [],
2904
+ "source": [
2905
+ "vis.plot_transformed_image(\"test_images/PostCivilWarAncestors.jpg\", compare=True)"
2906
+ ]
2907
+ },
2908
+ {
2909
+ "cell_type": "code",
2910
+ "execution_count": null,
2911
+ "metadata": {},
2912
+ "outputs": [],
2913
+ "source": [
2914
+ "vis.plot_transformed_image(\"test_images/1939SewingBike.png\", compare=True)"
2915
+ ]
2916
+ },
2917
+ {
2918
+ "cell_type": "code",
2919
+ "execution_count": null,
2920
+ "metadata": {},
2921
+ "outputs": [],
2922
+ "source": [
2923
+ "vis.plot_transformed_image(\"test_images/1930MaineSchoolBus.jpg\", compare=True)"
2924
+ ]
2925
+ },
2926
+ {
2927
+ "cell_type": "code",
2928
+ "execution_count": null,
2929
+ "metadata": {},
2930
+ "outputs": [],
2931
+ "source": [
2932
+ "vis.plot_transformed_image(\"test_images/1913NewYorkConstruction.jpg\", compare=True)"
2933
+ ]
2934
+ },
2935
+ {
2936
+ "cell_type": "code",
2937
+ "execution_count": null,
2938
+ "metadata": {},
2939
+ "outputs": [],
2940
+ "source": [
2941
+ "vis.plot_transformed_image(\"test_images/1945HiroshimaChild.jpg\", compare=True)"
2942
+ ]
2943
+ },
2944
+ {
2945
+ "cell_type": "code",
2946
+ "execution_count": null,
2947
+ "metadata": {},
2948
+ "outputs": [],
2949
+ "source": [
2950
+ "vis.plot_transformed_image(\"test_images/1941GeorgiaFarmhouse.jpg\", render_factor=47, compare=True)"
2951
+ ]
2952
+ },
2953
+ {
2954
+ "cell_type": "code",
2955
+ "execution_count": null,
2956
+ "metadata": {},
2957
+ "outputs": [],
2958
+ "source": [
2959
+ "vis.plot_transformed_image(\"test_images/1934UmbriaItaly.jpg\", render_factor=21, compare=True)"
2960
+ ]
2961
+ },
2962
+ {
2963
+ "cell_type": "code",
2964
+ "execution_count": null,
2965
+ "metadata": {},
2966
+ "outputs": [],
2967
+ "source": [
2968
+ "vis.plot_transformed_image(\"test_images/1900sLadiesTeaParty.jpg\", compare=True)"
2969
+ ]
2970
+ },
2971
+ {
2972
+ "cell_type": "code",
2973
+ "execution_count": null,
2974
+ "metadata": {},
2975
+ "outputs": [],
2976
+ "source": [
2977
+ "vis.plot_transformed_image(\"test_images/1919WWIAviationOxygenMask.jpg\", compare=True)"
2978
+ ]
2979
+ },
2980
+ {
2981
+ "cell_type": "code",
2982
+ "execution_count": null,
2983
+ "metadata": {},
2984
+ "outputs": [],
2985
+ "source": [
2986
+ "vis.plot_transformed_image(\"test_images/1900NJThanksgiving.jpg\", compare=True)"
2987
+ ]
2988
+ },
2989
+ {
2990
+ "cell_type": "code",
2991
+ "execution_count": null,
2992
+ "metadata": {},
2993
+ "outputs": [],
2994
+ "source": [
2995
+ "vis.plot_transformed_image(\"test_images/1940Connecticut.jpg\", render_factor=42, compare=True)"
2996
+ ]
2997
+ },
2998
+ {
2999
+ "cell_type": "code",
3000
+ "execution_count": null,
3001
+ "metadata": {},
3002
+ "outputs": [],
3003
+ "source": [
3004
+ "vis.plot_transformed_image(\"test_images/1911ThanksgivingMaskers.jpg\", render_factor=36, compare=True)"
3005
+ ]
3006
+ },
3007
+ {
3008
+ "cell_type": "code",
3009
+ "execution_count": null,
3010
+ "metadata": {},
3011
+ "outputs": [],
3012
+ "source": [
3013
+ "vis.plot_transformed_image(\"test_images/1910ThanksgivingMaskersII.jpg\", compare=True)"
3014
+ ]
3015
+ },
3016
+ {
3017
+ "cell_type": "code",
3018
+ "execution_count": null,
3019
+ "metadata": {},
3020
+ "outputs": [],
3021
+ "source": [
3022
+ "vis.plot_transformed_image(\"test_images/1936PetToad.jpg\", compare=True)"
3023
+ ]
3024
+ },
3025
+ {
3026
+ "cell_type": "code",
3027
+ "execution_count": null,
3028
+ "metadata": {},
3029
+ "outputs": [],
3030
+ "source": [
3031
+ "vis.plot_transformed_image(\"test_images/1908RookeriesLondon.jpg\", compare=True)"
3032
+ ]
3033
+ },
3034
+ {
3035
+ "cell_type": "code",
3036
+ "execution_count": null,
3037
+ "metadata": {},
3038
+ "outputs": [],
3039
+ "source": [
3040
+ "vis.plot_transformed_image(\"test_images/1890sChineseImmigrants.jpg\", render_factor=36, compare=True)"
3041
+ ]
3042
+ },
3043
+ {
3044
+ "cell_type": "code",
3045
+ "execution_count": null,
3046
+ "metadata": {},
3047
+ "outputs": [],
3048
+ "source": [
3049
+ "vis.plot_transformed_image(\"test_images/1897VancouverAmberlamps.jpg\", compare=True)"
3050
+ ]
3051
+ },
3052
+ {
3053
+ "cell_type": "code",
3054
+ "execution_count": null,
3055
+ "metadata": {},
3056
+ "outputs": [],
3057
+ "source": [
3058
+ "vis.plot_transformed_image(\"test_images/1929VictorianCosplayLondon.jpg\", render_factor=30, compare=True)"
3059
+ ]
3060
+ },
3061
+ {
3062
+ "cell_type": "code",
3063
+ "execution_count": null,
3064
+ "metadata": {},
3065
+ "outputs": [],
3066
+ "source": [
3067
+ "vis.plot_transformed_image(\"test_images/1959ParisFriends.png\", render_factor=45, compare=True)"
3068
+ ]
3069
+ },
3070
+ {
3071
+ "cell_type": "code",
3072
+ "execution_count": null,
3073
+ "metadata": {},
3074
+ "outputs": [],
3075
+ "source": [
3076
+ "vis.plot_transformed_image(\"test_images/1925GypsyCampMaryland.jpg\", render_factor=45, compare=True)"
3077
+ ]
3078
+ },
3079
+ {
3080
+ "cell_type": "code",
3081
+ "execution_count": null,
3082
+ "metadata": {},
3083
+ "outputs": [],
3084
+ "source": [
3085
+ "vis.plot_transformed_image(\"test_images/1941PoolTableGeorgia.jpg\", render_factor=47, compare=True)"
3086
+ ]
3087
+ },
3088
+ {
3089
+ "cell_type": "code",
3090
+ "execution_count": null,
3091
+ "metadata": {},
3092
+ "outputs": [],
3093
+ "source": [
3094
+ "vis.plot_transformed_image(\"test_images/1900ParkDog.jpg\", compare=True)"
3095
+ ]
3096
+ },
3097
+ {
3098
+ "cell_type": "code",
3099
+ "execution_count": null,
3100
+ "metadata": {},
3101
+ "outputs": [],
3102
+ "source": [
3103
+ "vis.plot_transformed_image(\"test_images/1886Hoop.jpg\", compare=True)"
3104
+ ]
3105
+ },
3106
+ {
3107
+ "cell_type": "code",
3108
+ "execution_count": null,
3109
+ "metadata": {},
3110
+ "outputs": [],
3111
+ "source": [
3112
+ "vis.plot_transformed_image(\"test_images/1950sLondonPoliceChild.jpg\", compare=True)"
3113
+ ]
3114
+ },
3115
+ {
3116
+ "cell_type": "code",
3117
+ "execution_count": null,
3118
+ "metadata": {},
3119
+ "outputs": [],
3120
+ "source": [
3121
+ "vis.plot_transformed_image(\"test_images/1886ProspectPark.jpg\", render_factor=45, compare=True)"
3122
+ ]
3123
+ },
3124
+ {
3125
+ "cell_type": "code",
3126
+ "execution_count": null,
3127
+ "metadata": {},
3128
+ "outputs": [],
3129
+ "source": [
3130
+ "vis.plot_transformed_image(\"test_images/1930sRooftopPoland.jpg\", render_factor=37, compare=True)"
3131
+ ]
3132
+ },
3133
+ {
3134
+ "cell_type": "code",
3135
+ "execution_count": null,
3136
+ "metadata": {},
3137
+ "outputs": [],
3138
+ "source": [
3139
+ "vis.plot_transformed_image(\"test_images/1919RevereBeach.jpg\", render_factor=20, compare=True)"
3140
+ ]
3141
+ },
3142
+ {
3143
+ "cell_type": "code",
3144
+ "execution_count": null,
3145
+ "metadata": {},
3146
+ "outputs": [],
3147
+ "source": [
3148
+ "vis.plot_transformed_image(\"test_images/1936ParisCafe.jpg\", render_factor=47, compare=True)"
3149
+ ]
3150
+ },
3151
+ {
3152
+ "cell_type": "code",
3153
+ "execution_count": null,
3154
+ "metadata": {},
3155
+ "outputs": [],
3156
+ "source": [
3157
+ "vis.plot_transformed_image(\"test_images/1902FrenchYellowBellies.jpg\", render_factor=35, compare=True)"
3158
+ ]
3159
+ },
3160
+ {
3161
+ "cell_type": "code",
3162
+ "execution_count": null,
3163
+ "metadata": {},
3164
+ "outputs": [],
3165
+ "source": [
3166
+ "vis.plot_transformed_image(\"test_images/1940PAFamily.jpg\", render_factor=34, compare=True)"
3167
+ ]
3168
+ },
3169
+ {
3170
+ "cell_type": "code",
3171
+ "execution_count": null,
3172
+ "metadata": {},
3173
+ "outputs": [],
3174
+ "source": [
3175
+ "vis.plot_transformed_image(\"test_images/1910Finland.jpg\", render_factor=40, compare=True)"
3176
+ ]
3177
+ },
3178
+ {
3179
+ "cell_type": "code",
3180
+ "execution_count": null,
3181
+ "metadata": {},
3182
+ "outputs": [],
3183
+ "source": [
3184
+ "vis.plot_transformed_image(\"test_images/ZebraCarriageLondon1900.jpg\", render_factor=21, compare=True)"
3185
+ ]
3186
+ },
3187
+ {
3188
+ "cell_type": "code",
3189
+ "execution_count": null,
3190
+ "metadata": {},
3191
+ "outputs": [],
3192
+ "source": [
3193
+ "vis.plot_transformed_image(\"test_images/1904ChineseMan.jpg\", render_factor=14, compare=True)"
3194
+ ]
3195
+ },
3196
+ {
3197
+ "cell_type": "code",
3198
+ "execution_count": null,
3199
+ "metadata": {},
3200
+ "outputs": [],
3201
+ "source": [
3202
+ "vis.plot_transformed_image(\"test_images/CrystalPalaceLondon1854.PNG\", render_factor=15, compare=True)"
3203
+ ]
3204
+ },
3205
+ {
3206
+ "cell_type": "code",
3207
+ "execution_count": null,
3208
+ "metadata": {},
3209
+ "outputs": [],
3210
+ "source": [
3211
+ "vis.plot_transformed_image(\"test_images/James1.jpg\", render_factor=15, compare=True)"
3212
+ ]
3213
+ },
3214
+ {
3215
+ "cell_type": "code",
3216
+ "execution_count": null,
3217
+ "metadata": {},
3218
+ "outputs": [],
3219
+ "source": [
3220
+ "vis.plot_transformed_image(\"test_images/James2.jpg\", render_factor=20, compare=True)"
3221
+ ]
3222
+ },
3223
+ {
3224
+ "cell_type": "code",
3225
+ "execution_count": null,
3226
+ "metadata": {},
3227
+ "outputs": [],
3228
+ "source": [
3229
+ "vis.plot_transformed_image(\"test_images/James3.jpg\", render_factor=19, compare=True)"
3230
+ ]
3231
+ },
3232
+ {
3233
+ "cell_type": "code",
3234
+ "execution_count": null,
3235
+ "metadata": {},
3236
+ "outputs": [],
3237
+ "source": [
3238
+ "vis.plot_transformed_image(\"test_images/James4.jpg\", render_factor=30, compare=True)"
3239
+ ]
3240
+ },
3241
+ {
3242
+ "cell_type": "code",
3243
+ "execution_count": null,
3244
+ "metadata": {},
3245
+ "outputs": [],
3246
+ "source": [
3247
+ "vis.plot_transformed_image(\"test_images/James5.jpg\", render_factor=32, compare=True)"
3248
+ ]
3249
+ },
3250
+ {
3251
+ "cell_type": "code",
3252
+ "execution_count": null,
3253
+ "metadata": {},
3254
+ "outputs": [],
3255
+ "source": [
3256
+ "vis.plot_transformed_image(\"test_images/James6.jpg\", render_factor=28, compare=True)"
3257
+ ]
3258
+ },
3259
+ {
3260
+ "cell_type": "code",
3261
+ "execution_count": null,
3262
+ "metadata": {},
3263
+ "outputs": [],
3264
+ "source": []
3265
+ },
3266
+ {
3267
+ "cell_type": "code",
3268
+ "execution_count": null,
3269
+ "metadata": {},
3270
+ "outputs": [],
3271
+ "source": []
3272
+ }
3273
+ ],
3274
+ "metadata": {
3275
+ "kernelspec": {
3276
+ "display_name": "Python 3",
3277
+ "language": "python",
3278
+ "name": "python3"
3279
+ },
3280
+ "language_info": {
3281
+ "codemirror_mode": {
3282
+ "name": "ipython",
3283
+ "version": 3
3284
+ },
3285
+ "file_extension": ".py",
3286
+ "mimetype": "text/x-python",
3287
+ "name": "python",
3288
+ "nbconvert_exporter": "python",
3289
+ "pygments_lexer": "ipython3",
3290
+ "version": "3.7.6"
3291
+ },
3292
+ "toc": {
3293
+ "colors": {
3294
+ "hover_highlight": "#DAA520",
3295
+ "navigate_num": "#000000",
3296
+ "navigate_text": "#333333",
3297
+ "running_highlight": "#FF0000",
3298
+ "selected_highlight": "#FFD700",
3299
+ "sidebar_border": "#EEEEEE",
3300
+ "wrapper_background": "#FFFFFF"
3301
+ },
3302
+ "moveMenuLeft": true,
3303
+ "nav_menu": {
3304
+ "height": "67px",
3305
+ "width": "252px"
3306
+ },
3307
+ "navigate_menu": true,
3308
+ "number_sections": true,
3309
+ "sideBar": true,
3310
+ "threshold": 4,
3311
+ "toc_cell": false,
3312
+ "toc_section_display": "block",
3313
+ "toc_window_display": false,
3314
+ "widenNotebook": false
3315
+ }
3316
+ },
3317
+ "nbformat": 4,
3318
+ "nbformat_minor": 4
3319
+ }
ImageColorizerColab.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "### **<font color='blue'> Artistic Colorizer </font>**"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "colab_type": "text",
24
+ "id": "663IVxfrpIAb"
25
+ },
26
+ "source": [
27
+ "#β—’ DeOldify - Colorize your own photos!\n",
28
+ "\n",
29
+ "####**Credits:**\n",
30
+ "\n",
31
+ "Special thanks to:\n",
32
+ "\n",
33
+ "Matt Robinson and MarΓ­a Benavente for pioneering the DeOldify image colab notebook. \n",
34
+ "\n",
35
+ "Dana Kelley for doing things, breaking stuff & having an opinion on everything."
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {
41
+ "colab_type": "text",
42
+ "id": "ZjPqTBNoohK9"
43
+ },
44
+ "source": [
45
+ "\n",
46
+ "\n",
47
+ "---\n",
48
+ "\n",
49
+ "\n",
50
+ "#β—’ Verify Correct Runtime Settings\n",
51
+ "\n",
52
+ "**<font color='#FF000'> IMPORTANT </font>**\n",
53
+ "\n",
54
+ "In the \"Runtime\" menu for the notebook window, select \"Change runtime type.\" Ensure that the following are selected:\n",
55
+ "* Runtime Type = Python 3\n",
56
+ "* Hardware Accelerator = GPU \n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {
62
+ "colab_type": "text",
63
+ "id": "gaEJBGDlptEo"
64
+ },
65
+ "source": [
66
+ "#β—’ Git clone and install DeOldify"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "colab": {},
74
+ "colab_type": "code",
75
+ "id": "-T-svuHytJ-8"
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "!git clone https://github.com/jantic/DeOldify.git DeOldify "
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "cd DeOldify"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {
94
+ "colab_type": "text",
95
+ "id": "BDFjbNxaadNK"
96
+ },
97
+ "source": [
98
+ "#β—’ Setup"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {
105
+ "colab": {},
106
+ "colab_type": "code",
107
+ "id": "00_GcC_trpdE"
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "#NOTE: This must be the first call in order to work properly!\n",
112
+ "from deoldify import device\n",
113
+ "from deoldify.device_id import DeviceId\n",
114
+ "#choices: CPU, GPU0...GPU7\n",
115
+ "device.set(device=DeviceId.GPU0)\n",
116
+ "\n",
117
+ "import torch\n",
118
+ "\n",
119
+ "if not torch.cuda.is_available():\n",
120
+ " print('GPU not available.')"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {
127
+ "colab": {},
128
+ "colab_type": "code",
129
+ "id": "Lsx7xCXNSVt6"
130
+ },
131
+ "outputs": [],
132
+ "source": [
133
+ "!pip install -r requirements-colab.txt"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {
140
+ "colab": {},
141
+ "colab_type": "code",
142
+ "id": "MsJa69CMwj3l"
143
+ },
144
+ "outputs": [],
145
+ "source": [
146
+ "import fastai\n",
147
+ "from deoldify.visualize import *\n",
148
+ "import warnings\n",
149
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*?Your .*? set is empty.*?\")"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "!mkdir 'models'\n",
159
+ "!wget https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth -O ./models/ColorizeArtistic_gen.pth"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "colab": {},
167
+ "colab_type": "code",
168
+ "id": "tzHVnegp21hC"
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "colorizer = get_image_colorizer(artistic=True)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {
178
+ "colab_type": "text",
179
+ "id": "BDFjbNxaadNJ"
180
+ },
181
+ "source": [
182
+ "#β—’ Instructions"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "metadata": {},
188
+ "source": [
189
+ "### source_url\n",
190
+ "Type in a url to a direct link of an image. Usually that means they'll end in .png, .jpg, etc. NOTE: If you want to use your own image, upload it first to a site like Imgur. \n",
191
+ "\n",
192
+ "### render_factor\n",
193
+ "The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the image is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality images in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality images, but the colors may get slightly washed out. \n",
194
+ "\n",
195
+ "### watermarked\n",
196
+ "Selected by default, this places a watermark icon of a palette at the bottom left corner of the image. This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\n",
197
+ "\n",
198
+ "#### How to Download a Copy\n",
199
+ "Simply right click on the displayed image and click \"Save image as...\"!\n",
200
+ "\n",
201
+ "## Pro Tips\n",
202
+ "\n",
203
+ "You can evaluate how well the image is rendered at each render_factor by using the code at the bottom (that cell under \"See how well render_factor values perform on a frame here\"). "
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {
209
+ "colab_type": "text",
210
+ "id": "sUQrbSYipiJn"
211
+ },
212
+ "source": [
213
+ "#β—’ Colorize!!"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "source_url = '' #@param {type:\"string\"}\n",
223
+ "render_factor = 35 #@param {type: \"slider\", min: 7, max: 40}\n",
224
+ "watermarked = True #@param {type:\"boolean\"}\n",
225
+ "\n",
226
+ "if source_url is not None and source_url !='':\n",
227
+ " image_path = colorizer.plot_transformed_image_from_url(url=source_url, render_factor=render_factor, compare=True, watermarked=watermarked)\n",
228
+ " show_image_in_notebook(image_path)\n",
229
+ "else:\n",
230
+ " print('Provide an image url and try again.')"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {},
236
+ "source": [
237
+ "## See how well render_factor values perform on the image here"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "for i in range(10,40,2):\n",
247
+ " colorizer.plot_transformed_image('test_images/image.png', render_factor=i, display_render_factor=True, figsize=(8,8))"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {
253
+ "colab_type": "text",
254
+ "id": "X7Ycv_Y9xAHp"
255
+ },
256
+ "source": [
257
+ "---\n",
258
+ "#βš™ Recommended image sources \n",
259
+ "* [/r/TheWayWeWere](https://www.reddit.com/r/TheWayWeWere/)"
260
+ ]
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "accelerator": "GPU",
265
+ "colab": {
266
+ "collapsed_sections": [],
267
+ "name": "ImageColorizerColab.ipynb",
268
+ "provenance": [],
269
+ "toc_visible": true,
270
+ "version": "0.3.2"
271
+ },
272
+ "kernelspec": {
273
+ "display_name": "Python 3",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.7.6"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 4
292
+ }
ImageColorizerColabStable.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColabStable.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "### **<font color='blue'> Stable Colorizer </font>**"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "colab_type": "text",
24
+ "id": "663IVxfrpIAb"
25
+ },
26
+ "source": [
27
+ "#β—’ DeOldify - Colorize your own photos!\n",
28
+ "\n",
29
+ "####**Credits:**\n",
30
+ "\n",
31
+ "Special thanks to:\n",
32
+ "\n",
33
+ "Matt Robinson and MarΓ­a Benavente for pioneering the DeOldify image colab notebook. \n",
34
+ "\n",
35
+ "Dana Kelley for doing things, breaking stuff & having an opinion on everything."
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {
41
+ "colab_type": "text",
42
+ "id": "ZjPqTBNoohK9"
43
+ },
44
+ "source": [
45
+ "\n",
46
+ "\n",
47
+ "---\n",
48
+ "\n",
49
+ "\n",
50
+ "#β—’ Verify Correct Runtime Settings\n",
51
+ "\n",
52
+ "**<font color='#FF000'> IMPORTANT </font>**\n",
53
+ "\n",
54
+ "In the \"Runtime\" menu for the notebook window, select \"Change runtime type.\" Ensure that the following are selected:\n",
55
+ "* Runtime Type = Python 3\n",
56
+ "* Hardware Accelerator = GPU \n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {
62
+ "colab_type": "text",
63
+ "id": "gaEJBGDlptEo"
64
+ },
65
+ "source": [
66
+ "#β—’ Git clone and install DeOldify"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "colab": {},
74
+ "colab_type": "code",
75
+ "id": "-T-svuHytJ-8"
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "!git clone https://github.com/jantic/DeOldify.git DeOldify "
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "cd DeOldify"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {
94
+ "colab_type": "text",
95
+ "id": "BDFjbNxaadNK"
96
+ },
97
+ "source": [
98
+ "#β—’ Setup"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {
105
+ "colab": {},
106
+ "colab_type": "code",
107
+ "id": "00_GcC_trpdE"
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "#NOTE: This must be the first call in order to work properly!\n",
112
+ "from deoldify import device\n",
113
+ "from deoldify.device_id import DeviceId\n",
114
+ "#choices: CPU, GPU0...GPU7\n",
115
+ "device.set(device=DeviceId.GPU0)\n",
116
+ "\n",
117
+ "import torch\n",
118
+ "\n",
119
+ "if not torch.cuda.is_available():\n",
120
+ " print('GPU not available.')"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {
127
+ "colab": {},
128
+ "colab_type": "code",
129
+ "id": "Lsx7xCXNSVt6"
130
+ },
131
+ "outputs": [],
132
+ "source": [
133
+ "!pip install -r requirements-colab.txt"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {
140
+ "colab": {},
141
+ "colab_type": "code",
142
+ "id": "MsJa69CMwj3l"
143
+ },
144
+ "outputs": [],
145
+ "source": [
146
+ "import fastai\n",
147
+ "from deoldify.visualize import *\n",
148
+ "\n",
149
+ "torch.backends.cudnn.benchmark = True"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "!mkdir 'models'\n",
159
+ "!wget https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0 -O ./models/ColorizeStable_gen.pth"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "colab": {},
167
+ "colab_type": "code",
168
+ "id": "tzHVnegp21hC"
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "colorizer = get_image_colorizer(artistic=False)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {
178
+ "colab_type": "text",
179
+ "id": "BDFjbNxaadNJ"
180
+ },
181
+ "source": [
182
+ "#β—’ Instructions"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "metadata": {},
188
+ "source": [
189
+ "### source_url\n",
190
+ "Type in a url to a direct link of an image. Usually that means they'll end in .png, .jpg, etc. NOTE: If you want to use your own image, upload it first to a site like Imgur. \n",
191
+ "\n",
192
+ "### render_factor\n",
193
+ "The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the image is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality images in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality images, but the colors may get slightly washed out. \n",
194
+ "\n",
195
+ "### watermarked\n",
196
+ "Selected by default, this places a watermark icon of a palette at the bottom left corner of the image. This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\n",
197
+ "\n",
198
+ "#### How to Download a Copy\n",
199
+ "Simply right click on the displayed image and click \"Save image as...\"!\n",
200
+ "\n",
201
+ "## Pro Tips\n",
202
+ "\n",
203
+ "You can evaluate how well the image is rendered at each render_factor by using the code at the bottom (that cell under \"See how well render_factor values perform on a frame here\"). "
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {
209
+ "colab_type": "text",
210
+ "id": "sUQrbSYipiJn"
211
+ },
212
+ "source": [
213
+ "#β—’ Colorize!!"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "source_url = '' #@param {type:\"string\"}\n",
223
+ "render_factor = 35 #@param {type: \"slider\", min: 7, max: 40}\n",
224
+ "watermarked = True #@param {type:\"boolean\"}\n",
225
+ "\n",
226
+ "if source_url is not None and source_url !='':\n",
227
+ " image_path = colorizer.plot_transformed_image_from_url(url=source_url, render_factor=render_factor, compare=True, watermarked=watermarked)\n",
228
+ " show_image_in_notebook(image_path)\n",
229
+ "else:\n",
230
+ " print('Provide an image url and try again.')"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {},
236
+ "source": [
237
+ "## See how well render_factor values perform on the image here"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "for i in range(10,40,2):\n",
247
+ " colorizer.plot_transformed_image('test_images/image.png', render_factor=i, display_render_factor=True, figsize=(8,8))"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {
253
+ "colab_type": "text",
254
+ "id": "X7Ycv_Y9xAHp"
255
+ },
256
+ "source": [
257
+ "---\n",
258
+ "#βš™ Recommended image sources \n",
259
+ "* [/r/TheWayWeWere](https://www.reddit.com/r/TheWayWeWere/)"
260
+ ]
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "accelerator": "GPU",
265
+ "colab": {
266
+ "collapsed_sections": [],
267
+ "name": "ImageColorizerColabStable.ipynb",
268
+ "provenance": [],
269
+ "toc_visible": true,
270
+ "version": "0.3.2"
271
+ },
272
+ "kernelspec": {
273
+ "display_name": "Python 3",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.7.6"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 4
292
+ }
ImageColorizerStableTests.ipynb ADDED
@@ -0,0 +1,3334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "#NOTE: This must be the first call in order to work properly!\n",
10
+ "from deoldify import device\n",
11
+ "from deoldify.device_id import DeviceId\n",
12
+ "#choices: CPU, GPU0...GPU7\n",
13
+ "device.set(device=DeviceId.GPU0)"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from deoldify.visualize import *\n",
23
+ "plt.style.use('dark_background')\n",
24
+ "import warnings\n",
25
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*?Your .*? set is empty.*?\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "#Adjust render_factor (int) if image doesn't look quite right (max 64 on 11GB GPU). The default here works for most photos. \n",
35
+ "#It literally just is a number multiplied by 16 to get the square render resolution. \n",
36
+ "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
37
+ "#Example: render_factor=21 => color is rendered at 16x21 = 336x336 px. \n",
38
+ "render_factor=35"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "vis = get_image_colorizer(render_factor=render_factor, artistic=False)\n",
48
+ "#vis = get_video_colorizer(render_factor=render_factor).vis"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "vis.plot_transformed_image(\"test_images/poolparty.jpg\", render_factor=45, compare=True)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "vis.plot_transformed_image(\"test_images/1852GatekeepersWindsor.jpg\", render_factor=44, compare=True)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "vis.plot_transformed_image(\"test_images/Chief.jpg\", render_factor=10, compare=True)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "vis.plot_transformed_image(\"test_images/1850SchoolForGirls.jpg\", render_factor=42, compare=True)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "vis.plot_transformed_image(\"test_images/AtlanticCityBeach1905.jpg\", render_factor=32, compare=True)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "vis.plot_transformed_image(\"test_images/CottonMillWorkers1913.jpg\", compare=True)"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "vis.plot_transformed_image(\"test_images/BrooklynNavyYardHospital.jpg\", compare=True)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "vis.plot_transformed_image(\"test_images/FinnishPeasant1867.jpg\", compare=True)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "vis.plot_transformed_image(\"test_images/AtlanticCity1905.png\", render_factor=40, compare=True)"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=24, compare=True)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "vis.plot_transformed_image(\"test_images/Drive1905.jpg\", compare=True)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "vis.plot_transformed_image(\"test_images/IronLung.png\", render_factor=26, compare=True)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "vis.plot_transformed_image(\"test_images/FamilyWithDog.jpg\", compare=True)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "vis.plot_transformed_image(\"test_images/DayAtSeaBelgium.jpg\", compare=True)"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\", render_factor=16, compare=True)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "vis.plot_transformed_image(\"test_images/OldWomanSweden1904.jpg\", render_factor=20, compare=True)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "vis.plot_transformed_image(\"test_images/WomenTapingPlanes.jpg\", compare=True)"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "vis.plot_transformed_image(\"test_images/overmiller.jpg\", render_factor=30, compare=True)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "vis.plot_transformed_image(\"test_images/BritishDispatchRider.jpg\", render_factor=16, compare=True)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "vis.plot_transformed_image(\"test_images/MuseauNacionalDosCoches.jpg\", render_factor=19, compare=True)"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "vis.plot_transformed_image(\"test_images/abe.jpg\", render_factor=13, compare=True)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "vis.plot_transformed_image(\"test_images/RossCorbettHouseCork.jpg\", render_factor=40, compare=True)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "vis.plot_transformed_image(\"test_images/HPLabelleOfficeMontreal.jpg\", render_factor=44, compare=True)"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\", render_factor=32, compare=True)"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "vis.plot_transformed_image(\"test_images/airmen1943.jpg\", compare=True)"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "vis.plot_transformed_image(\"test_images/20sWoman.jpg\", render_factor=24, compare=True)"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "metadata": {},
289
+ "outputs": [],
290
+ "source": [
291
+ "vis.plot_transformed_image(\"test_images/egypt-1.jpg\", render_factor=18, compare=True)"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "vis.plot_transformed_image(\"test_images/Rutherford_Hayes.jpg\", compare=True)"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "vis.plot_transformed_image(\"test_images/einstein_portrait.jpg\", render_factor=15, compare=True)"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "vis.plot_transformed_image(\"test_images/pinkerton.jpg\", render_factor=7, compare=True)"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "vis.plot_transformed_image(\"test_images/WaltWhitman.jpg\", render_factor=9, compare=True)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "vis.plot_transformed_image(\"test_images/dorothea-lange.jpg\", render_factor=18, compare=True)"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": [
345
+ "vis.plot_transformed_image(\"test_images/Hemmingway2.jpg\", render_factor=22, compare=True)"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": null,
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "vis.plot_transformed_image(\"test_images/hemmingway.jpg\", render_factor=14, compare=True)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "vis.plot_transformed_image(\"test_images/smoking_kid.jpg\", render_factor=35, compare=True)"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "metadata": {},
370
+ "outputs": [],
371
+ "source": [
372
+ "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\", render_factor=42, compare=True)"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "vis.plot_transformed_image(\"test_images/dustbowl_2.jpg\", render_factor=16, compare=True)"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "metadata": {},
388
+ "outputs": [],
389
+ "source": [
390
+ "vis.plot_transformed_image(\"test_images/camera_man.jpg\", render_factor=25, compare=True)"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": null,
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "vis.plot_transformed_image(\"test_images/migrant_mother.jpg\", render_factor=32, compare=True)"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": null,
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "vis.plot_transformed_image(\"test_images/marktwain.jpg\", render_factor=14, compare=True)"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "metadata": {},
415
+ "outputs": [],
416
+ "source": [
417
+ "vis.plot_transformed_image(\"test_images/HelenKeller.jpg\", render_factor=35, compare=True)"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": [
426
+ "vis.plot_transformed_image(\"test_images/Evelyn_Nesbit.jpg\", render_factor=25, compare=True)"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": null,
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "vis.plot_transformed_image(\"test_images/Eddie-Adams.jpg\", compare=True)"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": null,
441
+ "metadata": {},
442
+ "outputs": [],
443
+ "source": [
444
+ "vis.plot_transformed_image(\"test_images/soldier_kids.jpg\", compare=True)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "vis.plot_transformed_image(\"test_images/AnselAdamsYosemite.jpg\", compare=True)"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "vis.plot_transformed_image(\"test_images/unnamed.jpg\", render_factor=28, compare=True)"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "vis.plot_transformed_image(\"test_images/workers_canyon.jpg\", render_factor=45, compare=True)"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "vis.plot_transformed_image(\"test_images/CottonMill.jpg\", compare=True)"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": [
489
+ "vis.plot_transformed_image(\"test_images/JudyGarland.jpeg\", compare=True)"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "vis.plot_transformed_image(\"test_images/kids_pit.jpg\", render_factor=30, compare=True)"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": null,
504
+ "metadata": {},
505
+ "outputs": [],
506
+ "source": [
507
+ "vis.plot_transformed_image(\"test_images/last_samurai.jpg\", render_factor=22, compare=True)"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": null,
513
+ "metadata": {},
514
+ "outputs": [],
515
+ "source": [
516
+ "vis.plot_transformed_image(\"test_images/AnselAdamsWhiteChurch.jpg\", render_factor=25, compare=True)"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "code",
521
+ "execution_count": null,
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "vis.plot_transformed_image(\"test_images/opium.jpg\", render_factor=30, compare=True)"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": null,
531
+ "metadata": {},
532
+ "outputs": [],
533
+ "source": [
534
+ "vis.plot_transformed_image(\"test_images/dorothea_lange_2.jpg\", render_factor=42, compare=True)"
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "execution_count": null,
540
+ "metadata": {},
541
+ "outputs": [],
542
+ "source": [
543
+ "vis.plot_transformed_image(\"test_images/rgs.jpg\", compare=True)"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "metadata": {},
550
+ "outputs": [],
551
+ "source": [
552
+ "vis.plot_transformed_image(\"test_images/wh-auden.jpg\", compare=True)"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": null,
558
+ "metadata": {},
559
+ "outputs": [],
560
+ "source": [
561
+ "vis.plot_transformed_image(\"test_images/w-b-yeats.jpg\", compare=True)"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": null,
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "vis.plot_transformed_image(\"test_images/marilyn_portrait.jpg\", compare=True)"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "metadata": {},
577
+ "outputs": [],
578
+ "source": [
579
+ "vis.plot_transformed_image(\"test_images/wilson-slaverevivalmeeting.jpg\", render_factor=45, compare=True)"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": null,
585
+ "metadata": {},
586
+ "outputs": [],
587
+ "source": [
588
+ "vis.plot_transformed_image(\"test_images/ww1_trench.jpg\", render_factor=18, compare=True)"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": null,
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": [
597
+ "vis.plot_transformed_image(\"test_images/women-bikers.png\", render_factor=23, compare=True)"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": null,
603
+ "metadata": {},
604
+ "outputs": [],
605
+ "source": [
606
+ "vis.plot_transformed_image(\"test_images/Unidentified1855.jpg\", render_factor=19, compare=True)"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": null,
612
+ "metadata": {},
613
+ "outputs": [],
614
+ "source": [
615
+ "vis.plot_transformed_image(\"test_images/skycrapper_lunch.jpg\", render_factor=25, compare=True)"
616
+ ]
617
+ },
618
+ {
619
+ "cell_type": "code",
620
+ "execution_count": null,
621
+ "metadata": {},
622
+ "outputs": [],
623
+ "source": [
624
+ "vis.plot_transformed_image(\"test_images/sioux.jpg\", render_factor=28, compare=True)"
625
+ ]
626
+ },
627
+ {
628
+ "cell_type": "code",
629
+ "execution_count": null,
630
+ "metadata": {},
631
+ "outputs": [],
632
+ "source": [
633
+ "vis.plot_transformed_image(\"test_images/school_kids.jpg\", render_factor=20, compare=True)"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": null,
639
+ "metadata": {},
640
+ "outputs": [],
641
+ "source": [
642
+ "vis.plot_transformed_image(\"test_images/royal_family.jpg\", render_factor=42, compare=True)"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": null,
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "vis.plot_transformed_image(\"test_images/redwood_lumberjacks.jpg\", render_factor=45, compare=True)"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": null,
657
+ "metadata": {},
658
+ "outputs": [],
659
+ "source": [
660
+ "vis.plot_transformed_image(\"test_images/poverty.jpg\", render_factor=40, compare=True)"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": null,
666
+ "metadata": {},
667
+ "outputs": [],
668
+ "source": [
669
+ "vis.plot_transformed_image(\"test_images/paperboy.jpg\", render_factor=45, compare=True)"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": null,
675
+ "metadata": {},
676
+ "outputs": [],
677
+ "source": [
678
+ "vis.plot_transformed_image(\"test_images/NativeAmericans.jpg\", render_factor=21, compare=True)"
679
+ ]
680
+ },
681
+ {
682
+ "cell_type": "code",
683
+ "execution_count": null,
684
+ "metadata": {},
685
+ "outputs": [],
686
+ "source": [
687
+ "vis.plot_transformed_image(\"test_images/helmut_newton-.jpg\", compare=True)"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": null,
693
+ "metadata": {},
694
+ "outputs": [],
695
+ "source": [
696
+ "vis.plot_transformed_image(\"test_images/Greece1911.jpg\", render_factor=44, compare=True)"
697
+ ]
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "execution_count": null,
702
+ "metadata": {},
703
+ "outputs": [],
704
+ "source": [
705
+ "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\", render_factor=18, compare=True)"
706
+ ]
707
+ },
708
+ {
709
+ "cell_type": "code",
710
+ "execution_count": null,
711
+ "metadata": {},
712
+ "outputs": [],
713
+ "source": [
714
+ "vis.plot_transformed_image(\"test_images/EgyptColosus.jpg\", compare=True)"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": null,
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": [
723
+ "vis.plot_transformed_image(\"test_images/egypt-2.jpg\", compare=True)"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": null,
729
+ "metadata": {},
730
+ "outputs": [],
731
+ "source": [
732
+ "vis.plot_transformed_image(\"test_images/dustbowl_sd.jpg\", compare=True)"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "metadata": {},
739
+ "outputs": [],
740
+ "source": [
741
+ "vis.plot_transformed_image(\"test_images/dustbowl_people.jpg\", render_factor=24, compare=True)"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "execution_count": null,
747
+ "metadata": {},
748
+ "outputs": [],
749
+ "source": [
750
+ "vis.plot_transformed_image(\"test_images/dustbowl_5.jpg\", compare=True)"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": null,
756
+ "metadata": {},
757
+ "outputs": [],
758
+ "source": [
759
+ "vis.plot_transformed_image(\"test_images/dustbowl_1.jpg\", compare=True)"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": null,
765
+ "metadata": {},
766
+ "outputs": [],
767
+ "source": [
768
+ "vis.plot_transformed_image(\"test_images/DriveThroughGiantTree.jpg\", render_factor=21, compare=True)"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": null,
774
+ "metadata": {},
775
+ "outputs": [],
776
+ "source": [
777
+ "vis.plot_transformed_image(\"test_images/covered-wagons-traveling.jpg\", compare=True)"
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "code",
782
+ "execution_count": null,
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "vis.plot_transformed_image(\"test_images/civil-war_2.jpg\", render_factor=42, compare=True)"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": null,
792
+ "metadata": {},
793
+ "outputs": [],
794
+ "source": [
795
+ "vis.plot_transformed_image(\"test_images/civil_war_4.jpg\", compare=True)"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": null,
801
+ "metadata": {},
802
+ "outputs": [],
803
+ "source": [
804
+ "vis.plot_transformed_image(\"test_images/civil_war_3.jpg\", render_factor=28, compare=True)"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": null,
810
+ "metadata": {},
811
+ "outputs": [],
812
+ "source": [
813
+ "vis.plot_transformed_image(\"test_images/civil_war.jpg\", compare=True)"
814
+ ]
815
+ },
816
+ {
817
+ "cell_type": "code",
818
+ "execution_count": null,
819
+ "metadata": {},
820
+ "outputs": [],
821
+ "source": [
822
+ "vis.plot_transformed_image(\"test_images/BritishSlum.jpg\", render_factor=30, compare=True)"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": null,
828
+ "metadata": {},
829
+ "outputs": [],
830
+ "source": [
831
+ "vis.plot_transformed_image(\"test_images/bicycles.jpg\", render_factor=27, compare=True)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": null,
837
+ "metadata": {},
838
+ "outputs": [],
839
+ "source": [
840
+ "vis.plot_transformed_image(\"test_images/brooklyn_girls_1940s.jpg\", compare=True)"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": null,
846
+ "metadata": {},
847
+ "outputs": [],
848
+ "source": [
849
+ "vis.plot_transformed_image(\"test_images/40sCouple.jpg\", render_factor=21, compare=True)"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "code",
854
+ "execution_count": null,
855
+ "metadata": {},
856
+ "outputs": [],
857
+ "source": [
858
+ "vis.plot_transformed_image(\"test_images/1946Wedding.jpg\", compare=True)"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": null,
864
+ "metadata": {},
865
+ "outputs": [],
866
+ "source": [
867
+ "vis.plot_transformed_image(\"test_images/Dolores1920s.jpg\", render_factor=18, compare=True)"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": null,
873
+ "metadata": {},
874
+ "outputs": [],
875
+ "source": [
876
+ "vis.plot_transformed_image(\"test_images/TitanicGym.jpg\", render_factor=26, compare=True)"
877
+ ]
878
+ },
879
+ {
880
+ "cell_type": "code",
881
+ "execution_count": null,
882
+ "metadata": {},
883
+ "outputs": [],
884
+ "source": [
885
+ "vis.plot_transformed_image(\"test_images/FrenchVillage1950s.jpg\", render_factor=41, compare=True)"
886
+ ]
887
+ },
888
+ {
889
+ "cell_type": "code",
890
+ "execution_count": null,
891
+ "metadata": {},
892
+ "outputs": [],
893
+ "source": [
894
+ "vis.plot_transformed_image(\"test_images/FrenchVillage1950s.jpg\", render_factor=32, compare=True)"
895
+ ]
896
+ },
897
+ {
898
+ "cell_type": "code",
899
+ "execution_count": null,
900
+ "metadata": {},
901
+ "outputs": [],
902
+ "source": [
903
+ "vis.plot_transformed_image(\"test_images/ClassDivide1930sBrittain.jpg\", render_factor=45, compare=True)"
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "code",
908
+ "execution_count": null,
909
+ "metadata": {},
910
+ "outputs": [],
911
+ "source": [
912
+ "vis.plot_transformed_image(\"test_images/1870sSphinx.jpg\", compare=True)"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": null,
918
+ "metadata": {},
919
+ "outputs": [],
920
+ "source": [
921
+ "vis.plot_transformed_image(\"test_images/1890Surfer.png\", render_factor=37, compare=True)"
922
+ ]
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "execution_count": null,
927
+ "metadata": {},
928
+ "outputs": [],
929
+ "source": [
930
+ "vis.plot_transformed_image(\"test_images/TV1930s.jpg\", render_factor=43, compare=True)"
931
+ ]
932
+ },
933
+ {
934
+ "cell_type": "code",
935
+ "execution_count": null,
936
+ "metadata": {},
937
+ "outputs": [],
938
+ "source": [
939
+ "vis.plot_transformed_image(\"test_images/1864UnionSoldier.jpg\", compare=True)"
940
+ ]
941
+ },
942
+ {
943
+ "cell_type": "code",
944
+ "execution_count": null,
945
+ "metadata": {},
946
+ "outputs": [],
947
+ "source": [
948
+ "vis.plot_transformed_image(\"test_images/1890sMedStudents.jpg\", render_factor=18, compare=True)"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "execution_count": null,
954
+ "metadata": {},
955
+ "outputs": [],
956
+ "source": [
957
+ "vis.plot_transformed_image(\"test_images/BellyLaughWWI.jpg\", compare=True)"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": null,
963
+ "metadata": {},
964
+ "outputs": [],
965
+ "source": [
966
+ "vis.plot_transformed_image(\"test_images/PiggyBackRide.jpg\", compare=True)"
967
+ ]
968
+ },
969
+ {
970
+ "cell_type": "code",
971
+ "execution_count": null,
972
+ "metadata": {},
973
+ "outputs": [],
974
+ "source": [
975
+ "vis.plot_transformed_image(\"test_images/HealingTree.jpg\", compare=True)"
976
+ ]
977
+ },
978
+ {
979
+ "cell_type": "code",
980
+ "execution_count": null,
981
+ "metadata": {},
982
+ "outputs": [],
983
+ "source": [
984
+ "vis.plot_transformed_image(\"test_images/ManPile.jpg\", compare=True)"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "code",
989
+ "execution_count": null,
990
+ "metadata": {},
991
+ "outputs": [],
992
+ "source": [
993
+ "vis.plot_transformed_image(\"test_images/1910Bike.jpg\", compare=True)"
994
+ ]
995
+ },
996
+ {
997
+ "cell_type": "code",
998
+ "execution_count": null,
999
+ "metadata": {},
1000
+ "outputs": [],
1001
+ "source": [
1002
+ "vis.plot_transformed_image(\"test_images/FreeportIL.jpg\", compare=True)"
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": null,
1008
+ "metadata": {},
1009
+ "outputs": [],
1010
+ "source": [
1011
+ "vis.plot_transformed_image(\"test_images/DutchBabyCoupleEllis.jpg\", compare=True)"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": null,
1017
+ "metadata": {},
1018
+ "outputs": [],
1019
+ "source": [
1020
+ "vis.plot_transformed_image(\"test_images/InuitWoman1903.png\", compare=True)"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "code",
1025
+ "execution_count": null,
1026
+ "metadata": {},
1027
+ "outputs": [],
1028
+ "source": [
1029
+ "vis.plot_transformed_image(\"test_images/1920sDancing.jpg\", compare=True)"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "metadata": {},
1036
+ "outputs": [],
1037
+ "source": [
1038
+ "vis.plot_transformed_image(\"test_images/AirmanDad.jpg\", render_factor=13, compare=True)"
1039
+ ]
1040
+ },
1041
+ {
1042
+ "cell_type": "code",
1043
+ "execution_count": null,
1044
+ "metadata": {},
1045
+ "outputs": [],
1046
+ "source": [
1047
+ "vis.plot_transformed_image(\"test_images/1910Racket.png\", render_factor=30, compare=True)"
1048
+ ]
1049
+ },
1050
+ {
1051
+ "cell_type": "code",
1052
+ "execution_count": null,
1053
+ "metadata": {},
1054
+ "outputs": [],
1055
+ "source": [
1056
+ "vis.plot_transformed_image(\"test_images/1880Paris.jpg\", render_factor=16, compare=True)"
1057
+ ]
1058
+ },
1059
+ {
1060
+ "cell_type": "code",
1061
+ "execution_count": null,
1062
+ "metadata": {},
1063
+ "outputs": [],
1064
+ "source": [
1065
+ "vis.plot_transformed_image(\"test_images/Deadwood1860s.jpg\", render_factor=13, compare=True)"
1066
+ ]
1067
+ },
1068
+ {
1069
+ "cell_type": "code",
1070
+ "execution_count": null,
1071
+ "metadata": {},
1072
+ "outputs": [],
1073
+ "source": [
1074
+ "vis.plot_transformed_image(\"test_images/1860sSamauris.jpg\", render_factor=43, compare=True)"
1075
+ ]
1076
+ },
1077
+ {
1078
+ "cell_type": "code",
1079
+ "execution_count": null,
1080
+ "metadata": {},
1081
+ "outputs": [],
1082
+ "source": [
1083
+ "vis.plot_transformed_image(\"test_images/LondonUnderground1860.jpg\", render_factor=45, compare=True)"
1084
+ ]
1085
+ },
1086
+ {
1087
+ "cell_type": "code",
1088
+ "execution_count": null,
1089
+ "metadata": {},
1090
+ "outputs": [],
1091
+ "source": [
1092
+ "vis.plot_transformed_image(\"test_images/Mid1800sSisters.jpg\", compare=True)"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "code",
1097
+ "execution_count": null,
1098
+ "metadata": {},
1099
+ "outputs": [],
1100
+ "source": [
1101
+ "vis.plot_transformed_image(\"test_images/1860Girls.jpg\", render_factor=45, compare=True)"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": null,
1107
+ "metadata": {},
1108
+ "outputs": [],
1109
+ "source": [
1110
+ "vis.plot_transformed_image(\"test_images/SanFran1851.jpg\", render_factor=44, compare=True)"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": null,
1116
+ "metadata": {},
1117
+ "outputs": [],
1118
+ "source": [
1119
+ "vis.plot_transformed_image(\"test_images/Kabuki1870s.png\", render_factor=8, compare=True)"
1120
+ ]
1121
+ },
1122
+ {
1123
+ "cell_type": "code",
1124
+ "execution_count": null,
1125
+ "metadata": {},
1126
+ "outputs": [],
1127
+ "source": [
1128
+ "vis.plot_transformed_image(\"test_images/Mormons1870s.jpg\", render_factor=44, compare=True)"
1129
+ ]
1130
+ },
1131
+ {
1132
+ "cell_type": "code",
1133
+ "execution_count": null,
1134
+ "metadata": {},
1135
+ "outputs": [],
1136
+ "source": [
1137
+ "vis.plot_transformed_image(\"test_images/EgyptianWomenLate1800s.jpg\", render_factor=44, compare=True)"
1138
+ ]
1139
+ },
1140
+ {
1141
+ "cell_type": "code",
1142
+ "execution_count": null,
1143
+ "metadata": {},
1144
+ "outputs": [],
1145
+ "source": [
1146
+ "vis.plot_transformed_image(\"test_images/PicadillyLate1800s.jpg\", render_factor=26, compare=True)"
1147
+ ]
1148
+ },
1149
+ {
1150
+ "cell_type": "code",
1151
+ "execution_count": null,
1152
+ "metadata": {},
1153
+ "outputs": [],
1154
+ "source": [
1155
+ "vis.plot_transformed_image(\"test_images/SutroBaths1880s.jpg\", compare=True)"
1156
+ ]
1157
+ },
1158
+ {
1159
+ "cell_type": "code",
1160
+ "execution_count": null,
1161
+ "metadata": {},
1162
+ "outputs": [],
1163
+ "source": [
1164
+ "vis.plot_transformed_image(\"test_images/1880sBrooklynBridge.jpg\", compare=True)"
1165
+ ]
1166
+ },
1167
+ {
1168
+ "cell_type": "code",
1169
+ "execution_count": null,
1170
+ "metadata": {},
1171
+ "outputs": [],
1172
+ "source": [
1173
+ "vis.plot_transformed_image(\"test_images/ChinaOpiumc1880.jpg\", render_factor=30, compare=True)"
1174
+ ]
1175
+ },
1176
+ {
1177
+ "cell_type": "code",
1178
+ "execution_count": null,
1179
+ "metadata": {},
1180
+ "outputs": [],
1181
+ "source": [
1182
+ "vis.plot_transformed_image(\"test_images/Locomotive1880s.jpg\", render_factor=9, compare=True)"
1183
+ ]
1184
+ },
1185
+ {
1186
+ "cell_type": "code",
1187
+ "execution_count": null,
1188
+ "metadata": {},
1189
+ "outputs": [],
1190
+ "source": [
1191
+ "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\", compare=True)"
1192
+ ]
1193
+ },
1194
+ {
1195
+ "cell_type": "code",
1196
+ "execution_count": null,
1197
+ "metadata": {},
1198
+ "outputs": [],
1199
+ "source": [
1200
+ "vis.plot_transformed_image(\"test_images/VictorianDragQueen1880s.png\", compare=True)"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": null,
1206
+ "metadata": {},
1207
+ "outputs": [],
1208
+ "source": [
1209
+ "vis.plot_transformed_image(\"test_images/Sami1880s.jpg\", render_factor=44, compare=True)"
1210
+ ]
1211
+ },
1212
+ {
1213
+ "cell_type": "code",
1214
+ "execution_count": null,
1215
+ "metadata": {},
1216
+ "outputs": [],
1217
+ "source": [
1218
+ "vis.plot_transformed_image(\"test_images/ArkansasCowboys1880s.jpg\", render_factor=22, compare=True)"
1219
+ ]
1220
+ },
1221
+ {
1222
+ "cell_type": "code",
1223
+ "execution_count": null,
1224
+ "metadata": {},
1225
+ "outputs": [],
1226
+ "source": [
1227
+ "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\", render_factor=40, compare=True)"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "cell_type": "code",
1232
+ "execution_count": null,
1233
+ "metadata": {},
1234
+ "outputs": [],
1235
+ "source": [
1236
+ "vis.plot_transformed_image(\"test_images/Rottindean1890s.png\", render_factor=20, compare=True)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": null,
1242
+ "metadata": {},
1243
+ "outputs": [],
1244
+ "source": [
1245
+ "vis.plot_transformed_image(\"test_images/1890sPingPong.jpg\", compare=True)"
1246
+ ]
1247
+ },
1248
+ {
1249
+ "cell_type": "code",
1250
+ "execution_count": null,
1251
+ "metadata": {},
1252
+ "outputs": [],
1253
+ "source": [
1254
+ "vis.plot_transformed_image(\"test_images/London1937.png\", render_factor=45, compare=True)"
1255
+ ]
1256
+ },
1257
+ {
1258
+ "cell_type": "code",
1259
+ "execution_count": null,
1260
+ "metadata": {},
1261
+ "outputs": [],
1262
+ "source": [
1263
+ "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\", render_factor=37, compare=True)"
1264
+ ]
1265
+ },
1266
+ {
1267
+ "cell_type": "code",
1268
+ "execution_count": null,
1269
+ "metadata": {},
1270
+ "outputs": [],
1271
+ "source": [
1272
+ "vis.plot_transformed_image(\"test_images/OregonTrail1870s.jpg\", render_factor=40, compare=True)"
1273
+ ]
1274
+ },
1275
+ {
1276
+ "cell_type": "code",
1277
+ "execution_count": null,
1278
+ "metadata": {},
1279
+ "outputs": [],
1280
+ "source": [
1281
+ "vis.plot_transformed_image(\"test_images/EasterNyc1911.jpg\", render_factor=19, compare=True)"
1282
+ ]
1283
+ },
1284
+ {
1285
+ "cell_type": "code",
1286
+ "execution_count": null,
1287
+ "metadata": {},
1288
+ "outputs": [],
1289
+ "source": [
1290
+ "vis.plot_transformed_image(\"test_images/1899NycBlizzard.jpg\", render_factor=45, compare=True)"
1291
+ ]
1292
+ },
1293
+ {
1294
+ "cell_type": "code",
1295
+ "execution_count": null,
1296
+ "metadata": {},
1297
+ "outputs": [],
1298
+ "source": [
1299
+ "vis.plot_transformed_image(\"test_images/Edinburgh1920s.jpg\", render_factor=17, compare=True)"
1300
+ ]
1301
+ },
1302
+ {
1303
+ "cell_type": "code",
1304
+ "execution_count": null,
1305
+ "metadata": {},
1306
+ "outputs": [],
1307
+ "source": [
1308
+ "vis.plot_transformed_image(\"test_images/1890sShoeShopOhio.jpg\", render_factor=46, compare=True)"
1309
+ ]
1310
+ },
1311
+ {
1312
+ "cell_type": "code",
1313
+ "execution_count": null,
1314
+ "metadata": {},
1315
+ "outputs": [],
1316
+ "source": [
1317
+ "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\", render_factor=40, compare=True)"
1318
+ ]
1319
+ },
1320
+ {
1321
+ "cell_type": "code",
1322
+ "execution_count": null,
1323
+ "metadata": {},
1324
+ "outputs": [],
1325
+ "source": [
1326
+ "vis.plot_transformed_image(\"test_images/1938Reading.jpg\", render_factor=19, compare=True)"
1327
+ ]
1328
+ },
1329
+ {
1330
+ "cell_type": "code",
1331
+ "execution_count": null,
1332
+ "metadata": {},
1333
+ "outputs": [],
1334
+ "source": [
1335
+ "vis.plot_transformed_image(\"test_images/1850Geography.jpg\", compare=True)"
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "execution_count": null,
1341
+ "metadata": {},
1342
+ "outputs": [],
1343
+ "source": [
1344
+ "vis.plot_transformed_image(\"test_images/1901Electrophone.jpg\", render_factor=10, compare=True)"
1345
+ ]
1346
+ },
1347
+ {
1348
+ "cell_type": "code",
1349
+ "execution_count": null,
1350
+ "metadata": {},
1351
+ "outputs": [],
1352
+ "source": [
1353
+ "for i in range(8, 47):\n",
1354
+ " vis.plot_transformed_image(\"test_images/1901Electrophone.jpg\", render_factor=i, compare=True)"
1355
+ ]
1356
+ },
1357
+ {
1358
+ "cell_type": "code",
1359
+ "execution_count": null,
1360
+ "metadata": {},
1361
+ "outputs": [],
1362
+ "source": [
1363
+ "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\", render_factor=38, compare=True)"
1364
+ ]
1365
+ },
1366
+ {
1367
+ "cell_type": "code",
1368
+ "execution_count": null,
1369
+ "metadata": {},
1370
+ "outputs": [],
1371
+ "source": [
1372
+ "vis.plot_transformed_image(\"test_images/MaioreWoman1895NZ.jpg\", compare=True)"
1373
+ ]
1374
+ },
1375
+ {
1376
+ "cell_type": "code",
1377
+ "execution_count": null,
1378
+ "metadata": {},
1379
+ "outputs": [],
1380
+ "source": [
1381
+ "vis.plot_transformed_image(\"test_images/WestVirginiaHouse.jpg\", compare=True)"
1382
+ ]
1383
+ },
1384
+ {
1385
+ "cell_type": "code",
1386
+ "execution_count": null,
1387
+ "metadata": {},
1388
+ "outputs": [],
1389
+ "source": [
1390
+ "vis.plot_transformed_image(\"test_images/1920sGuadalope.jpg\", compare=True)"
1391
+ ]
1392
+ },
1393
+ {
1394
+ "cell_type": "code",
1395
+ "execution_count": null,
1396
+ "metadata": {},
1397
+ "outputs": [],
1398
+ "source": [
1399
+ "vis.plot_transformed_image(\"test_images/1909Chicago.jpg\", render_factor=45, compare=True)"
1400
+ ]
1401
+ },
1402
+ {
1403
+ "cell_type": "code",
1404
+ "execution_count": null,
1405
+ "metadata": {},
1406
+ "outputs": [],
1407
+ "source": [
1408
+ "vis.plot_transformed_image(\"test_images/1920sFarmKid.jpg\", compare=True)"
1409
+ ]
1410
+ },
1411
+ {
1412
+ "cell_type": "code",
1413
+ "execution_count": null,
1414
+ "metadata": {},
1415
+ "outputs": [],
1416
+ "source": [
1417
+ "vis.plot_transformed_image(\"test_images/ParisLate1800s.jpg\", render_factor=45, compare=True)"
1418
+ ]
1419
+ },
1420
+ {
1421
+ "cell_type": "code",
1422
+ "execution_count": null,
1423
+ "metadata": {},
1424
+ "outputs": [],
1425
+ "source": [
1426
+ "vis.plot_transformed_image(\"test_images/1900sDaytonaBeach.png\", render_factor=23, compare=True)"
1427
+ ]
1428
+ },
1429
+ {
1430
+ "cell_type": "code",
1431
+ "execution_count": null,
1432
+ "metadata": {},
1433
+ "outputs": [],
1434
+ "source": [
1435
+ "vis.plot_transformed_image(\"test_images/1930sGeorgia.jpg\", compare=True)"
1436
+ ]
1437
+ },
1438
+ {
1439
+ "cell_type": "code",
1440
+ "execution_count": null,
1441
+ "metadata": {},
1442
+ "outputs": [],
1443
+ "source": [
1444
+ "vis.plot_transformed_image(\"test_images/NorwegianBride1920s.jpg\", render_factor=30, compare=True)"
1445
+ ]
1446
+ },
1447
+ {
1448
+ "cell_type": "code",
1449
+ "execution_count": null,
1450
+ "metadata": {},
1451
+ "outputs": [],
1452
+ "source": [
1453
+ "vis.plot_transformed_image(\"test_images/Depression.jpg\", compare=True)"
1454
+ ]
1455
+ },
1456
+ {
1457
+ "cell_type": "code",
1458
+ "execution_count": null,
1459
+ "metadata": {},
1460
+ "outputs": [],
1461
+ "source": [
1462
+ "vis.plot_transformed_image(\"test_images/1888Slum.jpg\", render_factor=30, compare=True)"
1463
+ ]
1464
+ },
1465
+ {
1466
+ "cell_type": "code",
1467
+ "execution_count": null,
1468
+ "metadata": {},
1469
+ "outputs": [],
1470
+ "source": [
1471
+ "vis.plot_transformed_image(\"test_images/LivingRoom1920Sweden.jpg\", render_factor=45, compare=True)"
1472
+ ]
1473
+ },
1474
+ {
1475
+ "cell_type": "code",
1476
+ "execution_count": null,
1477
+ "metadata": {},
1478
+ "outputs": [],
1479
+ "source": [
1480
+ "vis.plot_transformed_image(\"test_images/1896NewsBoyGirl.jpg\", compare=True)"
1481
+ ]
1482
+ },
1483
+ {
1484
+ "cell_type": "code",
1485
+ "execution_count": null,
1486
+ "metadata": {},
1487
+ "outputs": [],
1488
+ "source": [
1489
+ "vis.plot_transformed_image(\"test_images/PetDucks1927.jpg\", compare=True)"
1490
+ ]
1491
+ },
1492
+ {
1493
+ "cell_type": "code",
1494
+ "execution_count": null,
1495
+ "metadata": {},
1496
+ "outputs": [],
1497
+ "source": [
1498
+ "vis.plot_transformed_image(\"test_images/1899SodaFountain.jpg\", render_factor=46, compare=True)"
1499
+ ]
1500
+ },
1501
+ {
1502
+ "cell_type": "code",
1503
+ "execution_count": null,
1504
+ "metadata": {},
1505
+ "outputs": [],
1506
+ "source": [
1507
+ "vis.plot_transformed_image(\"test_images/TimesSquare1955.jpg\", compare=True)"
1508
+ ]
1509
+ },
1510
+ {
1511
+ "cell_type": "code",
1512
+ "execution_count": null,
1513
+ "metadata": {},
1514
+ "outputs": [],
1515
+ "source": [
1516
+ "vis.plot_transformed_image(\"test_images/PuppyGify.jpg\", compare=True)"
1517
+ ]
1518
+ },
1519
+ {
1520
+ "cell_type": "code",
1521
+ "execution_count": null,
1522
+ "metadata": {},
1523
+ "outputs": [],
1524
+ "source": [
1525
+ "vis.plot_transformed_image(\"test_images/1890CliffHouseSF.jpg\", render_factor=30, compare=True)"
1526
+ ]
1527
+ },
1528
+ {
1529
+ "cell_type": "code",
1530
+ "execution_count": null,
1531
+ "metadata": {},
1532
+ "outputs": [],
1533
+ "source": [
1534
+ "vis.plot_transformed_image(\"test_images/1908FamilyPhoto.jpg\", render_factor=45, compare=True)"
1535
+ ]
1536
+ },
1537
+ {
1538
+ "cell_type": "code",
1539
+ "execution_count": null,
1540
+ "metadata": {},
1541
+ "outputs": [],
1542
+ "source": [
1543
+ "vis.plot_transformed_image(\"test_images/1900sSaloon.jpg\", render_factor=43, compare=True)"
1544
+ ]
1545
+ },
1546
+ {
1547
+ "cell_type": "code",
1548
+ "execution_count": null,
1549
+ "metadata": {},
1550
+ "outputs": [],
1551
+ "source": [
1552
+ "vis.plot_transformed_image(\"test_images/1890BostonHospital.jpg\", render_factor=40, compare=True)"
1553
+ ]
1554
+ },
1555
+ {
1556
+ "cell_type": "code",
1557
+ "execution_count": null,
1558
+ "metadata": {},
1559
+ "outputs": [],
1560
+ "source": [
1561
+ "vis.plot_transformed_image(\"test_images/1870Girl.jpg\", compare=True)"
1562
+ ]
1563
+ },
1564
+ {
1565
+ "cell_type": "code",
1566
+ "execution_count": null,
1567
+ "metadata": {},
1568
+ "outputs": [],
1569
+ "source": [
1570
+ "vis.plot_transformed_image(\"test_images/AustriaHungaryWomen1890s.jpg\", compare=True)"
1571
+ ]
1572
+ },
1573
+ {
1574
+ "cell_type": "code",
1575
+ "execution_count": null,
1576
+ "metadata": {},
1577
+ "outputs": [],
1578
+ "source": [
1579
+ "vis.plot_transformed_image(\"test_images/Shack.jpg\",render_factor=42, compare=True)"
1580
+ ]
1581
+ },
1582
+ {
1583
+ "cell_type": "code",
1584
+ "execution_count": null,
1585
+ "metadata": {},
1586
+ "outputs": [],
1587
+ "source": [
1588
+ "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\", render_factor=35, compare=True)"
1589
+ ]
1590
+ },
1591
+ {
1592
+ "cell_type": "code",
1593
+ "execution_count": null,
1594
+ "metadata": {},
1595
+ "outputs": [],
1596
+ "source": [
1597
+ "vis.plot_transformed_image(\"test_images/1948CarsGrandma.jpg\", compare=True)"
1598
+ ]
1599
+ },
1600
+ {
1601
+ "cell_type": "code",
1602
+ "execution_count": null,
1603
+ "metadata": {},
1604
+ "outputs": [],
1605
+ "source": [
1606
+ "vis.plot_transformed_image(\"test_images/PlanesManhattan1931.jpg\", compare=True)"
1607
+ ]
1608
+ },
1609
+ {
1610
+ "cell_type": "code",
1611
+ "execution_count": null,
1612
+ "metadata": {},
1613
+ "outputs": [],
1614
+ "source": [
1615
+ "vis.plot_transformed_image(\"test_images/WorriedKid1940sNyc.jpg\", compare=True)"
1616
+ ]
1617
+ },
1618
+ {
1619
+ "cell_type": "code",
1620
+ "execution_count": null,
1621
+ "metadata": {},
1622
+ "outputs": [],
1623
+ "source": [
1624
+ "vis.plot_transformed_image(\"test_images/1920sFamilyPhoto.jpg\", compare=True)"
1625
+ ]
1626
+ },
1627
+ {
1628
+ "cell_type": "code",
1629
+ "execution_count": null,
1630
+ "metadata": {},
1631
+ "outputs": [],
1632
+ "source": [
1633
+ "vis.plot_transformed_image(\"test_images/CatWash1931.jpg\", compare=True)"
1634
+ ]
1635
+ },
1636
+ {
1637
+ "cell_type": "code",
1638
+ "execution_count": null,
1639
+ "metadata": {},
1640
+ "outputs": [],
1641
+ "source": [
1642
+ "vis.plot_transformed_image(\"test_images/1940sBeerRiver.jpg\", compare=True)"
1643
+ ]
1644
+ },
1645
+ {
1646
+ "cell_type": "code",
1647
+ "execution_count": null,
1648
+ "metadata": {},
1649
+ "outputs": [],
1650
+ "source": [
1651
+ "vis.plot_transformed_image(\"test_images/VictorianLivingRoom.jpg\", render_factor=45, compare=True)"
1652
+ ]
1653
+ },
1654
+ {
1655
+ "cell_type": "code",
1656
+ "execution_count": null,
1657
+ "metadata": {},
1658
+ "outputs": [],
1659
+ "source": [
1660
+ "vis.plot_transformed_image(\"test_images/1897BlindmansBluff.jpg\", compare=True)"
1661
+ ]
1662
+ },
1663
+ {
1664
+ "cell_type": "code",
1665
+ "execution_count": null,
1666
+ "metadata": {},
1667
+ "outputs": [],
1668
+ "source": [
1669
+ "vis.plot_transformed_image(\"test_images/1874Mexico.png\", compare=True)"
1670
+ ]
1671
+ },
1672
+ {
1673
+ "cell_type": "code",
1674
+ "execution_count": null,
1675
+ "metadata": {},
1676
+ "outputs": [],
1677
+ "source": [
1678
+ "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\", render_factor=46, compare=True)"
1679
+ ]
1680
+ },
1681
+ {
1682
+ "cell_type": "code",
1683
+ "execution_count": null,
1684
+ "metadata": {},
1685
+ "outputs": [],
1686
+ "source": [
1687
+ "vis.plot_transformed_image(\"test_images/1867MusicianConstantinople.jpg\", compare=True)"
1688
+ ]
1689
+ },
1690
+ {
1691
+ "cell_type": "code",
1692
+ "execution_count": null,
1693
+ "metadata": {},
1694
+ "outputs": [],
1695
+ "source": [
1696
+ "vis.plot_transformed_image(\"test_images/1925Girl.jpg\", render_factor=25, compare=True)"
1697
+ ]
1698
+ },
1699
+ {
1700
+ "cell_type": "code",
1701
+ "execution_count": null,
1702
+ "metadata": {},
1703
+ "outputs": [],
1704
+ "source": [
1705
+ "vis.plot_transformed_image(\"test_images/1907Cowboys.jpg\", render_factor=28, compare=True)"
1706
+ ]
1707
+ },
1708
+ {
1709
+ "cell_type": "code",
1710
+ "execution_count": null,
1711
+ "metadata": {},
1712
+ "outputs": [],
1713
+ "source": [
1714
+ "vis.plot_transformed_image(\"test_images/WWIIPeeps.jpg\", render_factor=37, compare=True)"
1715
+ ]
1716
+ },
1717
+ {
1718
+ "cell_type": "code",
1719
+ "execution_count": null,
1720
+ "metadata": {},
1721
+ "outputs": [],
1722
+ "source": [
1723
+ "vis.plot_transformed_image(\"test_images/BabyBigBoots.jpg\", render_factor=40, compare=True)"
1724
+ ]
1725
+ },
1726
+ {
1727
+ "cell_type": "code",
1728
+ "execution_count": null,
1729
+ "metadata": {},
1730
+ "outputs": [],
1731
+ "source": [
1732
+ "vis.plot_transformed_image(\"test_images/1895BikeMaidens.jpg\", render_factor=25, compare=True)"
1733
+ ]
1734
+ },
1735
+ {
1736
+ "cell_type": "code",
1737
+ "execution_count": null,
1738
+ "metadata": {},
1739
+ "outputs": [],
1740
+ "source": [
1741
+ "vis.plot_transformed_image(\"test_images/IrishLate1800s.jpg\", render_factor=25, compare=True)"
1742
+ ]
1743
+ },
1744
+ {
1745
+ "cell_type": "code",
1746
+ "execution_count": null,
1747
+ "metadata": {},
1748
+ "outputs": [],
1749
+ "source": [
1750
+ "vis.plot_transformed_image(\"test_images/LibraryOfCongress1910.jpg\", render_factor=21, compare=True)"
1751
+ ]
1752
+ },
1753
+ {
1754
+ "cell_type": "code",
1755
+ "execution_count": null,
1756
+ "metadata": {},
1757
+ "outputs": [],
1758
+ "source": [
1759
+ "vis.plot_transformed_image(\"test_images/1875Olds.jpg\", render_factor=16, compare=True)"
1760
+ ]
1761
+ },
1762
+ {
1763
+ "cell_type": "code",
1764
+ "execution_count": null,
1765
+ "metadata": {},
1766
+ "outputs": [],
1767
+ "source": [
1768
+ "vis.plot_transformed_image(\"test_images/SenecaNative1908.jpg\", render_factor=30, compare=True)"
1769
+ ]
1770
+ },
1771
+ {
1772
+ "cell_type": "code",
1773
+ "execution_count": null,
1774
+ "metadata": {},
1775
+ "outputs": [],
1776
+ "source": [
1777
+ "vis.plot_transformed_image(\"test_images/WWIHospital.jpg\", render_factor=40, compare=True)"
1778
+ ]
1779
+ },
1780
+ {
1781
+ "cell_type": "code",
1782
+ "execution_count": null,
1783
+ "metadata": {},
1784
+ "outputs": [],
1785
+ "source": [
1786
+ "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\", render_factor=45, compare=True)"
1787
+ ]
1788
+ },
1789
+ {
1790
+ "cell_type": "code",
1791
+ "execution_count": null,
1792
+ "metadata": {},
1793
+ "outputs": [],
1794
+ "source": [
1795
+ "vis.plot_transformed_image(\"test_images/GreekImmigrants1905.jpg\", render_factor=25, compare=True)"
1796
+ ]
1797
+ },
1798
+ {
1799
+ "cell_type": "code",
1800
+ "execution_count": null,
1801
+ "metadata": {},
1802
+ "outputs": [],
1803
+ "source": [
1804
+ "vis.plot_transformed_image(\"test_images/FatMensShop.jpg\", render_factor=21, compare=True)"
1805
+ ]
1806
+ },
1807
+ {
1808
+ "cell_type": "code",
1809
+ "execution_count": null,
1810
+ "metadata": {},
1811
+ "outputs": [],
1812
+ "source": [
1813
+ "vis.plot_transformed_image(\"test_images/KidCage1930s.png\", compare=True)"
1814
+ ]
1815
+ },
1816
+ {
1817
+ "cell_type": "code",
1818
+ "execution_count": null,
1819
+ "metadata": {},
1820
+ "outputs": [],
1821
+ "source": [
1822
+ "vis.plot_transformed_image(\"test_images/FarmWomen1895.jpg\", compare=True)"
1823
+ ]
1824
+ },
1825
+ {
1826
+ "cell_type": "code",
1827
+ "execution_count": null,
1828
+ "metadata": {},
1829
+ "outputs": [],
1830
+ "source": [
1831
+ "vis.plot_transformed_image(\"test_images/NewZealand1860s.jpg\", compare=True)"
1832
+ ]
1833
+ },
1834
+ {
1835
+ "cell_type": "code",
1836
+ "execution_count": null,
1837
+ "metadata": {},
1838
+ "outputs": [],
1839
+ "source": [
1840
+ "vis.plot_transformed_image(\"test_images/JerseyShore1905.jpg\", render_factor=45, compare=True)"
1841
+ ]
1842
+ },
1843
+ {
1844
+ "cell_type": "code",
1845
+ "execution_count": null,
1846
+ "metadata": {},
1847
+ "outputs": [],
1848
+ "source": [
1849
+ "vis.plot_transformed_image(\"test_images/LondonKidsEarly1900s.jpg\", compare=True)"
1850
+ ]
1851
+ },
1852
+ {
1853
+ "cell_type": "code",
1854
+ "execution_count": null,
1855
+ "metadata": {},
1856
+ "outputs": [],
1857
+ "source": [
1858
+ "vis.plot_transformed_image(\"test_images/NYStreetClean1906.jpg\", compare=True)"
1859
+ ]
1860
+ },
1861
+ {
1862
+ "cell_type": "code",
1863
+ "execution_count": null,
1864
+ "metadata": {},
1865
+ "outputs": [],
1866
+ "source": [
1867
+ "vis.plot_transformed_image(\"test_images/Boston1937.jpg\", compare=True)"
1868
+ ]
1869
+ },
1870
+ {
1871
+ "cell_type": "code",
1872
+ "execution_count": null,
1873
+ "metadata": {},
1874
+ "outputs": [],
1875
+ "source": [
1876
+ "vis.plot_transformed_image(\"test_images/Cork1905.jpg\", render_factor=28, compare=True)"
1877
+ ]
1878
+ },
1879
+ {
1880
+ "cell_type": "code",
1881
+ "execution_count": null,
1882
+ "metadata": {},
1883
+ "outputs": [],
1884
+ "source": [
1885
+ "vis.plot_transformed_image(\"test_images/BoxedBedEarly1900s.jpg\", compare=True)"
1886
+ ]
1887
+ },
1888
+ {
1889
+ "cell_type": "code",
1890
+ "execution_count": null,
1891
+ "metadata": {},
1892
+ "outputs": [],
1893
+ "source": [
1894
+ "vis.plot_transformed_image(\"test_images/ZoologischerGarten1898.jpg\", compare=True)"
1895
+ ]
1896
+ },
1897
+ {
1898
+ "cell_type": "code",
1899
+ "execution_count": null,
1900
+ "metadata": {},
1901
+ "outputs": [],
1902
+ "source": [
1903
+ "vis.plot_transformed_image(\"test_images/EmpireState1930.jpg\", compare=True)"
1904
+ ]
1905
+ },
1906
+ {
1907
+ "cell_type": "code",
1908
+ "execution_count": null,
1909
+ "metadata": {},
1910
+ "outputs": [],
1911
+ "source": [
1912
+ "vis.plot_transformed_image(\"test_images/Agamemnon1919.jpg\", render_factor=40, compare=True)"
1913
+ ]
1914
+ },
1915
+ {
1916
+ "cell_type": "code",
1917
+ "execution_count": null,
1918
+ "metadata": {},
1919
+ "outputs": [],
1920
+ "source": [
1921
+ "vis.plot_transformed_image(\"test_images/AppalachianLoggers1901.jpg\", compare=True)"
1922
+ ]
1923
+ },
1924
+ {
1925
+ "cell_type": "code",
1926
+ "execution_count": null,
1927
+ "metadata": {},
1928
+ "outputs": [],
1929
+ "source": [
1930
+ "vis.plot_transformed_image(\"test_images/WWISikhs.jpg\", compare=True)"
1931
+ ]
1932
+ },
1933
+ {
1934
+ "cell_type": "code",
1935
+ "execution_count": null,
1936
+ "metadata": {},
1937
+ "outputs": [],
1938
+ "source": [
1939
+ "vis.plot_transformed_image(\"test_images/MementoMori1865.jpg\", compare=True)"
1940
+ ]
1941
+ },
1942
+ {
1943
+ "cell_type": "code",
1944
+ "execution_count": null,
1945
+ "metadata": {},
1946
+ "outputs": [],
1947
+ "source": [
1948
+ "vis.plot_transformed_image(\"test_images/RepBrennanRadio1922.jpg\", render_factor=43, compare=True)"
1949
+ ]
1950
+ },
1951
+ {
1952
+ "cell_type": "code",
1953
+ "execution_count": null,
1954
+ "metadata": {},
1955
+ "outputs": [],
1956
+ "source": [
1957
+ "vis.plot_transformed_image(\"test_images/Late1800sNative.jpg\", render_factor=20, compare=True)"
1958
+ ]
1959
+ },
1960
+ {
1961
+ "cell_type": "code",
1962
+ "execution_count": null,
1963
+ "metadata": {},
1964
+ "outputs": [],
1965
+ "source": [
1966
+ "vis.plot_transformed_image(\"test_images/GasPrices1939.jpg\", render_factor=30, compare=True)"
1967
+ ]
1968
+ },
1969
+ {
1970
+ "cell_type": "code",
1971
+ "execution_count": null,
1972
+ "metadata": {},
1973
+ "outputs": [],
1974
+ "source": [
1975
+ "vis.plot_transformed_image(\"test_images/1933RockefellerCenter.jpg\", compare=True)"
1976
+ ]
1977
+ },
1978
+ {
1979
+ "cell_type": "code",
1980
+ "execution_count": null,
1981
+ "metadata": {},
1982
+ "outputs": [],
1983
+ "source": [
1984
+ "vis.plot_transformed_image(\"test_images/Scotland1919.jpg\", compare=True)"
1985
+ ]
1986
+ },
1987
+ {
1988
+ "cell_type": "code",
1989
+ "execution_count": null,
1990
+ "metadata": {},
1991
+ "outputs": [],
1992
+ "source": [
1993
+ "vis.plot_transformed_image(\"test_images/1920CobblersShopLondon.jpg\", compare=True)"
1994
+ ]
1995
+ },
1996
+ {
1997
+ "cell_type": "code",
1998
+ "execution_count": null,
1999
+ "metadata": {},
2000
+ "outputs": [],
2001
+ "source": [
2002
+ "vis.plot_transformed_image(\"test_images/1909ParisFirstFemaleTaxisDriver.jpg\", compare=True)"
2003
+ ]
2004
+ },
2005
+ {
2006
+ "cell_type": "code",
2007
+ "execution_count": null,
2008
+ "metadata": {},
2009
+ "outputs": [],
2010
+ "source": [
2011
+ "vis.plot_transformed_image(\"test_images/HoovervilleSeattle1932.jpg\", compare=True)"
2012
+ ]
2013
+ },
2014
+ {
2015
+ "cell_type": "code",
2016
+ "execution_count": null,
2017
+ "metadata": {},
2018
+ "outputs": [],
2019
+ "source": [
2020
+ "vis.plot_transformed_image(\"test_images/ElephantLondon1934.png\", compare=True)"
2021
+ ]
2022
+ },
2023
+ {
2024
+ "cell_type": "code",
2025
+ "execution_count": null,
2026
+ "metadata": {},
2027
+ "outputs": [],
2028
+ "source": [
2029
+ "vis.plot_transformed_image(\"test_images/Jane_Addams.jpg\", compare=True)"
2030
+ ]
2031
+ },
2032
+ {
2033
+ "cell_type": "code",
2034
+ "execution_count": null,
2035
+ "metadata": {},
2036
+ "outputs": [],
2037
+ "source": [
2038
+ "vis.plot_transformed_image(\"test_images/AnselAdamsAdobe.jpg\", compare=True)"
2039
+ ]
2040
+ },
2041
+ {
2042
+ "cell_type": "code",
2043
+ "execution_count": null,
2044
+ "metadata": {},
2045
+ "outputs": [],
2046
+ "source": [
2047
+ "vis.plot_transformed_image(\"test_images/CricketLondon1930.jpg\", render_factor=45, compare=True)"
2048
+ ]
2049
+ },
2050
+ {
2051
+ "cell_type": "code",
2052
+ "execution_count": null,
2053
+ "metadata": {},
2054
+ "outputs": [],
2055
+ "source": [
2056
+ "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\", render_factor=32, compare=True)"
2057
+ ]
2058
+ },
2059
+ {
2060
+ "cell_type": "code",
2061
+ "execution_count": null,
2062
+ "metadata": {},
2063
+ "outputs": [],
2064
+ "source": [
2065
+ "vis.plot_transformed_image(\"test_images/AnselAdamsChurch.jpg\", compare=True)"
2066
+ ]
2067
+ },
2068
+ {
2069
+ "cell_type": "code",
2070
+ "execution_count": null,
2071
+ "metadata": {},
2072
+ "outputs": [],
2073
+ "source": [
2074
+ "vis.plot_transformed_image(\"test_images/BreadDelivery1920sIreland.jpg\", render_factor=20, compare=True)"
2075
+ ]
2076
+ },
2077
+ {
2078
+ "cell_type": "code",
2079
+ "execution_count": null,
2080
+ "metadata": {},
2081
+ "outputs": [],
2082
+ "source": [
2083
+ "vis.plot_transformed_image(\"test_images/BritishTeaBombay1890s.png\", compare=True)"
2084
+ ]
2085
+ },
2086
+ {
2087
+ "cell_type": "code",
2088
+ "execution_count": null,
2089
+ "metadata": {},
2090
+ "outputs": [],
2091
+ "source": [
2092
+ "vis.plot_transformed_image(\"test_images/CafeParis1928.jpg\", render_factor=35, compare=True)"
2093
+ ]
2094
+ },
2095
+ {
2096
+ "cell_type": "code",
2097
+ "execution_count": null,
2098
+ "metadata": {},
2099
+ "outputs": [],
2100
+ "source": [
2101
+ "vis.plot_transformed_image(\"test_images/BigManTavern1908NYC.jpg\", compare=True)"
2102
+ ]
2103
+ },
2104
+ {
2105
+ "cell_type": "code",
2106
+ "execution_count": null,
2107
+ "metadata": {},
2108
+ "outputs": [],
2109
+ "source": [
2110
+ "vis.plot_transformed_image(\"test_images/Cars1890sIreland.jpg\", compare=True)"
2111
+ ]
2112
+ },
2113
+ {
2114
+ "cell_type": "code",
2115
+ "execution_count": null,
2116
+ "metadata": {},
2117
+ "outputs": [],
2118
+ "source": [
2119
+ "vis.plot_transformed_image(\"test_images/GalwayIreland1902.jpg\", render_factor=35, compare=True)"
2120
+ ]
2121
+ },
2122
+ {
2123
+ "cell_type": "code",
2124
+ "execution_count": null,
2125
+ "metadata": {},
2126
+ "outputs": [],
2127
+ "source": [
2128
+ "vis.plot_transformed_image(\"test_images/HomeIreland1924.jpg\", render_factor=40, compare=True)"
2129
+ ]
2130
+ },
2131
+ {
2132
+ "cell_type": "code",
2133
+ "execution_count": null,
2134
+ "metadata": {},
2135
+ "outputs": [],
2136
+ "source": [
2137
+ "vis.plot_transformed_image(\"test_images/HydeParkLondon1920s.jpg\", render_factor=30, compare=True)"
2138
+ ]
2139
+ },
2140
+ {
2141
+ "cell_type": "code",
2142
+ "execution_count": null,
2143
+ "metadata": {},
2144
+ "outputs": [],
2145
+ "source": [
2146
+ "vis.plot_transformed_image(\"test_images/1929LondonOverFleetSt.jpg\", render_factor=25, compare=True)"
2147
+ ]
2148
+ },
2149
+ {
2150
+ "cell_type": "code",
2151
+ "execution_count": null,
2152
+ "metadata": {},
2153
+ "outputs": [],
2154
+ "source": [
2155
+ "vis.plot_transformed_image(\"test_images/AccordianKid1900Paris.jpg\", compare=True)"
2156
+ ]
2157
+ },
2158
+ {
2159
+ "cell_type": "code",
2160
+ "execution_count": null,
2161
+ "metadata": {},
2162
+ "outputs": [],
2163
+ "source": [
2164
+ "vis.plot_transformed_image(\"test_images/AnselAdamsBuildings.jpg\", render_factor=45, compare=True)"
2165
+ ]
2166
+ },
2167
+ {
2168
+ "cell_type": "code",
2169
+ "execution_count": null,
2170
+ "metadata": {},
2171
+ "outputs": [],
2172
+ "source": [
2173
+ "vis.plot_transformed_image(\"test_images/AthleticClubParis1913.jpg\", render_factor=42, compare=True)"
2174
+ ]
2175
+ },
2176
+ {
2177
+ "cell_type": "code",
2178
+ "execution_count": null,
2179
+ "metadata": {},
2180
+ "outputs": [],
2181
+ "source": [
2182
+ "vis.plot_transformed_image(\"test_images/BombedLibraryLondon1940.jpg\", compare=True)"
2183
+ ]
2184
+ },
2185
+ {
2186
+ "cell_type": "code",
2187
+ "execution_count": null,
2188
+ "metadata": {},
2189
+ "outputs": [],
2190
+ "source": [
2191
+ "vis.plot_transformed_image(\"test_images/Boston1937.jpg\", render_factor=30, compare=True)"
2192
+ ]
2193
+ },
2194
+ {
2195
+ "cell_type": "code",
2196
+ "execution_count": null,
2197
+ "metadata": {},
2198
+ "outputs": [],
2199
+ "source": [
2200
+ "vis.plot_transformed_image(\"test_images/BoulevardDuTemple1838.jpg\", render_factor=25, compare=True)"
2201
+ ]
2202
+ },
2203
+ {
2204
+ "cell_type": "code",
2205
+ "execution_count": null,
2206
+ "metadata": {},
2207
+ "outputs": [],
2208
+ "source": [
2209
+ "vis.plot_transformed_image(\"test_images/BumperCarsParis1930.jpg\", render_factor=25, compare=True)"
2210
+ ]
2211
+ },
2212
+ {
2213
+ "cell_type": "code",
2214
+ "execution_count": null,
2215
+ "metadata": {},
2216
+ "outputs": [],
2217
+ "source": [
2218
+ "vis.plot_transformed_image(\"test_images/CafeTerrace1925Paris.jpg\", render_factor=24, compare=True)"
2219
+ ]
2220
+ },
2221
+ {
2222
+ "cell_type": "code",
2223
+ "execution_count": null,
2224
+ "metadata": {},
2225
+ "outputs": [],
2226
+ "source": [
2227
+ "vis.plot_transformed_image(\"test_images/CoalDeliveryParis1915.jpg\", render_factor=37, compare=True)"
2228
+ ]
2229
+ },
2230
+ {
2231
+ "cell_type": "code",
2232
+ "execution_count": null,
2233
+ "metadata": {},
2234
+ "outputs": [],
2235
+ "source": [
2236
+ "vis.plot_transformed_image(\"test_images/CorkKids1910.jpg\", render_factor=32, compare=True)"
2237
+ ]
2238
+ },
2239
+ {
2240
+ "cell_type": "code",
2241
+ "execution_count": null,
2242
+ "metadata": {},
2243
+ "outputs": [],
2244
+ "source": [
2245
+ "vis.plot_transformed_image(\"test_images/DeepSeaDiver1915.png\", render_factor=16, compare=True)"
2246
+ ]
2247
+ },
2248
+ {
2249
+ "cell_type": "code",
2250
+ "execution_count": null,
2251
+ "metadata": {},
2252
+ "outputs": [],
2253
+ "source": [
2254
+ "vis.plot_transformed_image(\"test_images/EastEndLondonStreetKids1901.jpg\", compare=True)"
2255
+ ]
2256
+ },
2257
+ {
2258
+ "cell_type": "code",
2259
+ "execution_count": null,
2260
+ "metadata": {},
2261
+ "outputs": [],
2262
+ "source": [
2263
+ "vis.plot_transformed_image(\"test_images/FreightTrainTeens1934.jpg\", compare=True)"
2264
+ ]
2265
+ },
2266
+ {
2267
+ "cell_type": "code",
2268
+ "execution_count": null,
2269
+ "metadata": {},
2270
+ "outputs": [],
2271
+ "source": [
2272
+ "vis.plot_transformed_image(\"test_images/HarrodsLondon1920.jpg\", render_factor=45, compare=True)"
2273
+ ]
2274
+ },
2275
+ {
2276
+ "cell_type": "code",
2277
+ "execution_count": null,
2278
+ "metadata": {},
2279
+ "outputs": [],
2280
+ "source": [
2281
+ "vis.plot_transformed_image(\"test_images/HerbSeller1899Paris.jpg\", render_factor=17, compare=True)"
2282
+ ]
2283
+ },
2284
+ {
2285
+ "cell_type": "code",
2286
+ "execution_count": null,
2287
+ "metadata": {},
2288
+ "outputs": [],
2289
+ "source": [
2290
+ "vis.plot_transformed_image(\"test_images/CalcuttaPoliceman1920.jpg\", render_factor=20, compare=True)"
2291
+ ]
2292
+ },
2293
+ {
2294
+ "cell_type": "code",
2295
+ "execution_count": null,
2296
+ "metadata": {},
2297
+ "outputs": [],
2298
+ "source": [
2299
+ "vis.plot_transformed_image(\"test_images/ElectricScooter1915.jpeg\", render_factor=20, compare=True)"
2300
+ ]
2301
+ },
2302
+ {
2303
+ "cell_type": "code",
2304
+ "execution_count": null,
2305
+ "metadata": {},
2306
+ "outputs": [],
2307
+ "source": [
2308
+ "vis.plot_transformed_image(\"test_images/GreatGrandparentsIrelandEarly1900s.jpg\", compare=True)"
2309
+ ]
2310
+ },
2311
+ {
2312
+ "cell_type": "code",
2313
+ "execution_count": null,
2314
+ "metadata": {},
2315
+ "outputs": [],
2316
+ "source": [
2317
+ "vis.plot_transformed_image(\"test_images/HalloweenEarly1900s.jpg\", render_factor=11, compare=True)"
2318
+ ]
2319
+ },
2320
+ {
2321
+ "cell_type": "code",
2322
+ "execution_count": null,
2323
+ "metadata": {},
2324
+ "outputs": [],
2325
+ "source": [
2326
+ "vis.plot_transformed_image(\"test_images/IceManLondon1919.jpg\", compare=True)"
2327
+ ]
2328
+ },
2329
+ {
2330
+ "cell_type": "code",
2331
+ "execution_count": null,
2332
+ "metadata": {},
2333
+ "outputs": [],
2334
+ "source": [
2335
+ "vis.plot_transformed_image(\"test_images/LeBonMarcheParis1875.jpg\", compare=True)"
2336
+ ]
2337
+ },
2338
+ {
2339
+ "cell_type": "code",
2340
+ "execution_count": null,
2341
+ "metadata": {},
2342
+ "outputs": [],
2343
+ "source": [
2344
+ "vis.plot_transformed_image(\"test_images/LittleAirplane1934.jpg\", render_factor=35, compare=True)"
2345
+ ]
2346
+ },
2347
+ {
2348
+ "cell_type": "code",
2349
+ "execution_count": null,
2350
+ "metadata": {},
2351
+ "outputs": [],
2352
+ "source": [
2353
+ "vis.plot_transformed_image(\"test_images/RoyalUniversityMedStudent1900Ireland.jpg\", render_factor=45, compare=True)"
2354
+ ]
2355
+ },
2356
+ {
2357
+ "cell_type": "code",
2358
+ "execution_count": null,
2359
+ "metadata": {},
2360
+ "outputs": [],
2361
+ "source": [
2362
+ "vis.plot_transformed_image(\"test_images/LewisTomalinLondon1895.png\", render_factor=25, compare=True)"
2363
+ ]
2364
+ },
2365
+ {
2366
+ "cell_type": "code",
2367
+ "execution_count": null,
2368
+ "metadata": {},
2369
+ "outputs": [],
2370
+ "source": [
2371
+ "vis.plot_transformed_image(\"test_images/SunHelmetsLondon1933.jpg\", render_factor=40, compare=True)"
2372
+ ]
2373
+ },
2374
+ {
2375
+ "cell_type": "code",
2376
+ "execution_count": null,
2377
+ "metadata": {},
2378
+ "outputs": [],
2379
+ "source": [
2380
+ "vis.plot_transformed_image(\"test_images/Killarney1910.jpg\", render_factor=45, compare=True)"
2381
+ ]
2382
+ },
2383
+ {
2384
+ "cell_type": "code",
2385
+ "execution_count": null,
2386
+ "metadata": {},
2387
+ "outputs": [],
2388
+ "source": [
2389
+ "vis.plot_transformed_image(\"test_images/LondonSheep1920s.png\", compare=True)"
2390
+ ]
2391
+ },
2392
+ {
2393
+ "cell_type": "code",
2394
+ "execution_count": null,
2395
+ "metadata": {},
2396
+ "outputs": [],
2397
+ "source": [
2398
+ "vis.plot_transformed_image(\"test_images/PostOfficeVermont1914.png\", compare=True)"
2399
+ ]
2400
+ },
2401
+ {
2402
+ "cell_type": "code",
2403
+ "execution_count": null,
2404
+ "metadata": {},
2405
+ "outputs": [],
2406
+ "source": [
2407
+ "vis.plot_transformed_image(\"test_images/ServantsBessboroughHouse1908Ireland.jpg\", compare=True)"
2408
+ ]
2409
+ },
2410
+ {
2411
+ "cell_type": "code",
2412
+ "execution_count": null,
2413
+ "metadata": {},
2414
+ "outputs": [],
2415
+ "source": [
2416
+ "vis.plot_transformed_image(\"test_images/WaterfordIreland1909.jpg\", render_factor=35, compare=True)"
2417
+ ]
2418
+ },
2419
+ {
2420
+ "cell_type": "code",
2421
+ "execution_count": null,
2422
+ "metadata": {},
2423
+ "outputs": [],
2424
+ "source": [
2425
+ "vis.plot_transformed_image(\"test_images/Lisbon1919.jpg\", compare=True)"
2426
+ ]
2427
+ },
2428
+ {
2429
+ "cell_type": "code",
2430
+ "execution_count": null,
2431
+ "metadata": {},
2432
+ "outputs": [],
2433
+ "source": [
2434
+ "vis.plot_transformed_image(\"test_images/London1918WartimeClothesManufacture.jpg\", render_factor=45, compare=True)"
2435
+ ]
2436
+ },
2437
+ {
2438
+ "cell_type": "code",
2439
+ "execution_count": null,
2440
+ "metadata": {},
2441
+ "outputs": [],
2442
+ "source": [
2443
+ "vis.plot_transformed_image(\"test_images/LondonHeatWave1935.png\", compare=True)"
2444
+ ]
2445
+ },
2446
+ {
2447
+ "cell_type": "code",
2448
+ "execution_count": null,
2449
+ "metadata": {},
2450
+ "outputs": [],
2451
+ "source": [
2452
+ "vis.plot_transformed_image(\"test_images/LondonsSmallestShop1900.jpg\", compare=True)"
2453
+ ]
2454
+ },
2455
+ {
2456
+ "cell_type": "code",
2457
+ "execution_count": null,
2458
+ "metadata": {},
2459
+ "outputs": [],
2460
+ "source": [
2461
+ "vis.plot_transformed_image(\"test_images/MetropolitanDistrictRailway1869London.jpg\", compare=True)"
2462
+ ]
2463
+ },
2464
+ {
2465
+ "cell_type": "code",
2466
+ "execution_count": null,
2467
+ "metadata": {},
2468
+ "outputs": [],
2469
+ "source": [
2470
+ "vis.plot_transformed_image(\"test_images/NativeWoman1926.jpg\", render_factor=21, compare=True)"
2471
+ ]
2472
+ },
2473
+ {
2474
+ "cell_type": "code",
2475
+ "execution_count": null,
2476
+ "metadata": {},
2477
+ "outputs": [],
2478
+ "source": [
2479
+ "vis.plot_transformed_image(\"test_images/PaddysMarketCork1900s.jpg\", compare=True)"
2480
+ ]
2481
+ },
2482
+ {
2483
+ "cell_type": "code",
2484
+ "execution_count": null,
2485
+ "metadata": {},
2486
+ "outputs": [],
2487
+ "source": [
2488
+ "vis.plot_transformed_image(\"test_images/Paris1920Cart.jpg\", compare=True)"
2489
+ ]
2490
+ },
2491
+ {
2492
+ "cell_type": "code",
2493
+ "execution_count": null,
2494
+ "metadata": {},
2495
+ "outputs": [],
2496
+ "source": [
2497
+ "vis.plot_transformed_image(\"test_images/ParisLadies1910.jpg\", render_factor=20, compare=True)"
2498
+ ]
2499
+ },
2500
+ {
2501
+ "cell_type": "code",
2502
+ "execution_count": null,
2503
+ "metadata": {},
2504
+ "outputs": [],
2505
+ "source": [
2506
+ "vis.plot_transformed_image(\"test_images/ParisLadies1930s.jpg\", compare=True)"
2507
+ ]
2508
+ },
2509
+ {
2510
+ "cell_type": "code",
2511
+ "execution_count": null,
2512
+ "metadata": {},
2513
+ "outputs": [],
2514
+ "source": [
2515
+ "vis.plot_transformed_image(\"test_images/Sphinx.jpeg\", compare=True)"
2516
+ ]
2517
+ },
2518
+ {
2519
+ "cell_type": "code",
2520
+ "execution_count": null,
2521
+ "metadata": {},
2522
+ "outputs": [],
2523
+ "source": [
2524
+ "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\", render_factor=45, compare=True)"
2525
+ ]
2526
+ },
2527
+ {
2528
+ "cell_type": "code",
2529
+ "execution_count": null,
2530
+ "metadata": {},
2531
+ "outputs": [],
2532
+ "source": [
2533
+ "vis.plot_transformed_image(\"test_images/WorldsFair1900Paris.jpg\", compare=True)"
2534
+ ]
2535
+ },
2536
+ {
2537
+ "cell_type": "code",
2538
+ "execution_count": null,
2539
+ "metadata": {},
2540
+ "outputs": [],
2541
+ "source": [
2542
+ "vis.plot_transformed_image(\"test_images/London1850Coach.jpg\", render_factor=25, compare=True)"
2543
+ ]
2544
+ },
2545
+ {
2546
+ "cell_type": "code",
2547
+ "execution_count": null,
2548
+ "metadata": {},
2549
+ "outputs": [],
2550
+ "source": [
2551
+ "vis.plot_transformed_image(\"test_images/London1900EastEndBlacksmith.jpg\", compare=True)"
2552
+ ]
2553
+ },
2554
+ {
2555
+ "cell_type": "code",
2556
+ "execution_count": null,
2557
+ "metadata": {},
2558
+ "outputs": [],
2559
+ "source": [
2560
+ "vis.plot_transformed_image(\"test_images/London1930sCheetah.jpg\", render_factor=42, compare=True)"
2561
+ ]
2562
+ },
2563
+ {
2564
+ "cell_type": "code",
2565
+ "execution_count": null,
2566
+ "metadata": {},
2567
+ "outputs": [],
2568
+ "source": [
2569
+ "vis.plot_transformed_image(\"test_images/LondonFireBrigadeMember1926.jpg\", compare=True)"
2570
+ ]
2571
+ },
2572
+ {
2573
+ "cell_type": "code",
2574
+ "execution_count": null,
2575
+ "metadata": {},
2576
+ "outputs": [],
2577
+ "source": [
2578
+ "vis.plot_transformed_image(\"test_images/LondonGarbageTruck1910.jpg\", compare=True)"
2579
+ ]
2580
+ },
2581
+ {
2582
+ "cell_type": "code",
2583
+ "execution_count": null,
2584
+ "metadata": {},
2585
+ "outputs": [],
2586
+ "source": [
2587
+ "vis.plot_transformed_image(\"test_images/LondonRailwayWork1931.jpg\", render_factor=45, compare=True)"
2588
+ ]
2589
+ },
2590
+ {
2591
+ "cell_type": "code",
2592
+ "execution_count": null,
2593
+ "metadata": {},
2594
+ "outputs": [],
2595
+ "source": [
2596
+ "vis.plot_transformed_image(\"test_images/LondonStreets1900.jpg\", compare=True)"
2597
+ ]
2598
+ },
2599
+ {
2600
+ "cell_type": "code",
2601
+ "execution_count": null,
2602
+ "metadata": {},
2603
+ "outputs": [],
2604
+ "source": [
2605
+ "vis.plot_transformed_image(\"test_images/MuffinManlLondon1910.jpg\", render_factor=45, compare=True)"
2606
+ ]
2607
+ },
2608
+ {
2609
+ "cell_type": "code",
2610
+ "execution_count": null,
2611
+ "metadata": {},
2612
+ "outputs": [],
2613
+ "source": [
2614
+ "vis.plot_transformed_image(\"test_images/NativeCouple1912.jpg\", render_factor=21, compare=True)"
2615
+ ]
2616
+ },
2617
+ {
2618
+ "cell_type": "code",
2619
+ "execution_count": null,
2620
+ "metadata": {},
2621
+ "outputs": [],
2622
+ "source": [
2623
+ "vis.plot_transformed_image(\"test_images/NewspaperCivilWar1863.jpg\", compare=True)"
2624
+ ]
2625
+ },
2626
+ {
2627
+ "cell_type": "code",
2628
+ "execution_count": null,
2629
+ "metadata": {},
2630
+ "outputs": [],
2631
+ "source": [
2632
+ "vis.plot_transformed_image(\"test_images/PaddingtonStationLondon1907.jpg\", render_factor=45, compare=True)"
2633
+ ]
2634
+ },
2635
+ {
2636
+ "cell_type": "code",
2637
+ "execution_count": null,
2638
+ "metadata": {},
2639
+ "outputs": [],
2640
+ "source": [
2641
+ "vis.plot_transformed_image(\"test_images/Paris1899StreetDig.jpg\", compare=True)"
2642
+ ]
2643
+ },
2644
+ {
2645
+ "cell_type": "code",
2646
+ "execution_count": null,
2647
+ "metadata": {},
2648
+ "outputs": [],
2649
+ "source": [
2650
+ "vis.plot_transformed_image(\"test_images/Paris1926.jpg\", compare=True)"
2651
+ ]
2652
+ },
2653
+ {
2654
+ "cell_type": "code",
2655
+ "execution_count": null,
2656
+ "metadata": {},
2657
+ "outputs": [],
2658
+ "source": [
2659
+ "vis.plot_transformed_image(\"test_images/ParisWomenFurs1920s.jpg\", render_factor=21, compare=True)"
2660
+ ]
2661
+ },
2662
+ {
2663
+ "cell_type": "code",
2664
+ "execution_count": null,
2665
+ "metadata": {},
2666
+ "outputs": [],
2667
+ "source": [
2668
+ "vis.plot_transformed_image(\"test_images/PeddlerParis1899.jpg\", compare=True)"
2669
+ ]
2670
+ },
2671
+ {
2672
+ "cell_type": "code",
2673
+ "execution_count": null,
2674
+ "metadata": {},
2675
+ "outputs": [],
2676
+ "source": [
2677
+ "vis.plot_transformed_image(\"test_images/SchoolKidsConnemaraIreland1901.jpg\", compare=True)"
2678
+ ]
2679
+ },
2680
+ {
2681
+ "cell_type": "code",
2682
+ "execution_count": null,
2683
+ "metadata": {},
2684
+ "outputs": [],
2685
+ "source": [
2686
+ "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\", render_factor=33, compare=True)"
2687
+ ]
2688
+ },
2689
+ {
2690
+ "cell_type": "code",
2691
+ "execution_count": null,
2692
+ "metadata": {},
2693
+ "outputs": [],
2694
+ "source": [
2695
+ "vis.plot_transformed_image(\"test_images/SoapBoxRacerParis1920s.jpg\", render_factor=40, compare=True)"
2696
+ ]
2697
+ },
2698
+ {
2699
+ "cell_type": "code",
2700
+ "execution_count": null,
2701
+ "metadata": {},
2702
+ "outputs": [],
2703
+ "source": [
2704
+ "vis.plot_transformed_image(\"test_images/SoccerMotorcycles1923London.jpg\", compare=True)"
2705
+ ]
2706
+ },
2707
+ {
2708
+ "cell_type": "code",
2709
+ "execution_count": null,
2710
+ "metadata": {},
2711
+ "outputs": [],
2712
+ "source": [
2713
+ "vis.plot_transformed_image(\"test_images/WalkingLibraryLondon1930.jpg\", compare=True)"
2714
+ ]
2715
+ },
2716
+ {
2717
+ "cell_type": "code",
2718
+ "execution_count": null,
2719
+ "metadata": {},
2720
+ "outputs": [],
2721
+ "source": [
2722
+ "vis.plot_transformed_image(\"test_images/LondonStreetDoctor1877.png\", render_factor=38, compare=True)"
2723
+ ]
2724
+ },
2725
+ {
2726
+ "cell_type": "code",
2727
+ "execution_count": null,
2728
+ "metadata": {},
2729
+ "outputs": [],
2730
+ "source": [
2731
+ "vis.plot_transformed_image(\"test_images/jacksonville.jpg\", compare=True)"
2732
+ ]
2733
+ },
2734
+ {
2735
+ "cell_type": "code",
2736
+ "execution_count": null,
2737
+ "metadata": {},
2738
+ "outputs": [],
2739
+ "source": [
2740
+ "vis.plot_transformed_image(\"test_images/ZebraCarriageLondon1900.jpg\", compare=True)"
2741
+ ]
2742
+ },
2743
+ {
2744
+ "cell_type": "code",
2745
+ "execution_count": null,
2746
+ "metadata": {},
2747
+ "outputs": [],
2748
+ "source": [
2749
+ "vis.plot_transformed_image(\"test_images/StreetGramaphonePlayerLondon1920s.png\", compare=True)"
2750
+ ]
2751
+ },
2752
+ {
2753
+ "cell_type": "code",
2754
+ "execution_count": null,
2755
+ "metadata": {},
2756
+ "outputs": [],
2757
+ "source": [
2758
+ "vis.plot_transformed_image(\"test_images/YaleBranchBarnardsExpress.jpg\", compare=True)"
2759
+ ]
2760
+ },
2761
+ {
2762
+ "cell_type": "code",
2763
+ "execution_count": null,
2764
+ "metadata": {},
2765
+ "outputs": [],
2766
+ "source": [
2767
+ "vis.plot_transformed_image(\"test_images/SynagogueInterior.PNG\", compare=True)"
2768
+ ]
2769
+ },
2770
+ {
2771
+ "cell_type": "code",
2772
+ "execution_count": null,
2773
+ "metadata": {},
2774
+ "outputs": [],
2775
+ "source": [
2776
+ "vis.plot_transformed_image(\"test_images/ArmisticeDay1918.jpg\", compare=True)"
2777
+ ]
2778
+ },
2779
+ {
2780
+ "cell_type": "code",
2781
+ "execution_count": null,
2782
+ "metadata": {},
2783
+ "outputs": [],
2784
+ "source": [
2785
+ "vis.plot_transformed_image(\"test_images/FlyingMachinesParis1909.jpg\", render_factor=25, compare=True)"
2786
+ ]
2787
+ },
2788
+ {
2789
+ "cell_type": "code",
2790
+ "execution_count": null,
2791
+ "metadata": {},
2792
+ "outputs": [],
2793
+ "source": [
2794
+ "vis.plot_transformed_image(\"test_images/GreatAunt1920.jpg\", compare=True)"
2795
+ ]
2796
+ },
2797
+ {
2798
+ "cell_type": "code",
2799
+ "execution_count": null,
2800
+ "metadata": {},
2801
+ "outputs": [],
2802
+ "source": [
2803
+ "vis.plot_transformed_image(\"test_images/NewBrunswick1915.jpg\", compare=True)"
2804
+ ]
2805
+ },
2806
+ {
2807
+ "cell_type": "code",
2808
+ "execution_count": null,
2809
+ "metadata": {},
2810
+ "outputs": [],
2811
+ "source": [
2812
+ "vis.plot_transformed_image(\"test_images/ShoeMakerLate1800s.jpg\", compare=True)"
2813
+ ]
2814
+ },
2815
+ {
2816
+ "cell_type": "code",
2817
+ "execution_count": null,
2818
+ "metadata": {},
2819
+ "outputs": [],
2820
+ "source": [
2821
+ "vis.plot_transformed_image(\"test_images/SpottedBull1908.jpg\", compare=True)"
2822
+ ]
2823
+ },
2824
+ {
2825
+ "cell_type": "code",
2826
+ "execution_count": null,
2827
+ "metadata": {},
2828
+ "outputs": [],
2829
+ "source": [
2830
+ "vis.plot_transformed_image(\"test_images/TouristsGermany1904.jpg\", render_factor=35, compare=True)"
2831
+ ]
2832
+ },
2833
+ {
2834
+ "cell_type": "code",
2835
+ "execution_count": null,
2836
+ "metadata": {},
2837
+ "outputs": [],
2838
+ "source": [
2839
+ "vis.plot_transformed_image(\"test_images/TunisianStudents1914.jpg\", compare=True)"
2840
+ ]
2841
+ },
2842
+ {
2843
+ "cell_type": "code",
2844
+ "execution_count": null,
2845
+ "metadata": {},
2846
+ "outputs": [],
2847
+ "source": [
2848
+ "vis.plot_transformed_image(\"test_images/Yorktown1862.jpg\", compare=True)"
2849
+ ]
2850
+ },
2851
+ {
2852
+ "cell_type": "code",
2853
+ "execution_count": null,
2854
+ "metadata": {},
2855
+ "outputs": [],
2856
+ "source": [
2857
+ "vis.plot_transformed_image(\"test_images/LondonFashion1911.png\", compare=True)"
2858
+ ]
2859
+ },
2860
+ {
2861
+ "cell_type": "code",
2862
+ "execution_count": null,
2863
+ "metadata": {},
2864
+ "outputs": [],
2865
+ "source": [
2866
+ "vis.plot_transformed_image(\"test_images/1939GypsyKids.jpg\", render_factor=37, compare=True)"
2867
+ ]
2868
+ },
2869
+ {
2870
+ "cell_type": "code",
2871
+ "execution_count": null,
2872
+ "metadata": {},
2873
+ "outputs": [],
2874
+ "source": [
2875
+ "vis.plot_transformed_image(\"test_images/1936OpiumShanghai.jpg\", compare=True)"
2876
+ ]
2877
+ },
2878
+ {
2879
+ "cell_type": "code",
2880
+ "execution_count": null,
2881
+ "metadata": {},
2882
+ "outputs": [],
2883
+ "source": [
2884
+ "vis.plot_transformed_image(\"test_images/1923HollandTunnel.jpg\", compare=True)"
2885
+ ]
2886
+ },
2887
+ {
2888
+ "cell_type": "code",
2889
+ "execution_count": null,
2890
+ "metadata": {},
2891
+ "outputs": [],
2892
+ "source": [
2893
+ "vis.plot_transformed_image(\"test_images/1939YakimaWAGirl.jpg\", compare=True)"
2894
+ ]
2895
+ },
2896
+ {
2897
+ "cell_type": "code",
2898
+ "execution_count": null,
2899
+ "metadata": {},
2900
+ "outputs": [],
2901
+ "source": [
2902
+ "vis.plot_transformed_image(\"test_images/GoldenGateConstruction.jpg\", render_factor=35, compare=True)"
2903
+ ]
2904
+ },
2905
+ {
2906
+ "cell_type": "code",
2907
+ "execution_count": null,
2908
+ "metadata": {},
2909
+ "outputs": [],
2910
+ "source": [
2911
+ "vis.plot_transformed_image(\"test_images/PostCivilWarAncestors.jpg\", compare=True)"
2912
+ ]
2913
+ },
2914
+ {
2915
+ "cell_type": "code",
2916
+ "execution_count": null,
2917
+ "metadata": {},
2918
+ "outputs": [],
2919
+ "source": [
2920
+ "vis.plot_transformed_image(\"test_images/1939SewingBike.png\", compare=True)"
2921
+ ]
2922
+ },
2923
+ {
2924
+ "cell_type": "code",
2925
+ "execution_count": null,
2926
+ "metadata": {},
2927
+ "outputs": [],
2928
+ "source": [
2929
+ "vis.plot_transformed_image(\"test_images/1930MaineSchoolBus.jpg\", compare=True)"
2930
+ ]
2931
+ },
2932
+ {
2933
+ "cell_type": "code",
2934
+ "execution_count": null,
2935
+ "metadata": {},
2936
+ "outputs": [],
2937
+ "source": [
2938
+ "vis.plot_transformed_image(\"test_images/1913NewYorkConstruction.jpg\", compare=True)"
2939
+ ]
2940
+ },
2941
+ {
2942
+ "cell_type": "code",
2943
+ "execution_count": null,
2944
+ "metadata": {},
2945
+ "outputs": [],
2946
+ "source": [
2947
+ "vis.plot_transformed_image(\"test_images/1945HiroshimaChild.jpg\", compare=True)"
2948
+ ]
2949
+ },
2950
+ {
2951
+ "cell_type": "code",
2952
+ "execution_count": null,
2953
+ "metadata": {},
2954
+ "outputs": [],
2955
+ "source": [
2956
+ "vis.plot_transformed_image(\"test_images/1941GeorgiaFarmhouse.jpg\", render_factor=43, compare=True)"
2957
+ ]
2958
+ },
2959
+ {
2960
+ "cell_type": "code",
2961
+ "execution_count": null,
2962
+ "metadata": {},
2963
+ "outputs": [],
2964
+ "source": [
2965
+ "vis.plot_transformed_image(\"test_images/1934UmbriaItaly.jpg\", render_factor=21) "
2966
+ ]
2967
+ },
2968
+ {
2969
+ "cell_type": "code",
2970
+ "execution_count": null,
2971
+ "metadata": {},
2972
+ "outputs": [],
2973
+ "source": [
2974
+ "vis.plot_transformed_image(\"test_images/1900sLadiesTeaParty.jpg\", compare=True)"
2975
+ ]
2976
+ },
2977
+ {
2978
+ "cell_type": "code",
2979
+ "execution_count": null,
2980
+ "metadata": {},
2981
+ "outputs": [],
2982
+ "source": [
2983
+ "vis.plot_transformed_image(\"test_images/1919WWIAviationOxygenMask.jpg\", compare=True)"
2984
+ ]
2985
+ },
2986
+ {
2987
+ "cell_type": "code",
2988
+ "execution_count": null,
2989
+ "metadata": {},
2990
+ "outputs": [],
2991
+ "source": [
2992
+ "vis.plot_transformed_image(\"test_images/1900NJThanksgiving.jpg\", compare=True)"
2993
+ ]
2994
+ },
2995
+ {
2996
+ "cell_type": "code",
2997
+ "execution_count": null,
2998
+ "metadata": {},
2999
+ "outputs": [],
3000
+ "source": [
3001
+ "vis.plot_transformed_image(\"test_images/1940Connecticut.jpg\", render_factor=43, compare=True)"
3002
+ ]
3003
+ },
3004
+ {
3005
+ "cell_type": "code",
3006
+ "execution_count": null,
3007
+ "metadata": {},
3008
+ "outputs": [],
3009
+ "source": [
3010
+ "vis.plot_transformed_image(\"test_images/1940Connecticut.jpg\", render_factor=i, compare=True)"
3011
+ ]
3012
+ },
3013
+ {
3014
+ "cell_type": "code",
3015
+ "execution_count": null,
3016
+ "metadata": {},
3017
+ "outputs": [],
3018
+ "source": [
3019
+ "vis.plot_transformed_image(\"test_images/1911ThanksgivingMaskers.jpg\", render_factor=35, compare=True)"
3020
+ ]
3021
+ },
3022
+ {
3023
+ "cell_type": "code",
3024
+ "execution_count": null,
3025
+ "metadata": {},
3026
+ "outputs": [],
3027
+ "source": [
3028
+ "vis.plot_transformed_image(\"test_images/1910ThanksgivingMaskersII.jpg\", compare=True)"
3029
+ ]
3030
+ },
3031
+ {
3032
+ "cell_type": "code",
3033
+ "execution_count": null,
3034
+ "metadata": {},
3035
+ "outputs": [],
3036
+ "source": [
3037
+ "vis.plot_transformed_image(\"test_images/1936PetToad.jpg\", compare=True)"
3038
+ ]
3039
+ },
3040
+ {
3041
+ "cell_type": "code",
3042
+ "execution_count": null,
3043
+ "metadata": {},
3044
+ "outputs": [],
3045
+ "source": [
3046
+ "vis.plot_transformed_image(\"test_images/1908RookeriesLondon.jpg\", compare=True)"
3047
+ ]
3048
+ },
3049
+ {
3050
+ "cell_type": "code",
3051
+ "execution_count": null,
3052
+ "metadata": {},
3053
+ "outputs": [],
3054
+ "source": [
3055
+ "vis.plot_transformed_image(\"test_images/1890sChineseImmigrants.jpg\", render_factor=25, compare=True)"
3056
+ ]
3057
+ },
3058
+ {
3059
+ "cell_type": "code",
3060
+ "execution_count": null,
3061
+ "metadata": {},
3062
+ "outputs": [],
3063
+ "source": [
3064
+ "vis.plot_transformed_image(\"test_images/1897VancouverAmberlamps.jpg\", compare=True)"
3065
+ ]
3066
+ },
3067
+ {
3068
+ "cell_type": "code",
3069
+ "execution_count": null,
3070
+ "metadata": {},
3071
+ "outputs": [],
3072
+ "source": [
3073
+ "vis.plot_transformed_image(\"test_images/1929VictorianCosplayLondon.jpg\", render_factor=35, compare=True)"
3074
+ ]
3075
+ },
3076
+ {
3077
+ "cell_type": "code",
3078
+ "execution_count": null,
3079
+ "metadata": {},
3080
+ "outputs": [],
3081
+ "source": [
3082
+ "vis.plot_transformed_image(\"test_images/1959ParisFriends.png\", render_factor=40, compare=True)"
3083
+ ]
3084
+ },
3085
+ {
3086
+ "cell_type": "code",
3087
+ "execution_count": null,
3088
+ "metadata": {},
3089
+ "outputs": [],
3090
+ "source": [
3091
+ "vis.plot_transformed_image(\"test_images/1925GypsyCampMaryland.jpg\", render_factor=40, compare=True)"
3092
+ ]
3093
+ },
3094
+ {
3095
+ "cell_type": "code",
3096
+ "execution_count": null,
3097
+ "metadata": {},
3098
+ "outputs": [],
3099
+ "source": [
3100
+ "vis.plot_transformed_image(\"test_images/1941PoolTableGeorgia.jpg\", render_factor=45, compare=True)"
3101
+ ]
3102
+ },
3103
+ {
3104
+ "cell_type": "code",
3105
+ "execution_count": null,
3106
+ "metadata": {},
3107
+ "outputs": [],
3108
+ "source": [
3109
+ "vis.plot_transformed_image(\"test_images/1900ParkDog.jpg\", compare=True)"
3110
+ ]
3111
+ },
3112
+ {
3113
+ "cell_type": "code",
3114
+ "execution_count": null,
3115
+ "metadata": {},
3116
+ "outputs": [],
3117
+ "source": [
3118
+ "vis.plot_transformed_image(\"test_images/1886Hoop.jpg\", compare=True)"
3119
+ ]
3120
+ },
3121
+ {
3122
+ "cell_type": "code",
3123
+ "execution_count": null,
3124
+ "metadata": {},
3125
+ "outputs": [],
3126
+ "source": [
3127
+ "vis.plot_transformed_image(\"test_images/1950sLondonPoliceChild.jpg\", compare=True)"
3128
+ ]
3129
+ },
3130
+ {
3131
+ "cell_type": "code",
3132
+ "execution_count": null,
3133
+ "metadata": {},
3134
+ "outputs": [],
3135
+ "source": [
3136
+ "vis.plot_transformed_image(\"test_images/1886ProspectPark.jpg\", compare=True)"
3137
+ ]
3138
+ },
3139
+ {
3140
+ "cell_type": "code",
3141
+ "execution_count": null,
3142
+ "metadata": {},
3143
+ "outputs": [],
3144
+ "source": [
3145
+ "vis.plot_transformed_image(\"test_images/1930sRooftopPoland.jpg\", compare=True)"
3146
+ ]
3147
+ },
3148
+ {
3149
+ "cell_type": "code",
3150
+ "execution_count": null,
3151
+ "metadata": {},
3152
+ "outputs": [],
3153
+ "source": [
3154
+ "vis.plot_transformed_image(\"test_images/1919RevereBeach.jpg\", compare=True)"
3155
+ ]
3156
+ },
3157
+ {
3158
+ "cell_type": "code",
3159
+ "execution_count": null,
3160
+ "metadata": {},
3161
+ "outputs": [],
3162
+ "source": [
3163
+ "vis.plot_transformed_image(\"test_images/1936ParisCafe.jpg\", render_factor=46, compare=True)"
3164
+ ]
3165
+ },
3166
+ {
3167
+ "cell_type": "code",
3168
+ "execution_count": null,
3169
+ "metadata": {},
3170
+ "outputs": [],
3171
+ "source": [
3172
+ "vis.plot_transformed_image(\"test_images/1902FrenchYellowBellies.jpg\", compare=True)"
3173
+ ]
3174
+ },
3175
+ {
3176
+ "cell_type": "code",
3177
+ "execution_count": null,
3178
+ "metadata": {},
3179
+ "outputs": [],
3180
+ "source": [
3181
+ "vis.plot_transformed_image(\"test_images/1940PAFamily.jpg\", render_factor=42, compare=True)"
3182
+ ]
3183
+ },
3184
+ {
3185
+ "cell_type": "code",
3186
+ "execution_count": null,
3187
+ "metadata": {},
3188
+ "outputs": [],
3189
+ "source": [
3190
+ "vis.plot_transformed_image(\"test_images/1910Finland.jpg\", render_factor=40, compare=True)"
3191
+ ]
3192
+ },
3193
+ {
3194
+ "cell_type": "code",
3195
+ "execution_count": null,
3196
+ "metadata": {},
3197
+ "outputs": [],
3198
+ "source": [
3199
+ "vis.plot_transformed_image(\"test_images/ZebraCarriageLondon1900.jpg\", compare=True)"
3200
+ ]
3201
+ },
3202
+ {
3203
+ "cell_type": "code",
3204
+ "execution_count": null,
3205
+ "metadata": {},
3206
+ "outputs": [],
3207
+ "source": [
3208
+ "vis.plot_transformed_image(\"test_images/1904ChineseMan.jpg\", compare=True)"
3209
+ ]
3210
+ },
3211
+ {
3212
+ "cell_type": "code",
3213
+ "execution_count": null,
3214
+ "metadata": {},
3215
+ "outputs": [],
3216
+ "source": [
3217
+ "vis.plot_transformed_image(\"test_images/CrystalPalaceLondon1854.PNG\", compare=True)"
3218
+ ]
3219
+ },
3220
+ {
3221
+ "cell_type": "code",
3222
+ "execution_count": null,
3223
+ "metadata": {},
3224
+ "outputs": [],
3225
+ "source": [
3226
+ "vis.plot_transformed_image(\"test_images/James1.jpg\", render_factor=15, compare=True)"
3227
+ ]
3228
+ },
3229
+ {
3230
+ "cell_type": "code",
3231
+ "execution_count": null,
3232
+ "metadata": {},
3233
+ "outputs": [],
3234
+ "source": [
3235
+ "vis.plot_transformed_image(\"test_images/James2.jpg\", render_factor=20, compare=True)"
3236
+ ]
3237
+ },
3238
+ {
3239
+ "cell_type": "code",
3240
+ "execution_count": null,
3241
+ "metadata": {},
3242
+ "outputs": [],
3243
+ "source": [
3244
+ "vis.plot_transformed_image(\"test_images/James3.jpg\", render_factor=19, compare=True)"
3245
+ ]
3246
+ },
3247
+ {
3248
+ "cell_type": "code",
3249
+ "execution_count": null,
3250
+ "metadata": {},
3251
+ "outputs": [],
3252
+ "source": [
3253
+ "vis.plot_transformed_image(\"test_images/James4.jpg\", render_factor=30, compare=True)"
3254
+ ]
3255
+ },
3256
+ {
3257
+ "cell_type": "code",
3258
+ "execution_count": null,
3259
+ "metadata": {},
3260
+ "outputs": [],
3261
+ "source": [
3262
+ "vis.plot_transformed_image(\"test_images/James5.jpg\", render_factor=32, compare=True)"
3263
+ ]
3264
+ },
3265
+ {
3266
+ "cell_type": "code",
3267
+ "execution_count": null,
3268
+ "metadata": {},
3269
+ "outputs": [],
3270
+ "source": [
3271
+ "vis.plot_transformed_image(\"test_images/James6.jpg\", render_factor=28, compare=True)"
3272
+ ]
3273
+ },
3274
+ {
3275
+ "cell_type": "code",
3276
+ "execution_count": null,
3277
+ "metadata": {},
3278
+ "outputs": [],
3279
+ "source": []
3280
+ },
3281
+ {
3282
+ "cell_type": "code",
3283
+ "execution_count": null,
3284
+ "metadata": {},
3285
+ "outputs": [],
3286
+ "source": []
3287
+ }
3288
+ ],
3289
+ "metadata": {
3290
+ "kernelspec": {
3291
+ "display_name": "Python 3",
3292
+ "language": "python",
3293
+ "name": "python3"
3294
+ },
3295
+ "language_info": {
3296
+ "codemirror_mode": {
3297
+ "name": "ipython",
3298
+ "version": 3
3299
+ },
3300
+ "file_extension": ".py",
3301
+ "mimetype": "text/x-python",
3302
+ "name": "python",
3303
+ "nbconvert_exporter": "python",
3304
+ "pygments_lexer": "ipython3",
3305
+ "version": "3.7.6"
3306
+ },
3307
+ "toc": {
3308
+ "colors": {
3309
+ "hover_highlight": "#DAA520",
3310
+ "navigate_num": "#000000",
3311
+ "navigate_text": "#333333",
3312
+ "running_highlight": "#FF0000",
3313
+ "selected_highlight": "#FFD700",
3314
+ "sidebar_border": "#EEEEEE",
3315
+ "wrapper_background": "#FFFFFF"
3316
+ },
3317
+ "moveMenuLeft": true,
3318
+ "nav_menu": {
3319
+ "height": "67px",
3320
+ "width": "252px"
3321
+ },
3322
+ "navigate_menu": true,
3323
+ "number_sections": true,
3324
+ "sideBar": true,
3325
+ "threshold": 4,
3326
+ "toc_cell": false,
3327
+ "toc_section_display": "block",
3328
+ "toc_window_display": false,
3329
+ "widenNotebook": false
3330
+ }
3331
+ },
3332
+ "nbformat": 4,
3333
+ "nbformat_minor": 4
3334
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Jason Antic
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ include README.md
2
+ include LICENSE
3
+ include requirements.txt
README.md CHANGED
@@ -1,12 +1,538 @@
1
- ---
2
- title: Deoldify
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # DeOldify
3
+
4
+ **Quick Start**: The easiest way to colorize images using open source DeOldify
5
+ (for free!) is here: [DeOldify Image Colorization on DeepAI](https://deepai.org/machine-learning-model/colorizer)
6
+
7
+ **Desktop**: Want to run open source DeOldify for photos on Windows desktop?
8
+ ColorfulSoft made such a thing here and it really works- <https://github.com/ColorfulSoft/DeOldify.NET>.
9
+ No GPU required!
10
+
11
+ The **most advanced** version of DeOldify image colorization is available here,
12
+ exclusively. Try a few images for free! [MyHeritage In Color](https://www.myheritage.com/incolor)
13
+
14
+ **Huggingface Web Demo**: Integrated to [Huggingface Spaces](https://huggingface.co/spaces)
15
+ with [Gradio](https://github.com/gradio-app/gradio).
16
+ See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/PaddlePaddle/deoldify)
17
+
18
+ ----------------------------
19
+
20
+ Image (artistic) [![Colab for images](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb)
21
+ | Video [![Colab for video](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb)
22
+
23
+ Having trouble with the default image colorizer, aka "artistic"? Try the
24
+ "stable" one below. It generally won't produce colors that are as interesting as
25
+ "artistic", but the glitches are noticeably reduced.
26
+
27
+ Image (stable) [![Colab for stable model](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColabStable.ipynb)
28
+
29
+ Instructions on how to use the Colabs above have been kindly provided in video
30
+ tutorial form by Old Ireland in Colour's John Breslin. It's great! Click video
31
+ image below to watch.
32
+
33
+ [![DeOldify Tutorial](http://img.youtube.com/vi/VaEl0faDw38/0.jpg)](http://www.youtube.com/watch?v=VaEl0faDw38)
34
+
35
+ Get more updates on [Twitter
36
+ ![Twitter logo](resource_images/twitter.svg)](https://twitter.com/DeOldify).
37
+
38
+ ## Table of Contents
39
+
40
+ - [About DeOldify](#about-deoldify)
41
+ - [Example Videos](#example-videos)
42
+ - [Example Images](#example-images)
43
+ - [Stuff That Should Probably Be In A Paper](#stuff-that-should-probably-be-in-a-paper)
44
+ - [How to Achieve Stable Video](#how-to-achieve-stable-video)
45
+ - [What is NoGAN?](#what-is-nogan)
46
+ - [Why Three Models?](#why-three-models)
47
+ - [Technical Details](#the-technical-details)
48
+ - [Going Forward](#this-project-going-forward)
49
+ - [Getting Started Yourself](#getting-started-yourself)
50
+ - [Easiest Approach](#easiest-approach)
51
+ - [Your Own Machine](#your-own-machine-not-as-easy)
52
+ - [Pretrained Weights](#pretrained-weights)
53
+
54
+ ## About DeOldify
55
+
56
+ Simply put, the mission of this project is to colorize and restore old images and
57
+ film footage. We'll get into the details in a bit, but first let's see some
58
+ pretty pictures and videos!
59
+
60
+ ### New and Exciting Stuff in DeOldify
61
+
62
+ - Glitches and artifacts are almost entirely eliminated
63
+ - Better skin (less zombies)
64
+ - More highly detailed and photorealistic renders
65
+ - Much less "blue bias"
66
+ - **Video** - it actually looks good!
67
+ - **NoGAN** - a new and weird but highly effective way to do GAN training for
68
+ image to image.
69
+
70
+ ## Example Videos
71
+
72
+ **Note:** Click images to watch
73
+
74
+ ### Facebook F8 Demo
75
+
76
+ [![DeOldify Facebook F8 Movie Colorization Demo](http://img.youtube.com/vi/l3UXXid04Ys/0.jpg)](http://www.youtube.com/watch?v=l3UXXid04Ys)
77
+
78
+ ### Silent Movie Examples
79
+
80
+ [![DeOldify Silent Movie Examples](http://img.youtube.com/vi/EXn-n2iqEjI/0.jpg)](http://www.youtube.com/watch?v=EXn-n2iqEjI)
81
+
82
+ ## Example Images
83
+
84
+ "Migrant Mother" by Dorothea Lange (1936)
85
+
86
+ ![Migrant Mother](https://i.imgur.com/Bt0vnke.jpg)
87
+
88
+ Woman relaxing in her livingroom in Sweden (1920)
89
+
90
+ ![Sweden Living Room](https://i.imgur.com/158d0oU.jpg)
91
+
92
+ "Toffs and Toughs" by Jimmy Sime (1937)
93
+
94
+ ![Class Divide](https://i.imgur.com/VYuav4I.jpg)
95
+
96
+ Thanksgiving Maskers (1911)
97
+
98
+ ![Thanksgiving Maskers](https://i.imgur.com/n8qVJ5c.jpg)
99
+
100
+ Glen Echo Madame Careta Gypsy Camp in Maryland (1925)
101
+
102
+ ![Gypsy Camp](https://i.imgur.com/1oYrJRI.jpg)
103
+
104
+ "Mr. and Mrs. Lemuel Smith and their younger children in their farm house,
105
+ Carroll County, Georgia." (1941)
106
+
107
+ ![Georgia Farmhouse](https://i.imgur.com/I2j8ynm.jpg)
108
+
109
+ "Building the Golden Gate Bridge" (est 1937)
110
+
111
+ ![Golden Gate Bridge](https://i.imgur.com/6SbFjfq.jpg)
112
+
113
+ > **Note:** What you might be wondering is while this render looks cool, are the
114
+ > colors accurate? The original photo certainly makes it look like the towers of
115
+ > the bridge could be white. We looked into this and it turns out the answer is
116
+ > no - the towers were already covered in red primer by this time. So that's
117
+ > something to keep in mind- historical accuracy remains a huge challenge!
118
+
119
+ "Terrasse de cafΓ©, Paris" (1925)
120
+
121
+ ![Cafe Paris](https://i.imgur.com/WprQwP5.jpg)
122
+
123
+ Norwegian Bride (est late 1890s)
124
+
125
+ ![Norwegian Bride](https://i.imgur.com/MmtvrZm.jpg)
126
+
127
+ ZitkΓ‘la-Ε Γ‘ (Lakota: Red Bird), also known as Gertrude Simmons Bonnin (1898)
128
+
129
+ ![Native Woman](https://i.imgur.com/zIGM043.jpg)
130
+
131
+ Chinese Opium Smokers (1880)
132
+
133
+ ![Opium Real](https://i.imgur.com/lVGq8Vq.jpg)
134
+
135
+ ## Stuff That Should Probably Be In A Paper
136
+
137
+ ### How to Achieve Stable Video
138
+
139
+ NoGAN training is crucial to getting the kind of stable and colorful images seen
140
+ in this iteration of DeOldify. NoGAN training combines the benefits of GAN
141
+ training (wonderful colorization) while eliminating the nasty side effects
142
+ (like flickering objects in video). Believe it or not, video is rendered using
143
+ isolated image generation without any sort of temporal modeling tacked on. The
144
+ process performs 30-60 minutes of the GAN portion of "NoGAN" training, using 1%
145
+ to 3% of imagenet data once. Then, as with still image colorization, we
146
+ "DeOldify" individual frames before rebuilding the video.
147
+
148
+ In addition to improved video stability, there is an interesting thing going on
149
+ here worth mentioning. It turns out the models I run, even different ones and
150
+ with different training structures, keep arriving at more or less the same
151
+ solution. That's even the case for the colorization of things you may think
152
+ would be arbitrary and unknowable, like the color of clothing, cars, and even
153
+ special effects (as seen in "Metropolis").
154
+
155
+ ![Metropolis Special FX](https://thumbs.gfycat.com/HeavyLoneBlowfish-size_restricted.gif)
156
+
157
+ My best guess is that the models are learning some interesting rules about how to
158
+ colorize based on subtle cues present in the black and white images that I
159
+ certainly wouldn't expect to exist. This result leads to nicely deterministic and
160
+ consistent results, and that means you don't have track model colorization
161
+ decisions because they're not arbitrary. Additionally, they seem remarkably
162
+ robust so that even in moving scenes the renders are very consistent.
163
+
164
+ ![Moving Scene Example](https://thumbs.gfycat.com/FamiliarJubilantAsp-size_restricted.gif)
165
+
166
+ Other ways to stabilize video add up as well. First, generally speaking rendering
167
+ at a higher resolution (higher render_factor) will increase stability of
168
+ colorization decisions. This stands to reason because the model has higher
169
+ fidelity image information to work with and will have a greater chance of making
170
+ the "right" decision consistently. Closely related to this is the use of
171
+ resnet101 instead of resnet34 as the backbone of the generator- objects are
172
+ detected more consistently and correctly with this. This is especially important
173
+ for getting good, consistent skin rendering. It can be particularly visually
174
+ jarring if you wind up with "zombie hands", for example.
175
+
176
+ ![Zombie Hand Example](https://thumbs.gfycat.com/ThriftyInferiorIsabellinewheatear-size_restricted.gif)
177
+
178
+ Additionally, gaussian noise augmentation during training appears to help but at
179
+ this point the conclusions as to just how much are bit more tenuous (I just
180
+ haven't formally measured this yet). This is loosely based on work done in style
181
+ transfer video, described here:
182
+ <https://medium.com/element-ai-research-lab/stabilizing-neural-style-transfer-for-video-62675e203e42>.
183
+
184
+ Special thanks go to Rani Horev for his contributions in implementing this noise
185
+ augmentation.
186
+
187
+ ### What is NoGAN?
188
+
189
+ This is a new type of GAN training that I've developed to solve some key problems
190
+ in the previous DeOldify model. It provides the benefits of GAN training while
191
+ spending minimal time doing direct GAN training. Instead, most of the training
192
+ time is spent pretraining the generator and critic separately with more
193
+ straight-forward, fast and reliable conventional methods. A key insight here is
194
+ that those more "conventional" methods generally get you most of the results you
195
+ need, and that GANs can be used to close the gap on realism. During the very
196
+ short amount of actual GAN training the generator not only gets the full
197
+ realistic colorization capabilities that used to take days of progressively
198
+ resized GAN training, but it also doesn't accrue nearly as much of the artifacts
199
+ and other ugly baggage of GANs. In fact, you can pretty much eliminate glitches
200
+ and artifacts almost entirely depending on your approach. As far as I know this
201
+ is a new technique. And it's incredibly effective.
202
+
203
+ #### Original DeOldify Model
204
+
205
+ ![Before Flicker](https://thumbs.gfycat.com/CoordinatedVeneratedHogget-size_restricted.gif)
206
+
207
+ #### NoGAN-Based DeOldify Model
208
+
209
+ ![After Flicker](https://thumbs.gfycat.com/OilyBlackArctichare-size_restricted.gif)
210
+
211
+ The steps are as follows: First train the generator in a conventional way by
212
+ itself with just the feature loss. Next, generate images from that, and train
213
+ the critic on distinguishing between those outputs and real images as a basic
214
+ binary classifier. Finally, train the generator and critic together in a GAN
215
+ setting (starting right at the target size of 192px in this case). Now for
216
+ the weird part: All the useful GAN training here only takes place within a very
217
+ small window of time. There's an inflection point where it appears the critic
218
+ has transferred everything it can that is useful to the generator. Past this
219
+ point, image quality oscillates between the best that you can get at the
220
+ inflection point, or bad in a predictable way (orangish skin, overly red lips,
221
+ etc). There appears to be no productive training after the inflection point.
222
+ And this point lies within training on just 1% to 3% of the Imagenet Data!
223
+ That amounts to about 30-60 minutes of training at 192px.
224
+
225
+ The hard part is finding this inflection point. So far, I've accomplished this
226
+ by making a whole bunch of model save checkpoints (every 0.1% of data iterated
227
+ on) and then just looking for the point where images look great before they go
228
+ totally bonkers with orange skin (always the first thing to go). Additionally,
229
+ generator rendering starts immediately getting glitchy and inconsistent at this
230
+ point, which is no good particularly for video. What I'd really like to figure
231
+ out is what the tell-tale sign of the inflection point is that can be easily
232
+ automated as an early stopping point. Unfortunately, nothing definitive is
233
+ jumping out at me yet. For one, it's happening in the middle of training loss
234
+ decreasing- not when it flattens out, which would seem more reasonable on the surface.
235
+
236
+ Another key thing about NoGAN training is you can repeat pretraining the critic
237
+ on generated images after the initial GAN training, then repeat the GAN training
238
+ itself in the same fashion. This is how I was able to get extra colorful results
239
+ with the "artistic" model. But this does come at a cost currently- the output of
240
+ the generator becomes increasingly inconsistent and you have to experiment with
241
+ render resolution (render_factor) to get the best result. But the renders are
242
+ still glitch free and way more consistent than I was ever able to achieve with
243
+ the original DeOldify model. You can do about five of these repeat cycles, give
244
+ or take, before you get diminishing returns, as far as I can tell.
245
+
246
+ Keep in mind- I haven't been entirely rigorous in figuring out what all is going
247
+ on in NoGAN- I'll save that for a paper. That means there's a good chance I'm
248
+ wrong about something. But I think it's definitely worth putting out there now
249
+ because I'm finding it very useful- it's solving basically much of my remaining
250
+ problems I had in DeOldify.
251
+
252
+ This builds upon a technique developed in collaboration with Jeremy Howard and
253
+ Sylvain Gugger for Fast.AI's Lesson 7 in version 3 of Practical Deep Learning
254
+ for Coders Part I. The particular lesson notebook can be found here:
255
+ <https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb>
256
+
257
+ ## Why Three Models?
258
+
259
+ There are now three models to choose from in DeOldify. Each of these has key
260
+ strengths and weaknesses, and so have different use cases. Video is for video
261
+ of course. But stable and artistic are both for images, and sometimes one will
262
+ do images better than the other.
263
+
264
+ More details:
265
+
266
+ - **Artistic** - This model achieves the highest quality results in image
267
+ coloration, in terms of interesting details and vibrance. The most notable
268
+ drawback however is that it's a bit of a pain to fiddle around with to get the
269
+ best results (you have to adjust the rendering resolution or render_factor to
270
+ achieve this). Additionally, the model does not do as well as stable in a few
271
+ key common scenarios- nature scenes and portraits. The model uses a resnet34
272
+ backbone on a UNet with an emphasis on depth of layers on the decoder side.
273
+ This model was trained with 5 critic pretrain/GAN cycle repeats via NoGAN, in
274
+ addition to the initial generator/critic pretrain/GAN NoGAN training, at 192px.
275
+ This adds up to a total of 32% of Imagenet data trained once (12.5 hours of
276
+ direct GAN training).
277
+
278
+ - **Stable** - This model achieves the best results with landscapes and
279
+ portraits. Notably, it produces less "zombies"- where faces or limbs stay gray
280
+ rather than being colored in properly. It generally has less weird
281
+ miscolorations than artistic, but it's also less colorful in general. This
282
+ model uses a resnet101 backbone on a UNet with an emphasis on width of layers on
283
+ the decoder side. This model was trained with 3 critic pretrain/GAN cycle
284
+ repeats via NoGAN, in addition to the initial generator/critic pretrain/GAN
285
+ NoGAN training, at 192px. This adds up to a total of 7% of Imagenet data
286
+ trained once (3 hours of direct GAN training).
287
+
288
+ - **Video** - This model is optimized for smooth, consistent and flicker-free
289
+ video. This would definitely be the least colorful of the three models, but
290
+ it's honestly not too far off from "stable". The model is the same as "stable"
291
+ in terms of architecture, but differs in training. It's trained for a mere 2.2%
292
+ of Imagenet data once at 192px, using only the initial generator/critic
293
+ pretrain/GAN NoGAN training (1 hour of direct GAN training).
294
+
295
+ Because the training of the artistic and stable models was done before the
296
+ "inflection point" of NoGAN training described in "What is NoGAN???" was
297
+ discovered, I believe this amount of training on them can be knocked down
298
+ considerably. As far as I can tell, the models were stopped at "good points"
299
+ that were well beyond where productive training was taking place. I'll be
300
+ looking into this in the future.
301
+
302
+ Ideally, eventually these three models will be consolidated into one that has all
303
+ these good desirable unified. I think there's a path there, but it's going to
304
+ require more work! So for now, the most practical solution appears to be to
305
+ maintain multiple models.
306
+
307
+ ## The Technical Details
308
+
309
+ This is a deep learning based model. More specifically, what I've done is
310
+ combined the following approaches:
311
+
312
+ ### [Self-Attention Generative Adversarial Network](https://arxiv.org/abs/1805.08318)
313
+
314
+ Except the generator is a **pretrained U-Net**, and I've just modified it to
315
+ have the spectral normalization and self-attention. It's a pretty
316
+ straightforward translation.
317
+
318
+ ### [Two Time-Scale Update Rule](https://arxiv.org/abs/1706.08500)
319
+
320
+ This is also very straightforward – it's just one to one generator/critic
321
+ iterations and higher critic learning rate.
322
+ This is modified to incorporate a "threshold" critic loss that makes sure that
323
+ the critic is "caught up" before moving on to generator training.
324
+ This is particularly useful for the "NoGAN" method described below.
325
+
326
+ ### NoGAN
327
+
328
+ There's no paper here! This is a new type of GAN training that I've developed to
329
+ solve some key problems in the previous DeOldify model.
330
+ The gist is that you get the benefits of GAN training while spending minimal time
331
+ doing direct GAN training.
332
+ More details are in the [What is NoGAN?](#what-is-nogan) section (it's a doozy).
333
+
334
+ ### Generator Loss
335
+
336
+ Loss during NoGAN learning is two parts: One is a basic Perceptual Loss (or
337
+ Feature Loss) based on VGG16 – this just biases the generator model to replicate
338
+ the input image.
339
+ The second is the loss score from the critic. For the curious – Perceptual Loss
340
+ isn't sufficient by itself to produce good results.
341
+ It tends to just encourage a bunch of brown/green/blue – you know, cheating to
342
+ the test, basically, which neural networks are really good at doing!
343
+ Key thing to realize here is that GANs essentially are learning the loss function
344
+ for you – which is really one big step closer to toward the ideal that we're
345
+ shooting for in machine learning.
346
+ And of course you generally get much better results when you get the machine to
347
+ learn something you were previously hand coding.
348
+ That's certainly the case here.
349
+
350
+ **Of note:** There's no longer any "Progressive Growing of GANs" type training
351
+ going on here. It's just not needed in lieu of the superior results obtained
352
+ by the "NoGAN" technique described above.
353
+
354
+ The beauty of this model is that it should be generally useful for all sorts of
355
+ image modification, and it should do it quite well.
356
+ What you're seeing above are the results of the colorization model, but that's
357
+ just one component in a pipeline that I'm developing with the exact same approach.
358
+
359
+ ## This Project, Going Forward
360
+
361
+ So that's the gist of this project – I'm looking to make old photos and film
362
+ look reeeeaaally good with GANs, and more importantly, make the project *useful*.
363
+ In the meantime though this is going to be my baby and I'll be actively updating
364
+ and improving the code over the foreseeable future.
365
+ I'll try to make this as user-friendly as possible, but I'm sure there's going
366
+ to be hiccups along the way.
367
+
368
+ Oh and I swear I'll document the code properly...eventually. Admittedly I'm
369
+ *one of those* people who believes in "self documenting code" (LOL).
370
+
371
+ ## Getting Started Yourself
372
+
373
+ ### Easiest Approach
374
+
375
+ The easiest way to get started is to go straight to the Colab notebooks:
376
+
377
+ Image [![Colab for images](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb)
378
+ | Video [![Colab for video](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb)
379
+
380
+ Special thanks to Matt Robinson and MarΓ­a Benavente for their image Colab notebook
381
+ contributions, and Robert Bell for the video Colab notebook work!
382
+
383
+ ### Your Own Machine (not as easy)
384
+
385
+ #### Hardware and Operating System Requirements
386
+
387
+ - **(Training Only) BEEFY Graphics card**. I'd really like to have more memory
388
+ than the 11 GB in my GeForce 1080TI (11GB). You'll have a tough time with less.
389
+ The Generators and Critic are ridiculously large.
390
+ - **(Colorization Alone) A decent graphics card**. Approximately 4GB+ memory
391
+ video cards should be sufficient.
392
+ - **Linux**. I'm using Ubuntu 18.04, and I know 16.04 works fine too. **Windows
393
+ is not supported and any issues brought up related to this will not be investigated.**
394
+
395
+ #### Easy Install
396
+
397
+ You should now be able to do a simple install with Anaconda. Here are the steps:
398
+
399
+ Open the command line and navigate to the root folder you wish to install. Then
400
+ type the following commands
401
+
402
+ ```console
403
+ git clone https://github.com/jantic/DeOldify.git DeOldify
404
+ cd DeOldify
405
+ conda env create -f environment.yml
406
+ ```
407
+
408
+ Then start running with these commands:
409
+
410
+ ```console
411
+ source activate deoldify
412
+ jupyter lab
413
+ ```
414
+
415
+ From there you can start running the notebooks in Jupyter Lab, via the url they
416
+ provide you in the console.
417
+
418
+ > **Note:** You can also now do "conda activate deoldify" if you have the latest
419
+ version of conda and in fact that's now recommended. But a lot of people don't
420
+ have that yet so I'm not going to make it the default instruction here yet.
421
+
422
+ **Alternative Install:** User daddyparodz has kindly created an installer script
423
+ for Ubuntu, and in particular Ubuntu on WSL, that may make things easier:
424
+ <https://github.com/daddyparodz/AutoDeOldifyLocal>
425
+
426
+ #### Note on test_images Folder
427
+
428
+ The images in the `test_images` folder have been removed because they were using
429
+ Git LFS and that costs a lot of money when GitHub actually charges for bandwidth
430
+ on a popular open source project (they had a billing bug for while that was
431
+ recently fixed). The notebooks that use them (the image test ones) still point
432
+ to images in that directory that I (Jason) have personally and I'd like to keep
433
+ it that way because, after all, I'm by far the primary and most active developer.
434
+ But they won't work for you. Still, those notebooks are a convenient template
435
+ for making your own tests if you're so inclined.
436
+
437
+ #### Typical training
438
+
439
+ The notebook `ColorizeTrainingWandb` has been created to log and monitor results
440
+ through [Weights & Biases](https://www.wandb.com/). You can find a description of
441
+ typical training by consulting [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).
442
+
443
+ ## Pretrained Weights
444
+
445
+ To start right away on your own machine with your own images or videos without
446
+ training the models yourself, you'll need to download the "Completed Generator
447
+ Weights" listed below and drop them in the /models/ folder.
448
+
449
+ The colorization inference notebooks should be able to guide you from here. The
450
+ notebooks to use are named ImageColorizerArtistic.ipynb,
451
+ ImageColorizerStable.ipynb, and VideoColorizer.ipynb.
452
+
453
+ ### Completed Generator Weights
454
+
455
+ - [Artistic](https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth)
456
+ - [Stable](https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0)
457
+ - [Video](https://data.deepai.org/deoldify/ColorizeVideo_gen.pth)
458
+
459
+ ### Completed Critic Weights
460
+
461
+ - [Artistic](https://www.dropbox.com/s/1qd663zbk6ntzuy/ColorizeArtistic_crit.pth?dl=0)
462
+ - [Stable](https://www.dropbox.com/s/wlqu6w88qwzcvfn/ColorizeStable_crit.pth?dl=0)
463
+ - [Video](https://www.dropbox.com/s/oyl6qmwpdvpm95d/ColorizeVideo_crit.pth?dl=0)
464
+
465
+ ### Pretrain Only Generator Weights
466
+
467
+ - [Artistic](https://www.dropbox.com/s/lbuv6911aivm9zi/ColorizeArtistic_PretrainOnly_gen.pth?dl=0)
468
+ - [Stable](https://www.dropbox.com/s/6ita1pwyqjsmx4p/ColorizeStable_PretrainOnly_gen.pth?dl=0)
469
+ - [Video](https://www.dropbox.com/s/tl4uzkwwapz68ca/ColorizeVideo_PretrainOnly_gen.pth?dl=0)
470
+
471
+ ### Pretrain Only Critic Weights
472
+
473
+ - [Artistic](https://www.dropbox.com/s/6td494kcjqfmh26/ColorizeArtistic_PretrainOnly_crit.pth?dl=0)
474
+ - [Stable](https://www.dropbox.com/s/houkmrdivbia7z8/ColorizeStable_PretrainOnly_crit.pth?dl=0)
475
+ - [Video](https://www.dropbox.com/s/80wpz16x7yudblh/ColorizeVideo_PretrainOnly_crit.pth?dl=0)
476
+
477
+ ## Want the Old DeOldify?
478
+
479
+ We suspect some of you are going to want access to the original DeOldify model
480
+ for various reasons. We have that archived here: <https://github.com/dana-kelley/DeOldify>
481
+
482
+ ## Want More?
483
+
484
+ Follow [#DeOldify](https://twitter.com/search?q=%23Deoldify) on Twitter.
485
+
486
+ ## License
487
+
488
+ All code in this repository is under the MIT license as specified by the LICENSE
489
+ file.
490
+
491
+ The model weights listed in this readme under the "Pretrained Weights" section
492
+ are trained by ourselves and are released under the MIT license.
493
+
494
+ ## A Statement on Open Source Support
495
+
496
+ We believe that open source has done a lot of good for the world.Β  After all,
497
+ DeOldify simply wouldn't exist without it. But we also believe that there needs
498
+ to be boundaries on just how much is reasonable to be expected from an open
499
+ source project maintained by just two developers.
500
+
501
+ Our stance is that we're providing the code and documentation on research that
502
+ we believe is beneficial to the world.Β  What we have provided are novel takes
503
+ on colorization, GANs, and video that are hopefully somewhat friendly for
504
+ developers and researchers to learn from and adopt. This is the culmination of
505
+ well over a year of continuous work, free for you. What wasn't free was
506
+ shouldered by us, the developers.Β  We left our jobs, bought expensive GPUs, and
507
+ had huge electric bills as a result of dedicating ourselves to this.
508
+
509
+ What we haven't provided here is a ready to use free "product" or "app", and we
510
+ don't ever intend on providing that.Β  It's going to remain a Linux based project
511
+ without Windows support, coded in Python, and requiring people to have some extra
512
+ technical background to be comfortable using it.Β  Others have stepped in with
513
+ their own apps made with DeOldify, some paid and some free, which is what we want!
514
+ We're instead focusing on what we believe we can do best- making better
515
+ commercial models that people will pay for.
516
+ Does that mean you're not getting the very best for free?Β  Of course. We simply
517
+ don't believe that we're obligated to provide that, nor is it feasible! We
518
+ compete on research and sell that.Β  Not a GUI or web service that wraps said
519
+ research- that part isn't something we're going to be great at anyways. We're not
520
+ about to shoot ourselves in the foot by giving away our actual competitive
521
+ advantage for free, quite frankly.
522
+
523
+ We're also not willing to go down the rabbit hole of providing endless, open
524
+ ended and personalized support on this open source project.Β  Our position is
525
+ this:Β  If you have the proper background and resources, the project provides
526
+ more than enough to get you started. We know this because we've seen plenty of
527
+ people using it and making money off of their own projects with it.
528
+
529
+ Thus, if you have an issue come up and it happens to be an actual bug that
530
+ having it be fixed will benefit users generally, then great- that's something
531
+ we'll be happy to look into.
532
+
533
+ In contrast, if you're asking about something that really amounts to asking for
534
+ personalized and time consuming support that won't benefit anybody else, we're
535
+ not going to help. It's simply not in our interest to do that. We have bills to
536
+ pay, after all. And if you're asking for help on something that can already be
537
+ derived from the documentation or code?Β  That's simply annoying, and we're not
538
+ going to pretend to be ok with that.
VideoColorizer.ipynb ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "#NOTE: This must be the first call in order to work properly!\n",
10
+ "from deoldify import device\n",
11
+ "from deoldify.device_id import DeviceId\n",
12
+ "#choices: CPU, GPU0...GPU7\n",
13
+ "device.set(device=DeviceId.GPU0)"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from deoldify.visualize import *\n",
23
+ "plt.style.use('dark_background')\n",
24
+ "import warnings\n",
25
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*?Your .*? set is empty.*?\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "colorizer = get_video_colorizer()"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "# Instructions\n",
42
+ "\n",
43
+ "### source_url\n",
44
+ "Type in a url hosting a video from YouTube, Imgur, Twitter, Reddit, Vimeo, etc. Many sources work! GIFs also work. Full list here: https://ytdl-org.github.io/youtube-dl/supportedsites.html NOTE: If you want to use your own video, you can set source_url to None and just upload the file to video/source/ in Jupyter. Just make sure that the file_name parameter matches the file you uploaded.\n",
45
+ "\n",
46
+ "\n",
47
+ "### file_name\n",
48
+ "Name this whatever sensible file name you want (minus extension)! It should actually exist in video/source if source_url=None\n",
49
+ "\n",
50
+ "\n",
51
+ "### render_factor\n",
52
+ "The default value of 21 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out. \n",
53
+ "\n",
54
+ "\n",
55
+ "### file_name_ext\n",
56
+ "There's no reason to changes this.\n",
57
+ "\n",
58
+ "\n",
59
+ "### result_path\n",
60
+ "Ditto- don't change.\n",
61
+ "\n",
62
+ "\n",
63
+ "### How to Download a Copy\n",
64
+ "Simply shift+right click on the displayed video and click \"Save video as...\"!\n",
65
+ "\n",
66
+ "\n",
67
+ "## Pro Tips\n",
68
+ "1. If a video takes a long time to render and you're wondering how well the frames will actually be colorized, you can preview how well the frames will be rendered at each render_factor by using the code at the bottom. Just stop the video rendering by hitting the stop button on the cell, then run that bottom cell under \"See how well render_factor values perform on a frame here\". It's not perfect and you may still need to experiment a bit especially when it comes to figuring out how to reduce frame inconsistency. But it'll go a long way in narrowing down what actually works.\n",
69
+ "\n",
70
+ "\n",
71
+ "## Troubleshooting\n",
72
+ "The video player may wind up not showing up, in which case- make sure to wait for the Jupyter cell to complete processing first (the play button will stop spinning). Then follow these alternative download instructions\n",
73
+ "\n",
74
+ "1. In the menu to the left, click Home icon.\n",
75
+ "2. By default, rendered video will be in /video/result/\n",
76
+ "\n",
77
+ "If a video you downloaded doesn't play, it's probably because the cell didn't complete processing and the video is in a half-finished state.\n",
78
+ "If you get a 'CUDA out of memory' error, you probably have the render_factor too high. The max is 44 on 11GB video cards."
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "## Colorize!!"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "#NOTE: Max is 44 with 11GB video cards. 21 is a good default\n",
95
+ "render_factor=21\n",
96
+ "#NOTE: Make source_url None to just read from file at ./video/source/[file_name] directly without modification\n",
97
+ "source_url='https://twitter.com/silentmoviegifs/status/1116751583386034176'\n",
98
+ "file_name = 'DogShy1926'\n",
99
+ "file_name_ext = file_name + '.mp4'\n",
100
+ "result_path = None\n",
101
+ "\n",
102
+ "if source_url is not None:\n",
103
+ " result_path = colorizer.colorize_from_url(source_url, file_name_ext, render_factor=render_factor)\n",
104
+ "else:\n",
105
+ " result_path = colorizer.colorize_from_file_name(file_name_ext, render_factor=render_factor)\n",
106
+ "\n",
107
+ "show_video_in_notebook(result_path)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {},
113
+ "source": [
114
+ "## See how well render_factor values perform on a frame here"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "for i in range(10,45,2):\n",
124
+ " colorizer.vis.plot_transformed_image('video/bwframes/' + file_name + '/00001.jpg', render_factor=i, display_render_factor=True, figsize=(8,8))"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": []
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "kernelspec": {
137
+ "display_name": "Python 3 (ipykernel)",
138
+ "language": "python",
139
+ "name": "python3"
140
+ },
141
+ "language_info": {
142
+ "codemirror_mode": {
143
+ "name": "ipython",
144
+ "version": 3
145
+ },
146
+ "file_extension": ".py",
147
+ "mimetype": "text/x-python",
148
+ "name": "python",
149
+ "nbconvert_exporter": "python",
150
+ "pygments_lexer": "ipython3",
151
+ "version": "3.8.0"
152
+ },
153
+ "toc": {
154
+ "colors": {
155
+ "hover_highlight": "#DAA520",
156
+ "navigate_num": "#000000",
157
+ "navigate_text": "#333333",
158
+ "running_highlight": "#FF0000",
159
+ "selected_highlight": "#FFD700",
160
+ "sidebar_border": "#EEEEEE",
161
+ "wrapper_background": "#FFFFFF"
162
+ },
163
+ "moveMenuLeft": true,
164
+ "nav_menu": {
165
+ "height": "67px",
166
+ "width": "252px"
167
+ },
168
+ "navigate_menu": true,
169
+ "number_sections": true,
170
+ "sideBar": true,
171
+ "threshold": 4,
172
+ "toc_cell": false,
173
+ "toc_section_display": "block",
174
+ "toc_window_display": false,
175
+ "widenNotebook": false
176
+ }
177
+ },
178
+ "nbformat": 4,
179
+ "nbformat_minor": 4
180
+ }
VideoColorizerColab.ipynb ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "### **<font color='blue'> Video Colorizer </font>**"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "colab_type": "text",
24
+ "id": "663IVxfrpIAb"
25
+ },
26
+ "source": [
27
+ "#β—’ DeOldify - Colorize your own videos!\n",
28
+ "\n",
29
+ "\n",
30
+ "_FYI: This notebook is intended as a tool to colorize gifs and short videos, if you are trying to convert longer video you may hit the limit on processing space. Running the Jupyter notebook on your own machine is recommended (and faster) for larger video sizes._\n",
31
+ "\n",
32
+ "####**Credits:**\n",
33
+ "\n",
34
+ "Big special thanks to:\n",
35
+ "\n",
36
+ "Robert Bell for all his work on the video Colab notebook, and paving the way to video in DeOldify!\n",
37
+ "\n",
38
+ "Dana Kelley for doing things, breaking stuff & having an opinion on everything."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {
44
+ "colab_type": "text",
45
+ "id": "ZjPqTBNoohK9"
46
+ },
47
+ "source": [
48
+ "\n",
49
+ "\n",
50
+ "---\n",
51
+ "\n",
52
+ "\n",
53
+ "#β—’ Verify Correct Runtime Settings\n",
54
+ "\n",
55
+ "**<font color='#FF000'> IMPORTANT </font>**\n",
56
+ "\n",
57
+ "In the \"Runtime\" menu for the notebook window, select \"Change runtime type.\" Ensure that the following are selected:\n",
58
+ "* Runtime Type = Python 3\n",
59
+ "* Hardware Accelerator = GPU \n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {
65
+ "colab_type": "text",
66
+ "id": "gaEJBGDlptEo"
67
+ },
68
+ "source": [
69
+ "#β—’ Git clone and install DeOldify"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {
76
+ "colab": {},
77
+ "colab_type": "code",
78
+ "id": "-T-svuHytJ-8"
79
+ },
80
+ "outputs": [],
81
+ "source": [
82
+ "!git clone https://github.com/jantic/DeOldify.git DeOldify"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "cd DeOldify"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "markdown",
96
+ "metadata": {
97
+ "colab_type": "text",
98
+ "id": "BDFjbNxaadNJ"
99
+ },
100
+ "source": [
101
+ "#β—’ Setup"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {
108
+ "colab": {},
109
+ "colab_type": "code",
110
+ "id": "00_GcC_trpdE"
111
+ },
112
+ "outputs": [],
113
+ "source": [
114
+ "#NOTE: This must be the first call in order to work properly!\n",
115
+ "from deoldify import device\n",
116
+ "from deoldify.device_id import DeviceId\n",
117
+ "#choices: CPU, GPU0...GPU7\n",
118
+ "device.set(device=DeviceId.GPU0)\n",
119
+ "\n",
120
+ "import torch\n",
121
+ "\n",
122
+ "if not torch.cuda.is_available():\n",
123
+ " print('GPU not available.')\n",
124
+ "\n",
125
+ "from os import path"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {
132
+ "colab": {},
133
+ "colab_type": "code",
134
+ "id": "Lsx7xCXNSVt6"
135
+ },
136
+ "outputs": [],
137
+ "source": [
138
+ "!pip install -r requirements-colab.txt"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {
145
+ "colab": {},
146
+ "colab_type": "code",
147
+ "id": "MsJa69CMwj3l"
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "import fastai\n",
152
+ "from deoldify.visualize import *\n",
153
+ "from pathlib import Path\n",
154
+ "torch.backends.cudnn.benchmark=True\n",
155
+ "import warnings\n",
156
+ "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*?Your .*? set is empty.*?\")"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "!mkdir 'models'\n",
166
+ "!wget https://data.deepai.org/deoldify/ColorizeVideo_gen.pth -O ./models/ColorizeVideo_gen.pth"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "metadata": {
173
+ "colab": {},
174
+ "colab_type": "code",
175
+ "id": "tzHVnegp21hC"
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "colorizer = get_video_colorizer()"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {},
185
+ "source": [
186
+ "#β—’ Instructions"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "### source_url\n",
194
+ "Type in a url hosting a video from YouTube, Imgur, Twitter, Reddit, Vimeo, etc. Many sources work! GIFs also work. Full list here: https://ytdl-org.github.io/youtube-dl/supportedsites.html NOTE: If you want to use your own video, upload it first to a site like YouTube. \n",
195
+ "\n",
196
+ "### render_factor\n",
197
+ "The default value of 21 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the color portion of the video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant. Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out.\n",
198
+ "\n",
199
+ "### watermarked\n",
200
+ "Selected by default, this places a watermark icon of a palette at the bottom left corner of the image. This is intended to be a standard way to convey to others viewing the image that it is colorized by AI. We want to help promote this as a standard, especially as the technology continues to improve and the distinction between real and fake becomes harder to discern. This palette watermark practice was initiated and lead by the company MyHeritage in the MyHeritage In Color feature (which uses a newer version of DeOldify than what you're using here).\n",
201
+ "\n",
202
+ "### How to Download a Copy\n",
203
+ "Simply right click on the displayed video and click \"Save video as...\"!\n",
204
+ "\n",
205
+ "## Pro Tips\n",
206
+ "1. If a video takes a long time to render and you're wondering how well the frames will actually be colorized, you can preview how well the frames will be rendered at each render_factor by using the code at the bottom. Just stop the video rendering by hitting the stop button on the cell, then run that bottom cell under \"See how well render_factor values perform on a frame here\". It's not perfect and you may still need to experiment a bit especially when it comes to figuring out how to reduce frame inconsistency. But it'll go a long way in narrowing down what actually works.\n",
207
+ "2. If videos are taking way too much time for your liking, running the Jupyter notebook VideoColorizer.ipynb on your own machine (with DeOldify installed) will generally be much faster (as long as you have the hardware for it). \n",
208
+ "3. Longer videos (running multiple minutes) are going to have a rough time on Colabs. You'll be much better off using a local install of DeOldify instead in this case.\n",
209
+ "\n",
210
+ "## Troubleshooting\n",
211
+ "The video player may wind up not showing up, in which case- make sure to wait for the Jupyter cell to complete processing first (the play button will stop spinning). Then follow these alternative download instructions\n",
212
+ "\n",
213
+ "1. In the menu to the left, click Files\n",
214
+ "2. If you don't see the 'DeOldify' folder, click \"Refresh\"\n",
215
+ "3. By default, rendered video will be in /DeOldify/video/result/\n",
216
+ "\n",
217
+ "If a video you downloaded doesn't play, it's probably because the cell didn't complete processing and the video is in a half-finished state."
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "metadata": {
223
+ "colab_type": "text",
224
+ "id": "sUQrbSYipiJn"
225
+ },
226
+ "source": [
227
+ "#β—’ Colorize!!"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "source_url = '' #@param {type:\"string\"}\n",
237
+ "render_factor = 21 #@param {type: \"slider\", min: 5, max: 40}\n",
238
+ "watermarked = True #@param {type:\"boolean\"}\n",
239
+ "\n",
240
+ "if source_url is not None and source_url !='':\n",
241
+ " video_path = colorizer.colorize_from_url(source_url, 'video.mp4', render_factor, watermarked=watermarked)\n",
242
+ " show_video_in_notebook(video_path)\n",
243
+ "else:\n",
244
+ " print('Provide a video url and try again.')"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "## See how well render_factor values perform on a frame here"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "for i in range(10,40,2):\n",
261
+ " colorizer.vis.plot_transformed_image('video/bwframes/video/00001.jpg', render_factor=i, display_render_factor=True, figsize=(8,8))"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "metadata": {
267
+ "colab_type": "text",
268
+ "id": "X7Ycv_Y9xAHp"
269
+ },
270
+ "source": [
271
+ "---\n",
272
+ "#βš™ Recommended video and gif sources \n",
273
+ "* [/r/Nickelodeons/](https://www.reddit.com/r/Nickelodeons/)\n",
274
+ "* [r/silentmoviegifs](https://www.reddit.com/r/silentmoviegifs/)\n",
275
+ "* https://twitter.com/silentmoviegifs "
276
+ ]
277
+ }
278
+ ],
279
+ "metadata": {
280
+ "accelerator": "GPU",
281
+ "colab": {
282
+ "collapsed_sections": [],
283
+ "name": "VideoColorizerColab.ipynb",
284
+ "provenance": [],
285
+ "toc_visible": true,
286
+ "version": "0.3.2"
287
+ },
288
+ "kernelspec": {
289
+ "display_name": "Python 3",
290
+ "language": "python",
291
+ "name": "python3"
292
+ },
293
+ "language_info": {
294
+ "codemirror_mode": {
295
+ "name": "ipython",
296
+ "version": 3
297
+ },
298
+ "file_extension": ".py",
299
+ "mimetype": "text/x-python",
300
+ "name": "python",
301
+ "nbconvert_exporter": "python",
302
+ "pygments_lexer": "ipython3",
303
+ "version": "3.7.6"
304
+ }
305
+ },
306
+ "nbformat": 4,
307
+ "nbformat_minor": 4
308
+ }
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from deoldify import device
4
+ from deoldify.device_id import DeviceId
5
+ from deoldify.visualize import *
6
+ import tempfile
7
+
8
+ if not os.path.exists("./models/ColorizeArtistic_gen.pth"):
9
+ os.system(
10
+ 'wget "https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth" -O "./models/ColorizeArtistic_gen.pth"'
11
+ )
12
+
13
+ device.set(device=DeviceId.GPU0)
14
+
15
+ colorizer = get_image_colorizer(artistic=True)
16
+
17
+
18
+ def colorize(image):
19
+ tmp_folder = tempfile.TemporaryDirectory()
20
+ return colorizer.plot_transformed_image_from_url(
21
+ url=image,
22
+ path=f"{tmp_folder.name}/input.png",
23
+ render_factor=35,
24
+ compare=True,
25
+ results_dir=f"{tmp_folder.name}/output.png")
26
+
27
+
28
+ gr.Interface(colorize, ["text"],
29
+ ["image"]).queue(default_enabled=True).launch(show_api=True)
demo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #NOTE: This must be the first call in order to work properly!
2
+ from deoldify import device
3
+ from deoldify.device_id import DeviceId
4
+ from deoldify.visualize import *
5
+ import argparse
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("-i", "--input", type=str, required=True, help="the image path")
9
+ parser.add_argument("-o", "--output", type=str, required=True, help="the output path")
10
+ parser.add_argument("--artistic", action="store_true", default=True, help="enable artistic mode")
11
+ parser.add_argument("--compare", action="store_true", default=True, help="enable compare mode")
12
+ parser.add_argument("-r", "--render_factor", type=int, default=35, help="max is 45 for 11GB video cards. default is 35")
13
+ args = parser.parse_args()
14
+
15
+ device.set(device=DeviceId.GPU0)
16
+
17
+ colorizer = get_image_colorizer(artistic=args.artistic)
18
+
19
+ result_path = colorizer.plot_transformed_image(path=args.input, render_factor=args.render_factor, compare=args.compare, results_dir=args.output)
20
+
21
+ print(result_path)
deoldify/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
4
+ logging.getLogger().setLevel(logging.INFO)
5
+
6
+ from deoldify._device import _Device
7
+
8
+ device = _Device()
deoldify/_device.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from .device_id import DeviceId
4
+
5
+ #NOTE: This must be called first before any torch imports in order to work properly!
6
+
7
+ class DeviceException(Exception):
8
+ pass
9
+
10
+ class _Device:
11
+ def __init__(self):
12
+ self.set(DeviceId.CPU)
13
+
14
+ def is_gpu(self):
15
+ ''' Returns `True` if the current device is GPU, `False` otherwise. '''
16
+ return self.current() is not DeviceId.CPU
17
+
18
+ def current(self):
19
+ return self._current_device
20
+
21
+ def set(self, device:DeviceId):
22
+ if device == DeviceId.CPU:
23
+ os.environ['CUDA_VISIBLE_DEVICES']=''
24
+ else:
25
+ os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
26
+ import torch
27
+ torch.backends.cudnn.benchmark=False
28
+
29
+ self._current_device = device
30
+ return device
deoldify/augs.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from fastai.vision.image import TfmPixel
4
+
5
+ # Contributed by Rani Horev. Thank you!
6
+ def _noisify(
7
+ x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
8
+ ):
9
+ if noise_range > 255 or noise_range < 0:
10
+ raise Exception("noise_range must be between 0 and 255, inclusively.")
11
+
12
+ h, w = x.shape[1:]
13
+ img_size = h * w
14
+ mult = 10000.0
15
+ pct_pixels = (
16
+ random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
17
+ )
18
+ noise_count = int(img_size * pct_pixels)
19
+
20
+ for ii in range(noise_count):
21
+ yy = random.randrange(h)
22
+ xx = random.randrange(w)
23
+ noise = random.randrange(-noise_range, noise_range) / 255.0
24
+ x[:, yy, xx].add_(noise)
25
+
26
+ return x
27
+
28
+
29
+ noisify = TfmPixel(_noisify)
deoldify/critics.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_train import Learner
2
+ from fastai.core import *
3
+ from fastai.layers import NormType, conv_layer
4
+ from fastai.torch_core import *
5
+ from fastai.vision import *
6
+ from fastai.vision.data import ImageDataBunch
7
+ from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
8
+
9
+ _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
10
+
11
+
12
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
13
+ return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
14
+
15
+
16
+ def custom_gan_critic(
17
+ n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
18
+ ):
19
+ "Critic to train a `GAN`."
20
+ layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
21
+ for i in range(n_blocks):
22
+ layers += [
23
+ _conv(nf, nf, ks=3, stride=1),
24
+ nn.Dropout2d(p),
25
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
26
+ ]
27
+ nf *= 2
28
+ layers += [
29
+ _conv(nf, nf, ks=3, stride=1),
30
+ _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
31
+ Flatten(),
32
+ ]
33
+ return nn.Sequential(*layers)
34
+
35
+
36
+ def colorize_crit_learner(
37
+ data: ImageDataBunch,
38
+ loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
39
+ nf: int = 256,
40
+ ) -> Learner:
41
+ return Learner(
42
+ data,
43
+ custom_gan_critic(nf=nf),
44
+ metrics=accuracy_thresh_expand,
45
+ loss_func=loss_critic,
46
+ wd=1e-3,
47
+ )
deoldify/dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai import *
2
+ from fastai.core import *
3
+ from fastai.vision.transform import get_transforms
4
+ from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
5
+
6
+
7
+ def get_colorize_data(
8
+ sz: int,
9
+ bs: int,
10
+ crappy_path: Path,
11
+ good_path: Path,
12
+ random_seed: int = None,
13
+ keep_pct: float = 1.0,
14
+ num_workers: int = 8,
15
+ stats: tuple = imagenet_stats,
16
+ xtra_tfms=[],
17
+ ) -> ImageDataBunch:
18
+
19
+ src = (
20
+ ImageImageList.from_folder(crappy_path, convert_mode='RGB')
21
+ .use_partial_data(sample_pct=keep_pct, seed=random_seed)
22
+ .split_by_rand_pct(0.1, seed=random_seed)
23
+ )
24
+
25
+ data = (
26
+ src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
27
+ .transform(
28
+ get_transforms(
29
+ max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
30
+ ),
31
+ size=sz,
32
+ tfm_y=True,
33
+ )
34
+ .databunch(bs=bs, num_workers=num_workers, no_check=True)
35
+ .normalize(stats, do_y=True)
36
+ )
37
+
38
+ data.c = 3
39
+ return data
40
+
41
+
42
+ def get_dummy_databunch() -> ImageDataBunch:
43
+ path = Path('./dummy/')
44
+ return get_colorize_data(
45
+ sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
46
+ )
deoldify/device_id.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ class DeviceId(IntEnum):
4
+ GPU0 = 0,
5
+ GPU1 = 1,
6
+ GPU2 = 2,
7
+ GPU3 = 3,
8
+ GPU4 = 4,
9
+ GPU5 = 5,
10
+ GPU6 = 6,
11
+ GPU7 = 7,
12
+ CPU = 99
deoldify/filters.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_data import DatasetType
2
+ from fastai.basic_train import Learner
3
+ from abc import ABC, abstractmethod
4
+ from fastai.core import *
5
+ from fastai.vision import *
6
+ from fastai.vision.image import *
7
+ from fastai.vision.data import *
8
+ from fastai import *
9
+ import cv2
10
+ from PIL import Image as PilImage
11
+ from deoldify import device as device_settings
12
+ import logging
13
+
14
+
15
+ class IFilter(ABC):
16
+ @abstractmethod
17
+ def filter(
18
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
19
+ ) -> PilImage:
20
+ pass
21
+
22
+
23
+ class BaseFilter(IFilter):
24
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
25
+ super().__init__()
26
+ self.learn = learn
27
+
28
+ if not device_settings.is_gpu():
29
+ self.learn.model = self.learn.model.cpu()
30
+
31
+ self.device = next(self.learn.model.parameters()).device
32
+ self.norm, self.denorm = normalize_funcs(*stats)
33
+
34
+ def _transform(self, image: PilImage) -> PilImage:
35
+ return image
36
+
37
+ def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
38
+ # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
39
+ # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
40
+ targ_sz = (targ, targ)
41
+ return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
42
+
43
+ def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
44
+ result = self._scale_to_square(orig, sz)
45
+ result = self._transform(result)
46
+ return result
47
+
48
+ def _model_process(self, orig: PilImage, sz: int) -> PilImage:
49
+ model_image = self._get_model_ready_image(orig, sz)
50
+ x = pil2tensor(model_image, np.float32)
51
+ x = x.to(self.device)
52
+ x.div_(255)
53
+ x, y = self.norm((x, x), do_x=True)
54
+
55
+ try:
56
+ result = self.learn.pred_batch(
57
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
58
+ )
59
+ except RuntimeError as rerr:
60
+ if 'memory' not in str(rerr):
61
+ raise rerr
62
+ logging.warn('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
63
+ return model_image
64
+
65
+ out = result[0]
66
+ out = self.denorm(out.px, do_x=False)
67
+ out = image2np(out * 255).astype(np.uint8)
68
+ return PilImage.fromarray(out)
69
+
70
+ def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
71
+ targ_sz = orig.size
72
+ image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
73
+ return image
74
+
75
+
76
+ class ColorizerFilter(BaseFilter):
77
+ def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
78
+ super().__init__(learn=learn, stats=stats)
79
+ self.render_base = 16
80
+
81
+ def filter(
82
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
83
+ render_sz = render_factor * self.render_base
84
+ model_image = self._model_process(orig=filtered_image, sz=render_sz)
85
+ raw_color = self._unsquare(model_image, orig_image)
86
+
87
+ if post_process:
88
+ return self._post_process(raw_color, orig_image)
89
+ else:
90
+ return raw_color
91
+
92
+ def _transform(self, image: PilImage) -> PilImage:
93
+ return image.convert('LA').convert('RGB')
94
+
95
+ # This takes advantage of the fact that human eyes are much less sensitive to
96
+ # imperfections in chrominance compared to luminance. This means we can
97
+ # save a lot on memory and processing in the model, yet get a great high
98
+ # resolution result at the end. This is primarily intended just for
99
+ # inference
100
+ def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
101
+ color_np = np.asarray(raw_color)
102
+ orig_np = np.asarray(orig)
103
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
104
+ # do a black and white transform first to get better luminance values
105
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
106
+ hires = np.copy(orig_yuv)
107
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
108
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
109
+ final = PilImage.fromarray(final)
110
+ return final
111
+
112
+
113
+ class MasterFilter(BaseFilter):
114
+ def __init__(self, filters: List[IFilter], render_factor: int):
115
+ self.filters = filters
116
+ self.render_factor = render_factor
117
+
118
+ def filter(
119
+ self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
120
+ render_factor = self.render_factor if render_factor is None else render_factor
121
+ for filter in self.filters:
122
+ filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
123
+
124
+ return filtered_image
deoldify/generators.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_data import DataBunch
2
+ from fastai.basic_train import Learner
3
+ from fastai.layers import NormType
4
+ from fastai.torch_core import SplitFuncOrIdxList, apply_init, to_device
5
+ from fastai.vision import *
6
+ from fastai.vision.learner import cnn_config, create_body
7
+ from torch import nn
8
+ from .unet import DynamicUnetWide, DynamicUnetDeep
9
+ from .dataset import *
10
+
11
+ # Weights are implicitly read from ./models/ folder
12
+ def gen_inference_wide(
13
+ root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
14
+ data = get_dummy_databunch()
15
+ learn = gen_learner_wide(
16
+ data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
17
+ )
18
+ learn.path = root_folder
19
+ learn.load(weights_name)
20
+ learn.model.eval()
21
+ return learn
22
+
23
+
24
+ def gen_learner_wide(
25
+ data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
26
+ ) -> Learner:
27
+ return unet_learner_wide(
28
+ data,
29
+ arch=arch,
30
+ wd=1e-3,
31
+ blur=True,
32
+ norm_type=NormType.Spectral,
33
+ self_attention=True,
34
+ y_range=(-3.0, 3.0),
35
+ loss_func=gen_loss,
36
+ nf_factor=nf_factor,
37
+ )
38
+
39
+
40
+ # The code below is meant to be merged into fastaiv1 ideally
41
+ def unet_learner_wide(
42
+ data: DataBunch,
43
+ arch: Callable,
44
+ pretrained: bool = True,
45
+ blur_final: bool = True,
46
+ norm_type: Optional[NormType] = NormType,
47
+ split_on: Optional[SplitFuncOrIdxList] = None,
48
+ blur: bool = False,
49
+ self_attention: bool = False,
50
+ y_range: Optional[Tuple[float, float]] = None,
51
+ last_cross: bool = True,
52
+ bottle: bool = False,
53
+ nf_factor: int = 1,
54
+ **kwargs: Any
55
+ ) -> Learner:
56
+ "Build Unet learner from `data` and `arch`."
57
+ meta = cnn_config(arch)
58
+ body = create_body(arch, pretrained)
59
+ model = to_device(
60
+ DynamicUnetWide(
61
+ body,
62
+ n_classes=data.c,
63
+ blur=blur,
64
+ blur_final=blur_final,
65
+ self_attention=self_attention,
66
+ y_range=y_range,
67
+ norm_type=norm_type,
68
+ last_cross=last_cross,
69
+ bottle=bottle,
70
+ nf_factor=nf_factor,
71
+ ),
72
+ data.device,
73
+ )
74
+ learn = Learner(data, model, **kwargs)
75
+ learn.split(ifnone(split_on, meta['split']))
76
+ if pretrained:
77
+ learn.freeze()
78
+ apply_init(model[2], nn.init.kaiming_normal_)
79
+ return learn
80
+
81
+
82
+ # ----------------------------------------------------------------------
83
+
84
+ # Weights are implicitly read from ./models/ folder
85
+ def gen_inference_deep(
86
+ root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
87
+ data = get_dummy_databunch()
88
+ learn = gen_learner_deep(
89
+ data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
90
+ )
91
+ learn.path = root_folder
92
+ learn.load(weights_name)
93
+ learn.model.eval()
94
+ return learn
95
+
96
+
97
+ def gen_learner_deep(
98
+ data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
99
+ ) -> Learner:
100
+ return unet_learner_deep(
101
+ data,
102
+ arch,
103
+ wd=1e-3,
104
+ blur=True,
105
+ norm_type=NormType.Spectral,
106
+ self_attention=True,
107
+ y_range=(-3.0, 3.0),
108
+ loss_func=gen_loss,
109
+ nf_factor=nf_factor,
110
+ )
111
+
112
+
113
+ # The code below is meant to be merged into fastaiv1 ideally
114
+ def unet_learner_deep(
115
+ data: DataBunch,
116
+ arch: Callable,
117
+ pretrained: bool = True,
118
+ blur_final: bool = True,
119
+ norm_type: Optional[NormType] = NormType,
120
+ split_on: Optional[SplitFuncOrIdxList] = None,
121
+ blur: bool = False,
122
+ self_attention: bool = False,
123
+ y_range: Optional[Tuple[float, float]] = None,
124
+ last_cross: bool = True,
125
+ bottle: bool = False,
126
+ nf_factor: float = 1.5,
127
+ **kwargs: Any
128
+ ) -> Learner:
129
+ "Build Unet learner from `data` and `arch`."
130
+ meta = cnn_config(arch)
131
+ body = create_body(arch, pretrained)
132
+ model = to_device(
133
+ DynamicUnetDeep(
134
+ body,
135
+ n_classes=data.c,
136
+ blur=blur,
137
+ blur_final=blur_final,
138
+ self_attention=self_attention,
139
+ y_range=y_range,
140
+ norm_type=norm_type,
141
+ last_cross=last_cross,
142
+ bottle=bottle,
143
+ nf_factor=nf_factor,
144
+ ),
145
+ data.device,
146
+ )
147
+ learn = Learner(data, model, **kwargs)
148
+ learn.split(ifnone(split_on, meta['split']))
149
+ if pretrained:
150
+ learn.freeze()
151
+ apply_init(model[2], nn.init.kaiming_normal_)
152
+ return learn
153
+
154
+
155
+ # -----------------------------
deoldify/layers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from fastai.torch_core import *
3
+
4
+
5
+ # The code below is meant to be merged into fastaiv1 ideally
6
+
7
+
8
+ def custom_conv_layer(
9
+ ni: int,
10
+ nf: int,
11
+ ks: int = 3,
12
+ stride: int = 1,
13
+ padding: int = None,
14
+ bias: bool = None,
15
+ is_1d: bool = False,
16
+ norm_type: Optional[NormType] = NormType.Batch,
17
+ use_activ: bool = True,
18
+ leaky: float = None,
19
+ transpose: bool = False,
20
+ init: Callable = nn.init.kaiming_normal_,
21
+ self_attention: bool = False,
22
+ extra_bn: bool = False,
23
+ ):
24
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
25
+ if padding is None:
26
+ padding = (ks - 1) // 2 if not transpose else 0
27
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
28
+ if bias is None:
29
+ bias = not bn
30
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
31
+ conv = init_default(
32
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
33
+ init,
34
+ )
35
+ if norm_type == NormType.Weight:
36
+ conv = weight_norm(conv)
37
+ elif norm_type == NormType.Spectral:
38
+ conv = spectral_norm(conv)
39
+ layers = [conv]
40
+ if use_activ:
41
+ layers.append(relu(True, leaky=leaky))
42
+ if bn:
43
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
44
+ if self_attention:
45
+ layers.append(SelfAttention(nf))
46
+ return nn.Sequential(*layers)
deoldify/loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai import *
2
+ from fastai.core import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks import hook_outputs
5
+ import torchvision.models as models
6
+
7
+
8
+ class FeatureLoss(nn.Module):
9
+ def __init__(self, layer_wgts=[20, 70, 10]):
10
+ super().__init__()
11
+
12
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
13
+ requires_grad(self.m_feat, False)
14
+ blocks = [
15
+ i - 1
16
+ for i, o in enumerate(children(self.m_feat))
17
+ if isinstance(o, nn.MaxPool2d)
18
+ ]
19
+ layer_ids = blocks[2:5]
20
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
21
+ self.hooks = hook_outputs(self.loss_features, detach=False)
22
+ self.wgts = layer_wgts
23
+ self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
24
+ self.base_loss = F.l1_loss
25
+
26
+ def _make_features(self, x, clone=False):
27
+ self.m_feat(x)
28
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
29
+
30
+ def forward(self, input, target):
31
+ out_feat = self._make_features(target, clone=True)
32
+ in_feat = self._make_features(input)
33
+ self.feat_losses = [self.base_loss(input, target)]
34
+ self.feat_losses += [
35
+ self.base_loss(f_in, f_out) * w
36
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
37
+ ]
38
+
39
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
40
+ return sum(self.feat_losses)
41
+
42
+ def __del__(self):
43
+ self.hooks.remove()
44
+
45
+
46
+ # Refactored code, originally from https://github.com/VinceMarron/style_transfer
47
+ class WassFeatureLoss(nn.Module):
48
+ def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
49
+ super().__init__()
50
+ self.m_feat = models.vgg16_bn(True).features.cuda().eval()
51
+ requires_grad(self.m_feat, False)
52
+ blocks = [
53
+ i - 1
54
+ for i, o in enumerate(children(self.m_feat))
55
+ if isinstance(o, nn.MaxPool2d)
56
+ ]
57
+ layer_ids = blocks[2:5]
58
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
59
+ self.hooks = hook_outputs(self.loss_features, detach=False)
60
+ self.wgts = layer_wgts
61
+ self.wass_wgts = wass_wgts
62
+ self.metric_names = (
63
+ ['pixel']
64
+ + [f'feat_{i}' for i in range(len(layer_ids))]
65
+ + [f'wass_{i}' for i in range(len(layer_ids))]
66
+ )
67
+ self.base_loss = F.l1_loss
68
+
69
+ def _make_features(self, x, clone=False):
70
+ self.m_feat(x)
71
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
72
+
73
+ def _calc_2_moments(self, tensor):
74
+ chans = tensor.shape[1]
75
+ tensor = tensor.view(1, chans, -1)
76
+ n = tensor.shape[2]
77
+ mu = tensor.mean(2)
78
+ tensor = (tensor - mu[:, :, None]).squeeze(0)
79
+ # Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
80
+ if n == 0:
81
+ return None, None
82
+ cov = torch.mm(tensor, tensor.t()) / float(n)
83
+ return mu, cov
84
+
85
+ def _get_style_vals(self, tensor):
86
+ mean, cov = self._calc_2_moments(tensor)
87
+ if mean is None:
88
+ return None, None, None
89
+ eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
90
+ eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
91
+ root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
92
+ tr_cov = eigvals.clamp(min=0).sum()
93
+ return mean, tr_cov, root_cov
94
+
95
+ def _calc_l2wass_dist(
96
+ self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
97
+ ):
98
+ tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
99
+ mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
100
+ cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
101
+ var_overlap = torch.sqrt(
102
+ torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
103
+ ).sum()
104
+ dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
105
+ return dist
106
+
107
+ def _single_wass_loss(self, pred, targ):
108
+ mean_test, tr_cov_test, root_cov_test = targ
109
+ mean_synth, cov_synth = self._calc_2_moments(pred)
110
+ loss = self._calc_l2wass_dist(
111
+ mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
112
+ )
113
+ return loss
114
+
115
+ def forward(self, input, target):
116
+ out_feat = self._make_features(target, clone=True)
117
+ in_feat = self._make_features(input)
118
+ self.feat_losses = [self.base_loss(input, target)]
119
+ self.feat_losses += [
120
+ self.base_loss(f_in, f_out) * w
121
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
122
+ ]
123
+
124
+ styles = [self._get_style_vals(i) for i in out_feat]
125
+
126
+ if styles[0][0] is not None:
127
+ self.feat_losses += [
128
+ self._single_wass_loss(f_pred, f_targ) * w
129
+ for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
130
+ ]
131
+
132
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
133
+ return sum(self.feat_losses)
134
+
135
+ def __del__(self):
136
+ self.hooks.remove()
deoldify/save.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basic_train import Learner, LearnerCallback
2
+ from fastai.vision.gan import GANLearner
3
+
4
+
5
+ class GANSaveCallback(LearnerCallback):
6
+ """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
7
+
8
+ def __init__(
9
+ self,
10
+ learn: GANLearner,
11
+ learn_gen: Learner,
12
+ filename: str,
13
+ save_iters: int = 1000,
14
+ ):
15
+ super().__init__(learn)
16
+ self.learn_gen = learn_gen
17
+ self.filename = filename
18
+ self.save_iters = save_iters
19
+
20
+ def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
21
+ if iteration == 0:
22
+ return
23
+
24
+ if iteration % self.save_iters == 0:
25
+ self._save_gen_learner(iteration=iteration, epoch=epoch)
26
+
27
+ def _save_gen_learner(self, iteration: int, epoch: int):
28
+ filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
29
+ self.learn_gen.save(filename)
deoldify/unet.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.layers import *
2
+ from .layers import *
3
+ from fastai.torch_core import *
4
+ from fastai.callbacks.hooks import *
5
+ from fastai.vision import *
6
+
7
+
8
+ # The code below is meant to be merged into fastaiv1 ideally
9
+
10
+ __all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
11
+
12
+
13
+ def _get_sfs_idxs(sizes: Sizes) -> List[int]:
14
+ "Get the indexes of the layers where the size of the activation changes."
15
+ feature_szs = [size[-1] for size in sizes]
16
+ sfs_idxs = list(
17
+ np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
18
+ )
19
+ if feature_szs[0] != feature_szs[1]:
20
+ sfs_idxs = [0] + sfs_idxs
21
+ return sfs_idxs
22
+
23
+
24
+ class CustomPixelShuffle_ICNR(nn.Module):
25
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
26
+
27
+ def __init__(
28
+ self,
29
+ ni: int,
30
+ nf: int = None,
31
+ scale: int = 2,
32
+ blur: bool = False,
33
+ leaky: float = None,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+ nf = ifnone(nf, ni)
38
+ self.conv = custom_conv_layer(
39
+ ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
40
+ )
41
+ icnr(self.conv[0].weight)
42
+ self.shuf = nn.PixelShuffle(scale)
43
+ # Blurring over (h*w) kernel
44
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
45
+ # - https://arxiv.org/abs/1806.02658
46
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
47
+ self.blur = nn.AvgPool2d(2, stride=1)
48
+ self.relu = relu(True, leaky=leaky)
49
+
50
+ def forward(self, x):
51
+ x = self.shuf(self.relu(self.conv(x)))
52
+ return self.blur(self.pad(x)) if self.blur else x
53
+
54
+
55
+ class UnetBlockDeep(nn.Module):
56
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
57
+
58
+ def __init__(
59
+ self,
60
+ up_in_c: int,
61
+ x_in_c: int,
62
+ hook: Hook,
63
+ final_div: bool = True,
64
+ blur: bool = False,
65
+ leaky: float = None,
66
+ self_attention: bool = False,
67
+ nf_factor: float = 1.0,
68
+ **kwargs
69
+ ):
70
+ super().__init__()
71
+ self.hook = hook
72
+ self.shuf = CustomPixelShuffle_ICNR(
73
+ up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
74
+ )
75
+ self.bn = batchnorm_2d(x_in_c)
76
+ ni = up_in_c // 2 + x_in_c
77
+ nf = int((ni if final_div else ni // 2) * nf_factor)
78
+ self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
79
+ self.conv2 = custom_conv_layer(
80
+ nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
81
+ )
82
+ self.relu = relu(leaky=leaky)
83
+
84
+ def forward(self, up_in: Tensor) -> Tensor:
85
+ s = self.hook.stored
86
+ up_out = self.shuf(up_in)
87
+ ssh = s.shape[-2:]
88
+ if ssh != up_out.shape[-2:]:
89
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
90
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
91
+ return self.conv2(self.conv1(cat_x))
92
+
93
+
94
+ class DynamicUnetDeep(SequentialEx):
95
+ "Create a U-Net from a given architecture."
96
+
97
+ def __init__(
98
+ self,
99
+ encoder: nn.Module,
100
+ n_classes: int,
101
+ blur: bool = False,
102
+ blur_final=True,
103
+ self_attention: bool = False,
104
+ y_range: Optional[Tuple[float, float]] = None,
105
+ last_cross: bool = True,
106
+ bottle: bool = False,
107
+ norm_type: Optional[NormType] = NormType.Batch,
108
+ nf_factor: float = 1.0,
109
+ **kwargs
110
+ ):
111
+ extra_bn = norm_type == NormType.Spectral
112
+ imsize = (256, 256)
113
+ sfs_szs = model_sizes(encoder, size=imsize)
114
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
115
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
116
+ x = dummy_eval(encoder, imsize).detach()
117
+
118
+ ni = sfs_szs[-1][1]
119
+ middle_conv = nn.Sequential(
120
+ custom_conv_layer(
121
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
122
+ ),
123
+ custom_conv_layer(
124
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
125
+ ),
126
+ ).eval()
127
+ x = middle_conv(x)
128
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
129
+
130
+ for i, idx in enumerate(sfs_idxs):
131
+ not_final = i != len(sfs_idxs) - 1
132
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
133
+ do_blur = blur and (not_final or blur_final)
134
+ sa = self_attention and (i == len(sfs_idxs) - 3)
135
+ unet_block = UnetBlockDeep(
136
+ up_in_c,
137
+ x_in_c,
138
+ self.sfs[i],
139
+ final_div=not_final,
140
+ blur=blur,
141
+ self_attention=sa,
142
+ norm_type=norm_type,
143
+ extra_bn=extra_bn,
144
+ nf_factor=nf_factor,
145
+ **kwargs
146
+ ).eval()
147
+ layers.append(unet_block)
148
+ x = unet_block(x)
149
+
150
+ ni = x.shape[1]
151
+ if imsize != sfs_szs[0][-2:]:
152
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
153
+ if last_cross:
154
+ layers.append(MergeLayer(dense=True))
155
+ ni += in_channels(encoder)
156
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
157
+ layers += [
158
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
159
+ ]
160
+ if y_range is not None:
161
+ layers.append(SigmoidRange(*y_range))
162
+ super().__init__(*layers)
163
+
164
+ def __del__(self):
165
+ if hasattr(self, "sfs"):
166
+ self.sfs.remove()
167
+
168
+
169
+ # ------------------------------------------------------
170
+ class UnetBlockWide(nn.Module):
171
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
172
+
173
+ def __init__(
174
+ self,
175
+ up_in_c: int,
176
+ x_in_c: int,
177
+ n_out: int,
178
+ hook: Hook,
179
+ final_div: bool = True,
180
+ blur: bool = False,
181
+ leaky: float = None,
182
+ self_attention: bool = False,
183
+ **kwargs
184
+ ):
185
+ super().__init__()
186
+ self.hook = hook
187
+ up_out = x_out = n_out // 2
188
+ self.shuf = CustomPixelShuffle_ICNR(
189
+ up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
190
+ )
191
+ self.bn = batchnorm_2d(x_in_c)
192
+ ni = up_out + x_in_c
193
+ self.conv = custom_conv_layer(
194
+ ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
195
+ )
196
+ self.relu = relu(leaky=leaky)
197
+
198
+ def forward(self, up_in: Tensor) -> Tensor:
199
+ s = self.hook.stored
200
+ up_out = self.shuf(up_in)
201
+ ssh = s.shape[-2:]
202
+ if ssh != up_out.shape[-2:]:
203
+ up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
204
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
205
+ return self.conv(cat_x)
206
+
207
+
208
+ class DynamicUnetWide(SequentialEx):
209
+ "Create a U-Net from a given architecture."
210
+
211
+ def __init__(
212
+ self,
213
+ encoder: nn.Module,
214
+ n_classes: int,
215
+ blur: bool = False,
216
+ blur_final=True,
217
+ self_attention: bool = False,
218
+ y_range: Optional[Tuple[float, float]] = None,
219
+ last_cross: bool = True,
220
+ bottle: bool = False,
221
+ norm_type: Optional[NormType] = NormType.Batch,
222
+ nf_factor: int = 1,
223
+ **kwargs
224
+ ):
225
+
226
+ nf = 512 * nf_factor
227
+ extra_bn = norm_type == NormType.Spectral
228
+ imsize = (256, 256)
229
+ sfs_szs = model_sizes(encoder, size=imsize)
230
+ sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
231
+ self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
232
+ x = dummy_eval(encoder, imsize).detach()
233
+
234
+ ni = sfs_szs[-1][1]
235
+ middle_conv = nn.Sequential(
236
+ custom_conv_layer(
237
+ ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
238
+ ),
239
+ custom_conv_layer(
240
+ ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
241
+ ),
242
+ ).eval()
243
+ x = middle_conv(x)
244
+ layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
245
+
246
+ for i, idx in enumerate(sfs_idxs):
247
+ not_final = i != len(sfs_idxs) - 1
248
+ up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
249
+ do_blur = blur and (not_final or blur_final)
250
+ sa = self_attention and (i == len(sfs_idxs) - 3)
251
+
252
+ n_out = nf if not_final else nf // 2
253
+
254
+ unet_block = UnetBlockWide(
255
+ up_in_c,
256
+ x_in_c,
257
+ n_out,
258
+ self.sfs[i],
259
+ final_div=not_final,
260
+ blur=blur,
261
+ self_attention=sa,
262
+ norm_type=norm_type,
263
+ extra_bn=extra_bn,
264
+ **kwargs
265
+ ).eval()
266
+ layers.append(unet_block)
267
+ x = unet_block(x)
268
+
269
+ ni = x.shape[1]
270
+ if imsize != sfs_szs[0][-2:]:
271
+ layers.append(PixelShuffle_ICNR(ni, **kwargs))
272
+ if last_cross:
273
+ layers.append(MergeLayer(dense=True))
274
+ ni += in_channels(encoder)
275
+ layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
276
+ layers += [
277
+ custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
278
+ ]
279
+ if y_range is not None:
280
+ layers.append(SigmoidRange(*y_range))
281
+ super().__init__(*layers)
282
+
283
+ def __del__(self):
284
+ if hasattr(self, "sfs"):
285
+ self.sfs.remove()
deoldify/visualize.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.core import *
2
+ from fastai.vision import *
3
+ from matplotlib.axes import Axes
4
+ from .filters import IFilter, MasterFilter, ColorizerFilter
5
+ from .generators import gen_inference_deep, gen_inference_wide
6
+ from PIL import Image
7
+ import ffmpeg
8
+ import yt_dlp as youtube_dl
9
+ import gc
10
+ import requests
11
+ from io import BytesIO
12
+ import base64
13
+ from IPython import display as ipythondisplay
14
+ from IPython.display import HTML
15
+ from IPython.display import Image as ipythonimage
16
+ import cv2
17
+ import logging
18
+
19
+ # adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/
20
+ def get_watermarked(pil_image: Image) -> Image:
21
+ try:
22
+ image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
23
+ (h, w) = image.shape[:2]
24
+ image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255])
25
+ pct = 0.05
26
+ full_watermark = cv2.imread(
27
+ 'C:/Users/galquja/Desktop/Tools/tools/DeOldify/resource_images/watermark.png', cv2.IMREAD_UNCHANGED
28
+ )
29
+ (fwH, fwW) = full_watermark.shape[:2]
30
+ wH = int(pct * h)
31
+ wW = int((pct * h / fwH) * fwW)
32
+ watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)
33
+ overlay = np.zeros((h, w, 4), dtype="uint8")
34
+ (wH, wW) = watermark.shape[:2]
35
+ overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark
36
+ # blend the two images together using transparent overlays
37
+ output = image.copy()
38
+ cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)
39
+ rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
40
+ final_image = Image.fromarray(rgb_image)
41
+ return final_image
42
+ except:
43
+ # Don't want this to crash everything, so let's just not watermark the image for now.
44
+ return pil_image
45
+
46
+
47
+ class ModelImageVisualizer:
48
+ def __init__(self, filter: IFilter, results_dir: str = None):
49
+ self.filter = filter
50
+ self.results_dir = None if results_dir is None else Path(results_dir)
51
+ # self.results_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ def _clean_mem(self):
54
+ torch.cuda.empty_cache()
55
+ # gc.collect()
56
+
57
+ def _open_pil_image(self, path: Path) -> Image:
58
+ return PIL.Image.open(path).convert('RGB')
59
+
60
+ def _get_image_from_url(self, url: str) -> Image:
61
+ response = requests.get(url, timeout=30, headers={'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36'})
62
+ img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
63
+ return img
64
+
65
+ def plot_transformed_image_from_url(
66
+ self,
67
+ url: str,
68
+ path: str = 'test_images/image.png',
69
+ results_dir:Path = None,
70
+ figsize: Tuple[int, int] = (20, 20),
71
+ render_factor: int = None,
72
+
73
+ display_render_factor: bool = False,
74
+ compare: bool = False,
75
+ post_process: bool = True,
76
+ watermarked: bool = True,
77
+ ) -> Path:
78
+ img = self._get_image_from_url(url)
79
+ img.save(path)
80
+ return self.plot_transformed_image(
81
+ path=path,
82
+ results_dir=results_dir,
83
+ figsize=figsize,
84
+ render_factor=render_factor,
85
+ display_render_factor=display_render_factor,
86
+ compare=compare,
87
+ post_process = post_process,
88
+ watermarked=watermarked,
89
+ )
90
+
91
+ def plot_transformed_image(
92
+ self,
93
+ path: str,
94
+ results_dir:Path = None,
95
+ figsize: Tuple[int, int] = (20, 20),
96
+ render_factor: int = None,
97
+ display_render_factor: bool = False,
98
+ compare: bool = False,
99
+ post_process: bool = True,
100
+ watermarked: bool = True,
101
+ ) -> Path:
102
+ path = Path(path)
103
+ if results_dir is None:
104
+ results_dir = Path(self.results_dir)
105
+ result = self.get_transformed_image(
106
+ path, render_factor, post_process=post_process,watermarked=watermarked
107
+ )
108
+ orig = self._open_pil_image(path)
109
+ if compare:
110
+ self._plot_comparison(
111
+ figsize, render_factor, display_render_factor, orig, result
112
+ )
113
+ else:
114
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
115
+
116
+ orig.close()
117
+ result_path = self._save_result_image(path, result, results_dir=results_dir)
118
+ result.close()
119
+ return result_path
120
+
121
+ def _plot_comparison(
122
+ self,
123
+ figsize: Tuple[int, int],
124
+ render_factor: int,
125
+ display_render_factor: bool,
126
+ orig: Image,
127
+ result: Image,
128
+ ):
129
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
130
+ self._plot_image(
131
+ orig,
132
+ axes=axes[0],
133
+ figsize=figsize,
134
+ render_factor=render_factor,
135
+ display_render_factor=False,
136
+ )
137
+ self._plot_image(
138
+ result,
139
+ axes=axes[1],
140
+ figsize=figsize,
141
+ render_factor=render_factor,
142
+ display_render_factor=display_render_factor,
143
+ )
144
+
145
+ def _plot_solo(
146
+ self,
147
+ figsize: Tuple[int, int],
148
+ render_factor: int,
149
+ display_render_factor: bool,
150
+ result: Image,
151
+ ):
152
+ fig, axes = plt.subplots(1, 1, figsize=figsize)
153
+ self._plot_image(
154
+ result,
155
+ axes=axes,
156
+ figsize=figsize,
157
+ render_factor=render_factor,
158
+ display_render_factor=display_render_factor,
159
+ )
160
+
161
+ def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
162
+ if results_dir is None:
163
+ results_dir = Path(self.results_dir)
164
+ result_path = results_dir
165
+ image.save(result_path)
166
+ return result_path
167
+
168
+ def get_transformed_image(
169
+ self, path: Path, render_factor: int = None, post_process: bool = True,
170
+ watermarked: bool = True,
171
+ ) -> Image:
172
+ self._clean_mem()
173
+ orig_image = self._open_pil_image(path)
174
+ filtered_image = self.filter.filter(
175
+ orig_image, orig_image, render_factor=render_factor,post_process=post_process
176
+ )
177
+
178
+ if watermarked:
179
+ return get_watermarked(filtered_image)
180
+
181
+ return filtered_image
182
+
183
+ def _plot_image(
184
+ self,
185
+ image: Image,
186
+ render_factor: int,
187
+ axes: Axes = None,
188
+ figsize=(20, 20),
189
+ display_render_factor = False,
190
+ ):
191
+ if axes is None:
192
+ _, axes = plt.subplots(figsize=figsize)
193
+ axes.imshow(np.asarray(image) / 255)
194
+ axes.axis('off')
195
+ if render_factor is not None and display_render_factor:
196
+ plt.text(
197
+ 10,
198
+ 10,
199
+ 'render_factor: ' + str(render_factor),
200
+ color='white',
201
+ backgroundcolor='black',
202
+ )
203
+
204
+ def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
205
+ columns = min(num_images, max_columns)
206
+ rows = num_images // columns
207
+ rows = rows if rows * columns == num_images else rows + 1
208
+ return rows, columns
209
+
210
+
211
+ class VideoColorizer:
212
+ def __init__(self, vis: ModelImageVisualizer):
213
+ self.vis = vis
214
+ workfolder = Path('./video')
215
+ self.source_folder = workfolder / "source"
216
+ self.bwframes_root = workfolder / "bwframes"
217
+ self.audio_root = workfolder / "audio"
218
+ self.colorframes_root = workfolder / "colorframes"
219
+ self.result_folder = workfolder / "result"
220
+
221
+ def _purge_images(self, dir):
222
+ for f in os.listdir(dir):
223
+ if re.search('.*?\.jpg', f):
224
+ os.remove(os.path.join(dir, f))
225
+
226
+ def _get_ffmpeg_probe(self, path:Path):
227
+ try:
228
+ probe = ffmpeg.probe(str(path))
229
+ return probe
230
+ except ffmpeg.Error as e:
231
+ logging.error("ffmpeg error: {0}".format(e), exc_info=True)
232
+ logging.error('stdout:' + e.stdout.decode('UTF-8'))
233
+ logging.error('stderr:' + e.stderr.decode('UTF-8'))
234
+ raise e
235
+ except Exception as e:
236
+ logging.error('Failed to instantiate ffmpeg.probe. Details: {0}'.format(e), exc_info=True)
237
+ raise e
238
+
239
+ def _get_fps(self, source_path: Path) -> str:
240
+ probe = self._get_ffmpeg_probe(source_path)
241
+ stream_data = next(
242
+ (stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
243
+ None,
244
+ )
245
+ return stream_data['avg_frame_rate']
246
+
247
+ def _download_video_from_url(self, source_url, source_path: Path):
248
+ if source_path.exists():
249
+ source_path.unlink()
250
+
251
+ ydl_opts = {
252
+ 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
253
+ 'outtmpl': str(source_path),
254
+ 'retries': 30,
255
+ 'fragment-retries': 30
256
+ }
257
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
258
+ ydl.download([source_url])
259
+
260
+ def _extract_raw_frames(self, source_path: Path):
261
+ bwframes_folder = self.bwframes_root / (source_path.stem)
262
+ bwframe_path_template = str(bwframes_folder / '%5d.jpg')
263
+ bwframes_folder.mkdir(parents=True, exist_ok=True)
264
+ self._purge_images(bwframes_folder)
265
+
266
+ process = (
267
+ ffmpeg
268
+ .input(str(source_path))
269
+ .output(str(bwframe_path_template), format='image2', vcodec='mjpeg', **{'q:v':'0'})
270
+ .global_args('-hide_banner')
271
+ .global_args('-nostats')
272
+ .global_args('-loglevel', 'error')
273
+ )
274
+
275
+ try:
276
+ process.run()
277
+ except ffmpeg.Error as e:
278
+ logging.error("ffmpeg error: {0}".format(e), exc_info=True)
279
+ logging.error('stdout:' + e.stdout.decode('UTF-8'))
280
+ logging.error('stderr:' + e.stderr.decode('UTF-8'))
281
+ raise e
282
+ except Exception as e:
283
+ logging.error('Errror while extracting raw frames from source video. Details: {0}'.format(e), exc_info=True)
284
+ raise e
285
+
286
+ def _colorize_raw_frames(
287
+ self, source_path: Path, render_factor: int = None, post_process: bool = True,
288
+ watermarked: bool = True,
289
+ ):
290
+ colorframes_folder = self.colorframes_root / (source_path.stem)
291
+ colorframes_folder.mkdir(parents=True, exist_ok=True)
292
+ self._purge_images(colorframes_folder)
293
+ bwframes_folder = self.bwframes_root / (source_path.stem)
294
+
295
+ for img in progress_bar(os.listdir(str(bwframes_folder))):
296
+ img_path = bwframes_folder / img
297
+
298
+ if os.path.isfile(str(img_path)):
299
+ color_image = self.vis.get_transformed_image(
300
+ str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked
301
+ )
302
+ color_image.save(str(colorframes_folder / img))
303
+
304
+ def _build_video(self, source_path: Path) -> Path:
305
+ colorized_path = self.result_folder / (
306
+ source_path.name.replace('.mp4', '_no_audio.mp4')
307
+ )
308
+ colorframes_folder = self.colorframes_root / (source_path.stem)
309
+ colorframes_path_template = str(colorframes_folder / '%5d.jpg')
310
+ colorized_path.parent.mkdir(parents=True, exist_ok=True)
311
+ if colorized_path.exists():
312
+ colorized_path.unlink()
313
+ fps = self._get_fps(source_path)
314
+
315
+ process = (
316
+ ffmpeg
317
+ .input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=fps)
318
+ .output(str(colorized_path), crf=17, vcodec='libx264')
319
+ .global_args('-hide_banner')
320
+ .global_args('-nostats')
321
+ .global_args('-loglevel', 'error')
322
+ )
323
+
324
+ try:
325
+ process.run()
326
+ except ffmpeg.Error as e:
327
+ logging.error("ffmpeg error: {0}".format(e), exc_info=True)
328
+ logging.error('stdout:' + e.stdout.decode('UTF-8'))
329
+ logging.error('stderr:' + e.stderr.decode('UTF-8'))
330
+ raise e
331
+ except Exception as e:
332
+ logging.error('Errror while building output video. Details: {0}'.format(e), exc_info=True)
333
+ raise e
334
+
335
+ result_path = self.result_folder / source_path.name
336
+ if result_path.exists():
337
+ result_path.unlink()
338
+ # making copy of non-audio version in case adding back audio doesn't apply or fails.
339
+ shutil.copyfile(str(colorized_path), str(result_path))
340
+
341
+ # adding back sound here
342
+ audio_file = Path(str(source_path).replace('.mp4', '.aac'))
343
+ if audio_file.exists():
344
+ audio_file.unlink()
345
+
346
+ os.system(
347
+ 'ffmpeg -y -i "'
348
+ + str(source_path)
349
+ + '" -vn -acodec copy "'
350
+ + str(audio_file)
351
+ + '"'
352
+ + ' -hide_banner'
353
+ + ' -nostats'
354
+ + ' -loglevel error'
355
+ )
356
+
357
+ if audio_file.exists():
358
+ os.system(
359
+ 'ffmpeg -y -i "'
360
+ + str(colorized_path)
361
+ + '" -i "'
362
+ + str(audio_file)
363
+ + '" -shortest -c:v copy -c:a aac -b:a 256k "'
364
+ + str(result_path)
365
+ + '"'
366
+ + ' -hide_banner'
367
+ + ' -nostats'
368
+ + ' -loglevel error'
369
+ )
370
+ logging.info('Video created here: ' + str(result_path))
371
+ return result_path
372
+
373
+ def colorize_from_url(
374
+ self,
375
+ source_url,
376
+ file_name: str,
377
+ render_factor: int = None,
378
+ post_process: bool = True,
379
+ watermarked: bool = True,
380
+
381
+ ) -> Path:
382
+ source_path = self.source_folder / file_name
383
+ self._download_video_from_url(source_url, source_path)
384
+ return self._colorize_from_path(
385
+ source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
386
+ )
387
+
388
+ def colorize_from_file_name(
389
+ self, file_name: str, render_factor: int = None, watermarked: bool = True, post_process: bool = True,
390
+ ) -> Path:
391
+ source_path = self.source_folder / file_name
392
+ return self._colorize_from_path(
393
+ source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
394
+ )
395
+
396
+ def _colorize_from_path(
397
+ self, source_path: Path, render_factor: int = None, watermarked: bool = True, post_process: bool = True
398
+ ) -> Path:
399
+ if not source_path.exists():
400
+ raise Exception(
401
+ 'Video at path specfied, ' + str(source_path) + ' could not be found.'
402
+ )
403
+ self._extract_raw_frames(source_path)
404
+ self._colorize_raw_frames(
405
+ source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked
406
+ )
407
+ return self._build_video(source_path)
408
+
409
+
410
+ def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
411
+ return get_stable_video_colorizer(render_factor=render_factor)
412
+
413
+
414
+ def get_artistic_video_colorizer(
415
+ root_folder: Path = Path('C:/Users/galquja/Desktop/Tools/tools/DeOldify'),
416
+ weights_name: str = 'ColorizeArtistic_gen',
417
+ results_dir='result_images',
418
+ render_factor: int = 35
419
+ ) -> VideoColorizer:
420
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
421
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
422
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
423
+ return VideoColorizer(vis)
424
+
425
+
426
+ def get_stable_video_colorizer(
427
+ root_folder: Path = Path('C:/Users/galquja/Desktop/Tools/tools/DeOldify'),
428
+ weights_name: str = 'ColorizeVideo_gen',
429
+ results_dir='result_images',
430
+ render_factor: int = 21
431
+ ) -> VideoColorizer:
432
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
433
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
434
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
435
+ return VideoColorizer(vis)
436
+
437
+
438
+ def get_image_colorizer(
439
+ root_folder: Path = Path('C:/Users/galquja/Desktop/Tools/tools/DeOldify'), render_factor: int = 35, artistic: bool = True
440
+ ) -> ModelImageVisualizer:
441
+ if artistic:
442
+ return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
443
+ else:
444
+ return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
445
+
446
+
447
+ def get_stable_image_colorizer(
448
+ root_folder: Path = Path('C:/Users/galquja/Desktop/Tools/tools/DeOldify'),
449
+ weights_name: str = 'ColorizeStable_gen',
450
+ results_dir='result_images',
451
+ render_factor: int = 35
452
+ ) -> ModelImageVisualizer:
453
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
454
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
455
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
456
+ return vis
457
+
458
+
459
+ def get_artistic_image_colorizer(
460
+ root_folder: Path = Path('C:/Users/galquja/Desktop/Tools/tools/DeOldify'),
461
+ weights_name: str = 'ColorizeArtistic_gen',
462
+ results_dir='result_images',
463
+ render_factor: int = 35
464
+ ) -> ModelImageVisualizer:
465
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
466
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
467
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
468
+ return vis
469
+
470
+
471
+ def show_image_in_notebook(image_path: Path):
472
+ ipythondisplay.display(ipythonimage(str(image_path)))
473
+
474
+
475
+ def show_video_in_notebook(video_path: Path):
476
+ video = io.open(video_path, 'r+b').read()
477
+ encoded = base64.b64encode(video)
478
+ ipythondisplay.display(
479
+ HTML(
480
+ data='''<video alt="test" autoplay
481
+ loop controls style="height: 400px;">
482
+ <source src="data:video/mp4;base64,{0}" type="video/mp4" />
483
+ </video>'''.format(
484
+ encoded.decode('ascii')
485
+ )
486
+ )
487
+ )
environment.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: deoldify
2
+ channels:
3
+ - fastai
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - pip
8
+ - fastai=1.0.60
9
+ - python=3.10
10
+ - pytorch::pytorch=1.11.0
11
+ - pytorch::torchvision
12
+ - pytorch::torchaudio
13
+ - tensorboardX
14
+ - jupyterlab
15
+ - pillow>=9.0.0
16
+ - ipywidgets
17
+ - ffmpeg
18
+ - pip:
19
+ - ffmpeg-python
20
+ - opencv-python>=3.3.0.10
21
+ - wandb
22
+ - yt-dlp
fastai/LICENSE ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License, Version 2.0 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/
2
+
3
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
4
+
5
+ 1. Definitions.
6
+
7
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
8
+
9
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
10
+
11
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
12
+
13
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
14
+
15
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
16
+
17
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
18
+
19
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
20
+
21
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
22
+
23
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
24
+
25
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
26
+
27
+ 2. Grant of Copyright License.
28
+
29
+ Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
30
+
31
+ 3. Grant of Patent License.
32
+
33
+ Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
34
+
35
+ 4. Redistribution.
36
+
37
+ You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
38
+
39
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
40
+
41
+ 5. Submission of Contributions.
42
+
43
+ Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
44
+
45
+ 6. Trademarks.
46
+
47
+ This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
48
+
49
+ 7. Disclaimer of Warranty.
50
+
51
+ Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
52
+
53
+ 8. Limitation of Liability.
54
+
55
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
56
+
57
+ 9. Accepting Warranty or Additional Liability.
58
+
59
+ While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
60
+
fastai/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .version import __version__
2
+
fastai/basic_data.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "`fastai.data` loads and manages datasets with `DataBunch`"
2
+ from .torch_core import *
3
+ from torch.utils.data.dataloader import default_collate
4
+
5
+ DatasetType = Enum('DatasetType', 'Train Valid Test Single Fix')
6
+ __all__ = ['DataBunch', 'DeviceDataLoader', 'DatasetType', 'load_data']
7
+
8
+ old_dl_init = torch.utils.data.DataLoader.__init__
9
+
10
+ def intercept_args(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
11
+ num_workers=0, collate_fn=default_collate, pin_memory=True, drop_last=False,
12
+ timeout=0, worker_init_fn=None):
13
+ self.init_kwargs = {'batch_size':batch_size, 'shuffle':shuffle, 'sampler':sampler, 'batch_sampler':batch_sampler,
14
+ 'num_workers':num_workers, 'collate_fn':collate_fn, 'pin_memory':pin_memory,
15
+ 'drop_last': drop_last, 'timeout':timeout, 'worker_init_fn':worker_init_fn}
16
+ old_dl_init(self, dataset, **self.init_kwargs)
17
+
18
+ torch.utils.data.DataLoader.__init__ = intercept_args
19
+
20
+ def DataLoader___getattr__(dl, k:str)->Any: return getattr(dl.dataset, k)
21
+ DataLoader.__getattr__ = DataLoader___getattr__
22
+
23
+ def DataLoader___setstate__(dl, data:Any): dl.__dict__.update(data)
24
+ DataLoader.__setstate__ = DataLoader___setstate__
25
+
26
+ @dataclass
27
+ class DeviceDataLoader():
28
+ "Bind a `DataLoader` to a `torch.device`."
29
+ dl: DataLoader
30
+ device: torch.device
31
+ tfms: List[Callable]=None
32
+ collate_fn: Callable=data_collate
33
+ def __post_init__(self):
34
+ self.dl.collate_fn=self.collate_fn
35
+ self.tfms = listify(self.tfms)
36
+
37
+ def __len__(self)->int: return len(self.dl)
38
+ def __getattr__(self,k:str)->Any: return getattr(self.dl, k)
39
+ def __setstate__(self,data:Any): self.__dict__.update(data)
40
+
41
+ @property
42
+ def batch_size(self): return self.dl.batch_size
43
+ @batch_size.setter
44
+ def batch_size(self,v):
45
+ new_kwargs = {**self.dl.init_kwargs, 'batch_size':v, 'collate_fn':self.collate_fn}
46
+ self.dl = self.dl.__class__(self.dl.dataset, **new_kwargs)
47
+ if hasattr(self.dl.dataset, 'bs'): self.dl.dataset.bs = v
48
+
49
+ @property
50
+ def num_workers(self): return self.dl.num_workers
51
+ @num_workers.setter
52
+ def num_workers(self,v): self.dl.num_workers = v
53
+
54
+ def add_tfm(self,tfm:Callable)->None:
55
+ "Add `tfm` to `self.tfms`."
56
+ self.tfms.append(tfm)
57
+ def remove_tfm(self,tfm:Callable)->None:
58
+ "Remove `tfm` from `self.tfms`."
59
+ if tfm in self.tfms: self.tfms.remove(tfm)
60
+
61
+ def new(self, **kwargs):
62
+ "Create a new copy of `self` with `kwargs` replacing current values."
63
+ new_kwargs = {**self.dl.init_kwargs, **kwargs}
64
+ return DeviceDataLoader(self.dl.__class__(self.dl.dataset, **new_kwargs), self.device, self.tfms,
65
+ self.collate_fn)
66
+
67
+ def proc_batch(self,b:Tensor)->Tensor:
68
+ "Process batch `b` of `TensorImage`."
69
+ b = to_device(b, self.device)
70
+ for f in listify(self.tfms): b = f(b)
71
+ return b
72
+
73
+ def __iter__(self):
74
+ "Process and returns items from `DataLoader`."
75
+ for b in self.dl: yield self.proc_batch(b)
76
+
77
+ @classmethod
78
+ def create(cls, dataset:Dataset, bs:int=64, shuffle:bool=False, device:torch.device=defaults.device,
79
+ tfms:Collection[Callable]=tfms, num_workers:int=defaults.cpus, collate_fn:Callable=data_collate, **kwargs:Any):
80
+ "Create DeviceDataLoader from `dataset` with `bs` and `shuffle`: process using `num_workers`."
81
+ return cls(DataLoader(dataset, batch_size=bs, shuffle=shuffle, num_workers=num_workers, **kwargs),
82
+ device=device, tfms=tfms, collate_fn=collate_fn)
83
+
84
+ class DataBunch():
85
+ "Bind `train_dl`,`valid_dl` and `test_dl` in a data object."
86
+
87
+ def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, fix_dl:DataLoader=None, test_dl:Optional[DataLoader]=None,
88
+ device:torch.device=None, dl_tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.',
89
+ collate_fn:Callable=data_collate, no_check:bool=False):
90
+ self.dl_tfms = listify(dl_tfms)
91
+ self.device = defaults.device if device is None else device
92
+ assert not isinstance(train_dl,DeviceDataLoader)
93
+ def _create_dl(dl, **kwargs):
94
+ if dl is None: return None
95
+ return DeviceDataLoader(dl, self.device, self.dl_tfms, collate_fn, **kwargs)
96
+ self.train_dl,self.valid_dl,self.fix_dl,self.test_dl = map(_create_dl, [train_dl,valid_dl,fix_dl,test_dl])
97
+ if fix_dl is None: self.fix_dl = self.train_dl.new(shuffle=False, drop_last=False)
98
+ self.single_dl = _create_dl(DataLoader(valid_dl.dataset, batch_size=1, num_workers=0))
99
+ self.path = Path(path)
100
+ if not no_check: self.sanity_check()
101
+
102
+ def __repr__(self)->str:
103
+ return f'{self.__class__.__name__};\n\nTrain: {self.train_ds};\n\nValid: {self.valid_ds};\n\nTest: {self.test_ds}'
104
+
105
+ @staticmethod
106
+ def _init_ds(train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None):
107
+ # train_ds, but without training tfms
108
+ fix_ds = valid_ds.new(train_ds.x, train_ds.y) if hasattr(valid_ds,'new') else train_ds
109
+ return [o for o in (train_ds,valid_ds,fix_ds,test_ds) if o is not None]
110
+
111
+ @classmethod
112
+ def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
113
+ val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
114
+ device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **dl_kwargs)->'DataBunch':
115
+ "Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
116
+ datasets = cls._init_ds(train_ds, valid_ds, test_ds)
117
+ val_bs = ifnone(val_bs, bs)
118
+ dls = [DataLoader(d, b, shuffle=s, drop_last=s, num_workers=num_workers, **dl_kwargs) for d,b,s in
119
+ zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False)) if d is not None]
120
+ return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
121
+
122
+ def __getattr__(self,k:int)->Any: return getattr(self.train_dl, k)
123
+ def __setstate__(self,data:Any): self.__dict__.update(data)
124
+
125
+ def dl(self, ds_type:DatasetType=DatasetType.Valid)->DeviceDataLoader:
126
+ "Returns appropriate `Dataset` for validation, training, or test (`ds_type`)."
127
+ #TODO: refactor
128
+ return (self.train_dl if ds_type == DatasetType.Train else
129
+ self.test_dl if ds_type == DatasetType.Test else
130
+ self.valid_dl if ds_type == DatasetType.Valid else
131
+ self.single_dl if ds_type == DatasetType.Single else
132
+ self.fix_dl)
133
+
134
+ @property
135
+ def dls(self)->List[DeviceDataLoader]:
136
+ "Returns a list of all DeviceDataLoaders. If you need a specific DeviceDataLoader, access via the relevant property (`train_dl`, `valid_dl`, etc) as the index of DLs in this list is not guaranteed to remain constant."
137
+ res = [self.train_dl, self.fix_dl, self.single_dl]
138
+ # Preserve the original ordering of Train, Valid, Fix, Single, Test Data Loaders
139
+ # (Unknown/not verified as of 1.0.47 whether there are other methods explicitly using DLs their list index)
140
+ if self.valid_dl: res.insert(1, self.valid_dl)
141
+ return res if not self.test_dl else res + [self.test_dl]
142
+
143
+ def add_tfm(self,tfm:Callable)->None:
144
+ for dl in self.dls: dl.add_tfm(tfm)
145
+
146
+ def remove_tfm(self,tfm:Callable)->None:
147
+ for dl in self.dls: dl.remove_tfm(tfm)
148
+
149
+ def save(self, file:PathLikeOrBinaryStream= 'data_save.pkl')->None:
150
+ "Save the `DataBunch` in `self.path/file`. `file` can be file-like (file or buffer)"
151
+ if not getattr(self, 'label_list', False):
152
+ warn("Serializing the `DataBunch` only works when you created it using the data block API.")
153
+ return
154
+ try_save(self.label_list, self.path, file)
155
+
156
+ def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None)->None:
157
+ "Add the `items` as a test set. Pass along `label` otherwise label them with `EmptyLabel`."
158
+ self.label_list.add_test(items, label=label, tfms=tfms, tfm_y=tfm_y)
159
+ vdl = self.valid_dl
160
+ dl = DataLoader(self.label_list.test, vdl.batch_size, shuffle=False, drop_last=False, num_workers=vdl.num_workers)
161
+ self.test_dl = DeviceDataLoader(dl, vdl.device, vdl.tfms, vdl.collate_fn)
162
+
163
+ def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]:
164
+ "Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`."
165
+ dl = self.dl(ds_type)
166
+ w = self.num_workers
167
+ self.num_workers = 0
168
+ try: x,y = next(iter(dl))
169
+ finally: self.num_workers = w
170
+ if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu)
171
+ norm = getattr(self,'norm',False)
172
+ if denorm and norm:
173
+ x = self.denorm(x)
174
+ if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True)
175
+ return x,y
176
+
177
+ def one_item(self, item, detach:bool=False, denorm:bool=False, cpu:bool=False):
178
+ "Get `item` into a batch. Optionally `detach` and `denorm`."
179
+ ds = self.single_ds
180
+ with ds.set_item(item):
181
+ return self.one_batch(ds_type=DatasetType.Single, detach=detach, denorm=denorm, cpu=cpu)
182
+
183
+ def show_batch(self, rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None:
184
+ "Show a batch of data in `ds_type` on a few `rows`."
185
+ x,y = self.one_batch(ds_type, True, True)
186
+ if reverse: x,y = x.flip(0),y.flip(0)
187
+ n_items = rows **2 if self.train_ds.x._square_show else rows
188
+ if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
189
+ xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
190
+ #TODO: get rid of has_arg if possible
191
+ if has_arg(self.train_ds.y.reconstruct, 'x'):
192
+ ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
193
+ else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
194
+ self.train_ds.x.show_xys(xs, ys, **kwargs)
195
+
196
+ def export(self, file:PathLikeOrBinaryStream='export.pkl'):
197
+ "Export the minimal state of `self` for inference in `self.path/file`. `file` can be file-like (file or buffer)"
198
+ xtra = dict(normalize=self.norm.keywords) if getattr(self, 'norm', False) else {}
199
+ try_save(self.valid_ds.get_state(**xtra), self.path, file)
200
+
201
+ def _grab_dataset(self, dl:DataLoader):
202
+ ds = dl.dl.dataset
203
+ while hasattr(ds, 'dataset'): ds = ds.dataset
204
+ return ds
205
+
206
+ @property
207
+ def train_ds(self)->Dataset: return self._grab_dataset(self.train_dl)
208
+ @property
209
+ def valid_ds(self)->Dataset: return self._grab_dataset(self.valid_dl)
210
+ @property
211
+ def single_ds(self)->Dataset: return self._grab_dataset(self.single_dl)
212
+ @property
213
+ def loss_func(self)->OptLossFunc:
214
+ return getattr(self.train_ds.y, 'loss_func', F.nll_loss) if hasattr(self.train_ds, 'y') else F.nll_loss
215
+
216
+ @property
217
+ def test_ds(self)->Dataset:
218
+ return self._grab_dataset(self.test_dl) if self.test_dl is not None else None
219
+
220
+ @property
221
+ def empty_val(self)->bool:
222
+ if not hasattr(self, 'valid_dl') or self.valid_dl is None: return True
223
+ if hasattr(self.valid_ds, 'items') and len(self.valid_ds.items) == 0: return True
224
+ return (len(self.valid_ds) == 0)
225
+
226
+ @property
227
+ def is_empty(self)->bool:
228
+ return not ((self.train_dl and len(self.train_ds.items) != 0) or
229
+ (self.valid_dl and len(self.valid_ds.items) != 0) or
230
+ (self.test_dl and len(self.test_ds.items) != 0))
231
+
232
+ @property
233
+ def batch_size(self): return self.train_dl.batch_size
234
+ @batch_size.setter
235
+ def batch_size(self,v):
236
+ self.train_dl.batch_size,self.valid_dl.batch_size = v,v
237
+ if self.test_dl is not None: self.test_dl.batch_size = v
238
+
239
+ def sanity_check(self):
240
+ "Check the underlying data in the training set can be properly loaded."
241
+ final_message = "You can deactivate this warning by passing `no_check=True`."
242
+ if not hasattr(self.train_ds, 'items') or len(self.train_ds.items) == 0 or not hasattr(self.train_dl, 'batch_sampler'): return
243
+ if len(self.train_dl) == 0:
244
+ warn(f"""Your training dataloader is empty, you have only {len(self.train_dl.dataset)} items in your training set.
245
+ Your batch size is {self.train_dl.batch_size}, you should lower it.""")
246
+ print(final_message)
247
+ return
248
+ idx = next(iter(self.train_dl.batch_sampler))
249
+ samples,fails = [],[]
250
+ for i in idx:
251
+ try: samples.append(self.train_dl.dataset[i])
252
+ except: fails.append(i)
253
+ if len(fails) > 0:
254
+ warn_msg = "There seems to be something wrong with your dataset, for example, in the first batch can't access"
255
+ if len(fails) == len(idx):
256
+ warn_msg += f" any element of self.train_ds.\nTried: {show_some(idx)}"
257
+ else:
258
+ warn_msg += f" these elements in self.train_ds: {show_some(fails)}"
259
+ warn(warn_msg)
260
+ print(final_message)
261
+ return
262
+ try: batch = self.collate_fn(samples)
263
+ except:
264
+ message = "It's not possible to collate samples of your dataset together in a batch."
265
+ try:
266
+ shapes = [[o[i].data.shape for o in samples] for i in range(2)]
267
+ message += f'\nShapes of the inputs/targets:\n{shapes}'
268
+ except: pass
269
+ warn(message)
270
+ print(final_message)
271
+
272
+ def load_data(path:PathOrStr, file:PathLikeOrBinaryStream='data_save.pkl', bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus,
273
+ dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, collate_fn:Callable=data_collate,
274
+ no_check:bool=False, **kwargs)->DataBunch:
275
+ "Load a saved `DataBunch` from `path/file`. `file` can be file-like (file or buffer)"
276
+ source = Path(path)/file if is_pathlike(file) else file
277
+ ll = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
278
+ return ll.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, dl_tfms=dl_tfms, device=device,
279
+ collate_fn=collate_fn, no_check=no_check, **kwargs)
fastai/basic_train.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Provides basic training and validation with `Learner`"
2
+ from .torch_core import *
3
+ from .basic_data import *
4
+ from .callback import *
5
+ from .data_block import *
6
+ from .utils.ipython import gpu_mem_restore
7
+ import inspect
8
+ from fastprogress.fastprogress import format_time, IN_NOTEBOOK
9
+ from time import time
10
+ from fastai.sixel import plot_sixel
11
+
12
+ __all__ = ['Learner', 'LearnerCallback', 'Recorder', 'RecordOnCPU', 'fit', 'loss_batch', 'train_epoch', 'validate',
13
+ 'get_preds', 'load_learner']
14
+
15
+ defaults.lr = slice(3e-3)
16
+ defaults.wd = 1e-2
17
+ defaults.extra_callbacks = None
18
+ defaults.extra_callback_fns = None
19
+
20
+ def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
21
+ cb_handler:Optional[CallbackHandler]=None, count:[int]=[1], batch_multiplier:int=1)->Tuple[Union[Tensor,int,float,str]]:
22
+ "Calculate loss and metrics for a batch, call out to callbacks as necessary."
23
+ cb_handler = ifnone(cb_handler, CallbackHandler())
24
+ if not is_listy(xb): xb = [xb]
25
+ if not is_listy(yb): yb = [yb]
26
+ out = model(*xb)
27
+
28
+ if not loss_func: return to_detach(out), yb[0].detach()
29
+ out = cb_handler.on_loss_begin(out)
30
+ loss = loss_func(out, *yb)/batch_multiplier
31
+ count[0]-=1
32
+
33
+ if opt is not None:
34
+ loss,skip_bwd = cb_handler.on_backward_begin(loss)
35
+ if not skip_bwd: loss.backward()
36
+ if count[0] == 0:
37
+ if not cb_handler.on_backward_end(): opt.step()
38
+ if not cb_handler.on_step_end(): opt.zero_grad()
39
+ count[0] = batch_multiplier
40
+
41
+ return loss.detach().cpu()
42
+
43
+ def get_preds(model:nn.Module, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None,
44
+ activ:nn.Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) -> List[Tensor]:
45
+ "Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`."
46
+ res = [torch.cat(o).cpu() for o in
47
+ zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
48
+ if loss_func is not None:
49
+ with NoneReduceOnCPU(loss_func) as lf: res.append(lf(res[0], res[1]))
50
+ if activ is not None: res[0] = activ(res[0])
51
+ return res
52
+
53
+ def validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,
54
+ pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:
55
+ "Calculate `loss_func` of `model` on `dl` in evaluation mode."
56
+ model.eval()
57
+ with torch.no_grad():
58
+ val_losses,nums = [],[]
59
+ if cb_handler: cb_handler.set_dl(dl)
60
+ for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
61
+ if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
62
+ val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
63
+ val_losses.append(val_loss)
64
+ if not is_listy(yb): yb = [yb]
65
+ nums.append(first_el(yb).shape[0])
66
+ if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
67
+ if n_batch and (len(nums)>=n_batch): break
68
+ nums = np.array(nums, dtype=np.float32)
69
+ if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()
70
+ else: return val_losses
71
+
72
+ def train_epoch(model:nn.Module, dl:DataLoader, opt:optim.Optimizer, loss_func:LossFunction)->None:
73
+ "Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`."
74
+ model.train()
75
+ for xb,yb in dl:
76
+ loss = loss_func(model(xb), yb)
77
+ loss.backward()
78
+ opt.step()
79
+ opt.zero_grad()
80
+
81
+ @dataclass
82
+ class BasicLearner():
83
+ model:nn.Module
84
+ loss_func:LossFunction
85
+ opt:optim.Optimizer
86
+ data:DataBunch
87
+
88
+ def fit(epochs:int, learn:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None, batch_multiplier:int=1)->None:
89
+ "Fit the `model` on `data` and learn using `loss_func` and `opt`."
90
+ assert len(learn.data.train_dl) != 0, f"""Your training dataloader is empty, can't train a model.
91
+ Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements)."""
92
+ cb_handler = CallbackHandler(callbacks, metrics)
93
+ pbar = master_bar(range(epochs))
94
+ cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)
95
+
96
+ exception=False
97
+ try:
98
+ for epoch in pbar:
99
+ learn.model.train()
100
+ cb_handler.set_dl(learn.data.train_dl)
101
+ cb_handler.on_epoch_begin()
102
+ count = [batch_multiplier]
103
+ for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
104
+ xb, yb = cb_handler.on_batch_begin(xb, yb)
105
+ loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler, count=count, batch_multiplier=batch_multiplier)
106
+ if cb_handler.on_batch_end(loss): break
107
+
108
+ if not cb_handler.skip_validate and not learn.data.empty_val:
109
+ val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
110
+ cb_handler=cb_handler, pbar=pbar)
111
+ else: val_loss=None
112
+ if cb_handler.on_epoch_end(val_loss): break
113
+ except Exception as e:
114
+ exception = e
115
+ raise
116
+ finally: cb_handler.on_train_end(exception)
117
+
118
+ loss_func_name2activ = {'cross_entropy_loss': F.softmax, 'nll_loss': torch.exp, 'poisson_nll_loss': torch.exp,
119
+ 'kl_div_loss': torch.exp, 'bce_with_logits_loss': torch.sigmoid, 'cross_entropy': F.softmax,
120
+ 'kl_div': torch.exp, 'binary_cross_entropy_with_logits': torch.sigmoid,
121
+ }
122
+
123
+ def _loss_func_name2activ(name:str, axis:int=-1):
124
+ res = loss_func_name2activ[name]
125
+ if res == F.softmax: res = partial(F.softmax, dim=axis)
126
+ return res
127
+
128
+ def _loss_func2activ(loss_func):
129
+ if getattr(loss_func,'keywords',None):
130
+ if not loss_func.keywords.get('log_input', True): return
131
+ axis = getattr(loss_func, 'axis', -1)
132
+ # flattened loss
133
+ loss_func = getattr(loss_func, 'func', loss_func)
134
+ # could have a partial inside flattened loss! Duplicate on purpose.
135
+ loss_func = getattr(loss_func, 'func', loss_func)
136
+ cls_name = camel2snake(loss_func.__class__.__name__)
137
+ if cls_name == 'mix_up_loss':
138
+ loss_func = loss_func.crit
139
+ cls_name = camel2snake(loss_func.__class__.__name__)
140
+ if cls_name in loss_func_name2activ:
141
+ if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return
142
+ return _loss_func_name2activ(cls_name, axis)
143
+ if getattr(loss_func,'__name__','') in loss_func_name2activ:
144
+ return _loss_func_name2activ(loss_func.__name__, axis)
145
+ return noop
146
+
147
+ @dataclass
148
+ class Learner():
149
+ "Trainer for `model` using `data` to minimize `loss_func` with optimizer `opt_func`."
150
+ data:DataBunch
151
+ model:nn.Module
152
+ opt_func:Callable=AdamW
153
+ loss_func:Callable=None
154
+ metrics:Collection[Callable]=None
155
+ true_wd:bool=True
156
+ bn_wd:bool=True
157
+ wd:Floats=defaults.wd
158
+ train_bn:bool=True
159
+ path:str = None
160
+ model_dir:PathOrStr = 'models'
161
+ callback_fns:Collection[Callable]=None
162
+ callbacks:Collection[Callback]=field(default_factory=list)
163
+ layer_groups:Collection[nn.Module]=None
164
+ add_time:bool=True
165
+ silent:bool=None
166
+ def __post_init__(self)->None:
167
+ "Setup path,metrics, callbacks and ensure model directory exists."
168
+ self.path = Path(ifnone(self.path, self.data.path))
169
+ self.model = self.model.to(self.data.device)
170
+ self.loss_func = self.loss_func or self.data.loss_func
171
+ self.metrics=listify(self.metrics)
172
+ if not self.layer_groups: self.layer_groups = [nn.Sequential(*flatten_model(self.model))]
173
+ self.callbacks = listify(self.callbacks)
174
+ if self.silent is None: self.silent = defaults.silent
175
+ self.callback_fns = [partial(Recorder, add_time=self.add_time, silent=self.silent)] + listify(self.callback_fns)
176
+
177
+ def init(self, init): apply_init(self.model, init)
178
+
179
+ def _test_writeable_path(self):
180
+ path = self.path/self.model_dir
181
+ try:
182
+ path.mkdir(parents=True, exist_ok=True)
183
+ tmp_file = get_tmp_file(path)
184
+ except OSError as e:
185
+ raise Exception(f"{e}\nCan't write to '{path}', set `learn.model_dir` attribute in Learner to a full libpath path that is writable") from None
186
+ os.remove(tmp_file)
187
+
188
+ def lr_range(self, lr:Union[float,slice])->np.ndarray:
189
+ "Build differential learning rates from `lr`."
190
+ if not isinstance(lr,slice): return lr
191
+ if lr.start: res = even_mults(lr.start, lr.stop, len(self.layer_groups))
192
+ else: res = [lr.stop/10]*(len(self.layer_groups)-1) + [lr.stop]
193
+ return np.array(res)
194
+
195
+ def fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,
196
+ wd:Floats=None, callbacks:Collection[Callback]=None, batch_multiplier:int=1)->None:
197
+ "Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`."
198
+ lr = self.lr_range(lr)
199
+ if wd is None: wd = self.wd
200
+ if not getattr(self, 'opt', False): self.create_opt(lr, wd)
201
+ else: self.opt.lr,self.opt.wd = lr,wd
202
+ callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
203
+ if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
204
+ fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks, batch_multiplier=batch_multiplier)
205
+
206
+ def create_opt(self, lr:Floats, wd:Floats=0.)->None:
207
+ "Create optimizer with `lr` learning rate and `wd` weight decay."
208
+ self.opt = OptimWrapper.create(self.opt_func, lr, self.layer_groups, wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
209
+
210
+ def split(self, split_on:SplitFuncOrIdxList)->None:
211
+ "Split the model at `split_on`."
212
+ if isinstance(split_on,Callable): split_on = split_on(self.model)
213
+ self.layer_groups = split_model(self.model, split_on)
214
+ return self
215
+
216
+ def freeze_to(self, n:int)->None:
217
+ "Freeze layers up to layer group `n`."
218
+ for g in self.layer_groups[:n]:
219
+ for l in g:
220
+ if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
221
+ for g in self.layer_groups[n:]: requires_grad(g, True)
222
+ self.create_opt(defaults.lr)
223
+
224
+ def freeze(self)->None:
225
+ "Freeze up to last layer group."
226
+ assert(len(self.layer_groups)>1)
227
+ self.freeze_to(-1)
228
+
229
+ def unfreeze(self):
230
+ "Unfreeze entire model."
231
+ self.freeze_to(0)
232
+
233
+ def export(self, file:PathLikeOrBinaryStream='export.pkl', destroy=False):
234
+ "Export the state of the `Learner` in `self.path/file`. `file` can be file-like (file or buffer)"
235
+ if rank_distrib(): return # don't save if slave proc
236
+ args = ['opt_func', 'loss_func', 'metrics', 'true_wd', 'bn_wd', 'wd', 'train_bn', 'model_dir', 'callback_fns']
237
+ state = {a:getattr(self,a) for a in args}
238
+ state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
239
+ #layer_groups -> need to find a way
240
+ #TO SEE: do we save model structure and weights separately?
241
+ with ModelOnCPU(self.model) as m:
242
+ state['model'] = m
243
+ xtra = dict(normalize=self.data.norm.keywords) if getattr(self.data, 'norm', False) else {}
244
+ state['data'] = self.data.valid_ds.get_state(**xtra)
245
+ state['cls'] = self.__class__
246
+ try_save(state, self.path, file)
247
+ if destroy: self.destroy()
248
+
249
+ def save(self, file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True):
250
+ "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
251
+ if is_pathlike(file): self._test_writeable_path()
252
+ if rank_distrib(): return # don't save if slave proc
253
+ target = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file
254
+ if not hasattr(self, 'opt'): with_opt=False
255
+ if not with_opt: state = get_model(self.model).state_dict()
256
+ else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
257
+ torch.save(state, target)
258
+ if return_path: return target
259
+
260
+ def dl(self, ds_type:DatasetType=DatasetType.Valid):
261
+ "Return DataLoader for DatasetType `ds_type`."
262
+ return self.data.dl(ds_type)
263
+
264
+ def load(self, file:PathLikeOrBinaryStream=None, device:torch.device=None, strict:bool=True,
265
+ with_opt:bool=None, purge:bool=True, remove_module:bool=False):
266
+ "Load model and optimizer state (if `with_opt`) `file` from `self.model_dir` using `device`. `file` can be file-like (file or buffer)"
267
+ if purge: self.purge(clear_opt=ifnone(with_opt, False))
268
+ if device is None: device = self.data.device
269
+ elif isinstance(device, int): device = torch.device('cuda', device)
270
+ source = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file
271
+ state = torch.load(source, map_location=device)
272
+ if set(state.keys()) == {'model', 'opt'}:
273
+ model_state = state['model']
274
+ if remove_module: model_state = remove_module_load(model_state)
275
+ get_model(self.model).load_state_dict(model_state, strict=strict)
276
+ if ifnone(with_opt,True):
277
+ if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)
278
+ try: self.opt.load_state_dict(state['opt'])
279
+ except: pass
280
+ else:
281
+ if with_opt: warn("Saved filed doesn't contain an optimizer state.")
282
+ if remove_module: state = remove_module_load(state)
283
+ get_model(self.model).load_state_dict(state, strict=strict)
284
+ del state
285
+ gc.collect()
286
+ return self
287
+
288
+ def destroy(self):
289
+ "Free the Learner internals, leaving just an empty shell that consumes no memory"
290
+
291
+ class ZombieLearner(Learner):
292
+ msg = "this object has been destroyed"
293
+ def __getattr__(self, item): print(ZombieLearner.msg); return None
294
+ def destroyed(*args, **kwargs): print(ZombieLearner.msg)
295
+
296
+ attrs = [k for k in self.__dict__.keys() if not k.startswith("__")]
297
+ for a in attrs: delattr(self, a)
298
+ # the instance methods can still be called, but will just give a message
299
+ methods = [k for k in dir(self) if not k.startswith("__") and inspect.isroutine(getattr(self, k))]
300
+ for m in methods: setattr(self, m, ZombieLearner.destroyed)
301
+ self.__class__ = ZombieLearner
302
+ gc.collect()
303
+ print("this Learner object self-destroyed - it still exists, but no longer usable")
304
+
305
+ def purge(self, clear_opt:bool=True):
306
+ "Purge the `Learner` of all cached attributes to release some GPU memory."
307
+ self._test_writeable_path()
308
+ attrs_all = [k for k in self.__dict__.keys() if not k.startswith("__")]
309
+ attrs_pkl = ['bn_wd', 'callback_fns', 'layer_groups', 'loss_func', 'metrics', 'model',
310
+ 'model_dir', 'opt_func', 'path', 'train_bn', 'true_wd', 'wd']
311
+ # +callbacks: get pickled too, but not directly
312
+ attrs_keep = ['data', 'recorder']
313
+ attrs_del = list(set(attrs_all) - set(attrs_keep))
314
+ state = {a:getattr(self, a) for a in attrs_pkl}
315
+ state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
316
+ if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
317
+
318
+ tmp_file = get_tmp_file(self.path/self.model_dir)
319
+ torch.save(state, open(tmp_file, 'wb'))
320
+ for a in attrs_del: delattr(self, a)
321
+ gc.collect()
322
+ state = torch.load(tmp_file)
323
+ os.remove(tmp_file)
324
+
325
+ for a in attrs_pkl: setattr(self, a, state[a])
326
+ cb_state = state.pop('cb_state')
327
+ self.callbacks = [load_callback(c,s, self) for c,s in cb_state.items()]
328
+ if not clear_opt and 'opt' in state:
329
+ try: self.opt = OptimWrapper.load_with_state_and_layer_group(state['opt'], self.layer_groups)
330
+ except: warn("Wasn't able to properly load the optimizer state again.")
331
+ del state
332
+ gc.collect()
333
+ return self
334
+
335
+ def get_preds(self, ds_type:DatasetType=DatasetType.Valid, with_loss:bool=False, n_batch:Optional[int]=None,
336
+ pbar:Optional[PBar]=None) -> List[Tensor]:
337
+ "Return predictions and targets on `ds_type` dataset."
338
+ lf = self.loss_func if with_loss else None
339
+ return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
340
+ activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
341
+
342
+ def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False) -> List[Tensor]:
343
+ with torch.no_grad():
344
+ training = self.model.training
345
+ self.model.train(False)
346
+ "Return output of the model on one batch from `ds_type` dataset."
347
+ if batch is not None: xb,yb = batch
348
+ else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)
349
+ cb_handler = CallbackHandler(self.callbacks)
350
+ #xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
351
+ if not with_dropout:
352
+ preds = loss_batch(self.model.eval(), xb, yb, cb_handler=cb_handler)
353
+ else:
354
+ preds = loss_batch(self.model.eval().apply(self.apply_dropout), xb, yb, cb_handler=cb_handler)
355
+ res = _loss_func2activ(self.loss_func)(preds[0])
356
+ self.model.train(training)
357
+ if not reconstruct: return res
358
+ res = res.detach().cpu()
359
+ ds = self.dl(ds_type).dataset
360
+ norm = getattr(self.data, 'norm', False)
361
+ if norm and norm.keywords.get('do_y',False):
362
+ res = self.data.denorm(res, do_x=True)
363
+ return [ds.reconstruct(o) for o in res]
364
+
365
+ def backward(self, item):
366
+ "Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached."
367
+ xb,yb = self.data.one_item(item)
368
+ loss = loss_batch(self.model.eval(), xb, yb, self.loss_func, opt=FakeOptimizer(),
369
+ cb_handler=CallbackHandler(self.callbacks))
370
+ return loss
371
+
372
+ def predict(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs):
373
+ "Return predicted class, label and probabilities for `item`."
374
+ batch = self.data.one_item(item)
375
+ res = self.pred_batch(batch=batch, with_dropout=with_dropout)
376
+ raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
377
+ norm = getattr(self.data,'norm',False)
378
+ if norm:
379
+ x = self.data.denorm(x)
380
+ if norm.keywords.get('do_y',False): raw_pred = self.data.denorm(raw_pred)
381
+ ds = self.data.single_ds
382
+ pred = ds.y.analyze_pred(raw_pred, **kwargs)
383
+ x = ds.x.reconstruct(grab_idx(x, 0))
384
+ y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
385
+ return (x, y, pred, raw_pred) if return_x else (y, pred, raw_pred)
386
+
387
+ def validate(self, dl=None, callbacks=None, metrics=None):
388
+ "Validate on `dl` with potential `callbacks` and `metrics`."
389
+ dl = ifnone(dl, self.data.valid_dl)
390
+ metrics = ifnone(metrics, self.metrics)
391
+ cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)
392
+ cb_handler.on_epoch_begin()
393
+ val_metrics = validate(self.model, dl, self.loss_func, cb_handler)
394
+ cb_handler.on_epoch_end(val_metrics)
395
+ return cb_handler.state_dict['last_metrics']
396
+
397
+ def show_results(self, ds_type=DatasetType.Valid, rows:int=5, **kwargs):
398
+ "Show `rows` result of predictions on `ds_type` dataset."
399
+ #TODO: get read of has_arg x and split_kwargs_by_func if possible
400
+ #TODO: simplify this and refactor with pred_batch(...reconstruct=True)
401
+ n_items = rows ** 2 if self.data.train_ds.x._square_show_res else rows
402
+ if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
403
+ ds = self.dl(ds_type).dataset
404
+ self.callbacks.append(RecordOnCPU())
405
+ preds = self.pred_batch(ds_type)
406
+ *self.callbacks,rec_cpu = self.callbacks
407
+ x,y = rec_cpu.input,rec_cpu.target
408
+ norm = getattr(self.data,'norm',False)
409
+ if norm:
410
+ x = self.data.denorm(x)
411
+ if norm.keywords.get('do_y',False):
412
+ y = self.data.denorm(y, do_x=True)
413
+ preds = self.data.denorm(preds, do_x=True)
414
+ analyze_kwargs,kwargs = split_kwargs_by_func(kwargs, ds.y.analyze_pred)
415
+ preds = [ds.y.analyze_pred(grab_idx(preds, i), **analyze_kwargs) for i in range(n_items)]
416
+ xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
417
+ if has_arg(ds.y.reconstruct, 'x'):
418
+ ys = [ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
419
+ zs = [ds.y.reconstruct(z, x=x) for z,x in zip(preds,xs)]
420
+ else :
421
+ ys = [ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
422
+ zs = [ds.y.reconstruct(z) for z in preds]
423
+ ds.x.show_xyzs(xs, ys, zs, **kwargs)
424
+
425
+ def apply_dropout(self, m):
426
+ "If a module contains 'dropout' in it's name, it will be switched to .train() mode."
427
+ if 'dropout' in m.__class__.__name__.lower(): m.train()
428
+
429
+ def predict_with_mc_dropout(self, item:ItemBase, with_dropout:bool=True, n_times=10, **kwargs):
430
+ "Make predictions with dropout turned on for n_times (default 10)."
431
+ return [self.predict(item, with_dropout=with_dropout) for _ in range(n_times)]
432
+
433
+ class RecordOnCPU(Callback):
434
+ "Store the `input` and `target` going through the model on the CPU."
435
+ def on_batch_begin(self, last_input,last_target,**kwargs):
436
+ self.input,self.target = to_cpu(last_input),to_cpu(last_target)
437
+
438
+ class LearnerCallback(Callback):
439
+ "Base class for creating callbacks for a `Learner`."
440
+ def __init__(self, learn):
441
+ self._learn = weakref.ref(learn)
442
+ self.exclude,self.not_min = ['_learn'],[]
443
+ setattr(self.learn, self.cb_name, self)
444
+
445
+ def __getattr__(self,k): return getattr(self.learn, k)
446
+ def __setstate__(self,data:Any): self.__dict__.update(data)
447
+
448
+ @property
449
+ def learn(self) -> Learner: return self._learn()
450
+ @learn.setter
451
+ def learn(self, learn: Learner) -> None: self._learn = weakref.ref(learn)
452
+
453
+ @property
454
+ def cb_name(self): return camel2snake(self.__class__.__name__)
455
+
456
+ class Recorder(LearnerCallback):
457
+ "A `LearnerCallback` that records epoch, loss, opt and metric data during training."
458
+ _order=-10
459
+ def __init__(self, learn:Learner, add_time:bool=True, silent:bool=False):
460
+ super().__init__(learn)
461
+ self.opt = self.learn.opt
462
+ self.train_dl = self.learn.data.train_dl
463
+ self.no_val,self.silent,self.add_time = False,silent,add_time
464
+
465
+ def on_train_begin(self, pbar:PBar, metrics_names:Collection[str], **kwargs:Any)->None:
466
+ "Initialize recording status at beginning of training."
467
+ self.pbar = pbar
468
+ self.names = ['epoch', 'train_loss'] if self.no_val else ['epoch', 'train_loss', 'valid_loss']
469
+ self.metrics_names = metrics_names
470
+ if hasattr(self, '_added_met_names'): self.metrics_names += self._added_met_names
471
+ self.names += self.metrics_names
472
+ if self.add_time: self.names.append('time')
473
+ if not self.silent: self.pbar.write(self.names, table=True)
474
+ self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]
475
+
476
+ def on_epoch_begin(self, **kwargs:Any)->None:
477
+ if self.add_time: self.start_epoch = time()
478
+
479
+ def on_batch_begin(self, train, **kwargs:Any)->None:
480
+ "Record learning rate and momentum at beginning of batch."
481
+ if train:
482
+ self.lrs.append(self.opt.lr)
483
+ self.moms.append(self.opt.mom)
484
+
485
+ def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:
486
+ "Record the loss before any other callback has a chance to modify it."
487
+ self.losses.append(smooth_loss)
488
+ if self.pbar is not None and hasattr(self.pbar,'child'):
489
+ self.pbar.child.comment = f'{smooth_loss:.4f}'
490
+
491
+ def on_epoch_end(self, epoch:int, num_batch:int, smooth_loss:Tensor,
492
+ last_metrics=MetricsList, **kwargs:Any)->bool:
493
+ "Save epoch info: num_batch, smooth_loss, metrics."
494
+ self.nb_batches.append(num_batch)
495
+ if last_metrics is not None: self.val_losses.append(last_metrics[0])
496
+ else: last_metrics = [] if self.no_val else [None]
497
+ if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])
498
+ self.format_stats([epoch, smooth_loss] + last_metrics)
499
+
500
+ def format_stats(self, stats:TensorOrNumList)->None:
501
+ "Format stats before printing."
502
+ str_stats = []
503
+ for name,stat in zip(self.names,stats):
504
+ str_stats.append('#na#' if stat is None else str(stat) if isinstance(stat, int) else f'{stat:.6f}')
505
+ if self.add_time: str_stats.append(format_time(time() - self.start_epoch))
506
+ if not self.silent: self.pbar.write(str_stats, table=True)
507
+
508
+ def add_metric_names(self, names):
509
+ "Add `names` to the inner metric names."
510
+ if hasattr(self, '_added_met_names'): self._added_met_names += names
511
+ else: self._added_met_names = names
512
+
513
+ def plot_lr(self, show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
514
+ "Plot learning rate, `show_moms` to include momentum."
515
+ lrs = self._split_list(self.lrs, skip_start, skip_end)
516
+ iterations = self._split_list(range_of(self.lrs), skip_start, skip_end)
517
+ if show_moms:
518
+ moms = self._split_list(self.moms, skip_start, skip_end)
519
+ fig, axs = plt.subplots(1,2, figsize=(12,4))
520
+ axs[0].plot(iterations, lrs)
521
+ axs[0].set_xlabel('Iterations')
522
+ axs[0].set_ylabel('Learning Rate')
523
+ axs[1].plot(iterations, moms)
524
+ axs[1].set_xlabel('Iterations')
525
+ axs[1].set_ylabel('Momentum')
526
+ else:
527
+ fig, ax = plt.subplots()
528
+ ax.plot(iterations, lrs)
529
+ ax.set_xlabel('Iterations')
530
+ ax.set_ylabel('Learning Rate')
531
+ if ifnone(return_fig, defaults.return_fig): return fig
532
+ if not IN_NOTEBOOK: plot_sixel(fig)
533
+
534
+ @staticmethod
535
+ def smoothen_by_spline(xs, ys, **kwargs):
536
+ xs = np.arange(len(ys))
537
+ spl = scipy.interpolate.UnivariateSpline(xs, ys, **kwargs)
538
+ ys = spl(xs)
539
+ return ys
540
+
541
+ def plot(self, skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None,
542
+ **kwargs)->Optional[plt.Figure]:
543
+ "Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. Optionally plot and return min gradient"
544
+ lrs = self._split_list(self.lrs, skip_start, skip_end)
545
+ losses = self._split_list(self.losses, skip_start, skip_end)
546
+ losses = [x.item() for x in losses]
547
+ if 'k' in kwargs: losses = self.smoothen_by_spline(lrs, losses, **kwargs)
548
+ fig, ax = plt.subplots(1,1)
549
+ ax.plot(lrs, losses)
550
+ ax.set_ylabel("Loss")
551
+ ax.set_xlabel("Learning Rate")
552
+ ax.set_xscale('log')
553
+ ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
554
+ if suggestion:
555
+ try: mg = (np.gradient(np.array(losses))).argmin()
556
+ except:
557
+ print("Failed to compute the gradients, there might not be enough points.")
558
+ return
559
+ print(f"Min numerical gradient: {lrs[mg]:.2E}")
560
+ ax.plot(lrs[mg],losses[mg],markersize=10,marker='o',color='red')
561
+ self.min_grad_lr = lrs[mg]
562
+ ml = np.argmin(losses)
563
+ print(f"Min loss divided by 10: {lrs[ml]/10:.2E}")
564
+ if ifnone(return_fig, defaults.return_fig): return fig
565
+ if not IN_NOTEBOOK: plot_sixel(fig)
566
+
567
+ def plot_losses(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
568
+ "Plot training and validation losses."
569
+ fig, ax = plt.subplots(1,1)
570
+ losses = self._split_list(self.losses, skip_start, skip_end)
571
+ iterations = self._split_list(range_of(self.losses), skip_start, skip_end)
572
+ ax.plot(iterations, losses, label='Train')
573
+ val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)
574
+ val_losses = self._split_list_val(self.val_losses, skip_start, skip_end)
575
+ ax.plot(val_iter, val_losses, label='Validation')
576
+ ax.set_ylabel('Loss')
577
+ ax.set_xlabel('Batches processed')
578
+ ax.legend()
579
+ if ifnone(return_fig, defaults.return_fig): return fig
580
+ if not IN_NOTEBOOK: plot_sixel(fig)
581
+
582
+ def plot_metrics(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
583
+ "Plot metrics collected during training."
584
+ assert len(self.metrics) != 0, "There are no metrics to plot."
585
+ fig, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))
586
+ val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)
587
+ axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]
588
+ for i, ax in enumerate(axes):
589
+ values = [met[i] for met in self.metrics]
590
+ values = self._split_list_val(values, skip_start, skip_end)
591
+ ax.plot(val_iter, values)
592
+ ax.set_ylabel(str(self.metrics_names[i]))
593
+ ax.set_xlabel('Batches processed')
594
+ if ifnone(return_fig, defaults.return_fig): return fig
595
+ if not IN_NOTEBOOK: plot_sixel(fig)
596
+
597
+ def _split_list(self, vals:Collection[float], skip_start:int, skip_end:int):
598
+ return vals[skip_start:-skip_end] if skip_end > 0 else vals[skip_start:]
599
+
600
+ def _split_list_val(self, vals:Collection[float], skip_start:int, skip_end:int):
601
+ val_iter = np.cumsum(self.nb_batches)
602
+ start_val = (val_iter - skip_start >= 0).nonzero()[0].min()
603
+ end_val = (val_iter[-1] - val_iter - skip_end >= 0).nonzero()[0].max()+1
604
+ return vals[start_val:end_val] if skip_end > 0 else vals[start_val:]
605
+
606
+ class FakeOptimizer():
607
+ def step(self): pass
608
+ def zero_grad(self): pass
609
+
610
+ def load_callback(class_func, state, learn:Learner):
611
+ init_kwargs, others = split_kwargs_by_func(state, class_func.__init__)
612
+ res = class_func(learn, **init_kwargs) if issubclass(class_func, LearnerCallback) else class_func(**init_kwargs)
613
+ for k,v in others.items(): setattr(res, k, v)
614
+ return res
615
+
616
+ def load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, **db_kwargs):
617
+ "Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
618
+ source = Path(path)/file if is_pathlike(file) else file
619
+ state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
620
+ model = state.pop('model')
621
+ src = LabelLists.load_state(path, state.pop('data'))
622
+ if test is not None: src.add_test(test)
623
+ data = src.databunch(**db_kwargs)
624
+ cb_state = state.pop('cb_state')
625
+ clas_func = state.pop('cls')
626
+ res = clas_func(data, model, **state)
627
+ res.callback_fns = state['callback_fns'] #to avoid duplicates
628
+ res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
629
+ return res
fastai/basics.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .basic_train import *
2
+ from .callback import *
3
+ from .core import *
4
+ from .basic_data import *
5
+ from .data_block import *
6
+ from .layers import *
7
+ from .metrics import *
8
+ from .torch_core import *
9
+ from .train import *
10
+ from .datasets import *
11
+ from .version import *
12
+ from . import callbacks
13
+
14
+ """
15
+ from . import core,torch_core,basic_data,basic_train,callback,data_block,layers,metrics,train,datasets,callbacks
16
+
17
+ __all__ = [o for o in dir(core) if not o.startswith('_')]
18
+ __all__ += [o for o in dir(torch_core) if not o.startswith('_')]
19
+ __all__ += [*basic_train.__all__, *callback.__all__, 'core', 'torch_core', 'callbacks',
20
+ *basic_data.__all__, *data_block.__all__, *layers.__all__, *metrics.__all__,
21
+ *train.__all__, *datasets.__all__, '__version__']
22
+ """
23
+
24
+ try: from .gen_doc.nbdoc import doc
25
+ except: pass # Optional if jupyter is present
26
+ #__all__.append('doc')
27
+
28
+ __all__ = [o for o in dir(sys.modules[__name__]) if not o.startswith('_')] + ['__version__']
29
+
fastai/callback.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks."
2
+ from .basic_data import *
3
+ from .torch_core import *
4
+ import torch.distributed as dist
5
+
6
+ __all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList',
7
+ 'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly']
8
+
9
+ class OptimWrapper():
10
+ "Basic wrapper around `opt` to simplify hyper-parameters changes."
11
+ def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):
12
+ assert not isinstance(opt, OptimWrapper)
13
+ self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd
14
+ self.opt_keys = list(self.opt.param_groups[0].keys())
15
+ self.opt_keys.remove('params')
16
+ self.read_defaults()
17
+ self.wd = wd
18
+
19
+ @classmethod
20
+ def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0.,
21
+ true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer:
22
+ "Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
23
+ split_params = split_no_wd_params(layer_groups)
24
+ opt = opt_func([{'params': p, 'lr':0} for p in split_params])
25
+ opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)
26
+ opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func
27
+ return opt
28
+
29
+ def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True):
30
+ "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
31
+ opt_func = getattr(self, 'opt_func', self.opt.__class__)
32
+ res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
33
+ res.mom,res.beta = self.mom,self.beta
34
+ return res
35
+
36
+ def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]):
37
+ "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
38
+ opt_func = getattr(self, 'opt_func', self.opt.__class__)
39
+ opt = opt_func([{'params': p, 'lr':0} for p in param_groups])
40
+ opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
41
+ opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta
42
+ return opt
43
+
44
+ def __repr__(self)->str:
45
+ return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}'
46
+
47
+ #Pytorch optimizer methods
48
+ def step(self)->None:
49
+ "Set weight decay and step optimizer."
50
+ # weight decay outside of optimizer step (AdamW)
51
+ if self.true_wd:
52
+ for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
53
+ for p in pg1['params']: p.data.mul_(1 - wd*lr)
54
+ if self.bn_wd:
55
+ for p in pg2['params']: p.data.mul_(1 - wd*lr)
56
+ self.set_val('weight_decay', listify(0, self._wd))
57
+ self.opt.step()
58
+
59
+ def zero_grad(self)->None:
60
+ "Clear optimizer gradients."
61
+ self.opt.zero_grad()
62
+
63
+ #Passthrough to the inner opt.
64
+ def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None)
65
+ def __setstate__(self,data:Any): self.__dict__.update(data)
66
+
67
+ def clear(self):
68
+ "Reset the state of the inner optimizer."
69
+ sd = self.state_dict()
70
+ sd['state'] = {}
71
+ self.load_state_dict(sd)
72
+
73
+ @property
74
+ def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups])
75
+
76
+ #Hyperparameters as properties
77
+ @property
78
+ def lr(self)->float: return self._lr[-1]
79
+ @lr.setter
80
+ def lr(self, val:float)->None:
81
+ self._lr = self.set_val('lr', listify(val, self._lr))
82
+
83
+ @property
84
+ def mom(self)->float:return self._mom[-1]
85
+ @mom.setter
86
+ def mom(self, val:float)->None:
87
+ if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom))
88
+ elif 'betas' in self.opt_keys: self.set_val('betas', (listify(val, self._mom), self._beta))
89
+ self._mom = listify(val, self._mom)
90
+
91
+ @property
92
+ def beta(self)->float: return None if self._beta is None else self._beta[-1]
93
+ @beta.setter
94
+ def beta(self, val:float)->None:
95
+ "Set beta (or alpha as makes sense for given optimizer)."
96
+ if val is None: return
97
+ if 'betas' in self.opt_keys: self.set_val('betas', (self._mom, listify(val, self._beta)))
98
+ elif 'alpha' in self.opt_keys: self.set_val('alpha', listify(val, self._beta))
99
+ self._beta = listify(val, self._beta)
100
+
101
+ @property
102
+ def wd(self)->float: return self._wd[-1]
103
+ @wd.setter
104
+ def wd(self, val:float)->None:
105
+ "Set weight decay."
106
+ if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)
107
+ self._wd = listify(val, self._wd)
108
+
109
+ #Helper functions
110
+ def read_defaults(self)->None:
111
+ "Read the values inside the optimizer for the hyper-parameters."
112
+ self._beta = None
113
+ if 'lr' in self.opt_keys: self._lr = self.read_val('lr')
114
+ if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum')
115
+ if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha')
116
+ if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas')
117
+ if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay')
118
+ reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay']
119
+ stat_names = [n for n in self.opt_keys if n not in reserved_names]
120
+ self._stats = {n:self.read_val(n) for n in stat_names}
121
+
122
+ def get_stat(self, name:str)->float:
123
+ if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name)
124
+ else: return self._stats[name][-1]
125
+ def set_stat(self, name:str, value:Union[float, Collection[float]])->None:
126
+ if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value)
127
+ else:
128
+ val = listify(value, self._stats[name])
129
+ self.set_val(name, val)
130
+ self._stats[name] = val
131
+
132
+ def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any:
133
+ "Set `val` inside the optimizer dictionary at `key`."
134
+ if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)]
135
+ for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
136
+ pg1[key] = v
137
+ if bn_groups: pg2[key] = v
138
+ return val
139
+
140
+ def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]:
141
+ "Read a hyperparameter `key` in the optimizer dictionary."
142
+ val = [pg[key] for pg in self.opt.param_groups[::2]]
143
+ if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val]
144
+ return val
145
+
146
+ def get_state(self):
147
+ "Return the inner state minus the layer groups."
148
+ return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom,
149
+ 'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd}
150
+
151
+ @classmethod
152
+ def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]):
153
+ res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'],
154
+ bn_wd=state['bn_wd'])
155
+ res._mom,res._beta = state['mom'],state['beta']
156
+ res.load_state_dict(state['opt_state'])
157
+ return res
158
+
159
+ class Callback():
160
+ "Base class for callbacks that want to record values, dynamically change learner params, etc."
161
+ _order=0
162
+ def on_train_begin(self, **kwargs:Any)->None:
163
+ "To initialize constants in the callback."
164
+ pass
165
+ def on_epoch_begin(self, **kwargs:Any)->None:
166
+ "At the beginning of each epoch."
167
+ pass
168
+ def on_batch_begin(self, **kwargs:Any)->None:
169
+ "Set HP before the output and loss are computed."
170
+ pass
171
+ def on_loss_begin(self, **kwargs:Any)->None:
172
+ "Called after forward pass but before loss has been computed."
173
+ pass
174
+ def on_backward_begin(self, **kwargs:Any)->None:
175
+ "Called after the forward pass and the loss has been computed, but before backprop."
176
+ pass
177
+ def on_backward_end(self, **kwargs:Any)->None:
178
+ "Called after backprop but before optimizer step. Useful for true weight decay in AdamW."
179
+ pass
180
+ def on_step_end(self, **kwargs:Any)->None:
181
+ "Called after the step of the optimizer but before the gradients are zeroed."
182
+ pass
183
+ def on_batch_end(self, **kwargs:Any)->None:
184
+ "Called at the end of the batch."
185
+ pass
186
+ def on_epoch_end(self, **kwargs:Any)->None:
187
+ "Called at the end of an epoch."
188
+ pass
189
+ def on_train_end(self, **kwargs:Any)->None:
190
+ "Useful for cleaning up things and saving files/models."
191
+ pass
192
+ def jump_to_epoch(self, epoch)->None:
193
+ "To resume training at `epoch` directly."
194
+ pass
195
+
196
+ def get_state(self, minimal:bool=True):
197
+ "Return the inner state of the `Callback`, `minimal` or not."
198
+ to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()
199
+ if minimal: to_remove += getattr(self, 'not_min', []).copy()
200
+ return {k:v for k,v in self.__dict__.items() if k not in to_remove}
201
+
202
+ def __repr__(self):
203
+ attrs = func_args(self.__init__)
204
+ to_remove = getattr(self, 'exclude', [])
205
+ list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]
206
+ return '\n'.join(list_repr)
207
+
208
+ class SmoothenValue():
209
+ "Create a smooth moving average for a value (loss, etc) using `beta`."
210
+ def __init__(self, beta:float):
211
+ self.beta,self.n,self.mov_avg = beta,0,0
212
+
213
+ def add_value(self, val:float)->None:
214
+ "Add `val` to calculate updated smoothed value."
215
+ self.n += 1
216
+ self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
217
+ self.smooth = self.mov_avg / (1 - self.beta ** self.n)
218
+
219
+ CallbackList = Collection[Callback]
220
+
221
+ def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False}
222
+
223
+ @dataclass
224
+ class CallbackHandler():
225
+ "Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`."
226
+ callbacks:CallbackList=None
227
+ metrics:CallbackList=None
228
+ beta:float=0.98
229
+
230
+ def __post_init__(self)->None:
231
+ "Initialize smoother and learning stats."
232
+ self.callbacks = ifnone(self.callbacks, [])
233
+ self.metrics = ifnone(self.metrics, [])
234
+ self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
235
+ self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0))
236
+ self.smoothener = SmoothenValue(self.beta)
237
+ self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()
238
+
239
+ def _call_and_update(self, cb, cb_name, **kwargs)->None:
240
+ "Call `cb_name` on `cb` and update the inner state."
241
+ new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
242
+ for k,v in new.items():
243
+ if k not in self.state_dict:
244
+ raise Exception(f"{k} isn't a valid key in the state of the callbacks.")
245
+ else: self.state_dict[k] = v
246
+
247
+ def __call__(self, cb_name, call_mets=True, **kwargs)->None:
248
+ "Call through to all of the `CallbakHandler` functions."
249
+ if call_mets:
250
+ for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
251
+ for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
252
+
253
+ def set_dl(self, dl:DataLoader):
254
+ "Set the current `dl` used."
255
+ if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl)
256
+ if isinstance(dl.dataset, Callback):
257
+ self.callbacks.append(dl.dataset)
258
+ self.cb_dl = dl.dataset
259
+
260
+ def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:
261
+ "About to start learning."
262
+ self.state_dict = _get_init_state()
263
+ self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics))
264
+ names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics]
265
+ self('train_begin', metrics_names=names)
266
+ if self.state_dict['epoch'] != 0:
267
+ self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch']
268
+ for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch'])
269
+
270
+ def on_epoch_begin(self)->None:
271
+ "Handle new epoch."
272
+ self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False
273
+ self('epoch_begin')
274
+
275
+ def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
276
+ "Handle new batch `xb`,`yb` in `train` or validation."
277
+ self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
278
+ stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
279
+ self('batch_begin', mets = not self.state_dict['train'])
280
+ return self.state_dict['last_input'], self.state_dict['last_target']
281
+
282
+ def on_loss_begin(self, out:Tensor)->Any:
283
+ "Handle start of loss calculation with model output `out`."
284
+ self.state_dict['last_output'] = out
285
+ self('loss_begin', call_mets=False)
286
+ return self.state_dict['last_output']
287
+
288
+ def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]:
289
+ "Handle gradient calculation on `loss`."
290
+ self.smoothener.add_value(loss.detach().cpu())
291
+ self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
292
+ self('backward_begin', call_mets=False)
293
+ return self.state_dict['last_loss'], self.state_dict['skip_bwd']
294
+
295
+ def on_backward_end(self)->Any:
296
+ "Handle end of gradient calculation."
297
+ self('backward_end', call_mets=False)
298
+ return self.state_dict['skip_step']
299
+
300
+ def on_step_end(self)->Any:
301
+ "Handle end of optimization step."
302
+ self('step_end', call_mets=False)
303
+ return self.state_dict['skip_zero']
304
+
305
+ def on_batch_end(self, loss:Tensor)->Any:
306
+ "Handle end of processing one batch with `loss`."
307
+ self.state_dict['last_loss'] = loss
308
+ self('batch_end', call_mets = not self.state_dict['train'])
309
+ if self.state_dict['train']:
310
+ self.state_dict['iteration'] += 1
311
+ self.state_dict['num_batch'] += 1
312
+ return self.state_dict['stop_epoch']
313
+
314
+ def on_epoch_end(self, val_loss:Tensor)->bool:
315
+ "Epoch is done, process `val_loss`."
316
+ self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]
317
+ self('epoch_end', call_mets = val_loss is not None)
318
+ self.state_dict['epoch'] += 1
319
+ return self.state_dict['stop_training']
320
+
321
+ def on_train_end(self, exception:Union[bool,Exception])->None:
322
+ "Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
323
+ self('train_end', exception=exception)
324
+
325
+ @property
326
+ def skip_validate(self): return self.state_dict['skip_validate']
327
+
328
+ class AverageMetric(Callback):
329
+ "Wrap a `func` in a callback for metrics computation."
330
+ def __init__(self, func):
331
+ # If func has a __name__ use this one else it should be a partial
332
+ name = func.__name__ if hasattr(func, '__name__') else func.func.__name__
333
+ self.func, self.name = func, name
334
+ self.world = num_distrib()
335
+
336
+ def on_epoch_begin(self, **kwargs):
337
+ "Set the inner value to 0."
338
+ self.val, self.count = 0.,0
339
+
340
+ def on_batch_end(self, last_output, last_target, **kwargs):
341
+ "Update metric computation with `last_output` and `last_target`."
342
+ if not is_listy(last_target): last_target=[last_target]
343
+ self.count += first_el(last_target).size(0)
344
+ val = self.func(last_output, *last_target)
345
+ if self.world:
346
+ val = val.clone()
347
+ dist.all_reduce(val, op=dist.ReduceOp.SUM)
348
+ val /= self.world
349
+ self.val += first_el(last_target).size(0) * val.detach().cpu()
350
+
351
+ def on_epoch_end(self, last_metrics, **kwargs):
352
+ "Set the final result in `last_metrics`."
353
+ return add_metrics(last_metrics, self.val/self.count)
354
+
355
+ def annealing_no(start:Number, end:Number, pct:float)->Number:
356
+ "No annealing, always return `start`."
357
+ return start
358
+ def annealing_linear(start:Number, end:Number, pct:float)->Number:
359
+ "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
360
+ return start + pct * (end-start)
361
+ def annealing_exp(start:Number, end:Number, pct:float)->Number:
362
+ "Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
363
+ return start * (end/start) ** pct
364
+ def annealing_cos(start:Number, end:Number, pct:float)->Number:
365
+ "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
366
+ cos_out = np.cos(np.pi * pct) + 1
367
+ return end + (start-end)/2 * cos_out
368
+
369
+ def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number:
370
+ "Helper function for `anneal_poly`."
371
+ return end + (start-end) * (1-pct)**degree
372
+ def annealing_poly(degree:Number)->Number:
373
+ "Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0."
374
+ return functools.partial(do_annealing_poly, degree=degree)
375
+
376
+ class Scheduler():
377
+ "Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`"
378
+ def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):
379
+ self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)
380
+ self.n_iter = max(1,n_iter)
381
+ if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no
382
+ else: self.func = func
383
+ self.n = 0
384
+
385
+ def restart(self): self.n = 0
386
+
387
+ def step(self)->Number:
388
+ "Return next value along annealed schedule."
389
+ self.n += 1
390
+ return self.func(self.start, self.end, self.n/self.n_iter)
391
+
392
+ @property
393
+ def is_done(self)->bool:
394
+ "Return `True` if schedule completed."
395
+ return self.n >= self.n_iter
396
+
fastai/callbacks/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lr_finder import *
2
+ from .one_cycle import *
3
+ from .fp16 import *
4
+ from .general_sched import *
5
+ from .hooks import *
6
+ from .mixup import *
7
+ from .rnn import *
8
+ from .tracker import *
9
+ from .csv_logger import *
10
+ from .loss_metrics import *
11
+ from .oversampling import *
fastai/callbacks/csv_logger.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "A `Callback` that saves tracked metrics into a persistent file."
2
+ #Contribution from devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae
3
+ from ..torch_core import *
4
+ from ..basic_data import DataBunch
5
+ from ..callback import *
6
+ from ..basic_train import Learner, LearnerCallback
7
+ from time import time
8
+ from fastprogress.fastprogress import format_time
9
+
10
+ __all__ = ['CSVLogger']
11
+
12
+ class CSVLogger(LearnerCallback):
13
+ "A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."
14
+ def __init__(self, learn:Learner, filename: str = 'history', append: bool = False):
15
+ super().__init__(learn)
16
+ self.filename,self.path,self.append = filename,self.learn.path/f'{filename}.csv',append
17
+ self.add_time = True
18
+
19
+ def read_logged_file(self):
20
+ "Read the content of saved file"
21
+ return pd.read_csv(self.path)
22
+
23
+ def on_train_begin(self, **kwargs: Any) -> None:
24
+ "Prepare file with metric names."
25
+ self.path.parent.mkdir(parents=True, exist_ok=True)
26
+ self.file = self.path.open('a') if self.append else self.path.open('w')
27
+ self.file.write(','.join(self.learn.recorder.names[:(None if self.add_time else -1)]) + '\n')
28
+
29
+ def on_epoch_begin(self, **kwargs:Any)->None:
30
+ if self.add_time: self.start_epoch = time()
31
+
32
+ def on_epoch_end(self, epoch: int, smooth_loss: Tensor, last_metrics: MetricsList, **kwargs: Any) -> bool:
33
+ "Add a line with `epoch` number, `smooth_loss` and `last_metrics`."
34
+ last_metrics = ifnone(last_metrics, [])
35
+ stats = [str(stat) if isinstance(stat, int) else '#na#' if stat is None else f'{stat:.6f}'
36
+ for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)]
37
+ if self.add_time: stats.append(format_time(time() - self.start_epoch))
38
+ str_stats = ','.join(stats)
39
+ self.file.write(str_stats + '\n')
40
+
41
+ def on_train_end(self, **kwargs: Any) -> None:
42
+ "Close the file."
43
+ self.file.close()
fastai/callbacks/fp16.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Callback support for half precision (fp16) training. Increases training speed."
2
+ from ..torch_core import *
3
+ from ..callback import *
4
+ from ..basic_train import *
5
+ from torch._utils import _unflatten_dense_tensors
6
+ from torch.nn.utils import parameters_to_vector
7
+
8
+ __all__ = ['MixedPrecision']
9
+
10
+ def get_master(layer_groups:ModuleList, flat_master:bool=False) -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
11
+ "Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32."
12
+ split_params = split_no_wd_params(layer_groups)
13
+ model_params = [[param for param in pg if param.requires_grad] for pg in split_params]
14
+ if flat_master:
15
+ master_params = []
16
+ for lg in model_params:
17
+ if len(lg) !=0 :
18
+ mp = parameters_to_vector([param.data.float() for param in lg])
19
+ mp = torch.nn.Parameter(mp, requires_grad=True)
20
+ if mp.grad is None: mp.grad = mp.new(*mp.size())
21
+ master_params.append([mp])
22
+ else: master_params.append([])
23
+ return model_params, master_params
24
+ else:
25
+ master_params = [[param.clone().float().detach() for param in lg] for lg in model_params]
26
+ for mp in master_params:
27
+ for param in mp: param.requires_grad = True
28
+ return model_params, master_params
29
+
30
+ def model_g2master_g(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
31
+ "Copy the `model_params` gradients to `master_params` for the optimizer step."
32
+ if flat_master:
33
+ for model_group,master_group in zip(model_params,master_params):
34
+ if len(master_group) != 0:
35
+ if master_group[0].grad is None: master_group[0].grad = master_group[0].data.new(*master_group[0].data.size())
36
+ master_group[0].grad.data.copy_(parameters_to_vector([p.grad.data.float() for p in model_group]))
37
+ else:
38
+ for model_group,master_group in zip(model_params,master_params):
39
+ for model, master in zip(model_group, master_group):
40
+ if model.grad is not None:
41
+ if master.grad is None: master.grad = master.data.new(*master.data.size())
42
+ master.grad.data.copy_(model.grad.data)
43
+ else: master.grad = None
44
+
45
+ def master2model(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
46
+ "Copy `master_params` to `model_params`."
47
+ if flat_master:
48
+ for model_group,master_group in zip(model_params,master_params):
49
+ if len(model_group) != 0:
50
+ for model, master in zip(model_group, _unflatten_dense_tensors(master_group[0].data, model_group)):
51
+ model.data.copy_(master)
52
+ else:
53
+ for model_group,master_group in zip(model_params,master_params):
54
+ for model, master in zip(model_group, master_group): model.data.copy_(master.data)
55
+
56
+ def grad_overflow(param_group):
57
+ for group in param_group:
58
+ for p in group:
59
+ if p.grad is not None:
60
+ s = float(p.grad.data.float().sum())
61
+ if s == float('inf') or s == float('-inf') or s != s: return True
62
+ return False
63
+
64
+ class MixedPrecision(LearnerCallback):
65
+ _order = 999 #Need to run after things that could call on_backward_begin and change the loss
66
+ "Callback that handles mixed-precision training."
67
+ def __init__(self, learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,
68
+ flat_master:bool=False, max_scale:float=2**24):
69
+ super().__init__(learn)
70
+ self.flat_master,self.dynamic,self.max_noskip,self.clip,self.max_scale = flat_master,dynamic,max_noskip,clip,max_scale
71
+ self.loss_scale = ifnone(loss_scale, 2**16 if dynamic else 512)
72
+ self.not_min += ['model_params', 'master_params']
73
+ assert torch.backends.cudnn.enabled, "Mixed precision training requires cudnn."
74
+ self.opt = None
75
+
76
+ def on_train_begin(self, **kwargs:Any)->None:
77
+ "Prepare the master model."
78
+ #Get a copy of the model params in FP32
79
+ self.model_params, self.master_params = get_master(self.learn.layer_groups, self.flat_master)
80
+ #Changes the optimizer so that the optimization step is done in FP32.
81
+ new_opt = self.learn.opt.new_with_params(self.master_params)
82
+ if self.opt is not None:
83
+ self.opt.lr,self.opt.wd = self.learn.opt.lr,self.learn.opt.wd
84
+ new_opt.load_state_dict(self.opt)
85
+ self.learn.opt.opt = new_opt.opt
86
+ self.noskip = 0
87
+
88
+ def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:
89
+ "Convert half precision output to FP32 to avoid reduction overflow."
90
+ return {'last_output': to_float(last_output)}
91
+
92
+ def on_backward_begin(self, last_loss:Rank0Tensor, **kwargs:Any) -> Rank0Tensor:
93
+ "Scale gradients up by `self.loss_scale` to prevent underflow."
94
+ #To avoid gradient underflow, we scale the gradients
95
+ ret_loss = last_loss * self.loss_scale
96
+ return {'last_loss': ret_loss}
97
+
98
+ def on_backward_end(self, **kwargs:Any)->None:
99
+ "Convert the gradients back to FP32 and divide them by the scale."
100
+ if self.dynamic and grad_overflow(self.model_params) and self.loss_scale > 1:
101
+ self.loss_scale /= 2
102
+ self.noskip = 0
103
+ #The step will be skipped since we don't update the master grads so they are all None or zero
104
+ else:
105
+ model_g2master_g(self.model_params, self.master_params, self.flat_master)
106
+ for group in self.master_params:
107
+ for param in group:
108
+ if param.grad is not None: param.grad.div_(self.loss_scale)
109
+ if self.clip is not None:
110
+ for group in self.master_params: nn.utils.clip_grad_norm_(group, self.clip)
111
+ if not self.dynamic: return
112
+ self.noskip += 1
113
+ if self.noskip >= self.max_noskip and self.loss_scale < self.max_scale:
114
+ self.loss_scale *= 2
115
+ self.noskip = 0
116
+
117
+ def on_step_end(self, **kwargs:Any)->None:
118
+ "Update the params from master to model and zero grad."
119
+ #Zeros the gradients of the model since the optimizer is disconnected.
120
+ self.learn.model.zero_grad()
121
+ #Update the params from master to model.
122
+ master2model(self.model_params, self.master_params, self.flat_master)
fastai/callbacks/general_sched.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..core import *
2
+ from ..callback import *
3
+ from ..basic_train import Learner, LearnerCallback
4
+
5
+ __all__ = ['GeneralScheduler', 'TrainingPhase']
6
+
7
+ @dataclass
8
+ class TrainingPhase():
9
+ "Schedule hyper-parameters for a phase of `length` iterations."
10
+ length:int
11
+
12
+ def __post_init__(self): self.scheds = dict()
13
+ def schedule_hp(self, name, vals, anneal=None):
14
+ "Adds a schedule for `name` between `vals` using `anneal`."
15
+ self.scheds[name] = Scheduler(vals, self.length, anneal)
16
+ return self
17
+
18
+ class GeneralScheduler(LearnerCallback):
19
+ "Schedule multiple `TrainingPhase` for a `Learner`."
20
+ def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
21
+ super().__init__(learn)
22
+ self.phases,self.start_epoch = phases,start_epoch
23
+
24
+ def on_train_begin(self, epoch:int, **kwargs:Any)->None:
25
+ "Initialize the schedulers for training."
26
+ res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
27
+ self.start_epoch = ifnone(self.start_epoch, epoch)
28
+ self.scheds = [p.scheds for p in self.phases]
29
+ self.opt = self.learn.opt
30
+ for k,v in self.scheds[0].items():
31
+ v.restart()
32
+ self.opt.set_stat(k, v.start)
33
+ self.idx_s = 0
34
+ return res
35
+
36
+ def jump_to_epoch(self, epoch:int)->None:
37
+ for _ in range(len(self.learn.data.train_dl) * epoch):
38
+ self.on_batch_end(True)
39
+
40
+ def on_batch_end(self, train, **kwargs:Any)->None:
41
+ "Take a step in lr,mom sched, start next stepper when the current one is complete."
42
+ if train:
43
+ if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
44
+ sched = self.scheds[self.idx_s]
45
+ for k,v in sched.items(): self.opt.set_stat(k, v.step())
46
+ if list(sched.values())[0].is_done: self.idx_s += 1
fastai/callbacks/hooks.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Hooks provide extensibility at the model level."
2
+ from ..torch_core import *
3
+ from ..callback import *
4
+ from ..basic_train import *
5
+ from ..basic_data import *
6
+
7
+ __all__ = ['ActivationStats', 'Hook', 'HookCallback', 'Hooks', 'hook_output', 'hook_outputs',
8
+ 'model_sizes', 'num_features_model', 'model_summary', 'dummy_eval', 'dummy_batch']
9
+
10
+ class Hook():
11
+ "Create a hook on `m` with `hook_func`."
12
+ def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
13
+ self.hook_func,self.detach,self.stored = hook_func,detach,None
14
+ f = m.register_forward_hook if is_forward else m.register_backward_hook
15
+ self.hook = f(self.hook_fn)
16
+ self.removed = False
17
+
18
+ def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
19
+ "Applies `hook_func` to `module`, `input`, `output`."
20
+ if self.detach:
21
+ input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
22
+ output = (o.detach() for o in output) if is_listy(output) else output.detach()
23
+ self.stored = self.hook_func(module, input, output)
24
+
25
+ def remove(self):
26
+ "Remove the hook from the model."
27
+ if not self.removed:
28
+ self.hook.remove()
29
+ self.removed=True
30
+
31
+ def __enter__(self, *args): return self
32
+ def __exit__(self, *args): self.remove()
33
+
34
+ class Hooks():
35
+ "Create several hooks on the modules in `ms` with `hook_func`."
36
+ def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
37
+ self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]
38
+
39
+ def __getitem__(self,i:int)->Hook: return self.hooks[i]
40
+ def __len__(self)->int: return len(self.hooks)
41
+ def __iter__(self): return iter(self.hooks)
42
+ @property
43
+ def stored(self): return [o.stored for o in self]
44
+
45
+ def remove(self):
46
+ "Remove the hooks from the model."
47
+ for h in self.hooks: h.remove()
48
+
49
+ def __enter__(self, *args): return self
50
+ def __exit__ (self, *args): self.remove()
51
+
52
+ def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)
53
+
54
+ def hook_output (module:nn.Module, detach:bool=True, grad:bool=False)->Hook:
55
+ "Return a `Hook` that stores activations of `module` in `self.stored`"
56
+ return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
57
+
58
+ def hook_outputs(modules:Collection[nn.Module], detach:bool=True, grad:bool=False)->Hooks:
59
+ "Return `Hooks` that store activations of all `modules` in `self.stored`"
60
+ return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
61
+
62
+ class HookCallback(LearnerCallback):
63
+ "Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
64
+ def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
65
+ super().__init__(learn)
66
+ self.modules,self.do_remove = modules,do_remove
67
+
68
+ def on_train_begin(self, **kwargs):
69
+ "Register the `Hooks` on `self.modules`."
70
+ if not self.modules:
71
+ self.modules = [m for m in flatten_model(self.learn.model)
72
+ if hasattr(m, 'weight')]
73
+ self.hooks = Hooks(self.modules, self.hook)
74
+
75
+ def on_train_end(self, **kwargs):
76
+ "Remove the `Hooks`."
77
+ if self.do_remove: self.remove()
78
+
79
+ def remove(self):
80
+ if getattr(self, 'hooks', None): self.hooks.remove()
81
+ def __del__(self): self.remove()
82
+
83
+ class ActivationStats(HookCallback):
84
+ "Callback that record the mean and std of activations."
85
+
86
+ def on_train_begin(self, **kwargs):
87
+ "Initialize stats."
88
+ super().on_train_begin(**kwargs)
89
+ self.stats = []
90
+
91
+ def hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
92
+ "Take the mean and std of `o`."
93
+ return o.mean().item(),o.std().item()
94
+ def on_batch_end(self, train, **kwargs):
95
+ "Take the stored results and puts it in `self.stats`"
96
+ if train: self.stats.append(self.hooks.stored)
97
+ def on_train_end(self, **kwargs):
98
+ "Polish the final result."
99
+ super().on_train_end(**kwargs)
100
+ self.stats = tensor(self.stats).permute(2,1,0)
101
+
102
+ def dummy_batch(m: nn.Module, size:tuple=(64,64))->Tensor:
103
+ "Create a dummy batch to go through `m` with `size`."
104
+ ch_in = in_channels(m)
105
+ return one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)
106
+
107
+ def dummy_eval(m:nn.Module, size:tuple=(64,64)):
108
+ "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
109
+ m.eval()
110
+ return m(dummy_batch(m, size))
111
+ #return m.eval()(dummy_batch(m, size))
112
+
113
+ def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:
114
+ "Pass a dummy input through the model `m` to get the various sizes of activations."
115
+ with hook_outputs(m) as hooks:
116
+ x = dummy_eval(m, size)
117
+ return [o.stored.shape for o in hooks]
118
+
119
+ def num_features_model(m:nn.Module)->int:
120
+ "Return the number of output features for `model`."
121
+ sz = 64
122
+ while True:
123
+ try: return model_sizes(m, size=(sz,sz))[-1][1]
124
+ except Exception as e:
125
+ sz *= 2
126
+ if sz > 2048: raise
127
+
128
+ def total_params(m:nn.Module)->int:
129
+ params, trainable = 0, False
130
+ if hasattr(m, "weight") and hasattr(m.weight, "size"):
131
+ params += m.weight.numel()
132
+ trainable = m.weight.requires_grad
133
+ if hasattr(m, "bias") and hasattr(m.bias, "size"): params += m.bias.numel()
134
+ return params, trainable
135
+
136
+ def hook_params(modules:Collection[nn.Module])->Hooks:
137
+ return Hooks(modules, lambda m, i, o: total_params(m))
138
+
139
+ def params_size(m: Union[nn.Module,Learner], size: tuple = (3, 64, 64))->Tuple[Sizes, Tensor, Hooks]:
140
+ "Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`"
141
+ if isinstance(m, Learner):
142
+ if m.data.is_empty:
143
+ raise Exception("This is an empty `Learner` and `Learner.summary` requires some data to pass through the model.")
144
+ ds_type = DatasetType.Train if m.data.train_dl else (DatasetType.Valid if m.data.valid_dl else DatasetType.Test)
145
+ x = m.data.one_batch(ds_type=ds_type, detach=False, denorm=False)[0]
146
+ x = [o[:1] for o in x] if is_listy(x) else x[:1]
147
+ m = m.model
148
+ elif isinstance(m, nn.Module): x = next(m.parameters()).new(1, *size)
149
+ else: raise TypeError('You should either pass in a Learner or nn.Module')
150
+ with hook_outputs(flatten_model(m)) as hook_o:
151
+ with hook_params(flatten_model(m))as hook_p:
152
+ x = m.eval()(*x) if is_listy(x) else m.eval()(x)
153
+ output_size = [((o.stored.shape[1:]) if o.stored is not None else None) for o in hook_o]
154
+ params = [(o.stored if o.stored is not None else (None,None)) for o in hook_p]
155
+ params, trainables = map(list,zip(*params))
156
+ return output_size, params, trainables
157
+
158
+ def get_layer_name(layer:nn.Module)->str:
159
+ return str(layer.__class__).split(".")[-1].split("'")[0]
160
+
161
+ def layers_info(m:Collection[nn.Module]) -> Collection[namedtuple]:
162
+ func = lambda m:list(map(get_layer_name, flatten_model(m)))
163
+ layers_names = func(m.model) if isinstance(m, Learner) else func(m)
164
+ layers_sizes, layers_params, layers_trainable = params_size(m)
165
+ layer_info = namedtuple('Layer_Information', ['Layer', 'OutputSize', 'Params', 'Trainable'])
166
+ return list(map(layer_info, layers_names, layers_sizes, layers_params, layers_trainable))
167
+
168
+ def model_summary(m:Learner, n:int=70):
169
+ "Print a summary of `m` using a output text width of `n` chars"
170
+ info = layers_info(m)
171
+ header = ["Layer (type)", "Output Shape", "Param #", "Trainable"]
172
+ res = m.model.__class__.__name__ + "\n"
173
+ res += "=" * n + "\n"
174
+ res += f"{header[0]:<20} {header[1]:<20} {header[2]:<10} {header[3]:<10}\n"
175
+ res += "=" * n + "\n"
176
+ total_params = 0
177
+ total_trainable_params = 0
178
+ for layer, size, params, trainable in info:
179
+ if size is None: continue
180
+ total_params += int(params)
181
+ total_trainable_params += int(params) * trainable
182
+ size, trainable = str(list(size)), str(trainable)
183
+ res += f"{layer:<20} {size:<20} {int(params):<10,} {trainable:<10}\n"
184
+ res += "_" * n + "\n"
185
+ res += f"\nTotal params: {total_params:,}\n"
186
+ res += f"Total trainable params: {total_trainable_params:,}\n"
187
+ res += f"Total non-trainable params: {total_params - total_trainable_params:,}\n"
188
+
189
+ res += f"Optimized with {str(m.opt_func)[25:-1].replace('>', '')}\n"
190
+ if m.true_wd: res += f"Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \n"
191
+ if "wd" in str(m.opt_func) or "weight_decay" in str(m.opt_func): res += f"\x1b[1;31m Specifying weight decay in the optimizer has no effect, Learner will overwrite \x1b[0m \n"
192
+ if "lr" in str(m.opt_func) or "learning_rate" in str(m.opt_func): res += f"\x1b[1;31m Specifying lr in the optimizer has no effect, pass it to fit or the defaults.lr will apply \x1b[0m \n"
193
+ res += f"Loss function : {m.loss_func.__class__.__name__}\n"
194
+ res += "=" * n + "\n"
195
+ res += "Callbacks functions applied \n"
196
+ res += "\n".join([f" {cbs.__class__.__name__}" for cbs in m.callbacks])
197
+
198
+ return PrettyString(res)
199
+
200
+ Learner.summary = model_summary
fastai/callbacks/loss_metrics.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..torch_core import *
2
+ from ..callback import *
3
+ from ..basic_train import Learner, LearnerCallback
4
+
5
+ __all__ = ['LossMetrics']
6
+
7
+ class LossMetrics(LearnerCallback):
8
+ "Add `loss_func.metrics` to metrics named by `loss_func.metric_names`"
9
+ _order = -20 #Needs to run before the recorder
10
+
11
+ def on_train_begin(self, **kwargs):
12
+ "Add the metrics names to the `Recorder`."
13
+ self.names = ifnone(self.learn.loss_func.metric_names, [])
14
+ if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided')
15
+ self.learn.recorder.add_metric_names(self.names)
16
+
17
+ def on_epoch_begin(self, **kwargs):
18
+ "Initialize the metrics for this epoch."
19
+ self.metrics = {name:0. for name in self.names}
20
+ self.nums = 0
21
+
22
+ def on_batch_end(self, last_target, train, **kwargs):
23
+ "Update the metrics if not `train`"
24
+ if train: return
25
+ bs = last_target.size(0)
26
+ for name in self.names:
27
+ self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()
28
+ self.nums += bs
29
+
30
+ def on_epoch_end(self, last_metrics, **kwargs):
31
+ "Finish the computation and sends the result to the Recorder."
32
+ if not self.nums: return
33
+ metrics = [self.metrics[name]/self.nums for name in self.names]
34
+ return {'last_metrics': last_metrics+metrics}