diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b35e164c0a6a72efe2270409ea6e0e78098e33a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,141 @@ +log/ +output/ +.vscode/ +workspace/ +tmp_occupy_memory_saves/ +run*.sh +py-thin-plate-spline +wandb/ +pretrain/ +Pytorch-Correlation-extension/ +result + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cbe5ad1670406e4402217edfb82d2c56af7e8631 --- /dev/null +++ b/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataset/range_transform.py b/dataset/range_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..990bd338f661f443145d10954dc5d2074a01f07f --- /dev/null +++ b/dataset/range_transform.py @@ -0,0 +1,44 @@ +import torchvision.transforms as transforms +import util.functional as F +import numpy as np +from skimage import color + +im_mean = (124, 116, 104) + +im_normalization = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + +inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) + +# tensor l[-1, 1] ab[-1, 1] +# numpy l[0 100] ab[-127 128] +# transforms.Normalize: x_new = (x-mean) / std +inv_lll2rgb_trans = transforms.Normalize( + mean=[-1, 0, 0], + std=[1/50., 1/110., 1/110.]) + +im_rgb2lab_normalization = transforms.Normalize( + mean=[50, 0, 0], + std=[50, 110, 110]) + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, inputs): + return F.to_mytensor(inputs) + +class RGB2Lab(object): + def __init__(self): + pass + + def __call__(self, inputs): + # default return float64 + # return color.rgb2lab(inputs) + + # return float32 + return np.float32(color.rgb2lab(inputs)) \ No newline at end of file diff --git a/dataset/reseed.py b/dataset/reseed.py new file mode 100644 index 0000000000000000000000000000000000000000..600c998fa33485c073af7f9e13e885350a5c6940 --- /dev/null +++ b/dataset/reseed.py @@ -0,0 +1,6 @@ +import torch +import random + +def reseed(seed): + random.seed(seed) + torch.manual_seed(seed) \ No newline at end of file diff --git a/dataset/static_dataset.py b/dataset/static_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f487cb2fe6b34d0db799f5fa7ab9a6459fdfbc26 --- /dev/null +++ b/dataset/static_dataset.py @@ -0,0 +1,179 @@ +import os +from os import path + +import torch +from torch.utils.data.dataset import Dataset +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from PIL import Image +import numpy as np + +from dataset.range_transform import im_normalization, im_mean +from dataset.tps import random_tps_warp +from dataset.reseed import reseed + + +class StaticTransformDataset(Dataset): + """ + Generate pseudo VOS data by applying random transforms on static images. + Single-object only. + + Method 0 - FSS style (class/1.jpg class/1.png) + Method 1 - Others style (XXX.jpg XXX.png) + """ + def __init__(self, parameters, num_frames=3, max_num_obj=1): + self.num_frames = num_frames + self.max_num_obj = max_num_obj + + self.im_list = [] + for parameter in parameters: + root, method, multiplier = parameter + if method == 0: + # Get images + classes = os.listdir(root) + for c in classes: + imgs = os.listdir(path.join(root, c)) + jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] + + joint_list = [path.join(root, c, im) for im in jpg_list] + self.im_list.extend(joint_list * multiplier) + + elif method == 1: + self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier) + + print(f'{len(self.im_list)} images found.') + + # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames + self.pair_im_lone_transform = transforms.Compose([ + transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic + ]) + + self.pair_im_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean), + transforms.Resize(384, InterpolationMode.BICUBIC), + transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean), + ]) + + self.pair_gt_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0), + transforms.Resize(384, InterpolationMode.NEAREST), + transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0), + ]) + + + # These transform are the same for all pairs in the sampled sequence + self.all_im_lone_transform = transforms.Compose([ + transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), + transforms.RandomGrayscale(0.05), + ]) + + self.all_im_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean), + transforms.RandomHorizontalFlip(), + ]) + + self.all_gt_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0), + transforms.RandomHorizontalFlip(), + ]) + + # Final transform without randomness + self.final_im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + ]) + + self.final_gt_transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + def _get_sample(self, idx): + im = Image.open(self.im_list[idx]).convert('RGB') + gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') + + sequence_seed = np.random.randint(2147483647) + + images = [] + masks = [] + for _ in range(self.num_frames): + reseed(sequence_seed) + this_im = self.all_im_dual_transform(im) + this_im = self.all_im_lone_transform(this_im) + reseed(sequence_seed) + this_gt = self.all_gt_dual_transform(gt) + + pairwise_seed = np.random.randint(2147483647) + reseed(pairwise_seed) + this_im = self.pair_im_dual_transform(this_im) + this_im = self.pair_im_lone_transform(this_im) + reseed(pairwise_seed) + this_gt = self.pair_gt_dual_transform(this_gt) + + # Use TPS only some of the times + # Not because TPS is bad -- just that it is too slow and I need to speed up data loading + if np.random.rand() < 0.33: + this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) + + this_im = self.final_im_transform(this_im) + this_gt = self.final_gt_transform(this_gt) + + images.append(this_im) + masks.append(this_gt) + + images = torch.stack(images, 0) + masks = torch.stack(masks, 0) + + return images, masks.numpy() + + def __getitem__(self, idx): + additional_objects = np.random.randint(self.max_num_obj) + indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] + + merged_images = None + merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int) + + for i, list_id in enumerate(indices): + images, masks = self._get_sample(list_id) + if merged_images is None: + merged_images = images + else: + merged_images = merged_images*(1-masks) + images*masks + merged_masks[masks[:,0]>0.5] = (i+1) + + masks = merged_masks + + labels = np.unique(masks[0]) + # Remove background + labels = labels[labels!=0] + target_objects = labels.tolist() + + # Generate one-hot ground-truth + cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int) + first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int) + for i, l in enumerate(target_objects): + this_mask = (masks==l) + cls_gt[this_mask] = i+1 + first_frame_gt[0,i] = (this_mask[0]) + cls_gt = np.expand_dims(cls_gt, 1) + + info = {} + info['name'] = self.im_list[idx] + info['num_objects'] = max(1, len(target_objects)) + + # 1 if object exist, 0 otherwise + selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] + selector = torch.FloatTensor(selector) + + data = { + 'rgb': merged_images, + 'first_frame_gt': first_frame_gt, + 'cls_gt': cls_gt, + 'selector': selector, + 'info': info + } + + return data + + + def __len__(self): + return len(self.im_list) diff --git a/dataset/tps.py b/dataset/tps.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee3747c110a8ca03169e3ece5654ba4e8abd7fe --- /dev/null +++ b/dataset/tps.py @@ -0,0 +1,37 @@ +import numpy as np +from PIL import Image +import cv2 +import thinplate as tps + +cv2.setNumThreads(0) + +def pick_random_points(h, w, n_samples): + y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False) + x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False) + return y_idx/h, x_idx/w + + +def warp_dual_cv(img, mask, c_src, c_dst): + dshape = img.shape + theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) + grid = tps.tps_grid(theta, c_dst, dshape) + mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) + return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST) + + +def random_tps_warp(img, mask, scale, n_ctrl_pts=12): + """ + Apply a random TPS warp of the input image and mask + Uses randomness from numpy + """ + img = np.asarray(img) + mask = np.asarray(mask) + + h, w = mask.shape + points = pick_random_points(h, w, n_ctrl_pts) + c_src = np.stack(points, 1) + c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape) + warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst) + + return Image.fromarray(warp_im), Image.fromarray(warp_gt) + diff --git a/dataset/util.py b/dataset/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4f1dc551cb940b095a55d15f3dfef8f77513df22 --- /dev/null +++ b/dataset/util.py @@ -0,0 +1,12 @@ +import numpy as np + +def all_to_onehot(masks, labels): + if len(masks.shape) == 3: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) + else: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) + + for ni, l in enumerate(labels): + Ms[ni] = (masks == l).astype(np.uint8) + + return Ms diff --git a/dataset/vos_dataset.py b/dataset/vos_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6140d42beff54772ca484271256eb0821233e44a --- /dev/null +++ b/dataset/vos_dataset.py @@ -0,0 +1,210 @@ +import os +from os import path, replace + +import torch +from torch.utils.data.dataset import Dataset +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from PIL import Image +import numpy as np + +from dataset.range_transform import im_normalization, im_mean, im_rgb2lab_normalization, ToTensor, RGB2Lab +from dataset.reseed import reseed + +import util.functional as F + +class VOSDataset_221128_TransColorization_batch(Dataset): + """ + Works for DAVIS/YouTubeVOS/BL30K training + For each sequence: + - Pick three frames + - Pick two objects + - Apply some random transforms that are the same for all frames + - Apply random transform to each of the frame + - The distance between frames is controlled + """ + def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=2, finetune=False): + self.im_root = im_root + self.gt_root = gt_root + self.max_jump = max_jump + self.is_bl = is_bl + self.num_frames = num_frames + self.max_num_obj = max_num_obj + + self.videos = [] + self.frames = {} + vid_list = sorted(os.listdir(self.im_root)) + # Pre-filtering + for vid in vid_list: + if subset is not None: + if vid not in subset: + continue + frames = sorted(os.listdir(os.path.join(self.im_root, vid))) + if len(frames) < num_frames: + continue + self.frames[vid] = frames + self.videos.append(vid) + + print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root)) + + # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames + self.pair_im_lone_transform = transforms.Compose([ + transforms.ColorJitter(0.01, 0.01, 0.01, 0), + ]) + + self.pair_im_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean), + ]) + + self.pair_gt_dual_transform = transforms.Compose([ + transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0), + ]) + + # These transform are the same for all pairs in the sampled sequence + self.all_im_lone_transform = transforms.Compose([ + transforms.ColorJitter(0.1, 0.03, 0.03, 0), + # transforms.RandomGrayscale(0.05), + ]) + + patchsz = 448 # 224 + self.all_im_dual_transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR) + ]) + + self.all_gt_dual_transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST) + ]) + + # Final transform without randomness + self.final_im_transform = transforms.Compose([ + RGB2Lab(), + ToTensor(), + im_rgb2lab_normalization, + ]) + + def __getitem__(self, idx): + video = self.videos[idx] + info = {} + info['name'] = video + + vid_im_path = path.join(self.im_root, video) + vid_gt_path = path.join(self.gt_root, video) + frames = self.frames[video] + + trials = 0 + while trials < 5: + info['frames'] = [] # Appended with actual frames + + num_frames = self.num_frames + length = len(frames) + this_max_jump = min(len(frames), self.max_jump) + + # iterative sampling + frames_idx = [np.random.randint(length)] + acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx)) + while(len(frames_idx) < num_frames): + idx = np.random.choice(list(acceptable_set)) + frames_idx.append(idx) + new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))) + acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx)) + + frames_idx = sorted(frames_idx) + if np.random.rand() < 0.5: + # Reverse time + frames_idx = frames_idx[::-1] + + sequence_seed = np.random.randint(2147483647) + images = [] + masks = [] + target_objects = [] + for f_idx in frames_idx: + jpg_name = frames[f_idx] + png_name = jpg_name.replace('.jpg', '.png') + info['frames'].append(jpg_name) + + reseed(sequence_seed) + this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB') + this_im = self.all_im_dual_transform(this_im) + this_im = self.all_im_lone_transform(this_im) + + reseed(sequence_seed) + this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P') + this_gt = self.all_gt_dual_transform(this_gt) + + pairwise_seed = np.random.randint(2147483647) + reseed(pairwise_seed) + this_im = self.pair_im_dual_transform(this_im) + this_im = self.pair_im_lone_transform(this_im) + + reseed(pairwise_seed) + this_gt = self.pair_gt_dual_transform(this_gt) + + this_im = self.final_im_transform(this_im) + # print('1', torch.max(this_im[:1,:,:]), torch.min(this_im[:1,:,:])) + # print('2', torch.max(this_im[1:3,:,:]), torch.min(this_im[1:3,:,:])) + # print('3', torch.max(this_im), torch.min(this_im));assert 1==0 + # print(this_im.size());assert 1==0 + + this_gt = np.array(this_gt) + + this_im_l = this_im[:1,:,:] + this_im_ab = this_im[1:3,:,:] + # print(this_im_l.size(), this_im_ab.size());assert 1==0 + + # images.append(this_im_l) + # masks.append(this_im_ab) + + this_im_lll = this_im_l.repeat(3,1,1) + images.append(this_im_lll) + masks.append(this_im_ab) + + images = torch.stack(images, 0) + # print(images.size());assert 1==0 + + # target_objects = labels.tolist() + break + + first_frame_gt = masks[0].unsqueeze(0) + # print(first_frame_gt.size());assert 1==0 + + info['num_objects'] = 2 + + masks = np.stack(masks, 0) + # print(np.shape(masks));assert 1==0 + + + cls_gt = masks + + # # Generate one-hot ground-truth + # cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int) + # first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int) + # for i, l in enumerate(target_objects): + # this_mask = (masks==l) + # cls_gt[this_mask] = i+1 + # first_frame_gt[0,i] = (this_mask[0]) + # cls_gt = np.expand_dims(cls_gt, 1) + + # 1 if object exist, 0 otherwise + selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] + + # print(info['num_objects'], self.max_num_obj, selector);assert 1==0 + + selector = torch.FloatTensor(selector) + + # print(images.size(), np.shape(first_frame_gt), np.shape(cls_gt));assert 1==0 + ### torch.Size([8, 3, 384, 384]) torch.Size([1, 2, 384, 384]) (8, 2, 384, 384) + + data = { + 'rgb': images, + 'first_frame_gt': first_frame_gt, + 'cls_gt': cls_gt, + 'selector': selector, + 'info': info, + } + + return data + + def __len__(self): + return len(self.videos) diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/data/__init__.py b/inference/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/data/mask_mapper.py b/inference/data/mask_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..686466c59cf21ea319f1ee933d7c67d1b63b7b74 --- /dev/null +++ b/inference/data/mask_mapper.py @@ -0,0 +1,67 @@ +import numpy as np +import torch + +from dataset.util import all_to_onehot + + +class MaskMapper: + """ + This class is used to convert a indexed-mask to a one-hot representation. + It also takes care of remapping non-continuous indices + It has two modes: + 1. Default. Only masks with new indices are supposed to go into the remapper. + This is also the case for YouTubeVOS. + i.e., regions with index 0 are not "background", but "don't care". + + 2. Exhaustive. Regions with index 0 are considered "background". + Every single pixel is considered to be "labeled". + """ + def __init__(self): + self.labels = [] + self.remappings = {} + + # if coherent, no mapping is required + self.coherent = True + + def convert_mask(self, mask, exhaustive=False): + # mask is in index representation, H*W numpy array + labels = np.unique(mask).astype(np.uint8) + labels = labels[labels!=0].tolist() + + new_labels = list(set(labels) - set(self.labels)) + # print('new_labels', new_labels) # [255] + if not exhaustive: + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + + # add new remappings + for i, l in enumerate(new_labels): + self.remappings[l] = i+len(self.labels)+1 + if self.coherent and i+len(self.labels)+1 != l: + self.coherent = False + + if exhaustive: + new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) + else: + if self.coherent: + new_mapped_labels = new_labels + else: + new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) + # print(list(new_mapped_labels));assert 1==0 # [1] + + self.labels.extend(new_labels) + # print(self.labels);assert 1==0 # [255] + mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() + + # mask num_objects*H*W; new_mapped_labels: [num_objects] + return mask, new_mapped_labels + + + def remap_index_mask(self, mask): + # mask is in index representation, H*W numpy array + if self.coherent: + return mask + + new_mask = np.zeros_like(mask) + for l, i in self.remappings.items(): + new_mask[mask==i] = l + return new_mask \ No newline at end of file diff --git a/inference/data/test_datasets.py b/inference/data/test_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..231852a338cabbc2102f44d89e780846ce5dc367 --- /dev/null +++ b/inference/data/test_datasets.py @@ -0,0 +1,29 @@ +import os +from os import path +import json + +from inference.data.video_reader import VideoReader_221128_TransColorization + +class DAVISTestDataset_221128_TransColorization_batch: + def __init__(self, data_root, imset='2017/val.txt', size=-1): + self.image_dir = data_root + self.mask_dir = imset + self.size_dir = data_root + self.size = size + + self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store'] + + # print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0 + + def get_datasets(self): + for video in self.vid_list: + # print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0 + yield VideoReader_221128_TransColorization(video, + path.join(self.image_dir, video), + path.join(self.mask_dir, video), + size=self.size, + size_dir=path.join(self.size_dir, video), + ) + + def __len__(self): + return len(self.vid_list) diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e05950761a9218a530e67ee4546dc2e22f4b60 --- /dev/null +++ b/inference/data/video_reader.py @@ -0,0 +1,107 @@ +import os +from os import path + +from torch.utils.data.dataset import Dataset +from torchvision import transforms +from torchvision.transforms import InterpolationMode +import torch.nn.functional as Ff +from PIL import Image +import numpy as np + +from dataset.range_transform import im_normalization, im_rgb2lab_normalization, ToTensor, RGB2Lab + +class VideoReader_221128_TransColorization(Dataset): + """ + This class is used to read a video, one frame at a time + """ + def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None): + """ + image_dir - points to a directory of jpg images + mask_dir - points to a directory of png masks + size - resize min. side to size. Does nothing if <0. + to_save - optionally contains a list of file names without extensions + where the segmentation mask is required + use_all_mask - when true, read all available mask in mask_dir. + Default false. Set to true for YouTubeVOS validation. + """ + self.vid_name = vid_name + self.image_dir = image_dir + self.mask_dir = mask_dir + self.to_save = to_save + self.use_all_mask = use_all_mask + # print('use_all_mask', use_all_mask);assert 1==0 + if size_dir is None: + self.size_dir = self.image_dir + else: + self.size_dir = size_dir + + self.frames = [img for img in sorted(os.listdir(self.image_dir)) if img.endswith('.jpg') or img.endswith('.png')] + self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() + self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) + self.suffix = self.first_gt_path.split('.')[-1] + + if size < 0: + self.im_transform = transforms.Compose([ + RGB2Lab(), + ToTensor(), + im_rgb2lab_normalization, + ]) + else: + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + transforms.Resize(size, interpolation=InterpolationMode.BILINEAR), + ]) + self.size = size + + + def __getitem__(self, idx): + frame = self.frames[idx] + info = {} + data = {} + info['frame'] = frame + info['vid_name'] = self.vid_name + info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) + + im_path = path.join(self.image_dir, frame) + img = Image.open(im_path).convert('RGB') + + if self.image_dir == self.size_dir: + shape = np.array(img).shape[:2] + else: + size_path = path.join(self.size_dir, frame) + size_im = Image.open(size_path).convert('RGB') + shape = np.array(size_im).shape[:2] + + gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None + + img = self.im_transform(img) + img_l = img[:1,:,:] + img_lll = img_l.repeat(3,1,1) + + load_mask = self.use_all_mask or (gt_path == self.first_gt_path) + if load_mask and path.exists(gt_path): + mask = Image.open(gt_path).convert('RGB') + mask = self.im_transform(mask) + mask_ab = mask[1:3,:,:] + data['mask'] = mask_ab + + info['shape'] = shape + info['need_resize'] = not (self.size < 0) + data['rgb'] = img_lll + data['info'] = info + + return data + + def resize_mask(self, mask): + # mask transform is applied AFTER mapper, so we need to post-process it in eval.py + h, w = mask.shape[-2:] + min_hw = min(h, w) + return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), + mode='nearest') + + def get_palette(self): + return self.palette + + def __len__(self): + return len(self.frames) diff --git a/inference/inference_core.py b/inference/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..d0323f90ec28860afc3f67bade617a3691d04f87 --- /dev/null +++ b/inference/inference_core.py @@ -0,0 +1,111 @@ +from inference.memory_manager import MemoryManager +from model.network import ColorMNet +from model.aggregate import aggregate + +from util.tensor_util import pad_divide_by, unpad +import torch + +class InferenceCore: + def __init__(self, network:ColorMNet, config): + self.config = config + self.network = network + self.mem_every = config['mem_every'] + self.deep_update_every = config['deep_update_every'] + self.enable_long_term = config['enable_long_term'] + + # if deep_update_every < 0, synchronize deep update with memory frame + self.deep_update_sync = (self.deep_update_every < 0) + + self.clear_memory() + self.all_labels = None + + self.last_ti_key = None + self.last_ti_value = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + if not self.deep_update_sync: + self.last_deep_update_ti = -self.deep_update_every + self.memory = MemoryManager(config=self.config) + + def update_config(self, config): + self.mem_every = config['mem_every'] + self.deep_update_every = config['deep_update_every'] + self.enable_long_term = config['enable_long_term'] + + # if deep_update_every < 0, synchronize deep update with memory frame + self.deep_update_sync = (self.deep_update_every < 0) + self.memory.update_config(config) + + def set_all_labels(self, all_labels): + # self.all_labels = [l.item() for l in all_labels] + self.all_labels = all_labels + + def step(self, image, mask=None, valid_labels=None, end=False): + # image: 3*H*W + # mask: num_objects*H*W or None + self.curr_ti += 1 + divide_by = 112 # 16 + image, self.pad = pad_divide_by(image, divide_by) + image = image.unsqueeze(0) # add the batch dimension + + is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) + need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) + is_deep_update = ( + (self.deep_update_sync and is_mem_frame) or # synchronized + (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync + ) and (not end) + is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end) + + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=(self.enable_long_term or need_segment), + need_sk=is_mem_frame) + multi_scale_features = (f16, f8, f4) + + # segment the current frame is needed + if need_segment: + memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) + + # short term memory + batch, num_objects, value_dim, h, w = self.last_ti_value.shape + last_ti_value = self.last_ti_value.flatten(start_dim=1, end_dim=2) + memory_value_short, _ = self.network.short_term_attn(key, self.last_ti_key, last_ti_value, None, key.shape[-2:]) + memory_value_short = memory_value_short.permute(1, 2, 0).view(batch, num_objects, value_dim, h, w) + memory_readout += memory_value_short + + hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, + self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) + # remove batch dim + pred_prob_with_bg = pred_prob_with_bg[0] + pred_prob_no_bg = pred_prob_with_bg + if is_normal_update: + self.memory.set_hidden(hidden) + else: + pred_prob_no_bg = pred_prob_with_bg = None + + # use the input mask if any + if mask is not None: + mask, _ = pad_divide_by(mask, divide_by) + + pred_prob_with_bg = mask + + self.memory.create_hidden_state(2, key) + + # save as memory if needed + if is_mem_frame: + value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), + pred_prob_with_bg.unsqueeze(0), is_deep_update=is_deep_update) + + self.memory.add_memory(key, shrinkage, value, self.all_labels, + selection=selection if self.enable_long_term else None) + self.last_mem_ti = self.curr_ti + + self.last_ti_key = key + self.last_ti_value = value + + if is_deep_update: + self.memory.set_hidden(hidden) + self.last_deep_update_ti = self.curr_ti + + return unpad(pred_prob_with_bg, self.pad) diff --git a/inference/interact/__init__.py b/inference/interact/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/LICENSE b/inference/interact/fbrs/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..fa0086a952236971ab37901954d596efae9f4af6 --- /dev/null +++ b/inference/interact/fbrs/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/inference/interact/fbrs/__init__.py b/inference/interact/fbrs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/controller.py b/inference/interact/fbrs/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..57a0a9b7fec9a7bc9d0b6bc605b268b662fef77b --- /dev/null +++ b/inference/interact/fbrs/controller.py @@ -0,0 +1,103 @@ +import torch + +from ..fbrs.inference import clicker +from ..fbrs.inference.predictors import get_predictor + + +class InteractiveController: + def __init__(self, net, device, predictor_params, prob_thresh=0.5): + self.net = net.to(device) + self.prob_thresh = prob_thresh + self.clicker = clicker.Clicker() + self.states = [] + self.probs_history = [] + self.object_count = 0 + self._result_mask = None + + self.image = None + self.predictor = None + self.device = device + self.predictor_params = predictor_params + self.reset_predictor() + + def set_image(self, image): + self.image = image + self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8) + self.object_count = 0 + self.reset_last_object() + + def add_click(self, x, y, is_positive): + self.states.append({ + 'clicker': self.clicker.get_state(), + 'predictor': self.predictor.get_states() + }) + + click = clicker.Click(is_positive=is_positive, coords=(y, x)) + self.clicker.add_click(click) + pred = self.predictor.get_prediction(self.clicker) + torch.cuda.empty_cache() + + if self.probs_history: + self.probs_history.append((self.probs_history[-1][0], pred)) + else: + self.probs_history.append((torch.zeros_like(pred), pred)) + + def undo_click(self): + if not self.states: + return + + prev_state = self.states.pop() + self.clicker.set_state(prev_state['clicker']) + self.predictor.set_states(prev_state['predictor']) + self.probs_history.pop() + + def partially_finish_object(self): + object_prob = self.current_object_prob + if object_prob is None: + return + + self.probs_history.append((object_prob, torch.zeros_like(object_prob))) + self.states.append(self.states[-1]) + + self.clicker.reset_clicks() + self.reset_predictor() + + def finish_object(self): + object_prob = self.current_object_prob + if object_prob is None: + return + + self.object_count += 1 + object_mask = object_prob > self.prob_thresh + self._result_mask[object_mask] = self.object_count + self.reset_last_object() + + def reset_last_object(self): + self.states = [] + self.probs_history = [] + self.clicker.reset_clicks() + self.reset_predictor() + + def reset_predictor(self, predictor_params=None): + if predictor_params is not None: + self.predictor_params = predictor_params + self.predictor = get_predictor(self.net, device=self.device, + **self.predictor_params) + if self.image is not None: + self.predictor.set_input_image(self.image) + + @property + def current_object_prob(self): + if self.probs_history: + current_prob_total, current_prob_additive = self.probs_history[-1] + return torch.maximum(current_prob_total, current_prob_additive) + else: + return None + + @property + def is_incomplete_mask(self): + return len(self.probs_history) > 0 + + @property + def result_mask(self): + return self._result_mask.clone() diff --git a/inference/interact/fbrs/inference/__init__.py b/inference/interact/fbrs/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/inference/clicker.py b/inference/interact/fbrs/inference/clicker.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ea9cf319f88639fa0af45088cdf79c8954f83a --- /dev/null +++ b/inference/interact/fbrs/inference/clicker.py @@ -0,0 +1,103 @@ +from collections import namedtuple + +import numpy as np +from copy import deepcopy +from scipy.ndimage import distance_transform_edt + +Click = namedtuple('Click', ['is_positive', 'coords']) + + +class Clicker(object): + def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1): + if gt_mask is not None: + self.gt_mask = gt_mask == 1 + self.not_ignore_mask = gt_mask != ignore_label + else: + self.gt_mask = None + + self.reset_clicks() + + if init_clicks is not None: + for click in init_clicks: + self.add_click(click) + + def make_next_click(self, pred_mask): + assert self.gt_mask is not None + click = self._get_click(pred_mask) + self.add_click(click) + + def get_clicks(self, clicks_limit=None): + return self.clicks_list[:clicks_limit] + + def _get_click(self, pred_mask, padding=True): + fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) + fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) + + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') + + fn_mask_dt = distance_transform_edt(fn_mask) + fp_mask_dt = distance_transform_edt(fp_mask) + + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + fn_mask_dt = fn_mask_dt * self.not_clicked_map + fp_mask_dt = fp_mask_dt * self.not_clicked_map + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + if is_positive: + coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] + else: + coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] + + return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) + + def add_click(self, click): + coords = click.coords + + if click.is_positive: + self.num_pos_clicks += 1 + else: + self.num_neg_clicks += 1 + + self.clicks_list.append(click) + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = False + + def _remove_last_click(self): + click = self.clicks_list.pop() + coords = click.coords + + if click.is_positive: + self.num_pos_clicks -= 1 + else: + self.num_neg_clicks -= 1 + + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = True + + def reset_clicks(self): + if self.gt_mask is not None: + self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) + + self.num_pos_clicks = 0 + self.num_neg_clicks = 0 + + self.clicks_list = [] + + def get_state(self): + return deepcopy(self.clicks_list) + + def set_state(self, state): + self.reset_clicks() + for click in state: + self.add_click(click) + + def __len__(self): + return len(self.clicks_list) diff --git a/inference/interact/fbrs/inference/evaluation.py b/inference/interact/fbrs/inference/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..6be3ed813eb257309f433ece0035e0890a82207e --- /dev/null +++ b/inference/interact/fbrs/inference/evaluation.py @@ -0,0 +1,56 @@ +from time import time + +import numpy as np +import torch + +from ..inference import utils +from ..inference.clicker import Clicker + +try: + get_ipython() + from tqdm import tqdm_notebook as tqdm +except NameError: + from tqdm import tqdm + + +def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs): + all_ious = [] + + start_time = time() + for index in tqdm(range(len(dataset)), leave=False): + sample = dataset.get_sample(index) + item = dataset[index] + + if oracle_eval: + gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32) + gt_mask = gt_mask.unsqueeze(0).unsqueeze(0) + predictor.opt_functor.mask_loss.set_gt_mask(gt_mask) + _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs) + all_ious.append(sample_ious) + end_time = time() + elapsed_time = end_time - start_time + + return all_ious, elapsed_time + + +def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr, + pred_thr=0.49, max_clicks=20): + clicker = Clicker(gt_mask=instances_mask) + pred_mask = np.zeros_like(instances_mask) + ious_list = [] + + with torch.no_grad(): + predictor.set_input_image(image_nd) + + for click_number in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker) + pred_mask = pred_probs > pred_thr + + iou = utils.get_iou(instances_mask, pred_mask) + ious_list.append(iou) + + if iou >= max_iou_thr: + break + + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs diff --git a/inference/interact/fbrs/inference/predictors/__init__.py b/inference/interact/fbrs/inference/predictors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04b8b8618cd33efabdaec69328de2f5a8a58d2f9 --- /dev/null +++ b/inference/interact/fbrs/inference/predictors/__init__.py @@ -0,0 +1,95 @@ +from .base import BasePredictor +from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor +from .brs_functors import InputOptimizer, ScaleBiasOptimizer +from ..transforms import ZoomIn +from ...model.is_hrnet_model import DistMapsHRNetModel + + +def get_predictor(net, brs_mode, device, + prob_thresh=0.49, + with_flip=True, + zoom_in_params=dict(), + predictor_params=None, + brs_opt_func_params=None, + lbfgs_params=None): + lbfgs_params_ = { + 'm': 20, + 'factr': 0, + 'pgtol': 1e-8, + 'maxfun': 20, + } + + predictor_params_ = { + 'optimize_after_n_clicks': 1 + } + + if zoom_in_params is not None: + zoom_in = ZoomIn(**zoom_in_params) + else: + zoom_in = None + + if lbfgs_params is not None: + lbfgs_params_.update(lbfgs_params) + lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] + + if brs_opt_func_params is None: + brs_opt_func_params = dict() + + if brs_mode == 'NoBRS': + if predictor_params is not None: + predictor_params_.update(predictor_params) + predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) + elif brs_mode.startswith('f-BRS'): + predictor_params_.update({ + 'net_clicks_limit': 8, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + insertion_mode = { + 'f-BRS-A': 'after_c4', + 'f-BRS-B': 'after_aspp', + 'f-BRS-C': 'after_deeplab' + }[brs_mode] + + opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + if isinstance(net, DistMapsHRNetModel): + FeaturePredictor = HRNetFeatureBRSPredictor + insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] + else: + FeaturePredictor = FeatureBRSPredictor + + predictor = FeaturePredictor(net, device, + opt_functor=opt_functor, + with_flip=with_flip, + insertion_mode=insertion_mode, + zoom_in=zoom_in, + **predictor_params_) + elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': + use_dmaps = brs_mode == 'DistMap-BRS' + + predictor_params_.update({ + 'net_clicks_limit': 5, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + opt_functor = InputOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + predictor = InputBRSPredictor(net, device, + optimize_target='dmaps' if use_dmaps else 'rgb', + opt_functor=opt_functor, + with_flip=with_flip, + zoom_in=zoom_in, + **predictor_params_) + else: + raise NotImplementedError + + return predictor diff --git a/inference/interact/fbrs/inference/predictors/base.py b/inference/interact/fbrs/inference/predictors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3776506328ef9457afdad047fb4219c5e25c3ab6 --- /dev/null +++ b/inference/interact/fbrs/inference/predictors/base.py @@ -0,0 +1,100 @@ +import torch +import torch.nn.functional as F + +from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide + + +class BasePredictor(object): + def __init__(self, net, device, + net_clicks_limit=None, + with_flip=False, + zoom_in=None, + max_size=None, + **kwargs): + self.net = net + self.with_flip = with_flip + self.net_clicks_limit = net_clicks_limit + self.original_image = None + self.device = device + self.zoom_in = zoom_in + + self.transforms = [zoom_in] if zoom_in is not None else [] + if max_size is not None: + self.transforms.append(LimitLongestSide(max_size=max_size)) + self.transforms.append(SigmoidForPred()) + if with_flip: + self.transforms.append(AddHorizontalFlip()) + + def set_input_image(self, image_nd): + for transform in self.transforms: + transform.reset() + self.original_image = image_nd.to(self.device) + if len(self.original_image.shape) == 3: + self.original_image = self.original_image.unsqueeze(0) + + def get_prediction(self, clicker): + clicks_list = clicker.get_clicks() + + image_nd, clicks_lists, is_image_changed = self.apply_transforms( + self.original_image, [clicks_list] + ) + + pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) + prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(self.transforms): + prediction = t.inv_transform(prediction) + + if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): + print('zooming') + return self.get_prediction(clicker) + + # return prediction.cpu().numpy()[0, 0] + return prediction + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + return self.net(image_nd, points_nd)['instances'] + + def _get_transform_states(self): + return [x.get_state() for x in self.transforms] + + def _set_transform_states(self, states): + assert len(states) == len(self.transforms) + for state, transform in zip(states, self.transforms): + transform.set_state(state) + + def apply_transforms(self, image_nd, clicks_lists): + is_image_changed = False + for t in self.transforms: + image_nd, clicks_lists = t.transform(image_nd, clicks_lists) + is_image_changed |= t.image_changed + + return image_nd, clicks_lists, is_image_changed + + def get_points_nd(self, clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + if self.net_clicks_limit is not None: + num_max_points = min(self.net_clicks_limit, num_max_points) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:self.net_clicks_limit] + pos_clicks = [click.coords for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)] + + neg_clicks = [click.coords for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device=self.device) + + def get_states(self): + return {'transform_states': self._get_transform_states()} + + def set_states(self, states): + self._set_transform_states(states['transform_states']) diff --git a/inference/interact/fbrs/inference/predictors/brs.py b/inference/interact/fbrs/inference/predictors/brs.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc7296e52d5e575956eec8a614682a35cff9cd7 --- /dev/null +++ b/inference/interact/fbrs/inference/predictors/brs.py @@ -0,0 +1,280 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.optimize import fmin_l_bfgs_b + +from .base import BasePredictor +from ...model.is_hrnet_model import DistMapsHRNetModel + + +class BRSBasePredictor(BasePredictor): + def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): + super().__init__(model, device, **kwargs) + self.optimize_after_n_clicks = optimize_after_n_clicks + self.opt_functor = opt_functor + + self.opt_data = None + self.input_data = None + + def set_input_image(self, image_nd): + super().set_input_image(image_nd) + self.opt_data = None + self.input_data = None + + def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): + pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + + for list_indx, clicks_list in enumerate(clicks_lists): + for click in clicks_list: + y, x = click.coords + y, x = int(round(y)), int(round(x)) + y1, x1 = y - radius, x - radius + y2, x2 = y + radius + 1, x + radius + 1 + + if click.is_positive: + pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + else: + neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + + with torch.no_grad(): + pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) + neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) + + return pos_clicks_map, neg_clicks_map + + def get_states(self): + return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.opt_data = states['opt_data'] + + +class FeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'after_deeplab': + self.num_channels = model.feature_extractor.ch + elif self.insertion_mode == 'after_c4': + self.num_channels = model.feature_extractor.aspp_in_channels + elif self.insertion_mode == 'after_aspp': + self.num_channels = model.feature_extractor.ch + 32 + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'after_c4': + x = self.net.feature_extractor.aspp(scaled_backbone_features) + x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], + align_corners=True) + x = torch.cat((x, self._c1_features), dim=1) + scaled_backbone_features = self.net.feature_extractor.head(x) + elif self.insertion_mode == 'after_aspp': + scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) + + pred_logits = self.net.head(scaled_backbone_features) + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + coord_features = self.net.dist_maps(image_nd, points) + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': + c1, _, c3, c4 = self.net.feature_extractor.backbone(x) + c1 = self.net.feature_extractor.skip_project(c1) + + if self.insertion_mode == 'after_aspp': + x = self.net.feature_extractor.aspp(c4) + x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + backbone_features = x + else: + backbone_features = c4 + self._c1_features = c1 + else: + backbone_features = self.net.feature_extractor(x)[0] + + return backbone_features + + +class HRNetFeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'A': + self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) + elif self.insertion_mode == 'C': + self.num_channels = 2 * model.feature_extractor.ocr_width + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'A': + out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) + feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + feats = self.net.feature_extractor.ocr_distri_head(feats, context) + pred_logits = self.net.feature_extractor.cls_head(feats) + elif self.insertion_mode == 'C': + pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) + else: + raise NotImplementedError + + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + coord_features = self.net.dist_maps(image_nd, points) + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + feats = self.net.feature_extractor.compute_hrnet_feats(x) + if self.insertion_mode == 'A': + backbone_features = feats + elif self.insertion_mode == 'C': + out_aux = self.net.feature_extractor.aux_head(feats) + feats = self.net.feature_extractor.conv3x3_ocr(feats) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + raise NotImplementedError + + return backbone_features + + +class InputBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.optimize_target = optimize_target + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + + if self.opt_data is None or is_image_changed: + opt_channels = 2 if self.optimize_target == 'dmaps' else 3 + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), + device=self.device, dtype=torch.float32) + + def get_prediction_logits(opt_bias): + input_image = image_nd + if self.optimize_target == 'rgb': + input_image = input_image + opt_bias + dmaps = self.net.dist_maps(input_image, points_nd) + if self.optimize_target == 'dmaps': + dmaps = dmaps + opt_bias + + x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) + if self.optimize_target == 'all': + x = x + opt_bias + + if isinstance(self.net, DistMapsHRNetModel): + pred_logits = self.net.feature_extractor(x)[0] + else: + backbone_features = self.net.feature_extractor(x) + pred_logits = self.net.head(backbone_features[0]) + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) + + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, + shape=self.opt_data.shape) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), + **self.opt_functor.optimizer_params) + + self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits diff --git a/inference/interact/fbrs/inference/predictors/brs_functors.py b/inference/interact/fbrs/inference/predictors/brs_functors.py new file mode 100644 index 0000000000000000000000000000000000000000..92a5d9915ac0516144430db4ce06d1477319e694 --- /dev/null +++ b/inference/interact/fbrs/inference/predictors/brs_functors.py @@ -0,0 +1,109 @@ +import torch +import numpy as np + +from ...model.metrics import _compute_iou +from .brs_losses import BRSMaskLoss + + +class BaseOptimizer: + def __init__(self, optimizer_params, + prob_thresh=0.49, + reg_weight=1e-3, + min_iou_diff=0.01, + brs_loss=BRSMaskLoss(), + with_flip=False, + flip_average=False, + **kwargs): + self.brs_loss = brs_loss + self.optimizer_params = optimizer_params + self.prob_thresh = prob_thresh + self.reg_weight = reg_weight + self.min_iou_diff = min_iou_diff + self.with_flip = with_flip + self.flip_average = flip_average + + self.best_prediction = None + self._get_prediction_logits = None + self._opt_shape = None + self._best_loss = None + self._click_masks = None + self._last_mask = None + self.device = None + + def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): + self.best_prediction = None + self._get_prediction_logits = get_prediction_logits + self._click_masks = (pos_mask, neg_mask) + self._opt_shape = shape + self._last_mask = None + self.device = device + + def __call__(self, x): + opt_params = torch.from_numpy(x).float().to(self.device) + opt_params.requires_grad_(True) + + with torch.enable_grad(): + opt_vars, reg_loss = self.unpack_opt_params(opt_params) + result_before_sigmoid = self._get_prediction_logits(*opt_vars) + result = torch.sigmoid(result_before_sigmoid) + + pos_mask, neg_mask = self._click_masks + if self.with_flip and self.flip_average: + result, result_flipped = torch.chunk(result, 2, dim=0) + result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) + pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] + + loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) + loss = loss + reg_loss + + f_val = loss.detach().cpu().numpy() + if self.best_prediction is None or f_val < self._best_loss: + self.best_prediction = result_before_sigmoid.detach() + self._best_loss = f_val + + if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: + return [f_val, np.zeros_like(x)] + + current_mask = result > self.prob_thresh + if self._last_mask is not None and self.min_iou_diff > 0: + diff_iou = _compute_iou(current_mask, self._last_mask) + if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: + return [f_val, np.zeros_like(x)] + self._last_mask = current_mask + + loss.backward() + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + + return [f_val, f_grad] + + def unpack_opt_params(self, opt_params): + raise NotImplementedError + + +class InputOptimizer(BaseOptimizer): + def unpack_opt_params(self, opt_params): + opt_params = opt_params.view(self._opt_shape) + if self.with_flip: + opt_params_flipped = torch.flip(opt_params, dims=[3]) + opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) + reg_loss = self.reg_weight * torch.sum(opt_params**2) + + return (opt_params,), reg_loss + + +class ScaleBiasOptimizer(BaseOptimizer): + def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): + super().__init__(*args, **kwargs) + self.scale_act = scale_act + self.reg_bias_weight = reg_bias_weight + + def unpack_opt_params(self, opt_params): + scale, bias = torch.chunk(opt_params, 2, dim=0) + reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) + + if self.scale_act == 'tanh': + scale = torch.tanh(scale) + elif self.scale_act == 'sin': + scale = torch.sin(scale) + + return (1 + scale, bias), reg_loss diff --git a/inference/interact/fbrs/inference/predictors/brs_losses.py b/inference/interact/fbrs/inference/predictors/brs_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9998ab120b9987e79509d0ee594e8b6c431a9f --- /dev/null +++ b/inference/interact/fbrs/inference/predictors/brs_losses.py @@ -0,0 +1,58 @@ +import torch + +from ...model.losses import SigmoidBinaryCrossEntropyLoss + + +class BRSMaskLoss(torch.nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self._eps = eps + + def forward(self, result, pos_mask, neg_mask): + pos_diff = (1 - result) * pos_mask + pos_target = torch.sum(pos_diff ** 2) + pos_target = pos_target / (torch.sum(pos_mask) + self._eps) + + neg_diff = result * neg_mask + neg_target = torch.sum(neg_diff ** 2) + neg_target = neg_target / (torch.sum(neg_mask) + self._eps) + + loss = pos_target + neg_target + + with torch.no_grad(): + f_max_pos = torch.max(torch.abs(pos_diff)).item() + f_max_neg = torch.max(torch.abs(neg_diff)).item() + + return loss, f_max_pos, f_max_neg + + +class OracleMaskLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.gt_mask = None + self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) + self.predictor = None + self.history = [] + + def set_gt_mask(self, gt_mask): + self.gt_mask = gt_mask + self.history = [] + + def forward(self, result, pos_mask, neg_mask): + gt_mask = self.gt_mask.to(result.device) + if self.predictor.object_roi is not None: + r1, r2, c1, c2 = self.predictor.object_roi[:4] + gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] + gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) + + if result.shape[0] == 2: + gt_mask_flipped = torch.flip(gt_mask, dims=[3]) + gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) + + loss = self.loss(result, gt_mask) + self.history.append(loss.detach().cpu().numpy()[0]) + + if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: + return 0, 0, 0 + + return loss, 1.0, 1.0 diff --git a/inference/interact/fbrs/inference/transforms/__init__.py b/inference/interact/fbrs/inference/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd54e38a2f84b3fef481672a7ceab070eb01b82 --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/__init__.py @@ -0,0 +1,5 @@ +from .base import SigmoidForPred +from .flip import AddHorizontalFlip +from .zoom_in import ZoomIn +from .limit_longest_side import LimitLongestSide +from .crops import Crops diff --git a/inference/interact/fbrs/inference/transforms/base.py b/inference/interact/fbrs/inference/transforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5a2deb3c44f5aed7530fd1e299fff1273737b8 --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/base.py @@ -0,0 +1,38 @@ +import torch + + +class BaseTransform(object): + def __init__(self): + self.image_changed = False + + def transform(self, image_nd, clicks_lists): + raise NotImplementedError + + def inv_transform(self, prob_map): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def set_state(self, state): + raise NotImplementedError + + +class SigmoidForPred(BaseTransform): + def transform(self, image_nd, clicks_lists): + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + return torch.sigmoid(prob_map) + + def reset(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass diff --git a/inference/interact/fbrs/inference/transforms/crops.py b/inference/interact/fbrs/inference/transforms/crops.py new file mode 100644 index 0000000000000000000000000000000000000000..0910a2825608cf3fa761212d182dc1e8e5c242c4 --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/crops.py @@ -0,0 +1,97 @@ +import math + +import torch +import numpy as np + +from ...inference.clicker import Click +from .base import BaseTransform + + +class Crops(BaseTransform): + def __init__(self, crop_size=(320, 480), min_overlap=0.2): + super().__init__() + self.crop_height, self.crop_width = crop_size + self.min_overlap = min_overlap + + self.x_offsets = None + self.y_offsets = None + self._counts = None + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_height, image_width = image_nd.shape[2:4] + self._counts = None + + if image_height < self.crop_height or image_width < self.crop_width: + return image_nd, clicks_lists + + self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) + self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) + self._counts = np.zeros((image_height, image_width)) + + image_crops = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 + image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] + image_crops.append(image_crop) + image_crops = torch.cat(image_crops, dim=0) + self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) + + clicks_list = clicks_lists[0] + clicks_lists = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx)) + for x in clicks_list] + clicks_lists.append(crop_clicks) + + return image_crops, clicks_lists + + def inv_transform(self, prob_map): + if self._counts is None: + return prob_map + + new_prob_map = torch.zeros((1, 1, *self._counts.shape), + dtype=prob_map.dtype, device=prob_map.device) + + crop_indx = 0 + for dy in self.y_offsets: + for dx in self.x_offsets: + new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] + crop_indx += 1 + new_prob_map = torch.div(new_prob_map, self._counts) + + return new_prob_map + + def get_state(self): + return self.x_offsets, self.y_offsets, self._counts + + def set_state(self, state): + self.x_offsets, self.y_offsets, self._counts = state + + def reset(self): + self.x_offsets = None + self.y_offsets = None + self._counts = None + + +def get_offsets(length, crop_size, min_overlap_ratio=0.2): + if length == crop_size: + return [0] + + N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) + N = math.ceil(N) + + overlap_ratio = (N - length / crop_size) / (N - 1) + overlap_width = int(crop_size * overlap_ratio) + + offsets = [0] + for i in range(1, N): + new_offset = offsets[-1] + crop_size - overlap_width + if new_offset + crop_size > length: + new_offset = length - crop_size + + offsets.append(new_offset) + + return offsets diff --git a/inference/interact/fbrs/inference/transforms/flip.py b/inference/interact/fbrs/inference/transforms/flip.py new file mode 100644 index 0000000000000000000000000000000000000000..c1543cb65f8d3892054dc96f39a8196987fb6bfd --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/flip.py @@ -0,0 +1,37 @@ +import torch + +from ..clicker import Click +from .base import BaseTransform + + +class AddHorizontalFlip(BaseTransform): + def transform(self, image_nd, clicks_lists): + assert len(image_nd.shape) == 4 + image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) + + image_width = image_nd.shape[3] + clicks_lists_flipped = [] + for clicks_list in clicks_lists: + clicks_list_flipped = [Click(is_positive=click.is_positive, + coords=(click.coords[0], image_width - click.coords[1] - 1)) + for click in clicks_list] + clicks_lists_flipped.append(clicks_list_flipped) + clicks_lists = clicks_lists + clicks_lists_flipped + + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 + num_maps = prob_map.shape[0] // 2 + prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] + + return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) + + def get_state(self): + return None + + def set_state(self, state): + pass + + def reset(self): + pass diff --git a/inference/interact/fbrs/inference/transforms/limit_longest_side.py b/inference/interact/fbrs/inference/transforms/limit_longest_side.py new file mode 100644 index 0000000000000000000000000000000000000000..50c5a53d2670df52285621dc0d33e86df520d77c --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/limit_longest_side.py @@ -0,0 +1,22 @@ +from .zoom_in import ZoomIn, get_roi_image_nd + + +class LimitLongestSide(ZoomIn): + def __init__(self, max_size=800): + super().__init__(target_size=max_size, skip_clicks=0) + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_max_size = max(image_nd.shape[2:4]) + self.image_changed = False + + if image_max_size <= self.target_size: + return image_nd, clicks_lists + self._input_image = image_nd + + self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + self.image_changed = True + + tclicks_lists = [self._transform_clicks(clicks_lists[0])] + return self._roi_image, tclicks_lists diff --git a/inference/interact/fbrs/inference/transforms/zoom_in.py b/inference/interact/fbrs/inference/transforms/zoom_in.py new file mode 100644 index 0000000000000000000000000000000000000000..6c11ecc241570fe2429e85bdccbb713a70d9ffd6 --- /dev/null +++ b/inference/interact/fbrs/inference/transforms/zoom_in.py @@ -0,0 +1,171 @@ +import torch + +from ..clicker import Click +from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox +from .base import BaseTransform + + +class ZoomIn(BaseTransform): + def __init__(self, + target_size=400, + skip_clicks=1, + expansion_ratio=1.4, + min_crop_size=200, + recompute_thresh_iou=0.5, + prob_thresh=0.50): + super().__init__() + self.target_size = target_size + self.min_crop_size = min_crop_size + self.skip_clicks = skip_clicks + self.expansion_ratio = expansion_ratio + self.recompute_thresh_iou = recompute_thresh_iou + self.prob_thresh = prob_thresh + + self._input_image_shape = None + self._prev_probs = None + self._object_roi = None + self._roi_image = None + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + self.image_changed = False + + clicks_list = clicks_lists[0] + if len(clicks_list) <= self.skip_clicks: + return image_nd, clicks_lists + + self._input_image_shape = image_nd.shape + + current_object_roi = None + if self._prev_probs is not None: + current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if current_pred_mask.sum() > 0: + current_object_roi = get_object_roi(current_pred_mask, clicks_list, + self.expansion_ratio, self.min_crop_size) + + if current_object_roi is None: + return image_nd, clicks_lists + + update_object_roi = False + if self._object_roi is None: + update_object_roi = True + elif not check_object_roi(self._object_roi, clicks_list): + update_object_roi = True + elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: + update_object_roi = True + + if update_object_roi: + self._object_roi = current_object_roi + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + self.image_changed = True + + tclicks_lists = [self._transform_clicks(clicks_list)] + return self._roi_image.to(image_nd.device), tclicks_lists + + def inv_transform(self, prob_map): + if self._object_roi is None: + self._prev_probs = prob_map.cpu().numpy() + return prob_map + + assert prob_map.shape[0] == 1 + rmin, rmax, cmin, cmax = self._object_roi + prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), + mode='bilinear', align_corners=True) + + if self._prev_probs is not None: + new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) + new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map + else: + new_prob_map = prob_map + + self._prev_probs = new_prob_map.cpu().numpy() + + return new_prob_map + + def check_possible_recalculation(self): + if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: + return False + + pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if pred_mask.sum() > 0: + possible_object_roi = get_object_roi(pred_mask, [], + self.expansion_ratio, self.min_crop_size) + image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) + if get_bbox_iou(possible_object_roi, image_roi) < 0.50: + return True + return False + + def get_state(self): + roi_image = self._roi_image.cpu() if self._roi_image is not None else None + return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed + + def set_state(self, state): + self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state + + def reset(self): + self._input_image_shape = None + self._object_roi = None + self._prev_probs = None + self._roi_image = None + self.image_changed = False + + def _transform_clicks(self, clicks_list): + if self._object_roi is None: + return clicks_list + + rmin, rmax, cmin, cmax = self._object_roi + crop_height, crop_width = self._roi_image.shape[2:] + + transformed_clicks = [] + for click in clicks_list: + new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) + new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) + transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c))) + return transformed_clicks + + +def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): + pred_mask = pred_mask.copy() + + for click in clicks_list: + if click.is_positive: + pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 + + bbox = get_bbox_from_mask(pred_mask) + bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) + h, w = pred_mask.shape[0], pred_mask.shape[1] + bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) + + return bbox + + +def get_roi_image_nd(image_nd, object_roi, target_size): + rmin, rmax, cmin, cmax = object_roi + + height = rmax - rmin + 1 + width = cmax - cmin + 1 + + if isinstance(target_size, tuple): + new_height, new_width = target_size + else: + scale = target_size / max(height, width) + new_height = int(round(height * scale)) + new_width = int(round(width * scale)) + + with torch.no_grad(): + roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] + roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), + mode='bilinear', align_corners=True) + + return roi_image_nd + + +def check_object_roi(object_roi, clicks_list): + for click in clicks_list: + if click.is_positive: + if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: + return False + if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: + return False + + return True diff --git a/inference/interact/fbrs/inference/utils.py b/inference/interact/fbrs/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d29b890294d16a9b780f79a2629505c10ff1cee --- /dev/null +++ b/inference/interact/fbrs/inference/utils.py @@ -0,0 +1,177 @@ +from datetime import timedelta +from pathlib import Path + +import torch +import numpy as np + +from ..model.is_deeplab_model import get_deeplab_model +from ..model.is_hrnet_model import get_hrnet_model + + +def get_time_metrics(all_ious, elapsed_time): + n_images = len(all_ious) + n_clicks = sum(map(len, all_ious)) + + mean_spc = elapsed_time / n_clicks + mean_spi = elapsed_time / n_images + + return mean_spc, mean_spi + + +def load_is_model(checkpoint, device, backbone='auto', **kwargs): + if isinstance(checkpoint, (str, Path)): + state_dict = torch.load(checkpoint, map_location='cpu') + else: + state_dict = checkpoint + + if backbone == 'auto': + for k in state_dict.keys(): + if 'feature_extractor.stage2.0.branches' in k: + return load_hrnet_is_model(state_dict, device, backbone, **kwargs) + return load_deeplab_is_model(state_dict, device, backbone, **kwargs) + elif 'resnet' in backbone: + return load_deeplab_is_model(state_dict, device, backbone, **kwargs) + elif 'hrnet' in backbone: + return load_hrnet_is_model(state_dict, device, backbone, **kwargs) + else: + raise NotImplementedError('Unknown backbone') + + +def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256, + small=False, cpu_dist_maps=False, norm_radius=260): + if backbone == 'auto': + num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x]) + small = num_fe_weights < 1800 + + ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k] + assert len(ocr_f_down) == 1 + ocr_width = ocr_f_down[0].shape[0] + + s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k] + assert len(s2_conv1_w) == 1 + width = s2_conv1_w[0].shape[0] + + model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small, + with_aux_output=False, cpu_dist_maps=cpu_dist_maps, + norm_radius=norm_radius) + + model.load_state_dict(state_dict, strict=False) + for param in model.parameters(): + param.requires_grad = False + model.to(device) + model.eval() + + return model + + +def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2, + cpu_dist_maps=False, norm_radius=260): + if backbone == 'auto': + num_backbone_params = len([x for x in state_dict.keys() + if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)]) + + if num_backbone_params <= 181: + backbone = 'resnet34' + elif num_backbone_params <= 276: + backbone = 'resnet50' + elif num_backbone_params <= 531: + backbone = 'resnet101' + else: + raise NotImplementedError('Unknown backbone') + + if 'aspp_dropout' in state_dict: + aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy()) + else: + aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0] + deeplab_ch = aspp_project_weight.size(0) + if deeplab_ch == 256: + aspp_dropout = 0.5 + + model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch, + aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps, + norm_radius=norm_radius) + + model.load_state_dict(state_dict, strict=False) + for param in model.parameters(): + param.requires_grad = False + model.to(device) + model.eval() + + return model + + +def get_iou(gt_mask, pred_mask, ignore_label=-1): + ignore_gt_mask_inv = gt_mask != ignore_label + obj_gt_mask = gt_mask == 1 + + intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + + return intersection / union + + +def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): + def _get_noc(iou_arr, iou_thr): + vals = iou_arr >= iou_thr + return np.argmax(vals) + 1 if np.any(vals) else max_clicks + + noc_list = [] + over_max_list = [] + for iou_thr in iou_thrs: + scores_arr = np.array([_get_noc(iou_arr, iou_thr) + for iou_arr in all_ious], dtype=np.int) + + score = scores_arr.mean() + over_max = (scores_arr == max_clicks).sum() + + noc_list.append(score) + over_max_list.append(over_max) + + return noc_list, over_max_list + + +def find_checkpoint(weights_folder, checkpoint_name): + weights_folder = Path(weights_folder) + if ':' in checkpoint_name: + model_name, checkpoint_name = checkpoint_name.split(':') + models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] + assert len(models_candidates) == 1 + model_folder = models_candidates[0] + else: + model_folder = weights_folder + + if checkpoint_name.endswith('.pth'): + if Path(checkpoint_name).exists(): + checkpoint_path = checkpoint_name + else: + checkpoint_path = weights_folder / checkpoint_name + else: + model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) + assert len(model_checkpoints) == 1 + checkpoint_path = model_checkpoints[0] + + return str(checkpoint_path) + + +def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, + n_clicks=20, model_name=None): + table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|') + row_width = len(table_header) + + header = f'Eval results for model: {model_name}\n' if model_name is not None else '' + header += '-' * row_width + '\n' + header += table_header + '\n' + '-' * row_width + + eval_time = str(timedelta(seconds=int(elapsed_time))) + table_row = f'|{brs_type:^13}|{dataset_name:^11}|' + table_row += f'{noc_list[0]:^9.2f}|' + table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' + + return header, table_row \ No newline at end of file diff --git a/inference/interact/fbrs/model/__init__.py b/inference/interact/fbrs/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/model/initializer.py b/inference/interact/fbrs/model/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..470c7df4659bc1e80ceec80a170b3b2e0302fb84 --- /dev/null +++ b/inference/interact/fbrs/model/initializer.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import numpy as np + + +class Initializer(object): + def __init__(self, local_init=True, gamma=None): + self.local_init = local_init + self.gamma = gamma + + def __call__(self, m): + if getattr(m, '__initialized', False): + return + + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: + if m.weight is not None: + self._init_gamma(m.weight.data) + if m.bias is not None: + self._init_beta(m.bias.data) + else: + if getattr(m, 'weight', None) is not None: + self._init_weight(m.weight.data) + if getattr(m, 'bias', None) is not None: + self._init_bias(m.bias.data) + + if self.local_init: + object.__setattr__(m, '__initialized', True) + + def _init_weight(self, data): + nn.init.uniform_(data, -0.07, 0.07) + + def _init_bias(self, data): + nn.init.constant_(data, 0) + + def _init_gamma(self, data): + if self.gamma is None: + nn.init.constant_(data, 1.0) + else: + nn.init.normal_(data, 1.0, self.gamma) + + def _init_beta(self, data): + nn.init.constant_(data, 0) + + +class Bilinear(Initializer): + def __init__(self, scale, groups, in_channels, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.groups = groups + self.in_channels = in_channels + + def _init_weight(self, data): + """Reset the weight and bias.""" + bilinear_kernel = self.get_bilinear_kernel(self.scale) + weight = torch.zeros_like(data) + for i in range(self.in_channels): + if self.groups == 1: + j = i + else: + j = 0 + weight[i, j] = bilinear_kernel + data[:] = weight + + @staticmethod + def get_bilinear_kernel(scale): + """Generate a bilinear upsampling kernel.""" + kernel_size = 2 * scale - scale % 2 + scale = (kernel_size + 1) // 2 + center = scale - 0.5 * (1 + kernel_size % 2) + + og = np.ogrid[:kernel_size, :kernel_size] + kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) + + return torch.tensor(kernel, dtype=torch.float32) + + +class XavierGluon(Initializer): + def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): + super().__init__(**kwargs) + + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = float(magnitude) + + def _init_weight(self, arr): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) + + if self.factor_type == 'avg': + factor = (fan_in + fan_out) / 2.0 + elif self.factor_type == 'in': + factor = fan_in + elif self.factor_type == 'out': + factor = fan_out + else: + raise ValueError('Incorrect factor type') + scale = np.sqrt(self.magnitude / factor) + + if self.rnd_type == 'uniform': + nn.init.uniform_(arr, -scale, scale) + elif self.rnd_type == 'gaussian': + nn.init.normal_(arr, 0, scale) + else: + raise ValueError('Unknown random type') diff --git a/inference/interact/fbrs/model/is_deeplab_model.py b/inference/interact/fbrs/model/is_deeplab_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a75cc0f56c1a068dc742f65a42a6ec85e9ad83 --- /dev/null +++ b/inference/interact/fbrs/model/is_deeplab_model.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn + +from .ops import DistMaps +from .modeling.deeplab_v3 import DeepLabV3Plus +from .modeling.basic_blocks import SepConvHead + + +def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, + norm_layer=nn.BatchNorm2d, backbone_norm_layer=None, + use_rgb_conv=True, cpu_dist_maps=False, + norm_radius=260): + model = DistMapsModel( + feature_extractor=DeepLabV3Plus(backbone=backbone, + ch=deeplab_ch, + project_dropout=aspp_dropout, + norm_layer=norm_layer, + backbone_norm_layer=backbone_norm_layer), + head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, + num_layers=2, norm_layer=norm_layer), + use_rgb_conv=use_rgb_conv, + norm_layer=norm_layer, + norm_radius=norm_radius, + cpu_dist_maps=cpu_dist_maps + ) + + return model + + +class DistMapsModel(nn.Module): + def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True, + cpu_dist_maps=False, norm_radius=260): + super(DistMapsModel, self).__init__() + + if use_rgb_conv: + self.rgb_conv = nn.Sequential( + nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2), + norm_layer(8), + nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), + ) + else: + self.rgb_conv = None + + self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps) + self.feature_extractor = feature_extractor + self.head = head + + def forward(self, image, points): + coord_features = self.dist_maps(image, points) + + if self.rgb_conv is not None: + x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) + else: + c1, c2 = torch.chunk(coord_features, 2, dim=1) + c3 = torch.ones_like(c1) + coord_features = torch.cat((c1, c2, c3), dim=1) + x = 0.8 * image * coord_features + 0.2 * image + + backbone_features = self.feature_extractor(x) + instance_out = self.head(backbone_features[0]) + instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], + mode='bilinear', align_corners=True) + + return {'instances': instance_out} + + def load_weights(self, path_to_weights): + current_state_dict = self.state_dict() + new_state_dict = torch.load(path_to_weights, map_location='cpu') + current_state_dict.update(new_state_dict) + self.load_state_dict(current_state_dict) + + def get_trainable_params(self): + backbone_params = nn.ParameterList() + other_params = nn.ParameterList() + + for name, param in self.named_parameters(): + if param.requires_grad: + if 'backbone' in name: + backbone_params.append(param) + else: + other_params.append(param) + return backbone_params, other_params + + diff --git a/inference/interact/fbrs/model/is_hrnet_model.py b/inference/interact/fbrs/model/is_hrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ced540a782c7b6e5b498d2e345faa95cb4015f4c --- /dev/null +++ b/inference/interact/fbrs/model/is_hrnet_model.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from .ops import DistMaps +from .modeling.hrnet_ocr import HighResolutionNet + + +def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260, + use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False, + norm_layer=nn.BatchNorm2d): + model = DistMapsHRNetModel( + feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small, + num_classes=1, norm_layer=norm_layer), + use_rgb_conv=use_rgb_conv, + with_aux_output=with_aux_output, + norm_layer=norm_layer, + norm_radius=norm_radius, + cpu_dist_maps=cpu_dist_maps + ) + + return model + + +class DistMapsHRNetModel(nn.Module): + def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False, + norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False): + super(DistMapsHRNetModel, self).__init__() + self.with_aux_output = with_aux_output + + if use_rgb_conv: + self.rgb_conv = nn.Sequential( + nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2), + norm_layer(8), + nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), + ) + else: + self.rgb_conv = None + + self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps) + self.feature_extractor = feature_extractor + + def forward(self, image, points): + coord_features = self.dist_maps(image, points) + + if self.rgb_conv is not None: + x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) + else: + c1, c2 = torch.chunk(coord_features, 2, dim=1) + c3 = torch.ones_like(c1) + coord_features = torch.cat((c1, c2, c3), dim=1) + x = 0.8 * image * coord_features + 0.2 * image + + feature_extractor_out = self.feature_extractor(x) + instance_out = feature_extractor_out[0] + instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], + mode='bilinear', align_corners=True) + outputs = {'instances': instance_out} + if self.with_aux_output: + instance_aux_out = feature_extractor_out[1] + instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:], + mode='bilinear', align_corners=True) + outputs['instances_aux'] = instance_aux_out + + return outputs + + def load_weights(self, path_to_weights): + current_state_dict = self.state_dict() + new_state_dict = torch.load(path_to_weights) + current_state_dict.update(new_state_dict) + self.load_state_dict(current_state_dict) + + def get_trainable_params(self): + backbone_params = nn.ParameterList() + other_params = nn.ParameterList() + other_params_keys = [] + nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head'] + + for name, param in self.named_parameters(): + if param.requires_grad: + if any(x in name for x in nonbackbone_keywords): + other_params.append(param) + other_params_keys.append(name) + else: + backbone_params.append(param) + print('Nonbackbone params:', sorted(other_params_keys)) + return backbone_params, other_params diff --git a/inference/interact/fbrs/model/losses.py b/inference/interact/fbrs/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..fd89bf02b108533bc8c5639f233549d7387d3dbc --- /dev/null +++ b/inference/interact/fbrs/model/losses.py @@ -0,0 +1,134 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import misc + + +class NormalizedFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, + from_logits=False, batch_axis=0, + weight=None, size_average=True, detach_delimeter=True, + eps=1e-12, scale=1.0, + ignore_label=-1): + super(NormalizedFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._scale = scale + self._from_logits = from_logits + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._k_sum = 0 + + def forward(self, pred, label, sample_weight=None): + one_hot = label > 0 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(one_hot, pred, 1 - pred) + pt = torch.where(sample_weight, pt, torch.ones_like(pt)) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) + beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + + ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() + sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + if np.any(ignore_area == 0): + self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return self._scale * loss + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + + +class FocalLoss(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, + from_logits=False, batch_axis=0, + weight=None, num_class=None, + eps=1e-9, size_average=True, scale=1.0): + super(FocalLoss, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._scale = scale + self._num_class = num_class + self._from_logits = from_logits + self._eps = eps + self._size_average = size_average + + def forward(self, pred, label, sample_weight=None): + if not self._from_logits: + pred = F.sigmoid(pred) + + one_hot = label > 0 + pt = torch.where(one_hot, pred, 1 - pred) + + t = label != -1 + alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t) + beta = (1 - pt) ** self._gamma + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + sample_weight = label != -1 + + loss = self._weight * (loss * sample_weight) + + if self._size_average: + tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return self._scale * loss + + +class SigmoidBinaryCrossEntropyLoss(nn.Module): + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): + super(SigmoidBinaryCrossEntropyLoss, self).__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + label = torch.where(sample_weight, label, torch.zeros_like(label)) + + if not self._from_sigmoid: + loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) + else: + eps = 1e-12 + loss = -(torch.log(pred + eps) * label + + torch.log(1. - pred + eps) * (1. - label)) + + loss = self._weight * (loss * sample_weight) + return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) diff --git a/inference/interact/fbrs/model/metrics.py b/inference/interact/fbrs/model/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9944feb1cf76cfb8707122c7a6ea7a830c02070a --- /dev/null +++ b/inference/interact/fbrs/model/metrics.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +from ..utils import misc + + +class TrainMetric(object): + def __init__(self, pred_outputs, gt_outputs): + self.pred_outputs = pred_outputs + self.gt_outputs = gt_outputs + + def update(self, *args, **kwargs): + raise NotImplementedError + + def get_epoch_value(self): + raise NotImplementedError + + def reset_epoch_stats(self): + raise NotImplementedError + + def log_states(self, sw, tag_prefix, global_step): + pass + + @property + def name(self): + return type(self).__name__ + + +class AdaptiveIoU(TrainMetric): + def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = init_thresh + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = 0.0 + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def update(self, pred, gt): + gt_mask = gt > 0 + if self._from_logits: + pred = torch.sigmoid(pred) + + gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() + if np.all(gt_mask_area == 0): + return + + ignore_mask = gt == self._ignore_label + max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() + best_thresh = self._iou_thresh + for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: + temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() + if temp_iou > max_iou: + max_iou = temp_iou + best_thresh = t + + self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou + self._epoch_iou_sum += max_iou + self._epoch_batch_count += 1 + + def get_epoch_value(self): + if self._epoch_batch_count > 0: + return self._epoch_iou_sum / self._epoch_batch_count + else: + return 0.0 + + def reset_epoch_stats(self): + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) + sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) + + @property + def iou_thresh(self): + return self._iou_thresh + + +def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): + if ignore_mask is not None: + pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) + + reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) + union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + nonzero = union > 0 + + iou = intersection[nonzero] / union[nonzero] + if not keep_ignore: + return iou + else: + result = np.full_like(intersection, -1) + result[nonzero] = iou + return result diff --git a/inference/interact/fbrs/model/modeling/__init__.py b/inference/interact/fbrs/model/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/model/modeling/basic_blocks.py b/inference/interact/fbrs/model/modeling/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..35946e8b6639460d5822b46a3e82a85bc4f1060e --- /dev/null +++ b/inference/interact/fbrs/model/modeling/basic_blocks.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +from ...model import ops + + +class ConvHead(nn.Module): + def __init__(self, out_channels, in_channels=32, num_layers=1, + kernel_size=3, padding=1, + norm_layer=nn.BatchNorm2d): + super(ConvHead, self).__init__() + convhead = [] + + for i in range(num_layers): + convhead.extend([ + nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), + nn.ReLU(), + norm_layer(in_channels) if norm_layer is not None else nn.Identity() + ]) + convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + self.convhead = nn.Sequential(*convhead) + + def forward(self, *inputs): + return self.convhead(inputs[0]) + + +class SepConvHead(nn.Module): + def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, + kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, + norm_layer=nn.BatchNorm2d): + super(SepConvHead, self).__init__() + + sepconvhead = [] + + for i in range(num_layers): + sepconvhead.append( + SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + dw_kernel=kernel_size, dw_padding=padding, + norm_layer=norm_layer, activation='relu') + ) + if dropout_ratio > 0 and dropout_indx == i: + sepconvhead.append(nn.Dropout(dropout_ratio)) + + sepconvhead.append( + nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) + ) + + self.layers = nn.Sequential(*sepconvhead) + + def forward(self, *inputs): + x = inputs[0] + + return self.layers(x) + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, + activation=None, use_bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + _activation = ops.select_activation_function(activation) + self.body = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, + padding=dw_padding, bias=use_bias, groups=in_channels), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + _activation() + ) + + def forward(self, x): + return self.body(x) diff --git a/inference/interact/fbrs/model/modeling/deeplab_v3.py b/inference/interact/fbrs/model/modeling/deeplab_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..8e863862c48a75a2ba9d9aa8a8025ee4333308d5 --- /dev/null +++ b/inference/interact/fbrs/model/modeling/deeplab_v3.py @@ -0,0 +1,176 @@ +from contextlib import ExitStack + +import torch +from torch import nn +import torch.nn.functional as F + +from .basic_blocks import SeparableConv2d +from .resnet import ResNetBackbone +from ...model import ops + + +class DeepLabV3Plus(nn.Module): + def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, + backbone_norm_layer=None, + ch=256, + project_dropout=0.5, + inference_mode=False, + **kwargs): + super(DeepLabV3Plus, self).__init__() + if backbone_norm_layer is None: + backbone_norm_layer = norm_layer + + self.backbone_name = backbone + self.norm_layer = norm_layer + self.backbone_norm_layer = backbone_norm_layer + self.inference_mode = False + self.ch = ch + self.aspp_in_channels = 2048 + self.skip_project_in_channels = 256 # layer 1 out_channels + + self._kwargs = kwargs + if backbone == 'resnet34': + self.aspp_in_channels = 512 + self.skip_project_in_channels = 64 + + self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, + norm_layer=self.backbone_norm_layer, **kwargs) + + self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, + norm_layer=self.norm_layer) + self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) + self.aspp = _ASPP(in_channels=self.aspp_in_channels, + atrous_rates=[12, 24, 36], + out_channels=ch, + project_dropout=project_dropout, + norm_layer=self.norm_layer) + + if inference_mode: + self.set_prediction_mode() + + def load_pretrained_weights(self): + pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, + norm_layer=self.backbone_norm_layer, **self._kwargs) + backbone_state_dict = self.backbone.state_dict() + pretrained_state_dict = pretrained.state_dict() + + backbone_state_dict.update(pretrained_state_dict) + self.backbone.load_state_dict(backbone_state_dict) + + if self.inference_mode: + for param in self.backbone.parameters(): + param.requires_grad = False + + def set_prediction_mode(self): + self.inference_mode = True + self.eval() + + def forward(self, x): + with ExitStack() as stack: + if self.inference_mode: + stack.enter_context(torch.no_grad()) + + c1, _, c3, c4 = self.backbone(x) + c1 = self.skip_project(c1) + + x = self.aspp(c4) + x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + x = self.head(x) + + return x, + + +class _SkipProject(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): + super(_SkipProject, self).__init__() + _activation = ops.select_activation_function("relu") + + self.skip_project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + _activation() + ) + + def forward(self, x): + return self.skip_project(x) + + +class _DeepLabHead(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): + super(_DeepLabHead, self).__init__() + + self.block = nn.Sequential( + SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) + ) + + def forward(self, x): + return self.block(x) + + +class _ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates, out_channels=256, + project_dropout=0.5, norm_layer=nn.BatchNorm2d): + super(_ASPP, self).__init__() + + b0 = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) + b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) + b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) + b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) + + self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) + + project = [ + nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ] + if project_dropout > 0: + project.append(nn.Dropout(project_dropout)) + self.project = nn.Sequential(*project) + + def forward(self, x): + x = torch.cat([block(x) for block in self.concurent], dim=1) + + return self.project(x) + + +class _AsppPooling(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer): + super(_AsppPooling, self).__init__() + + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + def forward(self, x): + pool = self.gap(x) + return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) + + +def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): + block = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, padding=atrous_rate, + dilation=atrous_rate, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + return block diff --git a/inference/interact/fbrs/model/modeling/hrnet_ocr.py b/inference/interact/fbrs/model/modeling/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..c24a367dfd2d7b648018cc34366c0376f297b91e --- /dev/null +++ b/inference/interact/fbrs/model/modeling/hrnet_ocr.py @@ -0,0 +1,399 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionNet, self).__init__() + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + ocr_mid_channels = 2 * ocr_width + ocr_key_channels = ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners) + self.cls_head = nn.Conv2d( + ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + feats = self.compute_hrnet_feats(x) + out_aux = self.aux_head(feats) + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + out = self.cls_head(feats) + + return [out, out_aux] + + def compute_hrnet_feats(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + + print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict)))) + print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict)))) + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/inference/interact/fbrs/model/modeling/ocr.py b/inference/interact/fbrs/model/modeling/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..df3b4f67959fc6a088b93ee7a34b15c1e07402df --- /dev/null +++ b/inference/interact/fbrs/model/modeling/ocr.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3) # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/inference/interact/fbrs/model/modeling/resnet.py b/inference/interact/fbrs/model/modeling/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..349ea1cbd882a9b0daa1d6146b634e9baf3726e0 --- /dev/null +++ b/inference/interact/fbrs/model/modeling/resnet.py @@ -0,0 +1,39 @@ +import torch +from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s + + +class ResNetBackbone(torch.nn.Module): + def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): + super(ResNetBackbone, self).__init__() + + if backbone == 'resnet34': + pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet50': + pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet101': + pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet152': + pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + else: + raise RuntimeError(f'unknown backbone: {backbone}') + + self.conv1 = pretrained.conv1 + self.bn1 = pretrained.bn1 + self.relu = pretrained.relu + self.maxpool = pretrained.maxpool + self.layer1 = pretrained.layer1 + self.layer2 = pretrained.layer2 + self.layer3 = pretrained.layer3 + self.layer4 = pretrained.layer4 + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + c1 = self.layer1(x) + c2 = self.layer2(c1) + c3 = self.layer3(c2) + c4 = self.layer4(c3) + + return c1, c2, c3, c4 diff --git a/inference/interact/fbrs/model/modeling/resnetv1b.py b/inference/interact/fbrs/model/modeling/resnetv1b.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac --- /dev/null +++ b/inference/interact/fbrs/model/modeling/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/inference/interact/fbrs/model/ops.py b/inference/interact/fbrs/model/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f46ae39aeb14cdb0ca6d9922b67f4562c40be8df --- /dev/null +++ b/inference/interact/fbrs/model/ops.py @@ -0,0 +1,83 @@ +import torch +from torch import nn as nn +import numpy as np + +from . import initializer as initializer +from ..utils.cython import get_dist_maps + + +def select_activation_function(activation): + if isinstance(activation, str): + if activation.lower() == 'relu': + return nn.ReLU + elif activation.lower() == 'softplus': + return nn.Softplus + else: + raise ValueError(f"Unknown activation type {activation}") + elif isinstance(activation, nn.Module): + return activation + else: + raise ValueError(f"Unknown activation type {activation}") + + +class BilinearConvTranspose2d(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, scale, groups=1): + kernel_size = 2 * scale - scale % 2 + self.scale = scale + + super().__init__( + in_channels, out_channels, + kernel_size=kernel_size, + stride=scale, + padding=1, + groups=groups, + bias=False) + + self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) + + +class DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False): + super(DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = self.spatial_scale * self.norm_radius + coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + num_points = points.shape[1] // 2 + points = points.view(-1, 2) + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) + coords.add_(-add_xy) + coords.div_(self.norm_radius * self.spatial_scale) + coords.mul_(coords) + + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + + coords[invalid_points, :, :, :] = 1e6 + + coords = coords.view(-1, num_points, 1, rows, cols) + coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w + coords = coords.view(-1, 2, rows, cols) + + coords.sqrt_().mul_(2).tanh_() + + return coords + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) diff --git a/inference/interact/fbrs/model/syncbn/LICENSE b/inference/interact/fbrs/model/syncbn/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..fec54698d35926513ca1ddb7b6cee791daca834e --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Tamaki Kojima + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/inference/interact/fbrs/model/syncbn/README.md b/inference/interact/fbrs/model/syncbn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d9a9ea21ca73d08dbac027aea3a4909d6b67ace3 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/README.md @@ -0,0 +1,127 @@ +# pytorch-syncbn + +Tamaki Kojima(tamakoji@gmail.com) + +## Announcement + +**Pytorch 1.0 support** + +## Overview +This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. + +The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) + +## Remarks +- Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` +- Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation +- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm +- Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` + +## Requirements +For PyTorch, please refer to https://pytorch.org/ + +NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04 + +It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use. + +``` +sudo apt-get install ninja-build +``` + +Also install all dependencies for python. For pip, run: + + +``` +pip install -U -r requirements.txt +``` + +## Build + +There is no need to build. just run and JIT will take care. +JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes. + +## Usage + +Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` + +``` +import torch +from modules import nn as NN +num_gpu = torch.cuda.device_count() +model = nn.Sequential( + nn.Conv2d(3, 3, 1, 1, bias=False), + NN.BatchNorm2d(3), + nn.ReLU(inplace=True), + nn.Conv2d(3, 3, 1, 1, bias=False), + NN.BatchNorm2d(3), +).cuda() +model = nn.DataParallel(model, device_ids=range(num_gpu)) +x = torch.rand(num_gpu, 3, 2, 2).cuda() +z = model(x) +``` + +## Math + +### Forward +1. compute in each gpu +2. gather all from workers to master and compute where + + + + and + + + + and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. + +3. forward batchnorm using global stats by + + + + and then + + + + where is weight parameter and is bias parameter. + +4. save for backward + +### Backward + +1. Restore saved + +2. Compute below sums on each gpu + + + + and + + + + where + + then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. + +3. compute gradients using global stats + + + + where + + + + and + + + + and finally, + + + + + + + + Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. + + You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) \ No newline at end of file diff --git a/inference/interact/fbrs/model/syncbn/__init__.py b/inference/interact/fbrs/model/syncbn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/model/syncbn/modules/__init__.py b/inference/interact/fbrs/model/syncbn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py b/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8eb83a9d88b25cb8f1faebc9236da929a7722c7 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py @@ -0,0 +1 @@ +from .syncbn import batchnorm2d_sync diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py b/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c14098f0cfa422920f01fe4985dbeb7fedc2d1 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py @@ -0,0 +1,54 @@ +""" +/*****************************************************************************/ + +Extension module loader + +code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark + +/*****************************************************************************/ +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os.path + +import torch + +try: + from torch.utils.cpp_extension import load + from torch.utils.cpp_extension import CUDA_HOME +except ImportError: + raise ImportError( + "The cpp layer extensions requires PyTorch 0.4 or higher") + + +def _load_C_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + this_dir = os.path.join(this_dir, "csrc") + + main_file = glob.glob(os.path.join(this_dir, "*.cpp")) + sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) + sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) + + sources = main_file + sources_cpu + + extra_cflags = [] + extra_cuda_cflags = [] + if torch.cuda.is_available() and CUDA_HOME is not None: + sources.extend(sources_cuda) + extra_cflags = ["-O3", "-DWITH_CUDA"] + extra_cuda_cflags = ["--expt-extended-lambda"] + sources = [os.path.join(this_dir, s) for s in sources] + extra_include_paths = [this_dir] + return load( + name="ext_lib", + sources=sources, + extra_cflags=extra_cflags, + extra_include_paths=extra_include_paths, + extra_cuda_cflags=extra_cuda_cflags, + ) + + +_backend = _load_C_extensions() diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h new file mode 100644 index 0000000000000000000000000000000000000000..52567a478633aa043ad86624253763e594121bd1 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h @@ -0,0 +1,70 @@ +/***************************************************************************** + +SyncBN + +*****************************************************************************/ +#pragma once + +#ifdef WITH_CUDA +#include "cuda/ext_lib.h" +#endif + +/// SyncBN + +std::vector syncbn_sum_sqsum(const at::Tensor& x) { + if (x.is_cuda()) { +#ifdef WITH_CUDA + return syncbn_sum_sqsum_cuda(x); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } else { + AT_ERROR("CPU implementation not supported"); + } +} + +at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& mean, + const at::Tensor& var, bool affine, float eps) { + if (x.is_cuda()) { +#ifdef WITH_CUDA + return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } else { + AT_ERROR("CPU implementation not supported"); + } +} + +std::vector syncbn_backward_xhat(const at::Tensor& dz, + const at::Tensor& x, + const at::Tensor& mean, + const at::Tensor& var, float eps) { + if (dz.is_cuda()) { +#ifdef WITH_CUDA + return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } else { + AT_ERROR("CPU implementation not supported"); + } +} + +std::vector syncbn_backward( + const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, + const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, + float eps) { + if (dz.is_cuda()) { +#ifdef WITH_CUDA + return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, + sum_dz_xhat, affine, eps); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } else { + AT_ERROR("CPU implementation not supported"); + } +} diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..9458eba4f4715673ba480fae2c318f4745e8fe78 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu @@ -0,0 +1,280 @@ +/***************************************************************************** + +CUDA SyncBN code + +code referenced from : https://github.com/mapillary/inplace_abn + +*****************************************************************************/ +#include +#include +#include +#include +#include "cuda/common.h" + +// Utilities +void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) { + num = x.size(0); + chn = x.size(1); + sp = 1; + for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i); +} + +/// SyncBN + +template +struct SqSumOp { + __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + T x = tensor[(batch * chn + plane) * sp + n]; + return Pair(x, x * x); // x, x^2 + } + const T *tensor; + const int chn; + const int sp; +}; + +template +__global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum, + int num, int chn, int sp) { + int plane = blockIdx.x; + Pair res = + reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp); + __syncthreads(); + if (threadIdx.x == 0) { + sum[plane] = res.v1; + sqsum[plane] = res.v2; + } +} + +std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) { + CHECK_INPUT(x); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto sum = at::empty({chn}, x.options()); + auto sqsum = at::empty({chn}, x.options()); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + AT_DISPATCH_FLOATING_TYPES( + x.type(), "syncbn_sum_sqsum_cuda", ([&] { + syncbn_sum_sqsum_kernel<<>>( + x.data(), sum.data(), + sqsum.data(), num, chn, sp); + })); + return {sum, sqsum}; +} + +template +__global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight, + const T *bias, const T *mean, + const T *var, bool affine, float eps, + int num, int chn, int sp) { + int plane = blockIdx.x; + T _mean = mean[plane]; + T _var = var[plane]; + T _weight = affine ? weight[plane] : T(1); + T _bias = affine ? bias[plane] : T(0); + float _invstd = T(0); + if (_var || eps) { + _invstd = rsqrt(_var + eps); + } + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _x = x[(batch * chn + plane) * sp + n]; + T _xhat = (_x - _mean) * _invstd; + T _z = _xhat * _weight + _bias; + z[(batch * chn + plane) * sp + n] = _z; + } + } +} + +at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &mean, + const at::Tensor &var, bool affine, float eps) { + CHECK_INPUT(x); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + CHECK_INPUT(mean); + CHECK_INPUT(var); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + auto z = at::zeros_like(x); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + AT_DISPATCH_FLOATING_TYPES( + x.type(), "syncbn_forward_cuda", ([&] { + syncbn_forward_kernel<<>>( + z.data(), x.data(), + weight.data(), bias.data(), + mean.data(), var.data(), + affine, eps, num, chn, sp); + })); + return z; +} + +template +struct XHatOp { + __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s) + : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + // xhat = (x - bias) * weight + T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight; + // dxhat * x_hat + T _dz = dz[(batch * chn + plane) * sp + n]; + return Pair(_dz, _dz * _xhat); + } + const T weight; + const T bias; + const T *dz; + const T *x; + const int chn; + const int sp; +}; + +template +__global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x, + const T *mean, const T *var, + T *sum_dz, T *sum_dz_xhat, + float eps, int num, int chn, + int sp) { + int plane = blockIdx.x; + T _mean = mean[plane]; + T _var = var[plane]; + T _invstd = T(0); + if (_var || eps) { + _invstd = rsqrt(_var + eps); + } + Pair res = reduce, XHatOp>( + XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp); + __syncthreads(); + if (threadIdx.x == 0) { + // \sum(\frac{dJ}{dy_i}) + sum_dz[plane] = res.v1; + // \sum(\frac{dJ}{dy_i}*\hat{x_i}) + sum_dz_xhat[plane] = res.v2; + } +} + +std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz, + const at::Tensor &x, + const at::Tensor &mean, + const at::Tensor &var, + float eps) { + CHECK_INPUT(dz); + CHECK_INPUT(x); + CHECK_INPUT(mean); + CHECK_INPUT(var); + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + // Prepare output tensors + auto sum_dz = at::empty({chn}, x.options()); + auto sum_dz_xhat = at::empty({chn}, x.options()); + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + AT_DISPATCH_FLOATING_TYPES( + x.type(), "syncbn_backward_xhat_cuda", ([&] { + syncbn_backward_xhat_kernel<<>>( + dz.data(), x.data(), mean.data(), + var.data(), sum_dz.data(), + sum_dz_xhat.data(), eps, num, chn, sp); + })); + return {sum_dz, sum_dz_xhat}; +} + +template +__global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight, + const T *bias, const T *mean, + const T *var, const T *sum_dz, + const T *sum_dz_xhat, T *dx, T *dweight, + T *dbias, bool affine, float eps, + int num, int chn, int sp) { + int plane = blockIdx.x; + T _mean = mean[plane]; + T _var = var[plane]; + T _weight = affine ? weight[plane] : T(1); + T _sum_dz = sum_dz[plane]; + T _sum_dz_xhat = sum_dz_xhat[plane]; + T _invstd = T(0); + if (_var || eps) { + _invstd = rsqrt(_var + eps); + } + /* + \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} ( + N\frac{dJ}{d\hat{x_i}} - + \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) - + \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}) + ) + Note : N is omitted here since it will be accumulated and + _sum_dz and _sum_dz_xhat expected to be already normalized + before the call. + */ + if (dx) { + T _mul = _weight * _invstd; + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _dz = dz[(batch * chn + plane) * sp + n]; + T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd; + T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul; + dx[(batch * chn + plane) * sp + n] = _dx; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + if (affine) { + T _norm = num * sp; + dweight[plane] += _sum_dz_xhat * _norm; + dbias[plane] += _sum_dz * _norm; + } + } +} + +std::vector syncbn_backward_cuda( + const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var, + const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine, + float eps) { + CHECK_INPUT(dz); + CHECK_INPUT(x); + CHECK_INPUT(weight); + CHECK_INPUT(bias); + CHECK_INPUT(mean); + CHECK_INPUT(var); + CHECK_INPUT(sum_dz); + CHECK_INPUT(sum_dz_xhat); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto dx = at::zeros_like(dz); + auto dweight = at::zeros_like(weight); + auto dbias = at::zeros_like(bias); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + AT_DISPATCH_FLOATING_TYPES( + x.type(), "syncbn_backward_cuda", ([&] { + syncbn_backward_kernel<<>>( + dz.data(), x.data(), weight.data(), + bias.data(), mean.data(), var.data(), + sum_dz.data(), sum_dz_xhat.data(), + dx.data(), dweight.data(), + dbias.data(), affine, eps, num, chn, sp); + })); + return {dx, dweight, dbias}; +} \ No newline at end of file diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h new file mode 100644 index 0000000000000000000000000000000000000000..a6cb2debeea3b8caa0f7c640601a94dce4e629cb --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h @@ -0,0 +1,124 @@ +/***************************************************************************** + +CUDA utility funcs + +code referenced from : https://github.com/mapillary/inplace_abn + +*****************************************************************************/ +#pragma once + +#include + +// Checks +#ifndef AT_CHECK + #define AT_CHECK AT_ASSERT +#endif +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +/* + * General settings + */ +const int WARP_SIZE = 32; +const int MAX_BLOCK_SIZE = 512; + +template +struct Pair { + T v1, v2; + __device__ Pair() {} + __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} + __device__ Pair(T v) : v1(v), v2(v) {} + __device__ Pair(int v) : v1(v), v2(v) {} + __device__ Pair &operator+=(const Pair &a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } +}; + +/* + * Utility functions + */ +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, + int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDART_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } + +static int getNumThreads(int nElem) { + int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +template +static __device__ __forceinline__ T warpSum(T val) { +#if __CUDA_ARCH__ >= 300 + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); + } +#else + __shared__ T values[MAX_BLOCK_SIZE]; + values[threadIdx.x] = val; + __threadfence_block(); + const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; + for (int i = 1; i < WARP_SIZE; i++) { + val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; + } +#endif + return val; +} + +template +static __device__ __forceinline__ Pair warpSum(Pair value) { + value.v1 = warpSum(value.v1); + value.v2 = warpSum(value.v2); + return value; +} + +template +__device__ T reduce(Op op, int plane, int N, int C, int S) { + T sum = (T)0; + for (int batch = 0; batch < N; ++batch) { + for (int x = threadIdx.x; x < S; x += blockDim.x) { + sum += op(batch, plane, x); + } + } + + // sum over NumThreads within a warp + sum = warpSum(sum); + + // 'transpose', and reduce within warp again + __shared__ T shared[32]; + __syncthreads(); + if (threadIdx.x % WARP_SIZE == 0) { + shared[threadIdx.x / WARP_SIZE] = sum; + } + if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { + // zero out the other entries in shared + shared[threadIdx.x] = (T)0; + } + __syncthreads(); + if (threadIdx.x / WARP_SIZE == 0) { + sum = warpSum(shared[threadIdx.x]); + if (threadIdx.x == 0) { + shared[0] = sum; + } + } + __syncthreads(); + + // Everyone picks it up, should be broadcast into the whole gradInput + return shared[0]; +} \ No newline at end of file diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..1d707615ffcf5ad7dcabc60de8c9a0cfe035bf14 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h @@ -0,0 +1,24 @@ +/***************************************************************************** + +CUDA SyncBN code + +*****************************************************************************/ +#pragma once +#include +#include + +/// Sync-BN +std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); +at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& mean, + const at::Tensor& var, bool affine, float eps); +std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, + const at::Tensor& x, + const at::Tensor& mean, + const at::Tensor& var, + float eps); +std::vector syncbn_backward_cuda( + const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, + const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, + float eps); diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c2ecf142dd70de8a3bdaf9b04470c4cacee3086 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp @@ -0,0 +1,10 @@ +#include "bn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); + m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); + m.def("syncbn_backward_xhat", &syncbn_backward_xhat, + "First part of SyncBN backward computation"); + m.def("syncbn_backward", &syncbn_backward, + "Second part of SyncBN backward computation"); +} \ No newline at end of file diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py b/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py new file mode 100644 index 0000000000000000000000000000000000000000..867a432d14f4f28c25075caa85b22726424293ae --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py @@ -0,0 +1,137 @@ +""" +/*****************************************************************************/ + +BatchNorm2dSync with multi-gpu + +code referenced from : https://github.com/mapillary/inplace_abn + +/*****************************************************************************/ +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.cuda.comm as comm +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from ._csrc import _backend + + +def _count_samples(x): + count = 1 + for i, s in enumerate(x.size()): + if i != 1: + count *= s + return count + + +class BatchNorm2dSyncFunc(Function): + + @staticmethod + def forward(ctx, x, weight, bias, running_mean, running_var, + extra, compute_stats=True, momentum=0.1, eps=1e-05): + def _parse_extra(ctx, extra): + ctx.is_master = extra["is_master"] + if ctx.is_master: + ctx.master_queue = extra["master_queue"] + ctx.worker_queues = extra["worker_queues"] + ctx.worker_ids = extra["worker_ids"] + else: + ctx.master_queue = extra["master_queue"] + ctx.worker_queue = extra["worker_queue"] + # Save context + if extra is not None: + _parse_extra(ctx, extra) + ctx.compute_stats = compute_stats + ctx.momentum = momentum + ctx.eps = eps + ctx.affine = weight is not None and bias is not None + if ctx.compute_stats: + N = _count_samples(x) * (ctx.master_queue.maxsize + 1) + assert N > 1 + # 1. compute sum(x) and sum(x^2) + xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) + if ctx.is_master: + xsums, xsqsums = [xsum], [xsqsum] + # master : gatther all sum(x) and sum(x^2) from slaves + for _ in range(ctx.master_queue.maxsize): + xsum_w, xsqsum_w = ctx.master_queue.get() + ctx.master_queue.task_done() + xsums.append(xsum_w) + xsqsums.append(xsqsum_w) + xsum = comm.reduce_add(xsums) + xsqsum = comm.reduce_add(xsqsums) + mean = xsum / N + sumvar = xsqsum - xsum * mean + var = sumvar / N + uvar = sumvar / (N - 1) + # master : broadcast global mean, variance to all slaves + tensors = comm.broadcast_coalesced( + (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) + for ts, queue in zip(tensors[1:], ctx.worker_queues): + queue.put(ts) + else: + # slave : send sum(x) and sum(x^2) to master + ctx.master_queue.put((xsum, xsqsum)) + # slave : get global mean and variance + mean, uvar, var = ctx.worker_queue.get() + ctx.worker_queue.task_done() + + # Update running stats + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) + ctx.N = N + ctx.save_for_backward(x, weight, bias, mean, var) + else: + mean, var = running_mean, running_var + + # do batch norm forward + z = _backend.syncbn_forward(x, weight, bias, mean, var, + ctx.affine, ctx.eps) + return z + + @staticmethod + @once_differentiable + def backward(ctx, dz): + x, weight, bias, mean, var = ctx.saved_tensors + dz = dz.contiguous() + + # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) + sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat( + dz, x, mean, var, ctx.eps) + if ctx.is_master: + sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] + # master : gatther from slaves + for _ in range(ctx.master_queue.maxsize): + sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() + ctx.master_queue.task_done() + sum_dzs.append(sum_dz_w) + sum_dz_xhats.append(sum_dz_xhat_w) + # master : compute global stats + sum_dz = comm.reduce_add(sum_dzs) + sum_dz_xhat = comm.reduce_add(sum_dz_xhats) + sum_dz /= ctx.N + sum_dz_xhat /= ctx.N + # master : broadcast global stats + tensors = comm.broadcast_coalesced( + (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) + for ts, queue in zip(tensors[1:], ctx.worker_queues): + queue.put(ts) + else: + # slave : send to master + ctx.master_queue.put((sum_dz, sum_dz_xhat)) + # slave : get global stats + sum_dz, sum_dz_xhat = ctx.worker_queue.get() + ctx.worker_queue.task_done() + + # do batch norm backward + dx, dweight, dbias = _backend.syncbn_backward( + dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, + ctx.affine, ctx.eps) + + return dx, dweight, dbias, \ + None, None, None, None, None, None + +batchnorm2d_sync = BatchNorm2dSyncFunc.apply + +__all__ = ["batchnorm2d_sync"] diff --git a/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py b/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5aca9879273811b681baddc5755e20e838a361 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py @@ -0,0 +1 @@ +from .syncbn import * diff --git a/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py b/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py new file mode 100644 index 0000000000000000000000000000000000000000..b118c9d4aac3ee86821797bc9f794cd9aa38b1b2 --- /dev/null +++ b/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py @@ -0,0 +1,148 @@ +""" +/*****************************************************************************/ + +BatchNorm2dSync with multi-gpu + +/*****************************************************************************/ +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + # python 3 + from queue import Queue +except ImportError: + # python 2 + from Queue import Queue + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from isegm.model.syncbn.modules.functional import batchnorm2d_sync + + +class _BatchNorm(nn.Module): + """ + Customized BatchNorm from nn.BatchNorm + >> added freeze attribute to enable bn freeze. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True): + super(_BatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + self.freezed = False + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + else: + self.register_parameter('running_mean', None) + self.register_parameter('running_var', None) + self.reset_parameters() + + def reset_parameters(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + if self.affine: + self.weight.data.uniform_() + self.bias.data.zero_() + + def _check_input_dim(self, input): + return NotImplemented + + def forward(self, input): + self._check_input_dim(input) + + compute_stats = not self.freezed and \ + self.training and self.track_running_stats + + ret = F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, compute_stats, + self.momentum, self.eps) + return ret + + def extra_repr(self): + return '{num_features}, eps={eps}, momentum={momentum}, '\ + 'affine={affine}, ' \ + 'track_running_stats={track_running_stats}'.format( + **self.__dict__) + + +class BatchNorm2dNoSync(_BatchNorm): + """ + Equivalent to nn.BatchNorm2d + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class BatchNorm2dSync(BatchNorm2dNoSync): + """ + BatchNorm2d with automatic multi-GPU Sync + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True): + super(BatchNorm2dSync, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + self.sync_enabled = True + self.devices = list(range(torch.cuda.device_count())) + if len(self.devices) > 1: + # Initialize queues + self.worker_ids = self.devices[1:] + self.master_queue = Queue(len(self.worker_ids)) + self.worker_queues = [Queue(1) for _ in self.worker_ids] + + def forward(self, x): + compute_stats = not self.freezed and \ + self.training and self.track_running_stats + if self.sync_enabled and compute_stats and len(self.devices) > 1: + if x.get_device() == self.devices[0]: + # Master mode + extra = { + "is_master": True, + "master_queue": self.master_queue, + "worker_queues": self.worker_queues, + "worker_ids": self.worker_ids + } + else: + # Worker mode + extra = { + "is_master": False, + "master_queue": self.master_queue, + "worker_queue": self.worker_queues[ + self.worker_ids.index(x.get_device())] + } + return batchnorm2d_sync(x, self.weight, self.bias, + self.running_mean, self.running_var, + extra, compute_stats, self.momentum, + self.eps) + return super(BatchNorm2dSync, self).forward(x) + + def __repr__(self): + """repr""" + rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ + 'affine={affine}, ' \ + 'track_running_stats={track_running_stats},' \ + 'devices={devices})' + return rep.format(name=self.__class__.__name__, **self.__dict__) + +#BatchNorm2d = BatchNorm2dNoSync +BatchNorm2d = BatchNorm2dSync diff --git a/inference/interact/fbrs/utils/__init__.py b/inference/interact/fbrs/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/fbrs/utils/cython/__init__.py b/inference/interact/fbrs/utils/cython/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb66bdbba883b9477bbc1a52d8355131d32a04cb --- /dev/null +++ b/inference/interact/fbrs/utils/cython/__init__.py @@ -0,0 +1,2 @@ +# noinspection PyUnresolvedReferences +from .dist_maps import get_dist_maps \ No newline at end of file diff --git a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx new file mode 100644 index 0000000000000000000000000000000000000000..779a7f02ad7c2ba25e68302c6fc6683cd4ab54f7 --- /dev/null +++ b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx @@ -0,0 +1,63 @@ +import numpy as np +cimport cython +cimport numpy as np +from libc.stdlib cimport malloc, free + +ctypedef struct qnode: + int row + int col + int layer + int orig_row + int orig_col + +@cython.infer_types(True) +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, + int height, int width, float norm_delimeter): + cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ + np.full((2, height, width), 1e6, dtype=np.float32, order="C") + + cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] + cdef int i, j, x, y, dx, dy + cdef qnode v + cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) + cdef int qhead = 0, qtail = -1 + cdef float ndist + + for i in range(points.shape[0]): + x, y = round(points[i, 0]), round(points[i, 1]) + if x >= 0: + qtail += 1 + q[qtail].row = x + q[qtail].col = y + q[qtail].orig_row = x + q[qtail].orig_col = y + if i >= points.shape[0] / 2: + q[qtail].layer = 1 + else: + q[qtail].layer = 0 + dist_maps[q[qtail].layer, x, y] = 0 + + while qtail - qhead + 1 > 0: + v = q[qhead] + qhead += 1 + + for k in range(4): + x = v.row + dxy[2 * k] + y = v.col + dxy[2 * k + 1] + + ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 + if (x >= 0 and y >= 0 and x < height and y < width and + dist_maps[v.layer, x, y] > ndist): + qtail += 1 + q[qtail].orig_col = v.orig_col + q[qtail].orig_row = v.orig_row + q[qtail].layer = v.layer + q[qtail].row = x + q[qtail].col = y + dist_maps[v.layer, x, y] = ndist + + free(q) + return dist_maps diff --git a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld new file mode 100644 index 0000000000000000000000000000000000000000..bd4451729201b5ebc6bbbd8f392389ab6b530636 --- /dev/null +++ b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld @@ -0,0 +1,7 @@ +import numpy + +def make_ext(modname, pyxfilename): + from distutils.extension import Extension + return Extension(modname, [pyxfilename], + include_dirs=[numpy.get_include()], + extra_compile_args=['-O3'], language='c++') diff --git a/inference/interact/fbrs/utils/cython/dist_maps.py b/inference/interact/fbrs/utils/cython/dist_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffa1e3f25231cd7c48b66ef8ef5167235c3ea4e --- /dev/null +++ b/inference/interact/fbrs/utils/cython/dist_maps.py @@ -0,0 +1,3 @@ +import pyximport; pyximport.install(pyximport=True, language_level=3) +# noinspection PyUnresolvedReferences +from ._get_dist_maps import get_dist_maps \ No newline at end of file diff --git a/inference/interact/fbrs/utils/misc.py b/inference/interact/fbrs/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..65ce96dc5667494446110fda75e29243338e2b88 --- /dev/null +++ b/inference/interact/fbrs/utils/misc.py @@ -0,0 +1,62 @@ +from functools import partial + +import torch +import numpy as np + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims + + +def get_unique_labels(mask): + return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1 + + +def get_bbox_from_mask(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return rmin, rmax, cmin, cmax + + +def expand_bbox(bbox, expand_ratio, min_crop_size=None): + rmin, rmax, cmin, cmax = bbox + rcenter = 0.5 * (rmin + rmax) + ccenter = 0.5 * (cmin + cmax) + height = expand_ratio * (rmax - rmin + 1) + width = expand_ratio * (cmax - cmin + 1) + if min_crop_size is not None: + height = max(height, min_crop_size) + width = max(width, min_crop_size) + + rmin = int(round(rcenter - 0.5 * height)) + rmax = int(round(rcenter + 0.5 * height)) + cmin = int(round(ccenter - 0.5 * width)) + cmax = int(round(ccenter + 0.5 * width)) + + return rmin, rmax, cmin, cmax + + +def clamp_bbox(bbox, rmin, rmax, cmin, cmax): + return (max(rmin, bbox[0]), min(rmax, bbox[1]), + max(cmin, bbox[2]), min(cmax, bbox[3])) + + +def get_bbox_iou(b1, b2): + h_iou = get_segments_iou(b1[:2], b2[:2]) + w_iou = get_segments_iou(b1[2:4], b2[2:4]) + return h_iou * w_iou + + +def get_segments_iou(s1, s2): + a, b = s1 + c, d = s2 + intersection = max(0, min(b, d) - max(a, c) + 1) + union = max(1e-6, max(b, d) - min(a, c) + 1) + return intersection / union diff --git a/inference/interact/fbrs/utils/vis.py b/inference/interact/fbrs/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1a291306453c15bdfe5117302beb62e0fe7248 --- /dev/null +++ b/inference/interact/fbrs/utils/vis.py @@ -0,0 +1,129 @@ +from functools import lru_cache + +import cv2 +import numpy as np + + +def visualize_instances(imask, bg_color=255, + boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): + num_objects = imask.max() + 1 + palette = get_palette(num_objects) + if bg_color is not None: + palette[0] = bg_color + + result = palette[imask].astype(np.uint8) + if boundaries_color is not None: + boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) + tresult = result.astype(np.float32) + tresult[boundaries_mask] = boundaries_color + tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result + result = tresult.astype(np.uint8) + + return result + + +@lru_cache(maxsize=16) +def get_palette(num_cls): + palette = np.zeros(3 * num_cls, dtype=np.int32) + + for j in range(0, num_cls): + lab = j + i = 0 + + while lab > 0: + palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) + palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) + palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) + i = i + 1 + lab >>= 3 + + return palette.reshape((-1, 3)) + + +def visualize_mask(mask, num_cls): + palette = get_palette(num_cls) + mask[mask == -1] = 0 + + return palette[mask].astype(np.uint8) + + +def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): + proposal_map, colors, candidates = proposals_info + + proposal_map = draw_probmap(proposal_map) + for x, y in candidates: + proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) + + return proposal_map + + +def draw_probmap(x): + return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) + + +def draw_points(image, points, color, radius=3): + image = image.copy() + for p in points: + image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1) + + return image + + +def draw_instance_map(x, palette=None): + num_colors = x.max() + 1 + if palette is None: + palette = get_palette(num_colors) + + return palette[x].astype(np.uint8) + + +def blend_mask(image, mask, alpha=0.6): + if mask.min() == -1: + mask = mask.copy() + 1 + + imap = draw_instance_map(mask) + result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) + return result + + +def get_boundaries(instances_masks, boundaries_width=1): + boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) + + for obj_id in np.unique(instances_masks.flatten()): + if obj_id == 0: + continue + + obj_mask = instances_masks == obj_id + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) + + obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) + boundaries = np.logical_or(boundaries, obj_boundary) + return boundaries + + +def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), + neg_color=(255, 0, 0), radius=4): + result = img.copy() + + if mask is not None: + palette = get_palette(np.max(mask) + 1) + rgb_mask = palette[mask.astype(np.uint8)] + + mask_region = (mask > 0).astype(np.uint8) + result = result * (1 - mask_region[:, :, np.newaxis]) + \ + (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ + alpha * rgb_mask + result = result.astype(np.uint8) + + # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) + + if clicks_list is not None and len(clicks_list) > 0: + pos_points = [click.coords for click in clicks_list if click.is_positive] + neg_points = [click.coords for click in clicks_list if not click.is_positive] + + result = draw_points(result, pos_points, pos_color, radius=radius) + result = draw_points(result, neg_points, neg_color, radius=radius) + + return result + diff --git a/inference/interact/fbrs_controller.py b/inference/interact/fbrs_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe9ca496193829990b3db7b0f141aabeb61fd35 --- /dev/null +++ b/inference/interact/fbrs_controller.py @@ -0,0 +1,53 @@ +import torch +from .fbrs.controller import InteractiveController +from .fbrs.inference import utils + + +class FBRSController: + def __init__(self, checkpoint_path, device='cuda:0', max_size=800): + model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260) + + # Predictor params + zoomin_params = { + 'skip_clicks': 1, + 'target_size': 480, + 'expansion_ratio': 1.4, + } + + predictor_params = { + 'brs_mode': 'f-BRS-B', + 'prob_thresh': 0.5, + 'zoom_in_params': zoomin_params, + 'predictor_params': { + 'net_clicks_limit': 8, + 'max_size': 800, + }, + 'brs_opt_func_params': {'min_iou_diff': 1e-3}, + 'lbfgs_params': {'maxfun': 20} + } + + self.controller = InteractiveController(model, device, predictor_params) + self.anchored = False + self.device = device + + def unanchor(self): + self.anchored = False + + def interact(self, image, x, y, is_positive): + image = image.to(self.device, non_blocking=True) + if not self.anchored: + self.controller.set_image(image) + self.controller.reset_predictor() + self.anchored = True + + self.controller.add_click(x, y, is_positive) + # return self.controller.result_mask + # return self.controller.probs_history[-1][1] + return (self.controller.probs_history[-1][1]>0.5).float() + + def undo(self): + self.controller.undo_click() + if len(self.controller.probs_history) == 0: + return None + else: + return (self.controller.probs_history[-1][1]>0.5).float() \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..039a382bda5b5a892723df894c4dffab356e99c4 --- /dev/null +++ b/inference/interact/gui.py @@ -0,0 +1,933 @@ +""" +Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN +(which is based on https://github.com/seoungwugoh/ivs-demo) + +This version is much simplified. +In this repo, we don't have +- local control +- fusion module +- undo +- timers + +but with XMem as the backbone and is more memory (for both CPU and GPU) friendly +""" + +import functools + +import os +import cv2 +# fix conflicts between qt5 and cv2 +os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH") + +import numpy as np +import torch + +from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, + QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, + QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton) + +from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon +from PyQt5.QtCore import Qt, QTimer + +from model.network import XMem + +from inference.inference_core import InferenceCore +from .s2m_controller import S2MController +from .fbrs_controller import FBRSController + +from .interactive_utils import * +from .interaction import * +from .resource_manager import ResourceManager +from .gui_utils import * + + +class App(QWidget): + def __init__(self, net: XMem, + resource_manager: ResourceManager, + s2m_ctrl:S2MController, + fbrs_ctrl:FBRSController, config): + super().__init__() + + self.initialized = False + self.num_objects = config['num_objects'] + self.s2m_controller = s2m_ctrl + self.fbrs_controller = fbrs_ctrl + self.config = config + self.processor = InferenceCore(net, config) + self.processor.set_all_labels(list(range(1, self.num_objects+1))) + self.res_man = resource_manager + + self.num_frames = len(self.res_man) + self.height, self.width = self.res_man.h, self.res_man.w + + # set window + self.setWindowTitle('XMem Demo') + self.setGeometry(100, 100, self.width, self.height+100) + self.setWindowIcon(QIcon('docs/icon.png')) + + # some buttons + self.play_button = QPushButton('Play Video') + self.play_button.clicked.connect(self.on_play_video) + self.commit_button = QPushButton('Commit') + self.commit_button.clicked.connect(self.on_commit) + + self.forward_run_button = QPushButton('Forward Propagate') + self.forward_run_button.clicked.connect(self.on_forward_propagation) + self.forward_run_button.setMinimumWidth(200) + + self.backward_run_button = QPushButton('Backward Propagate') + self.backward_run_button.clicked.connect(self.on_backward_propagation) + self.backward_run_button.setMinimumWidth(200) + + self.reset_button = QPushButton('Reset Frame') + self.reset_button.clicked.connect(self.on_reset_mask) + + # LCD + self.lcd = QTextEdit() + self.lcd.setReadOnly(True) + self.lcd.setMaximumHeight(28) + self.lcd.setMaximumWidth(120) + self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) + + # timeline slider + self.tl_slider = QSlider(Qt.Horizontal) + self.tl_slider.valueChanged.connect(self.tl_slide) + self.tl_slider.setMinimum(0) + self.tl_slider.setMaximum(self.num_frames-1) + self.tl_slider.setValue(0) + self.tl_slider.setTickPosition(QSlider.TicksBelow) + self.tl_slider.setTickInterval(1) + + # brush size slider + self.brush_label = QLabel() + self.brush_label.setAlignment(Qt.AlignCenter) + self.brush_label.setMinimumWidth(100) + + self.brush_slider = QSlider(Qt.Horizontal) + self.brush_slider.valueChanged.connect(self.brush_slide) + self.brush_slider.setMinimum(1) + self.brush_slider.setMaximum(100) + self.brush_slider.setValue(3) + self.brush_slider.setTickPosition(QSlider.TicksBelow) + self.brush_slider.setTickInterval(2) + self.brush_slider.setMinimumWidth(300) + + # combobox + self.combo = QComboBox(self) + self.combo.addItem("davis") + self.combo.addItem("fade") + self.combo.addItem("light") + self.combo.addItem("popup") + self.combo.addItem("layered") + self.combo.currentTextChanged.connect(self.set_viz_mode) + + self.save_visualization_checkbox = QCheckBox(self) + self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle) + self.save_visualization_checkbox.setChecked(False) + self.save_visualization = False + + # Radio buttons for type of interactions + self.curr_interaction = 'Click' + self.interaction_group = QButtonGroup() + self.radio_fbrs = QRadioButton('Click') + self.radio_s2m = QRadioButton('Scribble') + self.radio_free = QRadioButton('Free') + self.interaction_group.addButton(self.radio_fbrs) + self.interaction_group.addButton(self.radio_s2m) + self.interaction_group.addButton(self.radio_free) + self.radio_fbrs.toggled.connect(self.interaction_radio_clicked) + self.radio_s2m.toggled.connect(self.interaction_radio_clicked) + self.radio_free.toggled.connect(self.interaction_radio_clicked) + self.radio_fbrs.toggle() + + # Main canvas -> QLabel + self.main_canvas = QLabel() + self.main_canvas.setSizePolicy(QSizePolicy.Expanding, + QSizePolicy.Expanding) + self.main_canvas.setAlignment(Qt.AlignCenter) + self.main_canvas.setMinimumSize(100, 100) + + self.main_canvas.mousePressEvent = self.on_mouse_press + self.main_canvas.mouseMoveEvent = self.on_mouse_motion + self.main_canvas.setMouseTracking(True) # Required for all-time tracking + self.main_canvas.mouseReleaseEvent = self.on_mouse_release + + # Minimap -> Also a QLbal + self.minimap = QLabel() + self.minimap.setSizePolicy(QSizePolicy.Expanding, + QSizePolicy.Expanding) + self.minimap.setAlignment(Qt.AlignTop) + self.minimap.setMinimumSize(100, 100) + + # Zoom-in buttons + self.zoom_p_button = QPushButton('Zoom +') + self.zoom_p_button.clicked.connect(self.on_zoom_plus) + self.zoom_m_button = QPushButton('Zoom -') + self.zoom_m_button.clicked.connect(self.on_zoom_minus) + + # Parameters setting + self.clear_mem_button = QPushButton('Clear memory') + self.clear_mem_button.clicked.connect(self.on_clear_memory) + + self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') + self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size') + self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)') + self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)') + + self.update_memory_size() + self.update_gpu_usage() + + self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames', + callback=self.on_work_min_change) + self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames', + callback=self.on_work_max_change) + self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000, + 'Max. long-term memory size', step=1000, callback=self.update_config) + self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280, + 'Number of prototypes', step=32, callback=self.update_config) + self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)', + callback=self.update_config) + + self.work_mem_min.setValue(self.processor.memory.min_mt_frames) + self.work_mem_max.setValue(self.processor.memory.max_mt_frames) + self.long_mem_max.setValue(self.processor.memory.max_long_elements) + self.num_prototypes_box.setValue(self.processor.memory.num_prototypes) + self.mem_every_box.setValue(self.processor.mem_every) + + # import mask/layer + self.import_mask_button = QPushButton('Import mask') + self.import_mask_button.clicked.connect(self.on_import_mask) + self.import_layer_button = QPushButton('Import layer') + self.import_layer_button.clicked.connect(self.on_import_layer) + + # Console on the GUI + self.console = QPlainTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(100) + self.console.setMaximumHeight(100) + + # navigator + navi = QHBoxLayout() + navi.addWidget(self.lcd) + navi.addWidget(self.play_button) + + interact_subbox = QVBoxLayout() + interact_topbox = QHBoxLayout() + interact_botbox = QHBoxLayout() + interact_topbox.setAlignment(Qt.AlignCenter) + interact_topbox.addWidget(self.radio_s2m) + interact_topbox.addWidget(self.radio_fbrs) + interact_topbox.addWidget(self.radio_free) + interact_topbox.addWidget(self.brush_label) + interact_botbox.addWidget(self.brush_slider) + interact_subbox.addLayout(interact_topbox) + interact_subbox.addLayout(interact_botbox) + navi.addLayout(interact_subbox) + + navi.addStretch(1) + navi.addWidget(self.reset_button) + + navi.addStretch(1) + navi.addWidget(QLabel('Overlay Mode')) + navi.addWidget(self.combo) + navi.addWidget(QLabel('Save overlay during propagation')) + navi.addWidget(self.save_visualization_checkbox) + navi.addStretch(1) + navi.addWidget(self.commit_button) + navi.addWidget(self.forward_run_button) + navi.addWidget(self.backward_run_button) + + # Drawing area, main canvas and minimap + draw_area = QHBoxLayout() + draw_area.addWidget(self.main_canvas, 4) + + # Minimap area + minimap_area = QVBoxLayout() + minimap_area.setAlignment(Qt.AlignTop) + mini_label = QLabel('Minimap') + mini_label.setAlignment(Qt.AlignTop) + minimap_area.addWidget(mini_label) + + # Minimap zooming + minimap_ctrl = QHBoxLayout() + minimap_ctrl.setAlignment(Qt.AlignTop) + minimap_ctrl.addWidget(self.zoom_p_button) + minimap_ctrl.addWidget(self.zoom_m_button) + minimap_area.addLayout(minimap_ctrl) + minimap_area.addWidget(self.minimap) + + # Parameters + minimap_area.addLayout(self.work_mem_gauge_layout) + minimap_area.addLayout(self.long_mem_gauge_layout) + minimap_area.addLayout(self.gpu_mem_gauge_layout) + minimap_area.addLayout(self.torch_mem_gauge_layout) + minimap_area.addWidget(self.clear_mem_button) + minimap_area.addLayout(self.work_mem_min_layout) + minimap_area.addLayout(self.work_mem_max_layout) + minimap_area.addLayout(self.long_mem_max_layout) + minimap_area.addLayout(self.num_prototypes_box_layout) + minimap_area.addLayout(self.mem_every_box_layout) + + # import mask/layer + import_area = QHBoxLayout() + import_area.setAlignment(Qt.AlignTop) + import_area.addWidget(self.import_mask_button) + import_area.addWidget(self.import_layer_button) + minimap_area.addLayout(import_area) + + # console + minimap_area.addWidget(self.console) + + draw_area.addLayout(minimap_area, 1) + + layout = QVBoxLayout() + layout.addLayout(draw_area) + layout.addWidget(self.tl_slider) + layout.addLayout(navi) + self.setLayout(layout) + + # timer to play video + self.timer = QTimer() + self.timer.setSingleShot(False) + + # timer to update GPU usage + self.gpu_timer = QTimer() + self.gpu_timer.setSingleShot(False) + self.gpu_timer.timeout.connect(self.on_gpu_timer) + self.gpu_timer.setInterval(2000) + self.gpu_timer.start() + + # current frame info + self.curr_frame_dirty = False + self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8) + self.current_image_torch = None + self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8) + self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda() + + # initialize visualization + self.viz_mode = 'davis' + self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) + self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) + self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) + self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) + self.cursur = 0 + self.on_showing = None + + # Zoom parameters + self.zoom_pixels = 150 + + # initialize action + self.interaction = None + self.pressed = False + self.right_click = False + self.current_object = 1 + self.last_ex = self.last_ey = 0 + + self.propagating = False + + # Objects shortcuts + for i in range(1, self.num_objects+1): + QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i)) + + # <- and -> shortcuts + QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame) + QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame) + + self.interacted_prob = None + self.overlay_layer = None + self.overlay_layer_torch = None + + # the object id used for popup/layered overlay + self.vis_target_objects = [1] + # try to load the default overlay + self._try_load_layer('./docs/ECCV-logo.png') + + self.load_current_image_mask() + self.show_current_frame() + self.show() + + self.console_push_text('Initialized.') + self.initialized = True + + def resizeEvent(self, event): + self.show_current_frame() + + def console_push_text(self, text): + self.console.moveCursor(QTextCursor.End) + self.console.insertPlainText(text+'\n') + + def interaction_radio_clicked(self, event): + self.last_interaction = self.curr_interaction + if self.radio_s2m.isChecked(): + self.curr_interaction = 'Scribble' + self.brush_size = 3 + self.brush_slider.setDisabled(True) + elif self.radio_fbrs.isChecked(): + self.curr_interaction = 'Click' + self.brush_size = 3 + self.brush_slider.setDisabled(True) + elif self.radio_free.isChecked(): + self.brush_slider.setDisabled(False) + self.brush_slide() + self.curr_interaction = 'Free' + if self.curr_interaction == 'Scribble': + self.commit_button.setEnabled(True) + else: + self.commit_button.setEnabled(False) + + def load_current_image_mask(self, no_mask=False): + self.current_image = self.res_man.get_image(self.cursur) + self.current_image_torch = None + + if not no_mask: + loaded_mask = self.res_man.get_mask(self.cursur) + if loaded_mask is None: + self.current_mask.fill(0) + else: + self.current_mask = loaded_mask.copy() + self.current_prob = None + + def load_current_torch_image_mask(self, no_mask=False): + if self.current_image_torch is None: + self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image) + + if self.current_prob is None and not no_mask: + self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() + + def compose_current_im(self): + self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask, + self.overlay_layer, self.vis_target_objects) + + def update_interact_vis(self): + # Update the interactions without re-computing the overlay + height, width, channel = self.viz.shape + bytesPerLine = 3 * width + + vis_map = self.vis_map + vis_alpha = self.vis_alpha + brush_vis_map = self.brush_vis_map + brush_vis_alpha = self.brush_vis_alpha + + self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha + self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha + self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8) + + qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888) + self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), + Qt.KeepAspectRatio, Qt.FastTransformation))) + + self.main_canvas_size = self.main_canvas.size() + self.image_size = qImg.size() + + def update_minimap(self): + ex, ey = self.last_ex, self.last_ey + r = self.zoom_pixels//2 + ex = int(round(max(r, min(self.width-r, ex)))) + ey = int(round(max(r, min(self.height-r, ey)))) + + patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8) + + height, width, channel = patch.shape + bytesPerLine = 3 * width + qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888) + self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(), + Qt.KeepAspectRatio, Qt.FastTransformation))) + + def update_current_image_fast(self): + # fast path, uses gpu. Changes the image in-place to avoid copying + self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm, + self.current_prob, self.overlay_layer_torch, self.vis_target_objects) + if self.save_visualization: + self.res_man.save_visualization(self.cursur, self.viz) + + height, width, channel = self.viz.shape + bytesPerLine = 3 * width + + qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888) + self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), + Qt.KeepAspectRatio, Qt.FastTransformation))) + + def show_current_frame(self, fast=False): + # Re-compute overlay and show the image + if fast: + self.update_current_image_fast() + else: + self.compose_current_im() + self.update_interact_vis() + self.update_minimap() + + self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) + self.tl_slider.setValue(self.cursur) + + def pixel_pos_to_image_pos(self, x, y): + # Un-scale and un-pad the label coordinates into image coordinates + oh, ow = self.image_size.height(), self.image_size.width() + nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width() + + h_ratio = nh/oh + w_ratio = nw/ow + dominate_ratio = min(h_ratio, w_ratio) + + # Solve scale + x /= dominate_ratio + y /= dominate_ratio + + # Solve padding + fh, fw = nh/dominate_ratio, nw/dominate_ratio + x -= (fw-ow)/2 + y -= (fh-oh)/2 + + return x, y + + def is_pos_out_of_bound(self, x, y): + x, y = self.pixel_pos_to_image_pos(x, y) + + out_of_bound = ( + (x < 0) or + (y < 0) or + (x > self.width-1) or + (y > self.height-1) + ) + + return out_of_bound + + def get_scaled_pos(self, x, y): + x, y = self.pixel_pos_to_image_pos(x, y) + + x = max(0, min(self.width-1, x)) + y = max(0, min(self.height-1, y)) + + return x, y + + def clear_visualization(self): + self.vis_map.fill(0) + self.vis_alpha.fill(0) + + def reset_this_interaction(self): + self.complete_interaction() + self.clear_visualization() + self.interaction = None + if self.fbrs_controller is not None: + self.fbrs_controller.unanchor() + + def set_viz_mode(self): + self.viz_mode = self.combo.currentText() + self.show_current_frame() + + def save_current_mask(self): + # save mask to hard disk + self.res_man.save_mask(self.cursur, self.current_mask) + + def tl_slide(self): + # if we are propagating, the on_run function will take care of everything + # don't do duplicate work here + if not self.propagating: + if self.curr_frame_dirty: + self.save_current_mask() + self.curr_frame_dirty = False + + self.reset_this_interaction() + self.cursur = self.tl_slider.value() + self.load_current_image_mask() + self.show_current_frame() + + def brush_slide(self): + self.brush_size = self.brush_slider.value() + self.brush_label.setText('Brush size: %d' % self.brush_size) + try: + if type(self.interaction) == FreeInteraction: + self.interaction.set_size(self.brush_size) + except AttributeError: + # Initialization, forget about it + pass + + def on_forward_propagation(self): + if self.propagating: + # acts as a pause button + self.propagating = False + else: + self.propagate_fn = self.on_next_frame + self.backward_run_button.setEnabled(False) + self.forward_run_button.setText('Pause Propagation') + self.on_propagation() + + def on_backward_propagation(self): + if self.propagating: + # acts as a pause button + self.propagating = False + else: + self.propagate_fn = self.on_prev_frame + self.forward_run_button.setEnabled(False) + self.backward_run_button.setText('Pause Propagation') + self.on_propagation() + + def on_pause(self): + self.propagating = False + self.forward_run_button.setEnabled(True) + self.backward_run_button.setEnabled(True) + self.clear_mem_button.setEnabled(True) + self.forward_run_button.setText('Forward Propagate') + self.backward_run_button.setText('Backward Propagate') + self.console_push_text('Propagation stopped.') + + def on_propagation(self): + # start to propagate + self.load_current_torch_image_mask() + self.show_current_frame(fast=True) + + self.console_push_text('Propagation started.') + self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:]) + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) + # clear + self.interacted_prob = None + self.reset_this_interaction() + + self.propagating = True + self.clear_mem_button.setEnabled(False) + # propagate till the end + while self.propagating: + self.propagate_fn() + + self.load_current_image_mask(no_mask=True) + self.load_current_torch_image_mask(no_mask=True) + + self.current_prob = self.processor.step(self.current_image_torch) + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) + + self.save_current_mask() + self.show_current_frame(fast=True) + + self.update_memory_size() + QApplication.processEvents() + + if self.cursur == 0 or self.cursur == self.num_frames-1: + break + + self.propagating = False + self.curr_frame_dirty = False + self.on_pause() + self.tl_slide() + QApplication.processEvents() + + def pause_propagation(self): + self.propagating = False + + def on_commit(self): + self.complete_interaction() + self.update_interacted_mask() + + def on_prev_frame(self): + # self.tl_slide will trigger on setValue + self.cursur = max(0, self.cursur-1) + self.tl_slider.setValue(self.cursur) + + def on_next_frame(self): + # self.tl_slide will trigger on setValue + self.cursur = min(self.cursur+1, self.num_frames-1) + self.tl_slider.setValue(self.cursur) + + def on_play_video_timer(self): + self.cursur += 1 + if self.cursur > self.num_frames-1: + self.cursur = 0 + self.tl_slider.setValue(self.cursur) + + def on_play_video(self): + if self.timer.isActive(): + self.timer.stop() + self.play_button.setText('Play Video') + else: + self.timer.start(1000 / 30) + self.play_button.setText('Stop Video') + + def on_reset_mask(self): + self.current_mask.fill(0) + if self.current_prob is not None: + self.current_prob.fill_(0) + self.curr_frame_dirty = True + self.save_current_mask() + self.reset_this_interaction() + self.show_current_frame() + + def on_zoom_plus(self): + self.zoom_pixels -= 25 + self.zoom_pixels = max(50, self.zoom_pixels) + self.update_minimap() + + def on_zoom_minus(self): + self.zoom_pixels += 25 + self.zoom_pixels = min(self.zoom_pixels, 300) + self.update_minimap() + + def set_navi_enable(self, boolean): + self.zoom_p_button.setEnabled(boolean) + self.zoom_m_button.setEnabled(boolean) + self.run_button.setEnabled(boolean) + self.tl_slider.setEnabled(boolean) + self.play_button.setEnabled(boolean) + self.lcd.setEnabled(boolean) + + def hit_number_key(self, number): + if number == self.current_object: + return + self.current_object = number + if self.fbrs_controller is not None: + self.fbrs_controller.unanchor() + self.console_push_text(f'Current object changed to {number}.') + self.clear_brush() + self.vis_brush(self.last_ex, self.last_ey) + self.update_interact_vis() + self.show_current_frame() + + def clear_brush(self): + self.brush_vis_map.fill(0) + self.brush_vis_alpha.fill(0) + + def vis_brush(self, ex, ey): + self.brush_vis_map = cv2.circle(self.brush_vis_map, + (int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1) + self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha, + (int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1) + + def on_mouse_press(self, event): + if self.is_pos_out_of_bound(event.x(), event.y()): + return + + # mid-click + if (event.button() == Qt.MidButton): + ex, ey = self.get_scaled_pos(event.x(), event.y()) + target_object = self.current_mask[int(ey),int(ex)] + if target_object in self.vis_target_objects: + self.vis_target_objects.remove(target_object) + else: + self.vis_target_objects.append(target_object) + self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}') + self.show_current_frame() + return + + self.right_click = (event.button() == Qt.RightButton) + self.pressed = True + + h, w = self.height, self.width + + self.load_current_torch_image_mask() + image = self.current_image_torch + + last_interaction = self.interaction + new_interaction = None + if self.curr_interaction == 'Scribble': + if last_interaction is None or type(last_interaction) != ScribbleInteraction: + self.complete_interaction() + new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(), + (h, w), self.s2m_controller, self.num_objects) + elif self.curr_interaction == 'Free': + if last_interaction is None or type(last_interaction) != FreeInteraction: + self.complete_interaction() + new_interaction = FreeInteraction(image, self.current_mask, (h, w), + self.num_objects) + new_interaction.set_size(self.brush_size) + elif self.curr_interaction == 'Click': + if (last_interaction is None or type(last_interaction) != ClickInteraction + or last_interaction.tar_obj != self.current_object): + self.complete_interaction() + self.fbrs_controller.unanchor() + new_interaction = ClickInteraction(image, self.current_prob, (h, w), + self.fbrs_controller, self.current_object) + + if new_interaction is not None: + self.interaction = new_interaction + + # Just motion it as the first step + self.on_mouse_motion(event) + + def on_mouse_motion(self, event): + ex, ey = self.get_scaled_pos(event.x(), event.y()) + self.last_ex, self.last_ey = ex, ey + self.clear_brush() + # Visualize + self.vis_brush(ex, ey) + if self.pressed: + if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': + obj = 0 if self.right_click else self.current_object + self.vis_map, self.vis_alpha = self.interaction.push_point( + ex, ey, obj, (self.vis_map, self.vis_alpha) + ) + self.update_interact_vis() + self.update_minimap() + + def update_interacted_mask(self): + self.current_prob = self.interacted_prob + self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob) + self.show_current_frame() + self.save_current_mask() + self.curr_frame_dirty = False + + def complete_interaction(self): + if self.interaction is not None: + self.clear_visualization() + self.interaction = None + + def on_mouse_release(self, event): + if not self.pressed: + # this can happen when the initial press is out-of-bound + return + + ex, ey = self.get_scaled_pos(event.x(), event.y()) + + self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur)) + interaction = self.interaction + + if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': + self.on_mouse_motion(event) + interaction.end_path() + if self.curr_interaction == 'Free': + self.clear_visualization() + elif self.curr_interaction == 'Click': + ex, ey = self.get_scaled_pos(event.x(), event.y()) + self.vis_map, self.vis_alpha = interaction.push_point(ex, ey, + self.right_click, (self.vis_map, self.vis_alpha)) + + self.interacted_prob = interaction.predict() + self.update_interacted_mask() + self.update_gpu_usage() + + self.pressed = self.right_click = False + + def wheelEvent(self, event): + ex, ey = self.get_scaled_pos(event.x(), event.y()) + if self.curr_interaction == 'Free': + self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30) + self.clear_brush() + self.vis_brush(ex, ey) + self.update_interact_vis() + self.update_minimap() + + def update_gpu_usage(self): + info = torch.cuda.mem_get_info() + global_free, global_total = info + global_free /= (2**30) + global_total /= (2**30) + global_used = global_total - global_free + + self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB') + self.gpu_mem_gauge.setValue(round(global_used/global_total*100)) + + used_by_torch = torch.cuda.max_memory_allocated() / (2**20) + self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB') + self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024)) + + def on_gpu_timer(self): + self.update_gpu_usage() + + def update_memory_size(self): + try: + max_work_elements = self.processor.memory.max_work_elements + max_long_elements = self.processor.memory.max_long_elements + + curr_work_elements = self.processor.memory.work_mem.size + curr_long_elements = self.processor.memory.long_mem.size + + self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}') + self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100)) + + self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}') + self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100)) + + except AttributeError: + self.work_mem_gauge.setFormat('Unknown') + self.long_mem_gauge.setFormat('Unknown') + self.work_mem_gauge.setValue(0) + self.long_mem_gauge.setValue(0) + + def on_work_min_change(self): + if self.initialized: + self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1)) + self.update_config() + + def on_work_max_change(self): + if self.initialized: + self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1)) + self.update_config() + + def update_config(self): + if self.initialized: + self.config['min_mid_term_frames'] = self.work_mem_min.value() + self.config['max_mid_term_frames'] = self.work_mem_max.value() + self.config['max_long_term_elements'] = self.long_mem_max.value() + self.config['num_prototypes'] = self.num_prototypes_box.value() + self.config['mem_every'] = self.mem_every_box.value() + + self.processor.update_config(self.config) + + def on_clear_memory(self): + self.processor.clear_memory() + torch.cuda.empty_cache() + self.update_gpu_usage() + self.update_memory_size() + + def _open_file(self, prompt): + options = QFileDialog.Options() + file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options) + return file_name + + def on_import_mask(self): + file_name = self._open_file('Mask') + if len(file_name) == 0: + return + + mask = self.res_man.read_external_image(file_name, size=(self.height, self.width)) + + shape_condition = ( + (len(mask.shape) == 2) and + (mask.shape[-1] == self.width) and + (mask.shape[-2] == self.height) + ) + + object_condition = ( + mask.max() <= self.num_objects + ) + + if not shape_condition: + self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.') + elif not object_condition: + self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') + else: + self.console_push_text(f'Mask file {file_name} loaded.') + self.current_image_torch = self.current_prob = None + self.current_mask = mask + self.show_current_frame() + self.save_current_mask() + + def on_import_layer(self): + file_name = self._open_file('Layer') + if len(file_name) == 0: + return + + self._try_load_layer(file_name) + + def _try_load_layer(self, file_name): + try: + layer = self.res_man.read_external_image(file_name, size=(self.height, self.width)) + + if layer.shape[-1] == 3: + layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1) + + condition = ( + (len(layer.shape) == 3) and + (layer.shape[-1] == 4) and + (layer.shape[-2] == self.width) and + (layer.shape[-3] == self.height) + ) + + if not condition: + self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.') + else: + self.console_push_text(f'Layer file {file_name} loaded.') + self.overlay_layer = layer + self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255 + self.show_current_frame() + except FileNotFoundError: + self.console_push_text(f'{file_name} not found.') + + def on_save_visualization_toggle(self): + self.save_visualization = self.save_visualization_checkbox.isChecked() diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..daf852b30a84893c836d7c3350b727aeed5d0a6b --- /dev/null +++ b/inference/interact/gui_utils.py @@ -0,0 +1,40 @@ +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar) + + +def create_parameter_box(min_val, max_val, text, step=1, callback=None): + layout = QHBoxLayout() + + dial = QSpinBox() + dial.setMaximumHeight(28) + dial.setMaximumWidth(150) + dial.setMinimum(min_val) + dial.setMaximum(max_val) + dial.setAlignment(Qt.AlignRight) + dial.setSingleStep(step) + dial.valueChanged.connect(callback) + + label = QLabel(text) + label.setAlignment(Qt.AlignRight) + + layout.addWidget(label) + layout.addWidget(dial) + + return dial, layout + + +def create_gauge(text): + layout = QHBoxLayout() + + gauge = QProgressBar() + gauge.setMaximumHeight(28) + gauge.setMaximumWidth(200) + gauge.setAlignment(Qt.AlignCenter) + + label = QLabel(text) + label.setAlignment(Qt.AlignRight) + + layout.addWidget(label) + layout.addWidget(gauge) + + return gauge, layout diff --git a/inference/interact/interaction.py b/inference/interact/interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..19f83f9d58a00cac079a7ba5c239196378603b64 --- /dev/null +++ b/inference/interact/interaction.py @@ -0,0 +1,252 @@ +""" +Contains all the types of interaction related to the GUI +Not related to automatic evaluation in the DAVIS dataset + +You can inherit the Interaction class to create new interaction types +undo is (sometimes partially) supported +""" + + +import torch +import torch.nn.functional as F +import numpy as np +import cv2 +import time +from .interactive_utils import color_map, index_numpy_to_one_hot_torch + + +def aggregate_sbg(prob, keep_bg=False, hard=False): + device = prob.device + k, h, w = prob.shape + ex_prob = torch.zeros((k+1, h, w), device=device) + ex_prob[0] = 0.5 + ex_prob[1:] = prob + ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7) + logits = torch.log((ex_prob /(1-ex_prob))) + + if hard: + # Very low temperature o((⊙﹏⊙))o 🥶 + logits *= 1000 + + if keep_bg: + return F.softmax(logits, dim=0) + else: + return F.softmax(logits, dim=0)[1:] + +def aggregate_wbg(prob, keep_bg=False, hard=False): + k, h, w = prob.shape + new_prob = torch.cat([ + torch.prod(1-prob, dim=0, keepdim=True), + prob + ], 0).clamp(1e-7, 1-1e-7) + logits = torch.log((new_prob /(1-new_prob))) + + if hard: + # Very low temperature o((⊙﹏⊙))o 🥶 + logits *= 1000 + + if keep_bg: + return F.softmax(logits, dim=0) + else: + return F.softmax(logits, dim=0)[1:] + +class Interaction: + def __init__(self, image, prev_mask, true_size, controller): + self.image = image + self.prev_mask = prev_mask + self.controller = controller + self.start_time = time.time() + + self.h, self.w = true_size + + self.out_prob = None + self.out_mask = None + + def predict(self): + pass + + +class FreeInteraction(Interaction): + def __init__(self, image, prev_mask, true_size, num_objects): + """ + prev_mask should be index format numpy array + """ + super().__init__(image, prev_mask, true_size, None) + + self.K = num_objects + + self.drawn_map = self.prev_mask.copy() + self.curr_path = [[] for _ in range(self.K + 1)] + + self.size = None + + def set_size(self, size): + self.size = size + + """ + k - object id + vis - a tuple (visualization map, pass through alpha). None if not needed. + """ + def push_point(self, x, y, k, vis=None): + if vis is not None: + vis_map, vis_alpha = vis + selected = self.curr_path[k] + selected.append((x, y)) + if len(selected) >= 2: + cv2.line(self.drawn_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + k, thickness=self.size) + + # Plot visualization + if vis is not None: + # Visualization for drawing + if k == 0: + vis_map = cv2.line(vis_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + color_map[k], thickness=self.size) + else: + vis_map = cv2.line(vis_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + color_map[k], thickness=self.size) + # Visualization on/off boolean filter + vis_alpha = cv2.line(vis_alpha, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + 0.75, thickness=self.size) + + if vis is not None: + return vis_map, vis_alpha + + def end_path(self): + # Complete the drawing + self.curr_path = [[] for _ in range(self.K + 1)] + + def predict(self): + self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda() + # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda() + # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:]) + # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True) + return self.out_prob + +class ScribbleInteraction(Interaction): + def __init__(self, image, prev_mask, true_size, controller, num_objects): + """ + prev_mask should be in an indexed form + """ + super().__init__(image, prev_mask, true_size, controller) + + self.K = num_objects + + self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8) + self.drawn_map.fill(255) + # background + k + self.curr_path = [[] for _ in range(self.K + 1)] + self.size = 3 + + """ + k - object id + vis - a tuple (visualization map, pass through alpha). None if not needed. + """ + def push_point(self, x, y, k, vis=None): + if vis is not None: + vis_map, vis_alpha = vis + selected = self.curr_path[k] + selected.append((x, y)) + if len(selected) >= 2: + self.drawn_map = cv2.line(self.drawn_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + k, thickness=self.size) + + # Plot visualization + if vis is not None: + # Visualization for drawing + if k == 0: + vis_map = cv2.line(vis_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + color_map[k], thickness=self.size) + else: + vis_map = cv2.line(vis_map, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + color_map[k], thickness=self.size) + # Visualization on/off boolean filter + vis_alpha = cv2.line(vis_alpha, + (int(round(selected[-2][0])), int(round(selected[-2][1]))), + (int(round(selected[-1][0])), int(round(selected[-1][1]))), + 0.75, thickness=self.size) + + # Optional vis return + if vis is not None: + return vis_map, vis_alpha + + def end_path(self): + # Complete the drawing + self.curr_path = [[] for _ in range(self.K + 1)] + + def predict(self): + self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map) + self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True) + return self.out_prob + + +class ClickInteraction(Interaction): + def __init__(self, image, prev_mask, true_size, controller, tar_obj): + """ + prev_mask in a prob. form + """ + super().__init__(image, prev_mask, true_size, controller) + self.tar_obj = tar_obj + + # negative/positive for each object + self.pos_clicks = [] + self.neg_clicks = [] + + self.out_prob = self.prev_mask.clone() + + """ + neg - Negative interaction or not + vis - a tuple (visualization map, pass through alpha). None if not needed. + """ + def push_point(self, x, y, neg, vis=None): + # Clicks + if neg: + self.neg_clicks.append((x, y)) + else: + self.pos_clicks.append((x, y)) + + # Do the prediction + self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg) + + # Plot visualization + if vis is not None: + vis_map, vis_alpha = vis + # Visualization for clicks + if neg: + vis_map = cv2.circle(vis_map, + (int(round(x)), int(round(y))), + 2, color_map[0], thickness=-1) + else: + vis_map = cv2.circle(vis_map, + (int(round(x)), int(round(y))), + 2, color_map[self.tar_obj], thickness=-1) + + vis_alpha = cv2.circle(vis_alpha, + (int(round(x)), int(round(y))), + 2, 1, thickness=-1) + + # Optional vis return + return vis_map, vis_alpha + + def predict(self): + self.out_prob = self.prev_mask.clone() + # a small hack to allow the interacting object to overwrite existing masks + # without remembering all the object probabilities + self.out_prob = torch.clamp(self.out_prob, max=0.9) + self.out_prob[self.tar_obj] = self.obj_mask + self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True) + return self.out_prob diff --git a/inference/interact/interactive_utils.py b/inference/interact/interactive_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9961f63aab4f59323454f34d76241a743190198f --- /dev/null +++ b/inference/interact/interactive_utils.py @@ -0,0 +1,175 @@ +# Modifed from https://github.com/seoungwugoh/ivs-demo + +import numpy as np + +import torch +import torch.nn.functional as F +from util.palette import davis_palette +from dataset.range_transform import im_normalization + +def image_to_torch(frame: np.ndarray, device='cuda'): + # frame: H*W*3 numpy array + frame = frame.transpose(2, 0, 1) + frame = torch.from_numpy(frame).float().to(device)/255 + frame_norm = im_normalization(frame) + return frame_norm, frame + +def torch_prob_to_numpy_mask(prob): + mask = torch.argmax(prob, dim=0) + mask = mask.cpu().numpy().astype(np.uint8) + return mask + +def index_numpy_to_one_hot_torch(mask, num_classes): + mask = torch.from_numpy(mask).long() + return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float() + +""" +Some constants fro visualization +""" +color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy() +# scales for better visualization +color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8) +color_map = color_map_np.tolist() +if torch.cuda.is_available(): + color_map_torch = torch.from_numpy(color_map_np).cuda() / 255 + +grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32) +if torch.cuda.is_available(): + grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0) + +def get_visualization(mode, image, mask, layer, target_object): + if mode == 'fade': + return overlay_davis(image, mask, fade=True) + elif mode == 'davis': + return overlay_davis(image, mask) + elif mode == 'light': + return overlay_davis(image, mask, 0.9) + elif mode == 'popup': + return overlay_popup(image, mask, target_object) + elif mode == 'layered': + if layer is None: + print('Layer file not given. Defaulting to DAVIS.') + return overlay_davis(image, mask) + else: + return overlay_layer(image, mask, layer, target_object) + else: + raise NotImplementedError + +def get_visualization_torch(mode, image, prob, layer, target_object): + if mode == 'fade': + return overlay_davis_torch(image, prob, fade=True) + elif mode == 'davis': + return overlay_davis_torch(image, prob) + elif mode == 'light': + return overlay_davis_torch(image, prob, 0.9) + elif mode == 'popup': + return overlay_popup_torch(image, prob, target_object) + elif mode == 'layered': + if layer is None: + print('Layer file not given. Defaulting to DAVIS.') + return overlay_davis_torch(image, prob) + else: + return overlay_layer_torch(image, prob, layer, target_object) + else: + raise NotImplementedError + +def overlay_davis(image, mask, alpha=0.5, fade=False): + """ Overlay segmentation on top of RGB image. from davis official""" + im_overlay = image.copy() + + colored_mask = color_map_np[mask] + foreground = image*alpha + (1-alpha)*colored_mask + binary_mask = (mask > 0) + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + if fade: + im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6 + return im_overlay.astype(image.dtype) + +def overlay_popup(image, mask, target_object): + # Keep foreground colored. Convert background to grayscale. + im_overlay = image.copy() + + binary_mask = ~(np.isin(mask, target_object)) + colored_region = (im_overlay[binary_mask]*grayscale_weights).sum(-1, keepdims=-1) + im_overlay[binary_mask] = colored_region + return im_overlay.astype(image.dtype) + +def overlay_layer(image, mask, layer, target_object): + # insert a layer between foreground and background + # The CPU version is less accurate because we are using the hard mask + # The GPU version has softer edges as it uses soft probabilities + obj_mask = (np.isin(mask, target_object)).astype(np.float32) + layer_alpha = layer[:, :, 3].astype(np.float32) / 255 + layer_rgb = layer[:, :, :3] + background_alpha = np.maximum(obj_mask, layer_alpha)[:,:,np.newaxis] + obj_mask = obj_mask[:,:,np.newaxis] + im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 255) + return im_overlay.astype(image.dtype) + +def overlay_davis_torch(image, mask, alpha=0.5, fade=False): + """ Overlay segmentation on top of RGB image. from davis official""" + # Changes the image in-place to avoid copying + image = image.permute(1, 2, 0) + im_overlay = image + mask = torch.argmax(mask, dim=0) + + colored_mask = color_map_torch[mask] + foreground = image*alpha + (1-alpha)*colored_mask + binary_mask = (mask > 0) + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + if fade: + im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6 + + im_overlay = (im_overlay*255).cpu().numpy() + im_overlay = im_overlay.astype(np.uint8) + + return im_overlay + +def overlay_popup_torch(image, mask, target_object): + # Keep foreground colored. Convert background to grayscale. + image = image.permute(1, 2, 0) + + if len(target_object) == 0: + obj_mask = torch.zeros_like(mask[0]).unsqueeze(2) + else: + # I should not need to convert this to numpy. + # uUsing list works most of the time but consistently fails + # if I include first object -> exclude it -> include it again. + # I check everywhere and it makes absolutely no sense. + # I am blaming this on PyTorch and calling it a day + obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0).unsqueeze(2) + gray_image = (image*grayscale_weights_torch).sum(-1, keepdim=True) + im_overlay = obj_mask*image + (1-obj_mask)*gray_image + + im_overlay = (im_overlay*255).cpu().numpy() + im_overlay = im_overlay.astype(np.uint8) + + return im_overlay + +def overlay_layer_torch(image, mask, layer, target_object): + # insert a layer between foreground and background + # The CPU version is less accurate because we are using the hard mask + # The GPU version has softer edges as it uses soft probabilities + image = image.permute(1, 2, 0) + + if len(target_object) == 0: + obj_mask = torch.zeros_like(mask[0]) + else: + # I should not need to convert this to numpy. + # uUsing list works most of the time but consistently fails + # if I include first object -> exclude it -> include it again. + # I check everywhere and it makes absolutely no sense. + # I am blaming this on PyTorch and calling it a day + obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0) + layer_alpha = layer[:, :, 3] + layer_rgb = layer[:, :, :3] + background_alpha = torch.maximum(obj_mask, layer_alpha).unsqueeze(2) + obj_mask = obj_mask.unsqueeze(2) + im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 1) + + im_overlay = (im_overlay*255).cpu().numpy() + im_overlay = im_overlay.astype(np.uint8) + + return im_overlay diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f28af2e35a3ea29958e5eee4e19b26f1fa010b --- /dev/null +++ b/inference/interact/resource_manager.py @@ -0,0 +1,206 @@ +import os +from os import path +import shutil +import collections + +import cv2 +from PIL import Image +if not hasattr(Image, 'Resampling'): # Pillow<9.0 + Image.Resampling = Image +import numpy as np + +from util.palette import davis_palette +import progressbar + + +# https://bugs.python.org/issue28178 +# ah python ah why +class LRU: + def __init__(self, func, maxsize=128): + self.cache = collections.OrderedDict() + self.func = func + self.maxsize = maxsize + + def __call__(self, *args): + cache = self.cache + if args in cache: + cache.move_to_end(args) + return cache[args] + result = self.func(*args) + cache[args] = result + if len(cache) > self.maxsize: + cache.popitem(last=False) + return result + + def invalidate(self, key): + self.cache.pop(key, None) + + +class ResourceManager: + def __init__(self, config): + # determine inputs + images = config['images'] + video = config['video'] + self.workspace = config['workspace'] + self.size = config['size'] + self.palette = davis_palette + + # create temporary workspace if not specified + if self.workspace is None: + if images is not None: + basename = path.basename(images) + elif video is not None: + basename = path.basename(video)[:-4] + else: + raise NotImplementedError( + 'Either images, video, or workspace has to be specified') + + self.workspace = path.join('./workspace', basename) + + print(f'Workspace is in: {self.workspace}') + + # determine the location of input images + need_decoding = False + need_resizing = False + if path.exists(path.join(self.workspace, 'images')): + pass + elif images is not None: + need_resizing = True + elif video is not None: + # will decode video into frames later + need_decoding = True + + # create workspace subdirectories + self.image_dir = path.join(self.workspace, 'images') + self.mask_dir = path.join(self.workspace, 'masks') + os.makedirs(self.image_dir, exist_ok=True) + os.makedirs(self.mask_dir, exist_ok=True) + + # convert read functions to be buffered + self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size']) + self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size']) + + # extract frames from video + if need_decoding: + self._extract_frames(video) + + # copy/resize existing images to the workspace + if need_resizing: + self._copy_resize_frames(images) + + # read all frame names + self.names = sorted(os.listdir(self.image_dir)) + self.names = [f[:-4] for f in self.names] # remove extensions + self.length = len(self.names) + + assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.' + + print(f'{self.length} images found.') + + self.height, self.width = self.get_image(0).shape[:2] + self.visualization_init = False + + def _extract_frames(self, video): + cap = cv2.VideoCapture(video) + frame_index = 0 + print(f'Extracting frames from {video} into {self.image_dir}...') + bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength) + while(cap.isOpened()): + _, frame = cap.read() + if frame is None: + break + if self.size > 0: + h, w = frame.shape[:2] + new_w = (w*self.size//min(w, h)) + new_h = (h*self.size//min(w, h)) + if new_w != w or new_h != h: + frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) + cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame) + frame_index += 1 + bar.update(frame_index) + bar.finish() + print('Done!') + + def _copy_resize_frames(self, images): + image_list = os.listdir(images) + print(f'Copying/resizing frames into {self.image_dir}...') + for image_name in progressbar.progressbar(image_list): + if self.size < 0: + # just copy + shutil.copy2(path.join(images, image_name), self.image_dir) + else: + frame = cv2.imread(path.join(images, image_name)) + h, w = frame.shape[:2] + new_w = (w*self.size//min(w, h)) + new_h = (h*self.size//min(w, h)) + if new_w != w or new_h != h: + frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) + cv2.imwrite(path.join(self.image_dir, image_name), frame) + print('Done!') + + def save_mask(self, ti, mask): + # mask should be uint8 H*W without channels + assert 0 <= ti < self.length + assert isinstance(mask, np.ndarray) + + mask = Image.fromarray(mask) + mask.putpalette(self.palette) + mask.save(path.join(self.mask_dir, self.names[ti]+'.png')) + self.invalidate(ti) + + def save_visualization(self, ti, image): + # image should be uint8 3*H*W + assert 0 <= ti < self.length + assert isinstance(image, np.ndarray) + if not self.visualization_init: + self.visualization_dir = path.join(self.workspace, 'visualization') + os.makedirs(self.visualization_dir, exist_ok=True) + self.visualization_init = True + + image = Image.fromarray(image) + image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg')) + + def _get_image_unbuffered(self, ti): + # returns H*W*3 uint8 array + assert 0 <= ti < self.length + + image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg')) + image = np.array(image) + return image + + def _get_mask_unbuffered(self, ti): + # returns H*W uint8 array + assert 0 <= ti < self.length + + mask_path = path.join(self.mask_dir, self.names[ti]+'.png') + if path.exists(mask_path): + mask = Image.open(mask_path) + mask = np.array(mask) + return mask + else: + return None + + def read_external_image(self, file_name, size=None): + image = Image.open(file_name) + is_mask = image.mode in ['L', 'P'] + if size is not None: + # PIL uses (width, height) + image = image.resize((size[1], size[0]), + resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC) + image = np.array(image) + return image + + def invalidate(self, ti): + # the image buffer is never invalidated + self.get_mask.invalidate((ti,)) + + def __len__(self): + return self.length + + @property + def h(self): + return self.height + + @property + def w(self): + return self.width diff --git a/inference/interact/s2m/__init__.py b/inference/interact/s2m/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/interact/s2m/_deeplab.py b/inference/interact/s2m/_deeplab.py new file mode 100644 index 0000000000000000000000000000000000000000..e663007dde9a56add1aa540be76cf2f5d81de82f --- /dev/null +++ b/inference/interact/s2m/_deeplab.py @@ -0,0 +1,180 @@ +# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import _SimpleSegmentationModel + + +__all__ = ["DeepLabV3"] + + +class DeepLabV3(_SimpleSegmentationModel): + """ + Implements DeepLabV3 model from + `"Rethinking Atrous Convolution for Semantic Image Segmentation" + `_. + + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + pass + +class DeepLabHeadV3Plus(nn.Module): + def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabHeadV3Plus, self).__init__() + self.project = nn.Sequential( + nn.Conv2d(low_level_channels, 48, 1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU(inplace=True), + ) + + self.aspp = ASPP(in_channels, aspp_dilate) + + self.classifier = nn.Sequential( + nn.Conv2d(304, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + low_level_feature = self.project( feature['low_level'] ) + output_feature = self.aspp(feature['out']) + output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) + return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class DeepLabHead(nn.Module): + def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabHead, self).__init__() + + self.classifier = nn.Sequential( + ASPP(in_channels, aspp_dilate), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + return self.classifier( feature['out'] ) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class AtrousSeparableConvolution(nn.Module): + """ Atrous Separable Convolution + """ + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, bias=True): + super(AtrousSeparableConvolution, self).__init__() + self.body = nn.Sequential( + # Separable Conv + nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), + # PointWise Conv + nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), + ) + + self._init_weight() + + def forward(self, x): + return self.body(x) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + modules = [ + nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ] + super(ASPPConv, self).__init__(*modules) + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super(ASPPPooling, self).__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True)) + + def forward(self, x): + size = x.shape[-2:] + x = super(ASPPPooling, self).forward(x) + return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + +class ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates): + super(ASPP, self).__init__() + out_channels = 256 + modules = [] + modules.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True))) + + rate1, rate2, rate3 = tuple(atrous_rates) + modules.append(ASPPConv(in_channels, out_channels, rate1)) + modules.append(ASPPConv(in_channels, out_channels, rate2)) + modules.append(ASPPConv(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(0.1),) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + + +def convert_to_separable_conv(module): + new_module = module + if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: + new_module = AtrousSeparableConvolution(module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.bias) + for name, child in module.named_children(): + new_module.add_module(name, convert_to_separable_conv(child)) + return new_module \ No newline at end of file diff --git a/inference/interact/s2m/s2m_network.py b/inference/interact/s2m/s2m_network.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f9a3fc4fcc9cc4210485fe24e4d740464d3f8a --- /dev/null +++ b/inference/interact/s2m/s2m_network.py @@ -0,0 +1,65 @@ +# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch + +from .utils import IntermediateLayerGetter +from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 +from . import s2m_resnet + +def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): + + if output_stride==8: + replace_stride_with_dilation=[False, True, True] + aspp_dilate = [12, 24, 36] + else: + replace_stride_with_dilation=[False, False, True] + aspp_dilate = [6, 12, 18] + + backbone = s2m_resnet.__dict__[backbone_name]( + pretrained=pretrained_backbone, + replace_stride_with_dilation=replace_stride_with_dilation) + + inplanes = 2048 + low_level_planes = 256 + + if name=='deeplabv3plus': + return_layers = {'layer4': 'out', 'layer1': 'low_level'} + classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) + elif name=='deeplabv3': + return_layers = {'layer4': 'out'} + classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + model = DeepLabV3(backbone, classifier) + return model + +def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): + + if backbone.startswith('resnet'): + model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) + else: + raise NotImplementedError + return model + + +# Deeplab v3 +def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) + + +# Deeplab v3+ +def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + Args: + num_classes (int): number of classes. + output_stride (int): output stride for deeplab. + pretrained_backbone (bool): If True, use the pretrained backbone. + """ + return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) + diff --git a/inference/interact/s2m/s2m_resnet.py b/inference/interact/s2m/s2m_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..89f1ce042c69daa9b18172a0aadf9bc1de6f300e --- /dev/null +++ b/inference/interact/s2m/s2m_resnet.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +try: + from torchvision.models.utils import load_state_dict_from_url +except ModuleNotFoundError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet50'] + + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) diff --git a/inference/interact/s2m/utils.py b/inference/interact/s2m/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2adecf63baa9c2db4cc70b04c25200f6bc0a6a6 --- /dev/null +++ b/inference/interact/s2m/utils.py @@ -0,0 +1,78 @@ +# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch + +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from collections import OrderedDict + +class _SimpleSegmentationModel(nn.Module): + def __init__(self, backbone, classifier): + super(_SimpleSegmentationModel, self).__init__() + self.backbone = backbone + self.classifier = classifier + + def forward(self, x): + input_shape = x.shape[-2:] + features = self.backbone(x) + x = self.classifier(features) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + return x + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + + Examples:: + + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + def __init__(self, model, return_layers): + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + + orig_return_layers = return_layers + return_layers = {k: v for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.named_children(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out diff --git a/inference/interact/s2m_controller.py b/inference/interact/s2m_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..e222259eebdf3938290c476f85ba8c8d79fb626d --- /dev/null +++ b/inference/interact/s2m_controller.py @@ -0,0 +1,39 @@ +import torch +import numpy as np +from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M + +from util.tensor_util import pad_divide_by, unpad + + +class S2MController: + """ + A controller for Scribble-to-Mask (for user interaction, not for DAVIS) + Takes the image, previous mask, and scribbles to produce a new mask + ignore_class is usually 255 + 0 is NOT the ignore class -- it is the label for the background + """ + def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'): + self.s2m_net = s2m_net + self.num_objects = num_objects + self.ignore_class = ignore_class + self.device = device + + def interact(self, image, prev_mask, scr_mask): + image = image.to(self.device, non_blocking=True) + prev_mask = prev_mask.unsqueeze(0) + + h, w = image.shape[-2:] + unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device) + + for ki in range(1, self.num_objects+1): + p_srb = (scr_mask==ki).astype(np.uint8) + n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8) + + Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device) + + inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1) + inputs, pads = pad_divide_by(inputs, 16) + + unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads) + + return unaggre_mask \ No newline at end of file diff --git a/inference/interact/timer.py b/inference/interact/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..d134aa180275528c0d485e6d237cd6832f62d77e --- /dev/null +++ b/inference/interact/timer.py @@ -0,0 +1,33 @@ +import time + +class Timer: + def __init__(self): + self._acc_time = 0 + self._paused = True + + def start(self): + if self._paused: + self.last_time = time.time() + self._paused = False + return self + + def pause(self): + self.count() + self._paused = True + return self + + def count(self): + if self._paused: + return self._acc_time + t = time.time() + self._acc_time += t - self.last_time + self.last_time = t + return self._acc_time + + def format(self): + # count = int(self.count()*100) + # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100) + return '%03.2f' % self.count() + + def __str__(self): + return self.format() \ No newline at end of file diff --git a/inference/kv_memory_store.py b/inference/kv_memory_store.py new file mode 100644 index 0000000000000000000000000000000000000000..33a332625f03b39f38f4b7162dcaddc8bafa262e --- /dev/null +++ b/inference/kv_memory_store.py @@ -0,0 +1,215 @@ +import torch +from typing import List + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + + """ + An object group is created when new objects enter the video + Objects in the same group share the same temporal extent + i.e., objects initialized in the same frame are in the same group + For DAVIS/interactive, there is only one object group + For YouTubeVOS, there can be multiple object groups + """ + + def __init__(self, count_usage: bool): + self.count_usage = count_usage + + # keys are stored in a single tensor and are shared between groups/objects + # values are stored as a list indexed by object groups + self.k = None + self.v = [] + self.obj_groups = [] + # for debugging only + self.all_objects = [] + + # shrinkage and selection are also single tensors + self.s = self.e = None + + # usage + if self.count_usage: + self.use_count = self.life_count = None + + def add(self, key, value, shrinkage, selection, objects: List[int]): + new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key + if self.k is None: + self.k = key + self.s = shrinkage + self.e = selection + if self.count_usage: + self.use_count = new_count + self.life_count = new_life + else: + self.k = torch.cat([self.k, key], -1) + if shrinkage is not None: + self.s = torch.cat([self.s, shrinkage], -1) + if selection is not None: + self.e = torch.cat([self.e, selection], -1) + if self.count_usage: + self.use_count = torch.cat([self.use_count, new_count], -1) + self.life_count = torch.cat([self.life_count, new_life], -1) + + # add the value + if objects is not None: + # When objects is given, v is a tensor; used in working memory + assert isinstance(value, torch.Tensor) + # First consume objects that are already in the memory bank + # cannot use set here because we need to preserve order + # shift by one as background is not part of value + remaining_objects = [obj-1 for obj in objects] + for gi, group in enumerate(self.obj_groups): + for obj in group: + # should properly raise an error if there are overlaps in obj_groups + remaining_objects.remove(obj) + self.v[gi] = torch.cat([self.v[gi], value[group]], -1) + + # If there are remaining objects, add them as a new group + if len(remaining_objects) > 0: + new_group = list(remaining_objects) + self.v.append(value[new_group]) + self.obj_groups.append(new_group) + self.all_objects.extend(new_group) + + assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order ' + else: + # When objects is not given, v is a list that already has the object groups sorted + # used in long-term memory + assert isinstance(value, list) + for gi, gv in enumerate(value): + if gv is None: + continue + if gi < self.num_groups: + self.v[gi] = torch.cat([self.v[gi], gv], -1) + else: + self.v.append(gv) + + def update_usage(self, usage): + # increase all life count by 1 + # increase use of indexed elements + if not self.count_usage: + return + + self.use_count += usage.view_as(self.use_count) + self.life_count += 1 + + def sieve_by_range(self, start: int, end: int, min_size: int): + # keep only the elements *outside* of this range (with some boundary conditions) + # i.e., concat (a[:start], a[end:]) + # min_size is only used for values, we do not sieve values under this size + # (because they are not consolidated) + + if end == 0: + # negative 0 would not work as the end index! + self.k = self.k[:,:,:start] + if self.count_usage: + self.use_count = self.use_count[:,:,:start] + self.life_count = self.life_count[:,:,:start] + if self.s is not None: + self.s = self.s[:,:,:start] + if self.e is not None: + self.e = self.e[:,:,:start] + + for gi in range(self.num_groups): + if self.v[gi].shape[-1] >= min_size: + self.v[gi] = self.v[gi][:,:,:start] + else: + self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) + if self.count_usage: + self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) + self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) + if self.s is not None: + self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) + if self.e is not None: + self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) + + for gi in range(self.num_groups): + if self.v[gi].shape[-1] >= min_size: + self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) + + def remove_obsolete_features(self, max_size: int): + # normalize with life duration + usage = self.get_usage().flatten() + + values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True) + survived = (usage > values[-1]) + + self.k = self.k[:, :, survived] + self.s = self.s[:, :, survived] if self.s is not None else None + # Long-term memory does not store ek so this should not be needed + self.e = self.e[:, :, survived] if self.e is not None else None + if self.num_groups > 1: + raise NotImplementedError("""The current data structure does not support feature removal with + multiple object groups (e.g., some objects start to appear later in the video) + The indices for "survived" is based on keys but not all values are present for every key + Basically we need to remap the indices for keys to values + """) + for gi in range(self.num_groups): + self.v[gi] = self.v[gi][:, :, survived] + + self.use_count = self.use_count[:, :, survived] + self.life_count = self.life_count[:, :, survived] + + def get_usage(self): + # return normalized usage + if not self.count_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_count / self.life_count + return usage + + def get_all_sliced(self, start: int, end: int): + # return k, sk, ek, usage in order, sliced by start and end + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[:,:,start:] + sk = self.s[:,:,start:] if self.s is not None else None + ek = self.e[:,:,start:] if self.e is not None else None + usage = self.get_usage()[:,:,start:] + else: + k = self.k[:,:,start:end] + sk = self.s[:,:,start:end] if self.s is not None else None + ek = self.e[:,:,start:end] if self.e is not None else None + usage = self.get_usage()[:,:,start:end] + + return k, sk, ek, usage + + def get_v_size(self, ni: int): + return self.v[ni].shape[2] + + def engaged(self): + return self.k is not None + + @property + def size(self): + if self.k is None: + return 0 + else: + return self.k.shape[-1] + + @property + def num_groups(self): + return len(self.v) + + @property + def key(self): + return self.k + + @property + def value(self): + return self.v + + @property + def shrinkage(self): + return self.s + + @property + def selection(self): + return self.e + diff --git a/inference/memory_manager.py b/inference/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ab693497586d84e848ce8261abfdf319c555ad --- /dev/null +++ b/inference/memory_manager.py @@ -0,0 +1,284 @@ +import torch +import warnings + +from inference.kv_memory_store import KeyValueMemoryStore +from model.memory_util import * + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, config): + self.hidden_dim = config['hidden_dim'] + self.top_k = config['top_k'] + + self.enable_long_term = config['enable_long_term'] + self.enable_long_term_usage = config['enable_long_term_count_usage'] + if self.enable_long_term: + self.max_mt_frames = config['max_mid_term_frames'] + self.min_mt_frames = config['min_mid_term_frames'] + self.num_prototypes = config['num_prototypes'] + self.max_long_elements = config['max_long_term_elements'] + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The hidden state will be stored in a single tensor for all objects + # B x num_objects x CH x H x W + self.hidden = None + + self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + if self.enable_long_term: + self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) + + self.reset_config = True + + def update_config(self, config): + self.reset_config = True + self.hidden_dim = config['hidden_dim'] + self.top_k = config['top_k'] + + assert self.enable_long_term == config['enable_long_term'], 'cannot update this' + assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this' + + self.enable_long_term_usage = config['enable_long_term_count_usage'] + if self.enable_long_term: + self.max_mt_frames = config['max_mid_term_frames'] + self.min_mt_frames = config['min_mid_term_frames'] + self.num_prototypes = config['num_prototypes'] + self.max_long_elements = config['max_long_term_elements'] + + def _readout(self, affinity, v): + # this function is for a single object group + return v @ affinity + + def match_memory(self, query_key, selection): + # query_key: B x C^k x H x W + # selection: B x C^k x H x W + num_groups = self.work_mem.num_groups + h, w = query_key.shape[-2:] + + query_key = query_key.flatten(start_dim=2) + selection = selection.flatten(start_dim=2) if selection is not None else None + + """ + Memory readout using keys + """ + + if self.enable_long_term and self.long_mem.engaged(): + # Use long-term memory + long_mem_size = self.long_mem.size + memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) + shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + work_mem_similarity = similarity[:, long_mem_size:] + long_mem_similarity = similarity[:, :long_mem_size] + + # get the usage with the first group + # the first group always have all the keys valid + affinity, usage = do_softmax( + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1), + top_k=self.top_k, inplace=True, return_usage=True) + affinity = [affinity] + + # compute affinity group by group as later groups only have a subset of keys + for gi in range(1, num_groups): + if gi < self.long_mem.num_groups: + # merge working and lt similarities before softmax + affinity_one_group = do_softmax( + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], + work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1), + top_k=self.top_k, inplace=True) + else: + # no long-term memory for this group + affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):], + top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity.append(affinity_one_group) + + all_memory_value = [] + for gi, gv in enumerate(self.work_mem.value): + # merge the working and lt values before readout + if gi < self.long_mem.num_groups: + all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1)) + else: + all_memory_value.append(gv) + + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_usage(work_usage.flatten()) + + if self.enable_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_usage(long_usage.flatten()) + else: + # No long-term memory + similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection) + + if self.enable_long_term: + affinity, usage = do_softmax(similarity, inplace=(num_groups==1), + top_k=self.top_k, return_usage=True) + + # Record memory usage for working memory + self.work_mem.update_usage(usage.flatten()) + else: + affinity = do_softmax(similarity, inplace=(num_groups==1), + top_k=self.top_k, return_usage=False) + + affinity = [affinity] + + # compute affinity group by group as later groups only have a subset of keys + for gi in range(1, num_groups): + affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], + top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity.append(affinity_one_group) + + all_memory_value = self.work_mem.value + + # Shared affinity within each group + all_readout_mem = torch.cat([ + self._readout(affinity[gi], gv) + for gi, gv in enumerate(all_memory_value) + ], 0) + + return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) + + def add_memory(self, key, shrinkage, value, objects, selection=None): + # key: 1*C*H*W + # value: 1*num_objects*C*H*W + # objects contain a list of object indices + if self.H is None or self.reset_config: + self.reset_config = False + self.H, self.W = key.shape[-2:] + self.HW = self.H*self.W + if self.enable_long_term: + # convert from num. frames to num. nodes + self.min_work_elements = self.min_mt_frames*self.HW + self.max_work_elements = self.max_mt_frames*self.HW + + # key: 1*C*N + # value: num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + value = value[0].flatten(start_dim=2) + + self.CK = key.shape[1] + self.CV = value.shape[1] + + if selection is not None: + if not self.enable_long_term: + warnings.warn('the selection factor is only needed in long-term mode', UserWarning) + selection = selection.flatten(start_dim=2) + + self.work_mem.add(key, value, shrinkage, selection, objects) + + # long-term memory cleanup + if self.enable_long_term: + # Do memory compressed if needed + if self.work_mem.size >= self.max_work_elements: + # Remove obsolete features if needed + if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): + self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) + + self.compress_features() + + + def create_hidden_state(self, n, sample_key): + # n is the TOTAL number of objects + h, w = sample_key.shape[-2:] + if self.hidden is None: + self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device) + elif self.hidden.shape[1] != n: + self.hidden = torch.cat([ + self.hidden, + torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device) + ], 1) + + assert(self.hidden.shape[1] == n) + + def set_hidden(self, hidden): + self.hidden = hidden + + def get_hidden(self): + return self.hidden + + def compress_features(self): + HW = self.HW + candidate_value = [] + total_work_mem_size = self.work_mem.size + for gv in self.work_mem.value: + # Some object groups might be added later in the video + # So not all keys have values associated with all objects + # We need to keep track of the key->value validity + mem_size_in_this_group = gv.shape[-1] + if mem_size_in_this_group == total_work_mem_size: + # full LT + candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + else: + # mem_size is smaller than total_work_mem_size, but at least HW + assert HW <= mem_size_in_this_group < total_work_mem_size + if mem_size_in_this_group > self.min_work_elements+HW: + # part of this object group still goes into LT + candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + else: + # this object group cannot go to the LT at all + candidate_value.append(None) + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) + + # remove consolidated working memory + self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) + + # add to long-term memory + self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) + + def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value): + # keys: 1*C*N + # values: num_objects*C*N + N = candidate_key.shape[-1] + + # find the indices with max usage + _, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + + # Prototypes are invalid for out-of-bound groups + validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value] + + prototype_key = candidate_key[:, :, prototype_indices] + prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None + + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection) + + # convert similarity to affinity + # need to do it group by group since the softmax normalization would be different + affinity = [ + do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None + for gi, gv in enumerate(candidate_value) + ] + + # some values can be have all False validity. Weed them out. + affinity = [ + aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity + ] + + # readout the values + prototype_value = [ + self._readout(affinity[gi], gv) if affinity[gi] is not None else None + for gi, gv in enumerate(candidate_value) + ] + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None + + return prototype_key, prototype_value, prototype_shrinkage diff --git a/input/blackswan/00000.png b/input/blackswan/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..3daeddee9b13c2d98cae65bcd584e67a7a5aa4ad Binary files /dev/null and b/input/blackswan/00000.png differ diff --git a/input/blackswan/00001.png b/input/blackswan/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..4e3d3dd49677f83d8c0cfb2e44520808cfae9bb1 Binary files /dev/null and b/input/blackswan/00001.png differ diff --git a/input/blackswan/00002.png b/input/blackswan/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..8cff513acad481f397dca72a9abab7937396c6a9 Binary files /dev/null and b/input/blackswan/00002.png differ diff --git a/input/blackswan/00003.png b/input/blackswan/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..4de4854720fce434d73ca6c3dd08c39207fbec6d Binary files /dev/null and b/input/blackswan/00003.png differ diff --git a/input/blackswan/00004.png b/input/blackswan/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..687b05e9149267a9323e18d8ce3f8a8e5e6414ee Binary files /dev/null and b/input/blackswan/00004.png differ diff --git a/input/blackswan/00005.png b/input/blackswan/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..7627fdff9669c24fa466bf0b5446484e555f3110 Binary files /dev/null and b/input/blackswan/00005.png differ diff --git a/input/blackswan/00006.png b/input/blackswan/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..6008c7e885819a783e46be59e7fa1afdfb39440a Binary files /dev/null and b/input/blackswan/00006.png differ diff --git a/input/blackswan/00007.png b/input/blackswan/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..09a0bfb317ede78eec273ed227bd2fefe1d1289b Binary files /dev/null and b/input/blackswan/00007.png differ diff --git a/input/blackswan/00008.png b/input/blackswan/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..8befb43c7e113711b5a604b4a14f1c0280b0d222 Binary files /dev/null and b/input/blackswan/00008.png differ diff --git a/input/blackswan/00009.png b/input/blackswan/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..3a7fd79e302e17cf72635e36195518eda9207759 Binary files /dev/null and b/input/blackswan/00009.png differ diff --git a/input/blackswan/00010.png b/input/blackswan/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..30e77e82ea7b8994df1cd86b9a78d2ae8d8ca805 Binary files /dev/null and b/input/blackswan/00010.png differ diff --git a/input/blackswan/00011.png b/input/blackswan/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..281508d8ef82eb584bcfe605bd362dc5dfc2c9d4 Binary files /dev/null and b/input/blackswan/00011.png differ diff --git a/input/blackswan/00012.png b/input/blackswan/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..d8191811bbec2bc0fec65a107d7e7a12dbb7c894 Binary files /dev/null and b/input/blackswan/00012.png differ diff --git a/input/blackswan/00013.png b/input/blackswan/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..6d8eb2ea540bc860801751c2de81743f79c6f117 Binary files /dev/null and b/input/blackswan/00013.png differ diff --git a/input/blackswan/00014.png b/input/blackswan/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..1a00bee60bddb2b209ada6101a7dabe367fe95b0 Binary files /dev/null and b/input/blackswan/00014.png differ diff --git a/input/blackswan/00015.png b/input/blackswan/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..525d8adf20c44d5457897eced67663ee31f2ecc2 Binary files /dev/null and b/input/blackswan/00015.png differ diff --git a/input/blackswan/00016.png b/input/blackswan/00016.png new file mode 100644 index 0000000000000000000000000000000000000000..926fbe7154c9ea50da38f8da6b95ce818f6cac5f Binary files /dev/null and b/input/blackswan/00016.png differ diff --git a/input/blackswan/00017.png b/input/blackswan/00017.png new file mode 100644 index 0000000000000000000000000000000000000000..871bad747eafeec29a94d4bd4b9242027050fb12 Binary files /dev/null and b/input/blackswan/00017.png differ diff --git a/input/blackswan/00018.png b/input/blackswan/00018.png new file mode 100644 index 0000000000000000000000000000000000000000..8855a4a77ecbdc8236a4e21baa985b0453f7f94a Binary files /dev/null and b/input/blackswan/00018.png differ diff --git a/input/blackswan/00019.png b/input/blackswan/00019.png new file mode 100644 index 0000000000000000000000000000000000000000..54052ca5f3ce592a77375c1c283a515216ad738e Binary files /dev/null and b/input/blackswan/00019.png differ diff --git a/input/blackswan/00020.png b/input/blackswan/00020.png new file mode 100644 index 0000000000000000000000000000000000000000..f73b15e01ec5fd759a08225ead1b57dec0e3625c Binary files /dev/null and b/input/blackswan/00020.png differ diff --git a/input/blackswan/00021.png b/input/blackswan/00021.png new file mode 100644 index 0000000000000000000000000000000000000000..b7a9b440e309692fd076d43bc938025d7e3009ad Binary files /dev/null and b/input/blackswan/00021.png differ diff --git a/input/blackswan/00022.png b/input/blackswan/00022.png new file mode 100644 index 0000000000000000000000000000000000000000..f8828e56e54d45a87338f9f1a9f9b007536dbfc7 Binary files /dev/null and b/input/blackswan/00022.png differ diff --git a/input/blackswan/00023.png b/input/blackswan/00023.png new file mode 100644 index 0000000000000000000000000000000000000000..3b0a26ea83f9d7e4bceaef7b59e0474f9438f1ed Binary files /dev/null and b/input/blackswan/00023.png differ diff --git a/input/blackswan/00024.png b/input/blackswan/00024.png new file mode 100644 index 0000000000000000000000000000000000000000..7c3800fecfc1d33d147e7fff9535f9c83782a03b Binary files /dev/null and b/input/blackswan/00024.png differ diff --git a/input/blackswan/00025.png b/input/blackswan/00025.png new file mode 100644 index 0000000000000000000000000000000000000000..805fddc4c09de2465fb5104307174a457a386382 Binary files /dev/null and b/input/blackswan/00025.png differ diff --git a/input/blackswan/00026.png b/input/blackswan/00026.png new file mode 100644 index 0000000000000000000000000000000000000000..f3ffa44062a393a8adb2d15db5b104b4b967a119 Binary files /dev/null and b/input/blackswan/00026.png differ diff --git a/input/blackswan/00027.png b/input/blackswan/00027.png new file mode 100644 index 0000000000000000000000000000000000000000..b491fe2a6bb7604f7324fed9e7b6cfeabf711f1e Binary files /dev/null and b/input/blackswan/00027.png differ diff --git a/input/blackswan/00028.png b/input/blackswan/00028.png new file mode 100644 index 0000000000000000000000000000000000000000..b6df5eede90ab15197877877f54b63110622b661 Binary files /dev/null and b/input/blackswan/00028.png differ diff --git a/input/blackswan/00029.png b/input/blackswan/00029.png new file mode 100644 index 0000000000000000000000000000000000000000..3e134241a7b72a71260574a25305057a75fdf653 Binary files /dev/null and b/input/blackswan/00029.png differ diff --git a/input/blackswan/00030.png b/input/blackswan/00030.png new file mode 100644 index 0000000000000000000000000000000000000000..2f7a438871076b04c2e9f8589d7c5decf113af17 Binary files /dev/null and b/input/blackswan/00030.png differ diff --git a/input/blackswan/00031.png b/input/blackswan/00031.png new file mode 100644 index 0000000000000000000000000000000000000000..06c9985e1a3017147d6821bb04c1feff330df770 Binary files /dev/null and b/input/blackswan/00031.png differ diff --git a/input/blackswan/00032.png b/input/blackswan/00032.png new file mode 100644 index 0000000000000000000000000000000000000000..65181c862bcdf6b046c6c8ff8b7298b17f80cf00 Binary files /dev/null and b/input/blackswan/00032.png differ diff --git a/input/blackswan/00033.png b/input/blackswan/00033.png new file mode 100644 index 0000000000000000000000000000000000000000..2b1897362f7936a208cdc53393f6ea4b2a6205f8 Binary files /dev/null and b/input/blackswan/00033.png differ diff --git a/input/blackswan/00034.png b/input/blackswan/00034.png new file mode 100644 index 0000000000000000000000000000000000000000..29da4528c83eaa36635a6ef6434d149ca48976a5 Binary files /dev/null and b/input/blackswan/00034.png differ diff --git a/input/blackswan/00035.png b/input/blackswan/00035.png new file mode 100644 index 0000000000000000000000000000000000000000..25c7d53a9fcff3c1496e78adc735e0e0b2596814 Binary files /dev/null and b/input/blackswan/00035.png differ diff --git a/input/blackswan/00036.png b/input/blackswan/00036.png new file mode 100644 index 0000000000000000000000000000000000000000..49daf9e8feb78c6ff5de5fab40ac44dd8d847b4a Binary files /dev/null and b/input/blackswan/00036.png differ diff --git a/input/blackswan/00037.png b/input/blackswan/00037.png new file mode 100644 index 0000000000000000000000000000000000000000..a584a712cdef876c0ff6e436702d8264c98bf13e Binary files /dev/null and b/input/blackswan/00037.png differ diff --git a/input/blackswan/00038.png b/input/blackswan/00038.png new file mode 100644 index 0000000000000000000000000000000000000000..4e55ec2761f64a751c0dd20b236046de9766c9ad Binary files /dev/null and b/input/blackswan/00038.png differ diff --git a/input/blackswan/00039.png b/input/blackswan/00039.png new file mode 100644 index 0000000000000000000000000000000000000000..15f547c33bfd454dcdb8b6c95bf4fe85f3a0b636 Binary files /dev/null and b/input/blackswan/00039.png differ diff --git a/input/blackswan/00040.png b/input/blackswan/00040.png new file mode 100644 index 0000000000000000000000000000000000000000..75b00d03b9e00562adf755afa79bb5f9603910d7 Binary files /dev/null and b/input/blackswan/00040.png differ diff --git a/input/blackswan/00041.png b/input/blackswan/00041.png new file mode 100644 index 0000000000000000000000000000000000000000..1d918ba53affeb5ab2ce9829b2dc59e263a61b8b Binary files /dev/null and b/input/blackswan/00041.png differ diff --git a/input/blackswan/00042.png b/input/blackswan/00042.png new file mode 100644 index 0000000000000000000000000000000000000000..70395b42de9dbf440eddd845c16f5c8cb44fce09 Binary files /dev/null and b/input/blackswan/00042.png differ diff --git a/input/blackswan/00043.png b/input/blackswan/00043.png new file mode 100644 index 0000000000000000000000000000000000000000..dc90f09eb0997efeec7a056b7706849a6bea5471 Binary files /dev/null and b/input/blackswan/00043.png differ diff --git a/input/blackswan/00044.png b/input/blackswan/00044.png new file mode 100644 index 0000000000000000000000000000000000000000..c1ca7a3198d7c8822f444be0725847cacff59482 Binary files /dev/null and b/input/blackswan/00044.png differ diff --git a/input/blackswan/00045.png b/input/blackswan/00045.png new file mode 100644 index 0000000000000000000000000000000000000000..089ad8c08b6f377835a43ee225e6cbc9458456ad Binary files /dev/null and b/input/blackswan/00045.png differ diff --git a/input/blackswan/00046.png b/input/blackswan/00046.png new file mode 100644 index 0000000000000000000000000000000000000000..7b8681a108f4b1bcaa434e42b546490e295ef4ca Binary files /dev/null and b/input/blackswan/00046.png differ diff --git a/input/blackswan/00047.png b/input/blackswan/00047.png new file mode 100644 index 0000000000000000000000000000000000000000..f737576977efdab060b2ea861441bf32a7a6181e Binary files /dev/null and b/input/blackswan/00047.png differ diff --git a/input/blackswan/00048.png b/input/blackswan/00048.png new file mode 100644 index 0000000000000000000000000000000000000000..fa06c0db7b1c7956e081b553f03d938377e74f7a Binary files /dev/null and b/input/blackswan/00048.png differ diff --git a/input/blackswan/00049.png b/input/blackswan/00049.png new file mode 100644 index 0000000000000000000000000000000000000000..1505fa35e8731f872ec83cb60c1d1a0b7c3f98da Binary files /dev/null and b/input/blackswan/00049.png differ diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/aggregate.py b/model/aggregate.py new file mode 100644 index 0000000000000000000000000000000000000000..7622391fb3ac9aa8b515df88cf3ea5297b367538 --- /dev/null +++ b/model/aggregate.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +# Soft aggregation from STM +def aggregate(prob, dim, return_logits=False): + new_prob = torch.cat([ + torch.prod(1-prob, dim=dim, keepdim=True), + prob + ], dim).clamp(1e-7, 1-1e-7) + logits = torch.log((new_prob /(1-new_prob))) + prob = F.softmax(logits, dim=dim) + + if return_logits: + return logits, prob + else: + return prob \ No newline at end of file diff --git a/model/attention.py b/model/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c593383fbbc9568301bc63b4a89f8af4f92f05f1 --- /dev/null +++ b/model/attention.py @@ -0,0 +1,916 @@ +import math +import pdb + +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.basic import DropOutLogit, ScaleOffset, DWConv2d + + +def multiply_by_ychunks(x, y, chunks=1): + if chunks <= 1: + return x @ y + else: + return torch.cat([x @ _y for _y in y.chunk(chunks, dim=-1)], dim=-1) + + +def multiply_by_xchunks(x, y, chunks=1): + if chunks <= 1: + return x @ y + else: + return torch.cat([_x @ y for _x in x.chunk(chunks, dim=-2)], dim=-2) + + +# Long-term attention +class MultiheadAttention(nn.Module): + def __init__(self, + d_model, + num_head=8, + dropout=0., + use_linear=True, + d_att=None, + use_dis=False, + qk_chunks=1, + max_mem_len_ratio=-1, + top_k=-1): + super().__init__() + self.d_model = d_model + self.num_head = num_head + self.use_dis = use_dis + self.qk_chunks = qk_chunks + self.max_mem_len_ratio = float(max_mem_len_ratio) + self.top_k = top_k + + self.hidden_dim = d_model // num_head + self.d_att = self.hidden_dim if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_linear = use_linear + + if use_linear: + self.linear_Q = nn.Linear(d_model, d_model) + self.linear_K = nn.Linear(d_model, d_model) + self.linear_V = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + self.projection = nn.Linear(d_model, d_model) + self._init_weight() + + def forward(self, Q, K, V): + """ + :param Q: A 3d tensor with shape of [T_q, bs, C_q] + :param K: A 3d tensor with shape of [T_k, bs, C_k] + :param V: A 3d tensor with shape of [T_v, bs, C_v] + """ + num_head = self.num_head + hidden_dim = self.hidden_dim + + bs = Q.size()[1] + + # Linear projections + if self.use_linear: + Q = self.linear_Q(Q) + K = self.linear_K(K) + V = self.linear_V(V) + + # Scale + Q = Q / self.T + + if not self.training and self.max_mem_len_ratio > 0: + mem_len_ratio = float(K.size(0)) / Q.size(0) + if mem_len_ratio > self.max_mem_len_ratio: + scaling_ratio = math.log(mem_len_ratio) / math.log( + self.max_mem_len_ratio) + Q = Q * scaling_ratio + + # Multi-head + Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) + K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) + V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) + + # Multiplication + QK = multiply_by_ychunks(Q, K, self.qk_chunks) + if self.use_dis: + QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True) + + # Activation + if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]: + top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1) + top_attn = torch.softmax(top_QK, dim=-1) + attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn) + else: + attn = torch.softmax(QK, dim=-1) + + # Dropouts + attn = self.dropout(attn) + + # Weighted sum + outputs = multiply_by_xchunks(attn, V, + self.qk_chunks).permute(2, 0, 1, 3) + + # Restore shape + outputs = outputs.reshape(-1, bs, self.d_model) + + outputs = self.projection(outputs) + + return outputs, attn + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +# Short-term attention +class MultiheadLocalAttentionV1(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.T = ((d_model / num_head)**0.5) + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(d_model, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.projection = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + + def forward(self, q, k, v): + n, c, h, w = v.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = c // self.num_head + + relative_emb = self.relative_emb_k(q) + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + + # Scale + q = q / self.T + + q = q.view(-1, hidden_dim, h, w) + k = k.reshape(-1, hidden_dim, h, w).contiguous() + unfolded_vu = self.pad_and_unfold(v).view( + n, self.num_head, hidden_dim, self.window_size * self.window_size, + h * w) + self.relative_emb_v.unsqueeze(0).unsqueeze(-1) + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + unfolded_k_mask = self.pad_and_unfold(memory_mask).bool().view( + 1, 1, self.window_size * self.window_size, + h * w).expand(n, self.num_head, -1, -1) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + relative_emb + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, hidden_dim, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + relative_emb + + qk_mask = 1 - unfolded_k_mask + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = torch.softmax(qk, dim=2) + + local_attn = self.dropout(local_attn) + + output = (local_attn.unsqueeze(2) * unfolded_vu).sum(dim=3).permute( + 3, 0, 1, 2).view(h * w, n, c) + + output = self.projection(output) + + return output, local_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x + + +class MultiheadLocalAttentionV2(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True, + d_att=None, + use_dis=False): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.hidden_dim = d_model // num_head + self.d_att = self.hidden_dim if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_dis = use_dis + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(self.d_att * self.num_head, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.projection = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + + self.drop_prob = dropout + + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v): + n, c, h, w = v.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = self.hidden_dim + + if self.qk_mask is not None and (h, w) == self.last_size_2d: + qk_mask = self.qk_mask + else: + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + unfolded_k_mask = self.pad_and_unfold(memory_mask).view( + 1, 1, self.window_size * self.window_size, h * w) + qk_mask = 1 - unfolded_k_mask + self.qk_mask = qk_mask + + relative_emb = self.relative_emb_k(q) + + # Scale + q = q / self.T + + q = q.view(-1, self.d_att, h, w) + k = k.view(-1, self.d_att, h, w) + v = v.view(-1, self.num_head, hidden_dim, h * w) + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, h * w) + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, hidden_dim, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, h * w) + if self.use_dis: + qk = 2 * qk - self.pad_and_unfold( + k.pow(2).sum(dim=1, keepdim=True)).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + + qk = qk + relative_emb + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = torch.softmax(qk, dim=2) + + local_attn = self.dropout(local_attn) + + agg_bias = torch.einsum('bhwn,hcw->bhnc', local_attn, + self.relative_emb_v) + + global_attn = self.local2global(local_attn, h, w) + + agg_value = (global_attn @ v.transpose(-2, -1)) + + output = (agg_value + agg_bias).permute(2, 0, 1, + 3).reshape(h * w, n, c) + + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def local2global(self, local_attn, height, width): + batch_size = local_attn.size()[0] + + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.local_mask is not None and (height, + width) == self.last_size_2d: + local_mask = self.local_mask + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=local_attn.device), + torch.arange(0, pad_width, device=local_attn.device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=local_attn.device), + torch.arange(0, width, device=local_attn.device) + ]) + + offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis + offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis + + local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= + self.max_dis) + local_mask = local_mask.view(1, 1, height * width, pad_height, + pad_width) + self.local_mask = local_mask + + global_attn = torch.zeros( + (batch_size, self.num_head, height * width, pad_height, pad_width), + device=local_attn.device) + global_attn = global_attn.type(torch.HalfTensor).cuda() + global_attn[local_mask.expand(batch_size, self.num_head, + -1, -1, -1)] = local_attn.transpose( + -1, -2).reshape(-1) + global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + batch_size, self.num_head, + height * width, height * width) + + return global_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x + + +class MultiheadLocalAttentionV3(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.T = ((d_model / num_head)**0.5) + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(d_model, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.projection = nn.Linear(d_model, d_model) + self.dropout = DropOutLogit(dropout) + + self.padded_local_mask = None + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v): + n, c, h, w = q.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = c // self.num_head + + relative_emb = self.relative_emb_k(q) + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + padded_local_mask, local_mask = self.compute_mask(h, + w, + device=q.device) + qk_mask = (~padded_local_mask).float() + + # Scale + q = q / self.T + + q = q.view(-1, self.num_head, hidden_dim, h * w) + k = k.view(-1, self.num_head, hidden_dim, h * w) + v = v.view(-1, self.num_head, hidden_dim, h * w) + + qk = q.transpose(-1, -2) @ k # [B, nH, kL, qL] + + pad_pixel = self.max_dis * self.dilation + + padded_qk = F.pad(qk.view(-1, self.num_head, h * w, h, w), + (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=-1e+8 if qk.dtype == torch.float32 else -1e+4) + + qk_mask = qk_mask * 1e+8 if (padded_qk.dtype + == torch.float32) else qk_mask * 1e+4 + padded_qk = padded_qk - qk_mask + + padded_qk[padded_local_mask.expand(n, self.num_head, -1, -1, + -1)] += relative_emb.transpose( + -1, -2).reshape(-1) + padded_qk = self.dropout(padded_qk) + + local_qk = padded_qk[padded_local_mask.expand(n, self.num_head, -1, -1, + -1)] + + global_qk = padded_qk[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + n, self.num_head, h * w, h * w) + + local_attn = torch.softmax(local_qk.reshape( + n, self.num_head, h * w, self.window_size * self.window_size), + dim=3) + global_attn = torch.softmax(global_qk, dim=3) + + agg_bias = torch.einsum('bhnw,hcw->nbhc', local_attn, + self.relative_emb_v).reshape(h * w, n, c) + + agg_value = (global_attn @ v.transpose(-2, -1)) + + output = agg_value + agg_bias + + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def compute_mask(self, height, width, device=None): + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.padded_local_mask is not None and (height, + width) == self.last_size_2d: + padded_local_mask = self.padded_local_mask + local_mask = self.local_mask + + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=device), + torch.arange(0, pad_width, device=device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=device), + torch.arange(0, width, device=device) + ]) + + qy = qy.reshape(-1, 1) + qx = qx.reshape(-1, 1) + offset_y = qy - ky.reshape(1, -1) + self.max_dis + offset_x = qx - kx.reshape(1, -1) + self.max_dis + padded_local_mask = (offset_y.abs() <= self.max_dis) & ( + offset_x.abs() <= self.max_dis) + padded_local_mask = padded_local_mask.view(1, 1, height * width, + pad_height, pad_width) + local_mask = padded_local_mask[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis] + pad_pixel = self.max_dis * self.dilation + local_mask = F.pad(local_mask.float(), + (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0).view(1, 1, height * width, pad_height, + pad_width) + self.padded_local_mask = padded_local_mask + self.local_mask = local_mask + + return padded_local_mask, local_mask + + +def linear_gate(x, dim=-1): + # return F.relu_(x).pow(2.) / x.size()[dim] + return torch.softmax(x, dim=dim) + + +def silu(x): + return x * torch.sigmoid(x) + + +class GatedPropagation(nn.Module): + def __init__(self, + d_qk, + d_vu, + num_head=8, + dropout=0., + use_linear=True, + d_att=None, + use_dis=False, + qk_chunks=1, + max_mem_len_ratio=-1, + top_k=-1, + expand_ratio=2.): + super().__init__() + expand_ratio = expand_ratio + self.expand_d_vu = int(d_vu * expand_ratio) + self.d_vu = d_vu + self.d_qk = d_qk + self.num_head = num_head + self.use_dis = use_dis + self.qk_chunks = qk_chunks + self.max_mem_len_ratio = float(max_mem_len_ratio) + self.top_k = top_k + + self.hidden_dim = self.expand_d_vu // num_head + self.d_att = d_qk // num_head if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_linear = use_linear + self.d_middle = self.d_att * self.num_head + + if use_linear: + self.linear_QK = nn.Linear(d_qk, self.d_middle) + half_d_vu = self.hidden_dim * num_head // 2 + self.linear_V1 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_V2 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_U1 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_U2 = nn.Linear(d_vu // 2, half_d_vu) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + + self.dw_conv = DWConv2d(self.expand_d_vu) + self.projection = nn.Linear(self.expand_d_vu, d_vu) + + self._init_weight() + + def forward(self, Q, K, V, U, size_2d): + """ + :param Q: A 3d tensor with shape of [T_q, bs, C_q] + :param K: A 3d tensor with shape of [T_k, bs, C_k] + :param V: A 3d tensor with shape of [T_v, bs, C_v] + """ + num_head = self.num_head + hidden_dim = self.hidden_dim + + l, bs, _ = Q.size() + + # Linear projections + if self.use_linear: + Q = K = self.linear_QK(Q) + + def cat(X1, X2): + if num_head > 1: + X1 = X1.view(-1, bs, num_head, hidden_dim // 2) + X2 = X2.view(-1, bs, num_head, hidden_dim // 2) + X = torch.cat([X1, X2], + dim=-1).view(-1, bs, num_head * hidden_dim) + else: + X = torch.cat([X1, X2], dim=-1) + return X + + V1, V2 = torch.split(V, self.d_vu // 2, dim=-1) + V1 = self.linear_V1(V1) + V2 = self.linear_V2(V2) + V = silu(cat(V1, V2)) + + U1, U2 = torch.split(U, self.d_vu // 2, dim=-1) + U1 = self.linear_U1(U1) + U2 = self.linear_U2(U2) + U = silu(cat(U1, U2)) + + # Scale + Q = Q / self.T + + if not self.training and self.max_mem_len_ratio > 0: + mem_len_ratio = float(K.size(0)) / Q.size(0) + if mem_len_ratio > self.max_mem_len_ratio: + scaling_ratio = math.log(mem_len_ratio) / math.log( + self.max_mem_len_ratio) + Q = Q * scaling_ratio + + # Multi-head + Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) + K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) + V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) + + # Multiplication + QK = multiply_by_ychunks(Q, K, self.qk_chunks) + if self.use_dis: + QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True) + + # Activation + if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]: + top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1) + top_attn = linear_gate(top_QK, dim=-1) + attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn) + else: + attn = linear_gate(QK, dim=-1) + + # Dropouts + attn = self.dropout(attn) + + # Weighted sum + outputs = multiply_by_xchunks(attn, V, + self.qk_chunks).permute(2, 0, 1, 3) + + # Restore shape + outputs = outputs.reshape(l, bs, -1) * U + + outputs = self.dw_conv(outputs, size_2d) + outputs = self.projection(outputs) + + return outputs, attn + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class LocalGatedPropagation(nn.Module): + def __init__(self, + d_qk, + d_vu, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True, + d_att=None, + use_dis=False, + expand_ratio=2.): + super().__init__() + expand_ratio = expand_ratio + self.expand_d_vu = int(d_vu * expand_ratio) + self.d_qk = d_qk + self.d_vu = d_vu + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.hidden_dim = self.expand_d_vu // num_head + self.d_att = d_qk // num_head if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_dis = use_dis + + self.d_middle = self.d_att * self.num_head + self.use_linear = use_linear + if use_linear: + self.linear_QK = nn.Conv2d(d_qk, self.d_middle, kernel_size=1) + self.linear_V = nn.Conv2d(d_vu, + self.expand_d_vu, + kernel_size=1, + groups=2) + self.linear_U = nn.Conv2d(d_vu, + self.expand_d_vu, + kernel_size=1, + groups=2) + + self.relative_emb_k = nn.Conv2d(self.d_middle, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.dw_conv = DWConv2d(self.expand_d_vu) + self.projection = nn.Linear(self.expand_d_vu, d_vu) + + self.dropout = nn.Dropout(dropout) + + self.drop_prob = dropout + + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v, u, size_2d): + n, c, h, w = v.size() + hidden_dim = self.hidden_dim + + if self.use_linear: + q = k = self.linear_QK(q) + v = silu(self.linear_V(v)) + # u = silu(self.linear_U(u)) + if self.num_head > 1: + v = v.view(-1, 2, self.num_head, hidden_dim // 2, + h * w).permute(0, 2, 1, 3, 4).reshape(n, -1, h, w) + # u = u.view(-1, 2, self.num_head, hidden_dim // 2, + # h * w).permute(4, 0, 2, 1, 3).reshape(h * w, n, -1) + # else: + # u = u.permute(2, 3, 0, 1).reshape(h * w, n, -1) + + if self.qk_mask is not None and (h, w) == self.last_size_2d: + qk_mask = self.qk_mask + else: + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + unfolded_k_mask = self.pad_and_unfold(memory_mask).view( + 1, 1, self.window_size * self.window_size, h * w) + qk_mask = 1 - unfolded_k_mask + self.qk_mask = qk_mask + + relative_emb = self.relative_emb_k(q) + + # Scale + q = q / self.T + + # print(q.shape) + # print(self.d_att, h, w) + # pdb.set_trace() + q = q.view(-1, self.d_att, h, w) + k = k.view(-1, self.d_att, h, w).contiguous() + v = v.view(-1, self.num_head, hidden_dim, h * w) + # print([n,c,h,w], q.shape, k.shape, v.shape);assert 0 + # [4, 1024, 24, 24] torch.Size([4, 64, 24, 24]) torch.Size([4, 64, 24, 24]) torch.Size([2, 1, 2048, 576]) + + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, h * w) + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, hidden_dim, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, h * w) + if self.use_dis: + qk = 2 * qk - self.pad_and_unfold( + k.pow(2).sum(dim=1, keepdim=True)).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + + qk = qk + relative_emb + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = linear_gate(qk, dim=2) + + local_attn = self.dropout(local_attn) + + global_attn = self.local2global(local_attn, h, w) + + agg_value = (global_attn @ v.transpose(-2, -1)).permute( + 2, 0, 1, 3).reshape(h * w, n, -1) + + # output = agg_value * u + output = agg_value + + output = self.dw_conv(output, size_2d) + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def local2global(self, local_attn, height, width): + batch_size = local_attn.size()[0] + + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.local_mask is not None and (height, + width) == self.last_size_2d: + local_mask = self.local_mask + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=local_attn.device), + torch.arange(0, pad_width, device=local_attn.device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=local_attn.device), + torch.arange(0, width, device=local_attn.device) + ]) + + offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis + offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis + + local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= + self.max_dis) + local_mask = local_mask.view(1, 1, height * width, pad_height, + pad_width) + self.local_mask = local_mask + + global_attn = torch.zeros( + (batch_size, self.num_head, height * width, pad_height, pad_width), + device=local_attn.device) + # global_attn = global_attn.type(torch.HalfTensor).cuda() + global_attn[local_mask.expand(batch_size, self.num_head, + -1, -1, -1)] = local_attn.transpose( + -1, -2).reshape(-1) + global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + batch_size, self.num_head, + height * width, height * width) + + return global_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x diff --git a/model/basic.py b/model/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..d09b2efa0801bd3ead0f7baf9388aa21a7ea79fe --- /dev/null +++ b/model/basic.py @@ -0,0 +1,205 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class ResBlock(nn.Module): + def __init__(self, indim, outdim=None, stride=1): + super(ResBlock, self).__init__() + if outdim == None: + outdim = indim + if indim == outdim and stride == 1: + self.downsample = None + else: + self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) + self.gn3 = nn.GroupNorm(8, outdim) + + self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) + + self.gn1 = nn.GroupNorm(8, outdim) + self.gn2 = nn.GroupNorm(8, outdim) + + def forward(self, x): + r = self.conv1(F.relu(x)) + r = self.gn1(r) + + r = self.conv2(F.relu(r)) + r = self.gn2(r) + + if self.downsample is not None: + x = self.gn3(self.downsample(x)) + + return x + r + +class ResGN(nn.Module): + def __init__(self, indim, outdim): + super().__init__() + self.res1 = ResBlock(indim, outdim) + self.res2 = ResBlock(outdim, outdim) + def forward(self, x): + return self.res2(self.res1(x)) + +class GroupNorm1D(nn.Module): + def __init__(self, indim, groups=8): + super().__init__() + self.gn = nn.GroupNorm(groups, indim) + + def forward(self, x): + return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) + + +class GNActDWConv2d(nn.Module): + def __init__(self, indim, gn_groups=32): + super().__init__() + self.gn = nn.GroupNorm(gn_groups, indim) + self.conv = nn.Conv2d(indim, + indim, + 5, + dilation=1, + padding=2, + groups=indim, + bias=False) + + def forward(self, x, size_2d): + h, w = size_2d + _, bs, c = x.size() + x = x.view(h, w, bs, c).permute(2, 3, 0, 1) + x = self.gn(x) + x = F.gelu(x) + x = self.conv(x) + x = x.view(bs, c, h * w).permute(2, 0, 1) + return x + + +class DWConv2d(nn.Module): + def __init__(self, indim, dropout=0.1): + super().__init__() + self.conv = nn.Conv2d(indim, + indim, + 5, + dilation=1, + padding=2, + groups=indim, + bias=False) + self.dropout = nn.Dropout2d(p=dropout, inplace=True) + + def forward(self, x, size_2d): + h, w = size_2d + _, bs, c = x.size() + x = x.view(h, w, bs, c).permute(2, 3, 0, 1) + x = self.conv(x) + x = self.dropout(x) + x = x.view(bs, c, h * w).permute(2, 0, 1) + return x + + +class ScaleOffset(nn.Module): + def __init__(self, indim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(indim)) + # torch.nn.init.normal_(self.gamma, std=0.02) + self.beta = nn.Parameter(torch.zeros(indim)) + + def forward(self, x): + if len(x.size()) == 3: + return x * self.gamma + self.beta + else: + return x * self.gamma.view(1, -1, 1, 1) + self.beta.view( + 1, -1, 1, 1) + + +class ConvGN(nn.Module): + def __init__(self, indim, outdim, kernel_size, gn_groups=8): + super().__init__() + self.conv = nn.Conv2d(indim, + outdim, + kernel_size, + padding=kernel_size // 2) + self.gn = nn.GroupNorm(gn_groups, outdim) + + def forward(self, x): + return self.gn(self.conv(x)) + + +def seq_to_2d(tensor, size_2d): + h, w = size_2d + _, n, c = tensor.size() + tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() + return tensor + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = ( + x.shape[0], + x.shape[1], + ) + (1, ) * (x.ndim - 2 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +def mask_out(x, y, mask_rate=0.15, training=False): + if mask_rate == 0. or not training: + return x + + keep_prob = 1 - mask_rate + shape = ( + x.shape[0], + x.shape[1], + ) + (1, ) * (x.ndim - 2 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x * random_tensor + y * (1 - random_tensor) + + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None, batch_dim=0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.batch_dim = batch_dim + + def forward(self, x): + return self.drop_path(x, self.drop_prob) + + def drop_path(self, x, drop_prob): + if drop_prob == 0. or not self.training: + return x + keep_prob = 1 - drop_prob + shape = [1 for _ in range(x.ndim)] + shape[self.batch_dim] = x.shape[self.batch_dim] + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropOutLogit(nn.Module): + def __init__(self, drop_prob=None): + super(DropOutLogit, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return self.drop_logit(x, self.drop_prob) + + def drop_logit(self, x, drop_prob): + if drop_prob == 0. or not self.training: + return x + random_tensor = drop_prob + torch.rand( + x.shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + mask = random_tensor * 1e+8 if ( + x.dtype == torch.float32) else random_tensor * 1e+4 + output = x - mask + return output diff --git a/model/cbam.py b/model/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..6423358429e2843b1f36ceb2bc1a485ea72b8eb4 --- /dev/null +++ b/model/cbam.py @@ -0,0 +1,77 @@ +# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + + def forward(self, x): + x = self.conv(x) + return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out diff --git a/model/group_modules.py b/model/group_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cd22bedb429aaed6adf37e5455e0732af7e52dab --- /dev/null +++ b/model/group_modules.py @@ -0,0 +1,93 @@ +""" +Group-specific modules +They handle features that also depends on the mask. +Features are typically of shape + batch_size * num_objects * num_channels * H * W + +All of them are permutation equivariant w.r.t. to the num_objects dimension +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def interpolate_groups(g, ratio, mode, align_corners): + if len(g.shape) == 4: + g = F.interpolate(g, scale_factor=ratio, mode=mode, align_corners=align_corners) + elif len(g.shape) == 5: + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, mode=mode, align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + +def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False): + return interpolate_groups(g, ratio, mode, align_corners) + +def downsample_groups(g, ratio=1/2, mode='area', align_corners=None): + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2D(nn.Conv2d): + def forward(self, g): + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + if in_dim == out_dim: + self.downsample = None + else: + self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) + + self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g): + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + if self.downsample is not None: + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, x_transform=None, method='cat', reverse_order=False): + super().__init__() + + self.x_transform = x_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x, g): + num_objects = g.shape[1] + + while 0: print(num_objects, g.size()) + # 3 torch.Size([8, 3, 2, 384, 384]) + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2) + else: + # print('2', g.size(), x.size(), x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1).size()) + # torch.Size([8, 2, 2, 224, 224]) torch.Size([8, 3, 224, 224]) torch.Size([8, 2, 3, 224, 224]) + # torch.Size([1, 1, 2, 480, 864]) torch.Size([1, 3, 480, 864]) torch.Size([1, 1, 3, 480, 864]) + g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2) + elif self.method == 'add': + # print(g.size(), x.size(), x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1).size()) + # torch.Size([8, 2, 512, 16, 16]) torch.Size([8, 512, 16, 16]) torch.Size([8, 2, 512, 16, 16]) + g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g + else: + raise NotImplementedError + + return g diff --git a/model/losses.py b/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5d901cb95f12b90699b2a5d7c081e18eab7a1a --- /dev/null +++ b/model/losses.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import defaultdict + + +def dice_loss(input_mask, cls_gt): + num_objects = input_mask.shape[1] + losses = [] + for i in range(num_objects): + mask = input_mask[:,i].flatten(start_dim=1) + # background not in mask, so we add one to cls_gt + gt = (cls_gt==(i+1)).float().flatten(start_dim=1) + numerator = 2 * (mask * gt).sum(-1) + denominator = mask.sum(-1) + gt.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + losses.append(loss) + return torch.cat(losses).mean() + +def l1_loss(input, target): + out = torch.abs(input - target) + return out.mean() + + +# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch +class BootstrappedCE(nn.Module): + def __init__(self, start_warm, end_warm, top_p=0.15): + super().__init__() + + self.start_warm = start_warm + self.end_warm = end_warm + self.top_p = top_p + + def forward(self, input, target, it): + if it < self.start_warm: + + return F.cross_entropy(input, target), 1.0 + + raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) + num_pixels = raw_loss.numel() + + if it > self.end_warm: + this_p = self.top_p + else: + this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) + loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) + return loss.mean(), this_p + + +class LossComputer: + def __init__(self, config): + super().__init__() + self.config = config + self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) + + def compute(self, data, num_objects, it): + losses = defaultdict(int) + + b, t = data['rgb'].shape[:2] + + losses['total_loss'] = 0 + for ti in range(1, t): + for bi in range(b): + loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) + losses['p'] += p / b / (t-1) + losses[f'ce_loss_{ti}'] += loss / b + + losses['total_loss'] += losses['ce_loss_%d'%ti] + losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) + losses['total_loss'] += losses[f'dice_loss_{ti}'] + + return losses + + + def compute_l1loss(self, data, num_objects, it): + losses = defaultdict(int) + + b, t = data['rgb'].shape[:2] + + losses['total_loss'] = 0 + for ti in range(1, t): + for bi in range(b): + losses['p'] = 0 + losses[f'ce_loss_{ti}'] = 0 + + losses['total_loss'] += losses['ce_loss_%d'%ti] + losses[f'dice_loss_{ti}'] = l1_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti]) + losses['total_loss'] += losses[f'dice_loss_{ti}'] + + return losses diff --git a/model/memory_util.py b/model/memory_util.py new file mode 100644 index 0000000000000000000000000000000000000000..faf6197b8c4ea990317476e2e3aeb8952a78aedf --- /dev/null +++ b/model/memory_util.py @@ -0,0 +1,80 @@ +import math +import numpy as np +import torch +from typing import Optional + + +def get_similarity(mk, ms, qk, qe): + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + CK = mk.shape[1] + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + if qe is not None: + # See appendix for derivation + # or you can just trust me ヽ(ー_ー )ノ + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq+two_ab-b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq+two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + +def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False): + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + +def get_affinity(mk, ms, qk, qe): + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe) + affinity = do_softmax(similarity) + return affinity + +def readout(affinity, mv): + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T*H*W) + mem = torch.bmm(mo, affinity) + mem = mem.view(B, CV, H, W) + + return mem diff --git a/model/modules.py b/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cea6986e89e05c8e701753800891b79ff327ff73 --- /dev/null +++ b/model/modules.py @@ -0,0 +1,271 @@ +""" +modules.py - This file stores the rather boring network blocks. + +x - usually means features that only depends on the image +g - usually means features that also depends on the mask. + They might have an extra "group" or "num_objects" dimension, hence + batch_size * num_objects * num_channels * H * W + +The trailing number of a variable usually denote the stride + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.group_modules import * +from model import resnet +from model.cbam import CBAM + + +class FeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim): + super().__init__() + + self.distributor = MainToGroupDistributor() + self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim) + self.attention = CBAM(g_mid_dim) + self.block2 = GroupResBlock(g_mid_dim, g_out_dim) + + def forward(self, x, g): + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + g = self.block1(g) + r = self.attention(g.flatten(start_dim=0, end_dim=1)) + r = r.view(batch_size, num_objects, *r.shape[1:]) + + g = self.block2(g+r) + + return g + + +class HiddenUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims, mid_dim, hidden_dim, ratio=1/2): + super().__init__() + self.hidden_dim = hidden_dim + self.ratio = ratio + + self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g, h): + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + # g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], self.ratio)) + \ + # self.g4_conv(downsample_groups(g[2], ratio=self.ratio)) + + g = torch.cat([g, h], 2) + + # defined slightly differently than standard GRU, + # namely the new value is generated before the forget gate. + # might provide better gradient but frankly it was initially just an + # implementation error that I never bothered fixing + values = self.transform(g) + forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) + update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) + new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) + new_h = forget_gate*h*(1-update_gate) + update_gate*new_value + + return new_h + + +class HiddenReinforcer(nn.Module): + # Used in the value encoder, a single GRU + def __init__(self, g_dim, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g, h): + g = torch.cat([g, h], 2) + + # defined slightly differently than standard GRU, + # namely the new value is generated before the forget gate. + # might provide better gradient but frankly it was initially just an + # implementation error that I never bothered fixing + values = self.transform(g) + forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) + update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) + new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) + new_h = forget_gate*h*(1-update_gate) + update_gate*new_value + + return new_h + + +class ValueEncoder(nn.Module): + def __init__(self, value_dim, hidden_dim, single_object=False): + super().__init__() + + self.single_object = single_object + network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2) + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu # 1/2, 64 + self.maxpool = network.maxpool + + self.layer1 = network.layer1 # 1/4, 64 + self.layer2 = network.layer2 # 1/8, 128 + self.layer3 = network.layer3 # 1/16, 256 + + self.distributor = MainToGroupDistributor() + self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim) # (1024 256) -> (384 256) -> (384 256) + if hidden_dim > 0: + self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim) + else: + self.hidden_reinforce = None + + def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): + # image_feat_f16 is the feature from the key encoder + if not self.single_object: + g = torch.stack([masks, others], 2) + else: + g = masks.unsqueeze(2) + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + g = g.flatten(start_dim=0, end_dim=1) + + g = self.conv1(g) + g = self.bn1(g) # 1/2, 64 + g = self.maxpool(g) # 1/4, 64 + g = self.relu(g) + + g = self.layer1(g) # 1/4 + g = self.layer2(g) # 1/8 + g = self.layer3(g) # 1/16 + + # handle dim problem raised by vit + g = F.interpolate(g, image_feat_f16.shape[2:], mode='bilinear', align_corners=False) + + g = g.view(batch_size, num_objects, *g.shape[1:]) + g = self.fuser(image_feat_f16, g) + + if is_deep_update and self.hidden_reinforce is not None: + h = self.hidden_reinforce(g, h) + + return g, h + +class KeyEncoder_DINOv2_v6(nn.Module): + def __init__(self): + super().__init__() + network = resnet.resnet50(pretrained=True) + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu # 1/2, 64 + self.maxpool = network.maxpool + + self.res2 = network.layer1 # 1/4, 256 + self.layer2 = network.layer2 # 1/8, 512 + self.layer3 = network.layer3 # 1/16, 1024 + + self.network2 = resnet.Segmentor() + + self.fuse1 = resnet.Fuse(384 * 4, 1024) # n = [8, 9, 10, 11] + self.fuse2 = resnet.Fuse(384 * 4, 512) + self.fuse3 = resnet.Fuse(384 * 4, 256) + + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') + self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear') + + def forward(self, f): + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + f4 = self.res2(x) # 1/4, 256 + f8 = self.layer2(f4) # 1/8, 512 + f16 = self.layer3(f8) # 1/16, 1024 + + f16_dino = self.network2(f) # 1/14, 384 -> interp to 1/16 + + g16 = self.fuse1(f16_dino, f16) + g8 = self.fuse2(self.upsample2(f16_dino), f8) + g4 = self.fuse3(self.upsample4(f16_dino), f4) + + return g16, g8, g4 + +class UpsampleBlock(nn.Module): + def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2): + super().__init__() + self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1) + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(g_up_dim, g_out_dim) + self.scale_factor = scale_factor + + def forward(self, skip_f, up_g): + skip_f = self.skip_conv(skip_f) + g = upsample_groups(up_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class KeyProjection(nn.Module): + def __init__(self, in_dim, keydim): + super().__init__() + + self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x, need_s, need_e): + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class Decoder(nn.Module): + def __init__(self, val_dim, hidden_dim): + super().__init__() + + self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512) # (1024, val_dim+hidden_dim, 512, 512) -> (384, val_dim+hidden_dim, 512, 512) -> (1536, val_dim+hidden_dim, 512, 512) + if hidden_dim > 0: + self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim, 1) + else: + self.hidden_update = None + + self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 + self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 + + self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1) + + def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True): + batch_size, num_objects = memory_readout.shape[:2] + + if self.hidden_update is not None: + g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2)) + else: + g16 = self.fuser(f16, memory_readout) + + g8 = self.up_16_8(f8, g16) + + g4 = self.up_8_4(f4, g8) + + logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1))) + + if h_out and self.hidden_update is not None: + g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2) + hidden_state = self.hidden_update([g16, g8, g4], hidden_state) + else: + hidden_state = None + + logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return hidden_state, logits diff --git a/model/network.py b/model/network.py new file mode 100644 index 0000000000000000000000000000000000000000..8994c43681ba624b82a445af8614cbb642adbcae --- /dev/null +++ b/model/network.py @@ -0,0 +1,225 @@ +""" +This file defines XMem, the highest level nn.Module interface +During training, it is used by trainer.py +During evaluation, it is used by inference_core.py + +It further depends on modules.py which gives more detailed implementations of sub-modules +""" + +import torch +import torch.nn as nn + +from model.aggregate import aggregate +from model.modules import * +from model.memory_util import * + +from model.attention import LocalGatedPropagation + +class ColorMNet(nn.Module): + def __init__(self, config, model_path=None, map_location=None): + """ + model_path/map_location are used in evaluation only + map_location is for converting models saved in cuda to cpu + """ + super().__init__() + model_weights = self.init_hyperparameters(config, model_path, map_location) + + self.single_object = config.get('single_object', False) + print(f'Single object mode: {self.single_object}') + + self.key_encoder = KeyEncoder_DINOv2_v6() + + self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object) + + # Projection from f16 feature space to key/value space + self.key_proj = KeyProjection(1024, self.key_dim) # 1024 -> 384 -> 3072 + + self.short_term_attn = LocalGatedPropagation(d_qk=64, # 256 + d_vu=512 * 2, + num_head=1, + dilation=1, + use_linear=False, + dropout=0, + d_att=64, # 128 + max_dis=7, + expand_ratio=1) + + self.decoder = Decoder(self.value_dim, self.hidden_dim) + + if model_weights is not None: + self.load_weights(model_weights, init_as_zero_if_needed=True) + + def encode_key(self, frame, need_sk=True, need_ek=True): + # Determine input shape + if len(frame.shape) == 5: + # shape is b*t*c*h*w + need_reshape = True + b, t = frame.shape[:2] + # flatten so that we can feed them into a 2D CNN + frame = frame.flatten(start_dim=0, end_dim=1) + elif len(frame.shape) == 4: + # shape is b*c*h*w + need_reshape = False + else: + raise NotImplementedError + + f16, f8, f4 = self.key_encoder(frame) + key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek) + + if need_reshape: + # B*C*T*H*W + key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous() + if shrinkage is not None: + shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous() + if selection is not None: + selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous() + + # B*T*C*H*W + f16 = f16.view(b, t, *f16.shape[-3:]) + f8 = f8.view(b, t, *f8.shape[-3:]) + f4 = f4.view(b, t, *f4.shape[-3:]) + + return key, shrinkage, selection, f16, f8, f4 + + def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True): + num_objects = masks.shape[1] + if num_objects != 1: + others = torch.cat([ + torch.sum( + masks[:, [j for j in range(num_objects) if i!=j]] + , dim=1, keepdim=True) + for i in range(num_objects)], 1) + else: + others = torch.zeros_like(masks) + + g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update) + + return g16, h16 + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key, query_selection, memory_key, + memory_shrinkage, memory_value): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + memory_value : B * num_objects * CV * T * H * W + """ + batch_size, num_objects = memory_value.shape[:2] + memory_value = memory_value.flatten(start_dim=1, end_dim=2) + + affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection) + memory = readout(affinity, memory_value) + memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:]) + + return memory + + def read_memory_short(self, query_key, memory_key, memory_value): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + memory_value : B * num_objects * CV * T * H * W + """ + batch_size, num_objects = memory_value.shape[:2] + memory_value = memory_value.flatten(start_dim=1, end_dim=2) + + size_2d = query_key.shape[-2:] + memory_value_short, _ = self.short_term_attn(query_key, memory_key, memory_value, None, size_2d) + + memory_value_short = memory_value_short.permute(1, 2, 0).view(batch_size, num_objects, self.value_dim, *memory_value.shape[-2:]) + + return memory_value_short + + def segment(self, multi_scale_features, memory_readout, + hidden_state, selector=None, h_out=True, strip_bg=True): + + hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out) + + prob = torch.tanh(logits) + logits = prob + + return hidden_state, logits, prob + + def forward(self, mode, *args, **kwargs): + if mode == 'encode_key': + return self.encode_key(*args, **kwargs) + elif mode == 'encode_value': + return self.encode_value(*args, **kwargs) + elif mode == 'read_memory': + return self.read_memory(*args, **kwargs) + elif mode == 'read_memory_short': + return self.read_memory_short(*args, **kwargs) + elif mode == 'segment': + return self.segment(*args, **kwargs) + else: + raise NotImplementedError + + def init_hyperparameters(self, config, model_path=None, map_location=None): + """ + Init three hyperparameters: key_dim, value_dim, and hidden_dim + If model_path is provided, we load these from the model weights + The actual parameters are then updated to the config in-place + + Otherwise we load it either from the config or default + """ + if model_path is not None: + # load the model and key/value/hidden dimensions with some hacks + # config is updated with the loaded parameters + model_weights = torch.load(model_path, map_location=map_location) + self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0] + self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0] + self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights + if self.disable_hidden: + self.hidden_dim = 0 + else: + self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3 + print(f'Hyperparameters read from the model weights: ' + f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}') + else: + model_weights = None + # load dimensions from config or default + if 'key_dim' not in config: + self.key_dim = 64 + print(f'key_dim not found in config. Set to default {self.key_dim}') + else: + self.key_dim = config['key_dim'] + + if 'value_dim' not in config: + self.value_dim = 512 + print(f'value_dim not found in config. Set to default {self.value_dim}') + else: + self.value_dim = config['value_dim'] + + if 'hidden_dim' not in config: + self.hidden_dim = 64 + print(f'hidden_dim not found in config. Set to default {self.hidden_dim}') + else: + self.hidden_dim = config['hidden_dim'] + + self.disable_hidden = (self.hidden_dim <= 0) + + config['key_dim'] = self.key_dim + config['value_dim'] = self.value_dim + config['hidden_dim'] = self.hidden_dim + + return model_weights + + def load_weights(self, src_dict, init_as_zero_if_needed=False): + # Maps SO weight (without other_mask) to MO weight (with other_mask) + for k in list(src_dict.keys()): + if k == 'value_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + print('Converting weights from single object to multiple objects.') + pads = torch.zeros((64,1,7,7), device=src_dict[k].device) + if not init_as_zero_if_needed: + print('Randomly initialized padding.') + nn.init.orthogonal_(pads) + else: + print('Zero-initialized padding.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + + self.load_state_dict(src_dict) diff --git a/model/resnet.py b/model/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ea3acb36a9879fffe55c3a78b5ce4d23704c5f --- /dev/null +++ b/model/resnet.py @@ -0,0 +1,399 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + +from torch.hub import load +import torchvision.models as models +import warnings +warnings.filterwarnings("ignore") +import torch.nn.functional as F + +from einops import rearrange + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, + padding=dilation, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model + +dino_backbones = { + 'dinov2_s':{ + 'name':'dinov2_vits14', + 'embedding_size':384, + 'patch_size':14 + }, + 'dinov2_b':{ + 'name':'dinov2_vitb14', + 'embedding_size':768, + 'patch_size':14 + }, + 'dinov2_l':{ + 'name':'dinov2_vitl14', + 'embedding_size':1024, + 'patch_size':14 + }, + 'dinov2_g':{ + 'name':'dinov2_vitg14', + 'embedding_size':1536, + 'patch_size':14 + }, +} + +class conv_head(nn.Module): + def __init__(self, embedding_size = 384, num_classes = 5): + super(conv_head, self).__init__() + self.segmentation_conv = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(embedding_size, 64, (3,3), padding=(1,1)), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, num_classes, (3,3), padding=(1,1)), + ) + + def forward(self, x): + x = self.segmentation_conv(x) + x = torch.sigmoid(x) + return x + +class Segmentor(nn.Module): + def __init__(self, num_classes=5, backbone = 'dinov2_s', head = 'conv', backbones = dino_backbones): + super(Segmentor, self).__init__() + self.heads = { + 'conv':conv_head + } + # internet + self.backbones = dino_backbones + self.backbone = load('facebookresearch/dinov2', self.backbones[backbone]['name']) # add trust_repo to + self.backbone.eval() + + # # local + # self.backbones = dino_backbones + # self.backbone = load('/root/.cache/torch/hub/facebookresearch_dinov2_main', self.backbones[backbone]['name'], source='local', pretrained=False) # add trust_repo to + # self.backbone.load_state_dict(torch.load('/root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth')) + # self.backbone.eval() + + self.conv3 = nn.Conv2d(1536, 1536, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(1536) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + with torch.no_grad(): + tokens = self.backbone.get_intermediate_layers(x, n=[8, 9, 10, 11], reshape=True) # last n=4 [8, 9, 10, 11] + + f16 = torch.cat(tokens, dim=1) + + f16 = self.conv3(f16) + f16 = self.bn3(f16) + f16 = self.relu(f16) + + old_size = (f16.shape[2], f16.shape[3]) + new_size = (int(old_size[0]*14/16), int(old_size[1]*14/16)) + f16 = F.interpolate(f16, size=new_size, mode='bilinear', align_corners=False) # scale_factor=3.5 + + return f16 + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +class CrossChannelAttention(nn.Module): + def __init__(self, dim, heads=8): + super().__init__() + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.heads = heads + + self.to_q = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True) + self.to_q_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True) + + self.to_k = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True) + self.to_k_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True) + + self.to_v = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True) + self.to_v_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True) + + self.to_out = nn.Sequential( + nn.Conv2d(dim*2, dim,1,1,0), + ) + + def forward(self, encoder, decoder): + # h = self.heads + b, c, h, w = encoder.shape + + q = self.to_q_dw(self.to_q(encoder)) + + k = self.to_k_dw(self.to_k(decoder)) + v = self.to_v_dw(self.to_v(decoder)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.heads, h=h, w=w) + + return self.to_out(out) + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + +@torch.jit.script +def swish(x): + return x * torch.sigmoid(x) + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + + x = swish(x) + # x = x * torch.sigmoid(x) + + x = self.conv1(x) + x = self.norm2(x) + + x = swish(x) + # x = x * torch.sigmoid(x) + + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + +class Fuse(nn.Module): + def __init__(self, dine_feat, out_feat): + # need to key same channel and HW for enc / dnc + super(Fuse, self).__init__() + + self.encode_enc = nn.Conv2d(dine_feat, out_feat, kernel_size=3, stride=1, padding=1) + + self.dim = out_feat + self.norm1 = LayerNorm2d(self.dim) + self.norm2 = LayerNorm2d(self.dim) + + self.dine_feat = dine_feat + self.out_feat = out_feat + self.crossattn = CrossChannelAttention(dim=out_feat) + + self.norm3 = LayerNorm2d(self.dim) + self.relu3 = nn.ReLU(inplace=True) + + def forward(self, enc, dnc): + enc = self.encode_enc(enc) + + res = enc + enc = self.norm1(enc) + dnc = self.norm2(dnc) + output = self.crossattn(enc, dnc) + res + + output = self.norm3(output) + output = self.relu3(output) + + return output \ No newline at end of file diff --git a/ref/blackswan/00000.png b/ref/blackswan/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..3daeddee9b13c2d98cae65bcd584e67a7a5aa4ad Binary files /dev/null and b/ref/blackswan/00000.png differ diff --git a/saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth b/saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth new file mode 100644 index 0000000000000000000000000000000000000000..265a8b64426936a298bc12e216aaed3a0393c101 --- /dev/null +++ b/saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaf6301d1a088c0d7133008079a83b5fac1fc0f791061b2cf1b657602013457a +size 494884817 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/download_bl30k.py b/scripts/download_bl30k.py new file mode 100644 index 0000000000000000000000000000000000000000..501a7e2ffa4b7913df7fe8f3c865d17e5b1fa371 --- /dev/null +++ b/scripts/download_bl30k.py @@ -0,0 +1,50 @@ +import os +import gdown +import tarfile + + +LICENSE = """ +This dataset is a derivative of ShapeNet. +Please read and respect their licenses and terms before use. +Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag. +Do not use this dataset for commercial purposes. +You should cite both ShapeNet and our paper if you use this dataset. +""" + +print(LICENSE) +print('Datasets will be downloaded and extracted to ../BL30K') +print('The script will download and extract the segment one by one') +print('You are going to need ~1TB of free disk space') +reply = input('[y] to confirm, others to exit: ') +if reply != 'y': + exit() + +links = [ + 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc', + 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ', + 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q', + 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-', + 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE', + 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r', +] + +names = [ + 'BL30K_a.tar', + 'BL30K_b.tar', + 'BL30K_c.tar', + 'BL30K_d.tar', + 'BL30K_e.tar', + 'BL30K_f.tar', +] + +for i, link in enumerate(links): + print('Downloading segment %d/%d ...' % (i, len(links))) + gdown.download(link, output='../%s' % names[i], quiet=False) + print('Extracting...') + with tarfile.open('../%s' % names[i], 'r') as tar_file: + tar_file.extractall('../%s' % names[i]) + print('Cleaning up...') + os.remove('../%s' % names[i]) + + +print('Done.') \ No newline at end of file diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7537aead493306d869fc7f0605cfaef994c68673 --- /dev/null +++ b/scripts/download_datasets.py @@ -0,0 +1,149 @@ +import os +import gdown +import zipfile +from scripts import resize_youtube + + +LICENSE = """ +These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets. +Please read and respect their licenses and terms before use. +You should cite the original papers if you use any of the datasets. + +For BL30K, see download_bl30k.py + +Links: +DUTS: http://saliencydetection.net/duts +HRSOD: https://github.com/yi94code/HRSOD +FSS: https://github.com/HKUSTCV/FSS-1000 +ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html +BIG: https://github.com/hkchengrex/CascadePSP + +YouTubeVOS: https://youtube-vos.org +DAVIS: https://davischallenge.org/ +BL30K: https://github.com/hkchengrex/MiVOS +Long-Time Video: https://github.com/xmlyqing00/AFB-URR +""" + +print(LICENSE) +print('Datasets will be downloaded and extracted to ../YouTube, ../YouTube2018, ../static, ../DAVIS, ../long_video_set') +reply = input('[y] to confirm, others to exit: ') +if reply != 'y': + exit() + + +""" +Static image data +""" +os.makedirs('../static', exist_ok=True) +print('Downloading static datasets...') +gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', output='../static/static_data.zip', quiet=False) +print('Extracting static datasets...') +with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file: + zip_file.extractall('../static/') +print('Cleaning up static datasets...') +os.remove('../static/static_data.zip') + + +""" +DAVIS dataset +""" +# Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing +os.makedirs('../DAVIS/2017', exist_ok=True) + +print('Downloading DAVIS 2016...') +gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', output='../DAVIS/DAVIS-data.zip', quiet=False) + +print('Downloading DAVIS 2017 trainval...') +gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', quiet=False) + +print('Downloading DAVIS 2017 testdev...') +gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', quiet=False) + +print('Downloading DAVIS 2017 scribbles...') +gdown.download('https://drive.google.com/uc?id=1JzIQSu36h7dVM8q0VoE4oZJwBXvrZlkl', output='../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', quiet=False) + +print('Extracting DAVIS datasets...') +with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file: + zip_file.extractall('../DAVIS/') +os.rename('../DAVIS/DAVIS', '../DAVIS/2016') + +with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file: + zip_file.extractall('../DAVIS/2017/') +with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', 'r') as zip_file: + zip_file.extractall('../DAVIS/2017/') +os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval') + +with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file: + zip_file.extractall('../DAVIS/2017/') +os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev') + +print('Cleaning up DAVIS datasets...') +os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip') +os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip') +os.remove('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip') +os.remove('../DAVIS/DAVIS-data.zip') + + +""" +YouTubeVOS dataset +""" +os.makedirs('../YouTube', exist_ok=True) +os.makedirs('../YouTube/all_frames', exist_ok=True) + +print('Downloading YouTubeVOS train...') +gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', output='../YouTube/train.zip', quiet=False) +print('Downloading YouTubeVOS val...') +gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', output='../YouTube/valid.zip', quiet=False) +print('Downloading YouTubeVOS all frames valid...') +gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', output='../YouTube/all_frames/valid.zip', quiet=False) + +print('Extracting YouTube datasets...') +with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file: + zip_file.extractall('../YouTube/') +with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file: + zip_file.extractall('../YouTube/') +with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file: + zip_file.extractall('../YouTube/all_frames') + +print('Cleaning up YouTubeVOS datasets...') +os.remove('../YouTube/train.zip') +os.remove('../YouTube/valid.zip') +os.remove('../YouTube/all_frames/valid.zip') + +print('Resizing YouTubeVOS to 480p...') +resize_youtube.resize_all('../YouTube/train', '../YouTube/train_480p') + +# YouTubeVOS 2018 +os.makedirs('../YouTube2018', exist_ok=True) +os.makedirs('../YouTube2018/all_frames', exist_ok=True) + +print('Downloading YouTubeVOS2018 val...') +gdown.download('https://drive.google.com/uc?id=1-QrceIl5sUNTKz7Iq0UsWC6NLZq7girr', output='../YouTube2018/valid.zip', quiet=False) +print('Downloading YouTubeVOS2018 all frames valid...') +gdown.download('https://drive.google.com/uc?id=1yVoHM6zgdcL348cFpolFcEl4IC1gorbV', output='../YouTube2018/all_frames/valid.zip', quiet=False) + +print('Extracting YouTube2018 datasets...') +with zipfile.ZipFile('../YouTube2018/valid.zip', 'r') as zip_file: + zip_file.extractall('../YouTube2018/') +with zipfile.ZipFile('../YouTube2018/all_frames/valid.zip', 'r') as zip_file: + zip_file.extractall('../YouTube2018/all_frames') + +print('Cleaning up YouTubeVOS2018 datasets...') +os.remove('../YouTube2018/valid.zip') +os.remove('../YouTube2018/all_frames/valid.zip') + + +""" +Long-Time Video dataset +""" +os.makedirs('../long_video_set', exist_ok=True) +print('Downloading long video dataset...') +gdown.download('https://drive.google.com/uc?id=100MxAuV0_UL20ca5c-5CNpqQ5QYPDSoz', output='../long_video_set/LongTimeVideo.zip', quiet=False) +print('Extracting long video dataset...') +with zipfile.ZipFile('../long_video_set/LongTimeVideo.zip', 'r') as zip_file: + zip_file.extractall('../long_video_set/') +print('Cleaning up long video dataset...') +os.remove('../long_video_set/LongTimeVideo.zip') + + +print('Done.') \ No newline at end of file diff --git a/scripts/download_models.sh b/scripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba669b54cec4ef7344326fc8b3e340501ca1bc90 --- /dev/null +++ b/scripts/download_models.sh @@ -0,0 +1,2 @@ +wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth +wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth \ No newline at end of file diff --git a/scripts/download_models_demo.sh b/scripts/download_models_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..a63f700c6d3daf5e97dad00dfcacb8e029bbce12 --- /dev/null +++ b/scripts/download_models_demo.sh @@ -0,0 +1,3 @@ +wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth +wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/fbrs.pth +wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/s2m.pth \ No newline at end of file diff --git a/scripts/expand_long_vid.py b/scripts/expand_long_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..ae237bc39d04f520b706769ac74139631cc88012 --- /dev/null +++ b/scripts/expand_long_vid.py @@ -0,0 +1,36 @@ +import sys +import os +from os import path +from shutil import copy2 + +input_path = sys.argv[1] +output_path = sys.argv[2] +multiplier = int(sys.argv[3]) +image_path = path.join(input_path, 'JPEGImages') +gt_path = path.join(input_path, 'Annotations') + +videos = sorted(os.listdir(image_path)) + +for vid in videos: + os.makedirs(path.join(output_path, 'JPEGImages', vid), exist_ok=True) + os.makedirs(path.join(output_path, 'Annotations', vid), exist_ok=True) + frames = sorted(os.listdir(path.join(image_path, vid))) + + num_frames = len(frames) + counter = 0 + output_counter = 0 + direction = 1 + for _ in range(multiplier): + for _ in range(num_frames): + copy2(path.join(image_path, vid, frames[counter]), + path.join(output_path, 'JPEGImages', vid, f'{output_counter:05d}.jpg')) + + mask_path = path.join(gt_path, vid, frames[counter].replace('.jpg', '.png')) + if path.exists(mask_path): + copy2(mask_path, + path.join(output_path, 'Annotations', vid, f'{output_counter:05d}.png')) + + counter += direction + output_counter += 1 + if counter == 0 or counter == len(frames) - 1: + direction *= -1 diff --git a/scripts/resize_youtube.py b/scripts/resize_youtube.py new file mode 100644 index 0000000000000000000000000000000000000000..501e9f54cde1a44ff11f2d7e640dd7f0748ce78e --- /dev/null +++ b/scripts/resize_youtube.py @@ -0,0 +1,77 @@ +import sys +import os +from os import path + +from PIL import Image +import numpy as np +from progressbar import progressbar +from multiprocessing import Pool + +new_min_size = 480 + +def resize_vid_jpeg(inputs): + vid_name, folder_path, out_path = inputs + + vid_path = path.join(folder_path, vid_name) + vid_out_path = path.join(out_path, 'JPEGImages', vid_name) + os.makedirs(vid_out_path, exist_ok=True) + + for im_name in os.listdir(vid_path): + hr_im = Image.open(path.join(vid_path, im_name)) + w, h = hr_im.size + + ratio = new_min_size / min(w, h) + + lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) + lr_im.save(path.join(vid_out_path, im_name)) + +def resize_vid_anno(inputs): + vid_name, folder_path, out_path = inputs + + vid_path = path.join(folder_path, vid_name) + vid_out_path = path.join(out_path, 'Annotations', vid_name) + os.makedirs(vid_out_path, exist_ok=True) + + for im_name in os.listdir(vid_path): + hr_im = Image.open(path.join(vid_path, im_name)).convert('P') + w, h = hr_im.size + + ratio = new_min_size / min(w, h) + + lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST) + lr_im.save(path.join(vid_out_path, im_name)) + + +def resize_all(in_path, out_path): + for folder in os.listdir(in_path): + + if folder not in ['JPEGImages', 'Annotations']: + continue + folder_path = path.join(in_path, folder) + videos = os.listdir(folder_path) + + videos = [(v, folder_path, out_path) for v in videos] + + if folder == 'JPEGImages': + print('Processing images') + os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True) + + pool = Pool(processes=8) + for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)): + pass + else: + print('Processing annotations') + os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True) + + pool = Pool(processes=8) + for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)): + pass + + +if __name__ == '__main__': + in_path = sys.argv[1] + out_path = sys.argv[2] + + resize_all(in_path, out_path) + + print('Done.') \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..adfad1999984c9292daa78fb99bd1e865afff348 --- /dev/null +++ b/test.py @@ -0,0 +1,230 @@ +import os +from os import path +from argparse import ArgumentParser +import shutil + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from PIL import Image + +from inference.data.test_datasets import DAVISTestDataset_221128_TransColorization_batch +from inference.data.mask_mapper import MaskMapper +from model.network import ColorMNet +from inference.inference_core import InferenceCore + +from progressbar import progressbar + +from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans + +from skimage import color, io +import cv2 + +try: + import hickle as hkl +except ImportError: + print('Failed to import hickle. Fine if not using multi-scale testing.') + + +""" +Arguments loading +""" +parser = ArgumentParser() +parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth') + +# dataset setting +parser.add_argument('--d16_batch_path', default='input') +parser.add_argument('--deoldify_path', default='ref') +parser.add_argument('--output', default='result') + +# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations" +parser.add_argument('--generic_path') +parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch') +parser.add_argument('--split', help='val/test', default='val') +parser.add_argument('--save_all', action='store_true', + help='Save all frames. Useful only in YouTubeVOS/long-time video', ) +parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking') + +# Long-term memory options +parser.add_argument('--disable_long_term', action='store_true') +parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) +parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) +parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', + type=int, default=10000) +parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) + +parser.add_argument('--top_k', type=int, default=30) +parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5) +parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) + +# Multi-scale options +parser.add_argument('--save_scores', action='store_true') +parser.add_argument('--flip', action='store_true') +parser.add_argument('--size', default=-1, type=int, + help='Resize the shorter side to this size. -1 to use original resolution. ') + +args = parser.parse_args() +config = vars(args) +config['enable_long_term'] = not config['disable_long_term'] + +def detach_to_cpu(x): + return x.detach().cpu() + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + +def lab2rgb_transform_PIL(mask): + mask_d = detach_to_cpu(mask) + mask_d = inv_lll2rgb_trans(mask_d) + im = tensor_to_np_float(mask_d) + + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + im = color.lab2rgb(im) + + return im.clip(0, 1) + +if args.output is None: + args.output = f'.output/{args.dataset}_{args.split}' + print(f'Output path not provided. Defaulting to {args.output}') + +""" +Data preparation +""" +is_youtube = args.dataset.startswith('Y') +is_davis = args.dataset.startswith('D') +is_lv = args.dataset.startswith('LV') + +if is_youtube or args.save_scores: + out_path = path.join(args.output, 'Annotations') +else: + out_path = args.output + +if args.split == 'val': + # Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format + meta_dataset = DAVISTestDataset_221128_TransColorization_batch(args.d16_batch_path, imset=args.deoldify_path, size=args.size) +else: + raise NotImplementedError +palette = None + +torch.autograd.set_grad_enabled(False) + +# Set up loader +meta_loader = meta_dataset.get_datasets() + +# Load our checkpoint +network = ColorMNet(config, args.model).cuda().eval() +if args.model is not None: + model_weights = torch.load(args.model) + network.load_weights(model_weights, init_as_zero_if_needed=True) +else: + print('No model loaded.') + +total_process_time = 0 +total_frames = 0 + +# Start eval +for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True): + + loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2) + vid_name = vid_reader.vid_name + vid_length = len(loader) + # no need to count usage for LT if the video is not that long anyway + config['enable_long_term_count_usage'] = ( + config['enable_long_term'] and + (vid_length + / (config['max_mid_term_frames']-config['min_mid_term_frames']) + * config['num_prototypes']) + >= config['max_long_term_elements'] + ) + + mapper = MaskMapper() + processor = InferenceCore(network, config=config) + first_mask_loaded = False + + for ti, data in enumerate(loader): + with torch.cuda.amp.autocast(enabled=not args.benchmark): + rgb = data['rgb'].cuda()[0] + msk = data.get('mask') + info = data['info'] + frame = info['frame'][0] + shape = info['shape'] + need_resize = info['need_resize'][0] + + """ + For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964 + Seems to be very similar in testing as my previous timing method + with two cuda sync + time.time() in STCN though + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + if not first_mask_loaded: + if msk is not None: + first_mask_loaded = True + else: + # no point to do anything without a mask + continue + + if args.flip: + rgb = torch.flip(rgb, dims=[-1]) + msk = torch.flip(msk, dims=[-1]) if msk is not None else None + + # Map possibly non-continuous labels to continuous ones + if msk is not None: + msk = torch.Tensor(msk[0]).cuda() + if need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + processor.set_all_labels(list(range(1,3))) + labels = range(1,3) + else: + labels = None + + # Run the model on this frame + prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1)) + + # Upsample to original size if needed + if need_resize: + prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] + + end.record() + torch.cuda.synchronize() + total_process_time += (start.elapsed_time(end)/1000) + total_frames += 1 + + if args.flip: + prob = torch.flip(prob, dims=[-1]) + + if args.save_scores: + prob = (prob.detach().cpu().numpy()*255).astype(np.uint8) + + # Save the mask + if args.save_all or info['save'][0]: + this_out_path = path.join(out_path, vid_name) + os.makedirs(this_out_path, exist_ok=True) + + out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1,:,:], prob], dim=0)) + out_mask_final = out_mask_final * 255 + out_mask_final = out_mask_final.astype(np.uint8) + + out_img = Image.fromarray(out_mask_final) + out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) + +print(f'Total processing time: {total_process_time}') +print(f'Total processed frames: {total_frames}') +print(f'FPS: {total_frames / total_process_time}') +print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') + +if not args.save_scores: + if is_youtube: + print('Making zip for YouTubeVOS...') + shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations') + elif is_davis and args.split == 'test': + print('Making zip for DAVIS test-dev...') + shutil.make_archive(args.output, 'zip', args.output) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/configuration.py b/util/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f7c8345d285100a84c9a72f84976454f0e4106 --- /dev/null +++ b/util/configuration.py @@ -0,0 +1,143 @@ +from argparse import ArgumentParser + + +def none_or_default(x, default): + return x if x is not None else default + +class Configuration(): + def parse(self, unknown_arg_ok=False): + parser = ArgumentParser() + + # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment + parser.add_argument('--benchmark', action='store_true') + parser.add_argument('--no_amp', action='store_true') + + # Data parameters + parser.add_argument('--static_root', help='Static training data root', default='../static') + parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') + parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') + parser.add_argument('--davis_root', help='DAVIS data root', default='/data2/yangyixin/data/DAVIS') + parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16) + + parser.add_argument('--key_dim', default=64, type=int) + parser.add_argument('--value_dim', default=512, type=int) + parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int) + + parser.add_argument('--deep_update_prob', default=0.2, type=float) + + parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02') + + parser.add_argument('--server', help='Training stage (0-pjs-04, 1-A6000Local)', default='0') + parser.add_argument('--savepath', help='Blender training data root', default='/data1/yangyixin/code/xmem') + + """ + Stage-specific learning parameters + Batch sizes are effective -- you don't have to scale them when you scale the number processes + """ + # Stage 0, static images + parser.add_argument('--s0_batch_size', default=16, type=int) + parser.add_argument('--s0_iterations', default=150000, type=int) + parser.add_argument('--s0_finetune', default=0, type=int) + parser.add_argument('--s0_steps', nargs="*", default=[], type=int) + parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float) + parser.add_argument('--s0_num_ref_frames', default=2, type=int) + parser.add_argument('--s0_num_frames', default=3, type=int) + parser.add_argument('--s0_start_warm', default=20000, type=int) + parser.add_argument('--s0_end_warm', default=70000, type=int) + + # Stage 1, BL30K + parser.add_argument('--s1_batch_size', default=8, type=int) + parser.add_argument('--s1_iterations', default=250000, type=int) + # fine-tune means fewer augmentations to train the sensory memory + parser.add_argument('--s1_finetune', default=0, type=int) + parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int) + parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float) + parser.add_argument('--s1_num_ref_frames', default=3, type=int) + parser.add_argument('--s1_num_frames', default=8, type=int) + parser.add_argument('--s1_start_warm', default=20000, type=int) + parser.add_argument('--s1_end_warm', default=70000, type=int) + + # Stage 2, DAVIS+YoutubeVOS, longer + parser.add_argument('--s2_batch_size', default=2, type=int) + parser.add_argument('--s2_iterations', default=150000, type=int) + # fine-tune means fewer augmentations to train the sensory memory + parser.add_argument('--s2_finetune', default=10000, type=int) + parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int) + + # parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float) + parser.add_argument('--s2_lr', help='Initial learning rate', default=2e-5, type=float) + + parser.add_argument('--s2_num_ref_frames', default=3, type=int) + parser.add_argument('--s2_num_frames', default=8, type=int) + parser.add_argument('--s2_start_warm', default=20000, type=int) + parser.add_argument('--s2_end_warm', default=70000, type=int) + + # Stage 3, DAVIS+YoutubeVOS, shorter + parser.add_argument('--s3_batch_size', default=8, type=int) + parser.add_argument('--s3_iterations', default=100000, type=int) + # fine-tune means fewer augmentations to train the sensory memory + parser.add_argument('--s3_finetune', default=10000, type=int) + parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int) + parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float) + parser.add_argument('--s3_num_ref_frames', default=3, type=int) + parser.add_argument('--s3_num_frames', default=8, type=int) + parser.add_argument('--s3_start_warm', default=20000, type=int) + parser.add_argument('--s3_end_warm', default=70000, type=int) + + parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float) + parser.add_argument('--weight_decay', default=0.05, type=float) + + # Loading + parser.add_argument('--load_network', help='Path to pretrained network weight only') + parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such') + + # Logging information + parser.add_argument('--log_text_interval', default=100, type=int) + + parser.add_argument('--log_image_interval', default=100, type=int) + + parser.add_argument('--save_network_interval', default=2500, type=int) + parser.add_argument('--save_checkpoint_interval', default=999999999999999999999, type=int) + parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL') + parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true') + + # # Multiprocessing parameters, not set by users + # parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process') + + if unknown_arg_ok: + args, _ = parser.parse_known_args() + self.args = vars(args) + else: + self.args = vars(parser.parse_args()) + + self.args['amp'] = not self.args['no_amp'] + + # check if the stages are valid + stage_to_perform = list(self.args['stages']) + for s in stage_to_perform: + if s not in ['0', '1', '2', '3']: + raise NotImplementedError + + def get_stage_parameters(self, stage): + parameters = { + 'batch_size': self.args['s%s_batch_size'%stage], + 'iterations': self.args['s%s_iterations'%stage], + 'finetune': self.args['s%s_finetune'%stage], + 'steps': self.args['s%s_steps'%stage], + 'lr': self.args['s%s_lr'%stage], + 'num_ref_frames': self.args['s%s_num_ref_frames'%stage], + 'num_frames': self.args['s%s_num_frames'%stage], + 'start_warm': self.args['s%s_start_warm'%stage], + 'end_warm': self.args['s%s_end_warm'%stage], + } + + return parameters + + def __getitem__(self, key): + return self.args[key] + + def __setitem__(self, key, value): + self.args[key] = value + + def __str__(self): + return str(self.args) diff --git a/util/davis_subset.txt b/util/davis_subset.txt new file mode 100644 index 0000000000000000000000000000000000000000..875c2409d2cc4cfc4491ebf7703cb432b26678d8 --- /dev/null +++ b/util/davis_subset.txt @@ -0,0 +1,60 @@ +bear +bmx-bumps +boat +boxing-fisheye +breakdance-flare +bus +car-turn +cat-girl +classic-car +color-run +crossing +dance-jump +dancing +disc-jockey +dog-agility +dog-gooses +dogs-scale +drift-turn +drone +elephant +flamingo +hike +hockey +horsejump-low +kid-football +kite-walk +koala +lady-running +lindy-hop +longboard +lucia +mallard-fly +mallard-water +miami-surf +motocross-bumps +motorbike +night-race +paragliding +planes-water +rallye +rhino +rollerblade +schoolgirls +scooter-board +scooter-gray +sheep +skate-park +snowboard +soccerball +stroller +stunt +surf +swing +tennis +tractor-sand +train +tuk-tuk +upside-down +varanus-cage +walking \ No newline at end of file diff --git a/util/functional.py b/util/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0b276eee81513b51baf2d5d9afd0a42366050dc2 --- /dev/null +++ b/util/functional.py @@ -0,0 +1,605 @@ +from __future__ import division + +import math +import random + +import torch +from PIL import Image, ImageEnhance, ImageOps + +try: + import accimage +except ImportError: + accimage = None +import collections +import numbers +import types +import warnings + +import numpy as np + + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError("pic should be PIL Image or ndarray. Got {}".format(type(pic))) + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + return img.float().div(255) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == "I": + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == "I;16": + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == "YCbCr": + nchannel = 3 + elif pic.mode == "I;16": + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def to_mytensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + pic_arr = np.array(pic) + if pic_arr.ndim == 2: + pic_arr = pic_arr[..., np.newaxis] + img = torch.from_numpy(pic_arr.transpose((2, 0, 1))) + if not isinstance(img, torch.FloatTensor): + return img.float() # no normalize .div(255) + else: + return img + + +def to_pil_image(pic, mode=None): + """Convert a tensor or an ndarray to PIL Image. + + See :class:`~torchvision.transforms.ToPIlImage` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + + .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + + Returns: + PIL Image: Image converted to PIL Image. + """ + if not (_is_numpy_image(pic) or _is_tensor_image(pic)): + raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic))) + + npimg = pic + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + + if not isinstance(npimg, np.ndarray): + raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, " + "not {}".format(type(npimg))) + + if npimg.shape[2] == 1: + expected_mode = None + npimg = npimg[:, :, 0] + if npimg.dtype == np.uint8: + expected_mode = "L" + if npimg.dtype == np.int16: + expected_mode = "I;16" + if npimg.dtype == np.int32: + expected_mode = "I" + elif npimg.dtype == np.float32: + expected_mode = "F" + if mode is not None and mode != expected_mode: + raise ValueError( + "Incorrect mode ({}) supplied for input type {}. Should be {}".format(mode, np.dtype, expected_mode) + ) + mode = expected_mode + + elif npimg.shape[2] == 4: + permitted_4_channel_modes = ["RGBA", "CMYK"] + if mode is not None and mode not in permitted_4_channel_modes: + raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) + + if mode is None and npimg.dtype == np.uint8: + mode = "RGBA" + else: + permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"] + if mode is not None and mode not in permitted_3_channel_modes: + raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) + if mode is None and npimg.dtype == np.uint8: + mode = "RGB" + + if mode is None: + raise TypeError("Input type {} is not supported".format(npimg.dtype)) + + return Image.fromarray(npimg, mode=mode) + + +def normalize(tensor, mean, std): + """Normalize a tensor image with mean and standard deviation. + + See ``Normalize`` for more details. + + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channely. + + Returns: + Tensor: Normalized Tensor image. + """ + if not _is_tensor_image(tensor): + raise TypeError("tensor is not a torch image.") + # TODO: make efficient + if tensor.size(0) == 1: + tensor.sub_(mean).div_(std) + else: + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def resize(img, size, interpolation=Image.BILINEAR): + """Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2): + raise TypeError("Got inappropriate size arg: {}".format(size)) + + if not isinstance(size, int): + return img.resize(size[::-1], interpolation) + + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(round(size * h / w)) + else: + oh = size + ow = int(round(size * w / h)) + return img.resize((ow, oh), interpolation) + + +def scale(*args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.") + return resize(*args, **kwargs) + + +def pad(img, padding, fill=0): + """Pad the given PIL Image on all sides with the given "pad" value. + + Args: + img (PIL Image): Image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + + Returns: + PIL Image: Padded image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError( + "Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) + + return ImageOps.expand(img, border=padding, fill=fill) + + +def crop(img, i, j, h, w): + """Crop the given PIL Image. + + Args: + img (PIL Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + return img.crop((j, i, j + w, i + h)) + + +def center_crop(img, output_size): + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(img, i, j, th, tw) + + +def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): + """Crop the given PIL Image and resize it to desired size. + + Notably used in RandomResizedCrop. + + Args: + img (PIL Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``scale``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL Image: Cropped image. + """ + assert _is_pil_image(img), "img should be PIL Image" + img = crop(img, i, j, h, w) + img = resize(img, size, interpolation) + return img + + +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontall flipped image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def five_crop(img, size): + """Crop the given PIL Image into four corners and the central crop. + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + Returns: + tuple: tuple (tl, tr, bl, br, center) corresponding top left, + top right, bottom left, bottom right and center crop. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + w, h = img.size + crop_h, crop_w = size + if crop_w > w or crop_h > h: + raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = center_crop(img, (crop_h, crop_w)) + return (tl, tr, bl, br, center) + + +def ten_crop(img, size, vertical_flip=False): + """Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, + br_flip, center_flip) corresponding top left, top right, + bottom left, bottom right and center crop and same for the + flipped image. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError("hue_factor is not in [-0.5, 0.5].".format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + input_mode = img.mode + if input_mode in {"L", "1", "I", "F"}: + return img + + h, s, v = img.convert("HSV").split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over="ignore"): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, "L") + + img = Image.merge("HSV", (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + input_mode = img.mode + img = img.convert("RGB") + + np_img = np.array(img, dtype=np.float32) + np_img = 255 * gain * ((np_img / 255) ** gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, "RGB").convert(input_mode) + return img + + +def rotate(img, angle, resample=False, expand=False, center=None): + """Rotate the image by angle and then (optionally) translate it by (n_columns, n_rows) + + + Args: + img (PIL Image): PIL Image to be rotated. + angle ({float, int}): In degrees degrees counter clockwise order. + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. + See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + """ + + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + return img.rotate(angle, resample, expand, center) + + +def to_grayscale(img, num_output_channels=1): + """Convert image to grayscale version of image. + + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Grayscale version of the image. + if num_output_channels == 1 : returned image is single channel + if num_output_channels == 3 : returned image is 3 channel with r == g == b + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + if num_output_channels == 1: + img = img.convert("L") + elif num_output_channels == 3: + img = img.convert("L") + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, "RGB") + else: + raise ValueError("num_output_channels should be either 1 or 3") + + return img diff --git a/util/image_saver.py b/util/image_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..619471921c328beac88734bed36c507ee4e4290b --- /dev/null +++ b/util/image_saver.py @@ -0,0 +1,302 @@ +import cv2 +import numpy as np + +import torch +from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans +from collections import defaultdict + +from PIL import Image +from skimage import color, io + +import util.functional as F +class Normalize(object): + def __init__(self): + pass + + def __call__(self, inputs): + inputs[0:1, :, :] = F.normalize(inputs[0:1, :, :], 50, 1) + inputs[1:3, :, :] = F.normalize(inputs[1:3, :, :], (0, 0), (1, 1)) + return inputs + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def transpose_np(x): + return np.transpose(x, [1,2,0]) + +def tensor_to_gray_im(x): + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +def tensor_to_im(x): + x = detach_to_cpu(x) + x = inv_im_trans(x).clamp(0, 1) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +# Predefined key <-> caption dict +key_captions = { + 'im': 'Image', + 'gt': 'GT', +} + +""" +Return an image array with captions +keys in dictionary will be used as caption if not provided +values should contain lists of cv2 images +""" +def get_image_array(images, grid_shape, captions={}): + h, w = grid_shape + cate_counts = len(images) + rows_counts = len(next(iter(images.values()))) + + font = cv2.FONT_HERSHEY_SIMPLEX + + output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) + col_cnt = 0 + for k, v in images.items(): + + # Default as key value itself + caption = captions.get(k, k) + + # Handles new line character + dy = 40 + for i, line in enumerate(caption.split('\n')): + cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), + font, 0.8, (255,255,255), 2, cv2.LINE_AA) + + # Put images + for row_cnt, img in enumerate(v): + im_shape = img.shape + if len(im_shape) == 2: + img = img[..., np.newaxis] + + img = (img * 255).astype('uint8') + + output_image[(col_cnt+0)*w:(col_cnt+1)*w, + (row_cnt+1)*h:(row_cnt+2)*h, :] = img + + col_cnt += 1 + + return output_image + +def base_transform(im, size): + im = tensor_to_np_float(im) + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + +def im_transform(im, size): + return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) + +def mask_transform(mask, size): + return base_transform(detach_to_cpu(mask), size=size) + +def out_transform(mask, size): + return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) + +def lll2rgb_transform(mask, size): + flag_test = False + + mask_d = detach_to_cpu(mask) + + mask_d[1:3,:,:] = 0 + + if flag_test: print('before inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d)) + mask_d = inv_lll2rgb_trans(mask_d) + if flag_test: print('after inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d));assert 1==0 + + im = tensor_to_np_float(mask_d) + + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + im = color.lab2rgb(im) + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + +def lab2rgb_transform(mask, size): + flag_test = False + + mask_d = detach_to_cpu(mask) + + if flag_test: print('before inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d)) + mask_d = inv_lll2rgb_trans(mask_d) + if flag_test: print('after inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d));assert 1==0 + + im = tensor_to_np_float(mask_d) + + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + im = color.lab2rgb(im) + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + + + +def pool_pairs_221128_TransColorization(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + + # max_num_objects = max(num_objects[:b]) + max_num_objects = 1 + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + # print(images['rgb'].size(), b, max_num_objects, images['info']['name'], GT_suffix) + # print(images['info']['name'][0][-25:-4]) + # print(images['info']['name'][1][-25:-4]) + # assert 1==0 + + for bi in range(b): + for ti in range(t): + + req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + + # req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) + # print(images['rgb'][bi,ti][:1,:,:].size(), images['first_frame_gt'][bi][0,:].size());assert 1==0 + req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) + + + else: + # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) + req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) + + # req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) + # print(images['cls_gt'][bi,ti,:,:].size());assert 1==0 + req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) + + # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) + # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) + + + return get_image_array(req_images, size, key_captions) + + +def pool_pairs_221128_TransColorization_val(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + + # max_num_objects = max(num_objects[:b]) + max_num_objects = 1 + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + # print(images['rgb'].size(), b, max_num_objects, images['info']['name'], GT_suffix) + # print(images['info']['name'][0][-25:-4]) + # print(images['info']['name'][1][-25:-4]) + # assert 1==0 + + for bi in range(b): + for ti in range(t): + + req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + + # req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) + # print(images['rgb'][bi,ti][:1,:,:].size(), images['first_frame_gt'][bi][0,:].size());assert 1==0 + req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) + + + else: + # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) + req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) + + # req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) + # print(images['cls_gt'][bi,ti,:,:].size());assert 1==0 + req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) + + # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) + # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) + + + return get_image_array(req_images, size, key_captions) + + + +def pool_pairs(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + max_num_objects = max(num_objects[:b]) + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + for bi in range(b): + for ti in range(t): + req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size)) + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) + # req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) + # req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) + else: + req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) + # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size)) + # req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size)) + # req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size)) + req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) + # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) + # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) + + + return get_image_array(req_images, size, key_captions) \ No newline at end of file diff --git a/util/load_subset.py b/util/load_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..3191f4fef05cec04a11eafdfa42b34b98a35549e --- /dev/null +++ b/util/load_subset.py @@ -0,0 +1,16 @@ +""" +load_subset.py - Presents a subset of data +DAVIS - only the training set +YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all +""" + + +def load_sub_davis(path='util/davis_subset.txt'): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset + +def load_sub_yv(path='util/yv_subset.txt'): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset diff --git a/util/log_integrator.py b/util/log_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b26d53de98b16e145090bcddf2041a3f2d1394 --- /dev/null +++ b/util/log_integrator.py @@ -0,0 +1,80 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation / logging to tensorboard +Call finalize and create a new Integrator when you want to display/log +""" + +import torch + + +class Integrator: + def __init__(self, logger, distributed=True, local_rank=0, world_size=1): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + self.logger = logger + + self.distributed = distributed + self.local_rank = local_rank + self.world_size = world_size + + def add_tensor(self, key, tensor): + if key not in self.values: + self.counts[key] = 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] = tensor + else: + self.values[key] = tensor.mean().item() + else: + self.counts[key] += 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] += tensor + else: + self.values[key] += tensor.mean().item() + + def add_dict(self, tensor_dict): + for k, v in tensor_dict.items(): + self.add_tensor(k, v) + + def add_hook(self, hook): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + e.g. for computing IoU + """ + if type(hook) == list: + self.hooks.extend(hook) + else: + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, prefix, it, f=None): + + for hook in self.hooks: + k, v = hook(self.values) + self.add_tensor(k, v) + + for k, v in self.values.items(): + + if k[:4] == 'hide': + continue + + avg = v / self.counts[k] + + if self.distributed: + # Inplace operation + avg = torch.tensor(avg).cuda() + torch.distributed.reduce(avg, dst=0) + + if self.local_rank == 0: + avg = (avg/self.world_size).cpu().item() + self.logger.log_metrics(prefix, k, avg, it, f) + else: + # Simple does it + self.logger.log_metrics(prefix, k, avg, it, f) + diff --git a/util/logger.py b/util/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd9fb29bb3f5b0e9ba9e2ee425f420e743d10e7 --- /dev/null +++ b/util/logger.py @@ -0,0 +1,101 @@ +""" +Dumps things to tensorboard and console +""" + +import os +import warnings + +import torchvision.transforms as transforms +from torch.utils.tensorboard import SummaryWriter + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def fix_width_trunc(x): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + +class TensorboardLogger: + def __init__(self, short_id, id, git_info, flag_occupy_memory, savepath='.'): + self.short_id = short_id + if self.short_id == 'NULL': + self.short_id = 'DEBUG' + + if id is None: + self.no_log = True + warnings.warn('Logging has been disbaled.') + else: + self.no_log = False + + self.inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) + + self.inv_seg_trans = transforms.Normalize( + mean=[-0.5/0.5], + std=[1/0.5]) + + log_path = os.path.join('.', 'tmp_occupy_memory_saves', '%s' % id) if flag_occupy_memory else os.path.join(savepath, 'saves', '%s' % id) + self.logger = SummaryWriter(log_path) + + self.log_string('git', git_info) + + def log_scalar(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_scalar(tag, x, step) + + def log_metrics(self, l1_tag, l2_tag, val, step, f=None): + tag = l1_tag + '/' + l2_tag + text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) + print(text) + if f is not None: + f.write(text + '\n') + f.flush() + self.log_scalar(tag, val, step) + + def log_im(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_im_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_cv2(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = x.transpose((2, 0, 1)) + self.logger.add_image(tag, x, step) + + def log_seg(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_seg_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_gray(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_string(self, tag, x): + print(tag, x) + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_text(tag, x) + \ No newline at end of file diff --git a/util/palette.py b/util/palette.py new file mode 100644 index 0000000000000000000000000000000000000000..d2541659563056b015b3d6e4c2b0accef3b4e831 --- /dev/null +++ b/util/palette.py @@ -0,0 +1,3 @@ +davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' + +youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' diff --git a/util/tensor_util.py b/util/tensor_util.py new file mode 100644 index 0000000000000000000000000000000000000000..05189d38e2b0b0d1d08bd7804b8e43418d6da637 --- /dev/null +++ b/util/tensor_util.py @@ -0,0 +1,47 @@ +import torch.nn.functional as F + + +def compute_tensor_iu(seg, gt): + intersection = (seg & gt).float().sum() + union = (seg | gt).float().sum() + + return intersection, union + +def compute_tensor_iou(seg, gt): + intersection, union = compute_tensor_iu(seg, gt) + iou = (intersection + 1e-6) / (union + 1e-6) + + return iou + +# STM +def pad_divide_by(in_img, d): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) + lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + +def unpad(img, pad): + if len(img.shape) == 4: + if pad[2]+pad[3] > 0: + img = img[:,:,pad[2]:-pad[3],:] + if pad[0]+pad[1] > 0: + img = img[:,:,:,pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2]+pad[3] > 0: + img = img[:,pad[2]:-pad[3],:] + if pad[0]+pad[1] > 0: + img = img[:,:,pad[0]:-pad[1]] + else: + raise NotImplementedError + return img \ No newline at end of file diff --git a/util/transforms.py b/util/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..df7a54559a329a84aec339182d2aa07f3ee4fdb0 --- /dev/null +++ b/util/transforms.py @@ -0,0 +1,61 @@ + +from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans +from skimage import color, io +import cv2 +import numpy as np + +def detach_to_cpu(x): + return x.detach().cpu() + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + +def lab2rgb_transform_PIL(mask): + flag_test = False + + mask_d = detach_to_cpu(mask) + + if flag_test: print('before inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d)) + mask_d = inv_lll2rgb_trans(mask_d) + if flag_test: print('after inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d));assert 1==0 + + im = tensor_to_np_float(mask_d) + + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + im = color.lab2rgb(im) + + return im.clip(0, 1) + +def calculate_psnr(img1, img2): + mse_value = ((img1 - img2)**2).mean() + if mse_value == 0: + result = float('inf') + else: + result = 20. * np.log10(255. / np.sqrt(mse_value)) + return result + +def calculate_psnr_for_folder(gt_folder, result_folder): + result_clips = sorted(os.listdir(result_folder)) + + psnr_values = [] + for clip in result_clips: + path_clip = os.path.join(result_folder, clip) + test_files = sorted(os.listdir(path_clip)) + + for img in test_files: + gt_path = os.path.join(gt_folder, clip, img) + result_path = os.path.join(path_clip, img) + + gt_img = np.array(Image.open(gt_path)) + result_img = np.array(Image.open(result_path)) + + psnr = calculate_psnr(gt_img, result_img) + psnr_values.append(psnr) + + avg_psnr = np.mean(psnr_values) + return avg_psnr \ No newline at end of file diff --git a/util/yv_subset.txt b/util/yv_subset.txt new file mode 100644 index 0000000000000000000000000000000000000000..a26e50a7b8e6233bf17c542b540765cd8a1c5716 --- /dev/null +++ b/util/yv_subset.txt @@ -0,0 +1,3464 @@ +003234408d +0043f083b5 +0044fa5fba +005a527edd +0065b171f9 +00917dcfc4 +00a23ccf53 +00ad5016a4 +01082ae388 +011ac0a06f +013099c098 +0155498c85 +01694ad9c8 +017ac35701 +01b80e8e1a +01baa5a4e1 +01c3111683 +01c4cb5ffe +01c76f0a82 +01c783268c +01ed275c6e +01ff60d1fa +020cd28cd2 +02264db755 +0248626d9a +02668dbffa +0274193026 +02d28375aa +02f3a5c4df +031ccc99b1 +0321b18c10 +0348a45bca +0355e92655 +0358b938c1 +0368107cf1 +0379ddf557 +038b2cc71d +038c15a5dd +03a06cc98a +03a63e187f +03c95b4dae +03e2b57b0e +04194e1248 +0444918a5f +04460a7a52 +04474174a4 +0450095513 +045f00aed2 +04667fabaa +04735c5030 +04990d1915 +04d62d9d98 +04f21da964 +04fbad476e +04fe256562 +0503bf89c9 +0536c9eed0 +054acb238f +05579ca250 +056c200404 +05774f3a2c +058a7592c8 +05a0a513df +05a569d8aa +05aa652648 +05d7715782 +05e0b0f28f +05fdbbdd7a +05ffcfed85 +0630391881 +06840b2bbe +068f7dce6f +0693719753 +06ce2b51fb +06e224798e +06ee361788 +06fbb3fa2c +0700264286 +070c918ca7 +07129e14a4 +07177017e9 +07238ffc58 +07353b2a89 +0738493cbf +075926c651 +075c701292 +0762ea9a30 +07652ee4af +076f206928 +077d32af19 +079049275c +07913cdda7 +07a11a35e8 +07ac33b6df +07b6e8fda8 +07c62c3d11 +07cc1c7d74 +080196ef01 +081207976e +081ae4fa44 +081d8250cb +082900c5d4 +0860df21e2 +0866d4c5e3 +0891ac2eb6 +08931bc458 +08aa2705d5 +08c8450db7 +08d50b926c +08e1e4de15 +08e48c1a48 +08f561c65e +08feb87790 +09049f6fe3 +092e4ff450 +09338adea8 +093c335ccc +0970d28339 +0974a213dc +097b471ed8 +0990941758 +09a348f4fa +09a6841288 +09c5bad17b +09c9ce80c7 +09ff54fef4 +0a23765d15 +0a275e7f12 +0a2f2bd294 +0a7a2514aa +0a7b27fde9 +0a8c467cc3 +0ac8c560ae +0b1627e896 +0b285c47f6 +0b34ec1d55 +0b5b5e8e5a +0b68535614 +0b6f9105fc +0b7dbfa3cb +0b9cea51ca +0b9d012be8 +0bcfc4177d +0bd37b23c1 +0bd864064c +0c11c6bf7b +0c26bc77ac +0c3a04798c +0c44a9d545 +0c817cc390 +0ca839ee9a +0cd7ac0ac0 +0ce06e0121 +0cfe974a89 +0d2fcc0dcd +0d3aad05d2 +0d40b015f4 +0d97fba242 +0d9cc80d7e +0dab85b6d3 +0db5c427a5 +0dbaf284f1 +0de4923598 +0df28a9101 +0e04f636c4 +0e05f0e232 +0e0930474b +0e27472bea +0e30020549 +0e621feb6c +0e803c7d73 +0e9ebe4e3c +0e9f2785ec +0ea68d418b +0eb403a222 +0ee92053d6 +0eefca067f +0f17fa6fcb +0f1ac8e9a3 +0f202e9852 +0f2ab8b1ff +0f51a78756 +0f5fbe16b0 +0f6072077b +0f6b69b2f4 +0f6c2163de +0f74ec5599 +0f9683715b +0fa7b59356 +0fb173695b +0fc958cde2 +0fe7b1a621 +0ffcdb491c +101caff7d4 +1022fe8417 +1032e80b37 +103f501680 +104e64565f +104f1ab997 +106242403f +10b31f5431 +10eced835e +110d26fa3a +1122c1d16a +1145b49a5f +11485838c2 +114e7676ec +1157472b95 +115ee1072c +1171141012 +117757b4b8 +1178932d2f +117cc76bda +1180cbf814 +1187bbd0e3 +1197e44b26 +119cf20728 +119dd54871 +11a0c3b724 +11a6ba8c94 +11c722a456 +11cbcb0b4d +11ccf5e99d +11ce6f452e +11e53de6f2 +11feabe596 +120cb9514d +12156b25b3 +122896672d +1232b2f1d4 +1233ac8596 +1239c87234 +1250423f7c +1257a1bc67 +125d1b19dd +126d203967 +1295e19071 +12ad198c54 +12bddb2bcb +12ec9b93ee +12eebedc35 +132852e094 +1329409f2a +13325cfa14 +134d06dbf9 +135625b53d +13870016f9 +13960b3c84 +13adaad9d9 +13ae097e20 +13e3070469 +13f6a8c20d +1416925cf2 +142d2621f5 +145d5d7c03 +145fdc3ac5 +1471274fa7 +14a6b5a139 +14c21cea0d +14dae0dc93 +14f9bd22b5 +14fd28ae99 +15097d5d4e +150ea711f2 +1514e3563f +152aaa3a9e +152b7d3bd7 +15617297cc +15abbe0c52 +15d1fb3de5 +15f67b0fab +161eb59aad +16288ea47f +164410ce62 +165c3c8cd4 +165c42b41b +165ec9e22b +1669502269 +16763cccbb +16adde065e +16af445362 +16afd538ad +16c3fa4d5d +16d1d65c27 +16e8599e94 +16fe9fb444 +1705796b02 +1724db7671 +17418e81ea +175169edbb +17622326fd +17656bae77 +17b0d94172 +17c220e4f6 +17c7bcd146 +17cb4afe89 +17cd79a434 +17d18604c3 +17d8ca1a37 +17e33f4330 +17f7a6d805 +180abc8378 +183ba3d652 +185bf64702 +18913cc690 +1892651815 +189ac8208a +189b44e92c +18ac264b76 +18b245ab49 +18b5cebc34 +18bad52083 +18bb5144d5 +18c6f205c5 +1903f9ea15 +1917b209f2 +191e74c01d +19367bb94e +193ffaa217 +19696b67d3 +197f3ab6f3 +1981e763cc +198afe39ae +19a6e62b9b +19b60d5335 +19c00c11f9 +19e061eb88 +19e8bc6178 +19ee80dac6 +1a25a9170a +1a359a6c1a +1a3e87c566 +1a5fe06b00 +1a6c0fbd1e +1a6f3b5a4b +1a8afbad92 +1a8bdc5842 +1a95752aca +1a9c131cb7 +1aa3da3ee3 +1ab27ec7ea +1abf16d21d +1acd0f993b +1ad202e499 +1af8d2395d +1afd39a1fa +1b2d31306f +1b3fa67f0e +1b43fa74b4 +1b73ea9fc2 +1b7e8bb255 +1b8680f8cd +1b883843c0 +1b8898785b +1b88ba1aa4 +1b96a498e5 +1bbc4c274f +1bd87fe9ab +1c4090c75b +1c41934f84 +1c72b04b56 +1c87955a3a +1c9f9eb792 +1ca240fede +1ca5673803 +1cada35274 +1cb44b920d +1cd10e62be +1d3087d5e5 +1d3685150a +1d6ff083aa +1d746352a6 +1da256d146 +1da4e956b1 +1daf812218 +1dba687bce +1dce57d05d +1de4a9e537 +1dec5446c8 +1dfbe6f586 +1e1a18c45a +1e1e42529d +1e4be70796 +1eb60959c8 +1ec8b2566b +1ecdc2941c +1ee0ac70ff +1ef8e17def +1f1a2a9fc0 +1f1beb8daa +1f2609ee13 +1f3876f8d0 +1f4ec0563d +1f64955634 +1f7d31b5b2 +1f8014b7fd +1f9c7d10f1 +1fa350df76 +1fc9538993 +1fe2f0ec59 +2000c02f9d +20142b2f05 +201a8d75e5 +2023b3ee4f +202b767bbc +203594a418 +2038987336 +2039c3aecb +204a90d81f +207bc6cf01 +208833d1d1 +20c6d8b362 +20e3e52e0a +2117fa0c14 +211bc5d102 +2120d9c3c3 +2125235a49 +21386f5978 +2142af8795 +215dfc0f73 +217bae91e5 +217c0d44e4 +219057c87b +21d0edbf81 +21df87ad76 +21f1d089f5 +21f4019116 +222597030f +222904eb5b +223a0e0657 +223bd973ab +22472f7395 +224e7c833e +225aba51d9 +2261d421ea +2263a8782b +2268cb1ffd +2268e93b0a +2293c99f3f +22a1141970 +22b13084b2 +22d9f5ab0c +22f02efe3a +232c09b75b +2350d71b4b +2376440551 +2383d8aafd +238b84e67f +238d4b86f6 +238d947c6b +23993ce90d +23b0c8a9ab +23b3beafcc +23d80299fe +23f404a9fc +240118e58a +2431dec2fd +24440e0ac7 +2457274dbc +2465bf515d +246b142c4d +247d729e36 +2481ceafeb +24866b4e6a +2489d78320 +24ab0b83e8 +24b0868d92 +24b5207cd9 +24ddf05c03 +250116161c +256ad2e3fc +256bd83d5e +256dcc8ab8 +2589956baa +258b3b33c6 +25ad437e29 +25ae395636 +25c750c6db +25d2c3fe5d +25dc80db7c +25f97e926f +26011bc28b +260846ffbe +260dd9ad33 +267964ee57 +2680861931 +268ac7d3fc +26b895d91e +26bc786d4f +26ddd2ef12 +26de3d18ca +26f7784762 +2703e52a6a +270ed80c12 +2719b742ab +272f4163d0 +27303333e1 +27659fa7d6 +279214115d +27a5f92a9c +27cf2af1f3 +27f0d5f8a2 +28075f33c1 +281629cb41 +282b0d51f5 +282fcab00b +28449fa0dc +28475208ca +285580b7c4 +285b69e223 +288c117201 +28a8eb9623 +28bf9c3cf3 +28c6b8f86a +28c972dacd +28d9fa6016 +28e392de91 +28f4a45190 +298c844fc9 +29a0356a2b +29d779f9e3 +29dde5f12b +29de7b6579 +29e630bdd0 +29f2332d30 +2a18873352 +2a3824ff31 +2a559dd27f +2a5c09acbd +2a63eb1524 +2a6a30a4ea +2a6d9099d1 +2a821394e3 +2a8c5b1342 +2abc8d66d2 +2ac9ef904a +2b08f37364 +2b351bfd7d +2b659a49d7 +2b69ee5c26 +2b6c30bbbd +2b88561cf2 +2b8b14954e +2ba621c750 +2bab50f9a7 +2bb00c2434 +2bbde474ef +2bdd82fb86 +2be06fb855 +2bf545c2f5 +2bffe4cf9a +2c04b887b7 +2c05209105 +2c0ad8cf39 +2c11fedca8 +2c1a94ebfb +2c1e8c8e2f +2c29fabcf1 +2c2c076c01 +2c3ea7ee7d +2c41fa0648 +2c44bb6d1c +2c54cfbb78 +2c5537eddf +2c6e63b7de +2cb10c6a7e +2cbcd5ccd1 +2cc5d9c5f6 +2cd01cf915 +2cdbf5f0a7 +2ce660f123 +2cf114677e +2d01eef98e +2d03593bdc +2d183ac8c4 +2d33ad3935 +2d3991d83e +2d4333577b +2d4d015c64 +2d8f5e5025 +2d900bdb8e +2d9a1a1d49 +2db0576a5c +2dc0838721 +2dcc417f82 +2df005b843 +2df356de14 +2e00393d96 +2e03b8127a +2e0f886168 +2e2bf37e6d +2e42410932 +2ea78f46e4 +2ebb017a26 +2ee2edba2a +2efb07554a +2f17e4fc1e +2f2c65c2f3 +2f2d9b33be +2f309c206b +2f53822e88 +2f53998171 +2f5b0c89b1 +2f680909e6 +2f710f66bd +2f724132b9 +2f7e3517ae +2f96f5fc6f +2f97d9fecb +2fbfa431ec +2fc9520b53 +2fcd9f4c62 +2feb30f208 +2ff7f5744f +30085a2cc6 +30176e3615 +301f72ee11 +3026bb2f61 +30318465dc +3054ca937d +306121e726 +3064ad91e8 +307444a47f +307bbb7409 +30a20194ab +30c35c64a4 +30dbdb2cd6 +30fc77d72f +310021b58b +3113140ee8 +3150b2ee57 +31539918c4 +318dfe2ce2 +3193da4835 +319f725ad9 +31bbd0d793 +322505c47f +322b237865 +322da43910 +3245e049fb +324c4c38f6 +324e35111a +3252398f09 +327dc4cabf +328d918c7d +3290c0de97 +3299ae3116 +32a7cd687b +33098cedb4 +3332334ac4 +334cb835ac +3355e056eb +33639a2847 +3373891cdc +337975816b +33e29d7e91 +34046fe4f2 +3424f58959 +34370a710f +343bc6a65a +3450382ef7 +3454303a08 +346aacf439 +346e92ff37 +34a5ece7dd +34b109755a +34d1b37101 +34dd2c70a7 +34efa703df +34fbee00a6 +3504df2fda +35195a56a1 +351c822748 +351cfd6bc5 +3543d8334c +35573455c7 +35637a827f +357a710863 +358bf16f9e +35ab34cc34 +35c6235b8d +35d01a438a +3605019d3b +3609bc3f88 +360e25da17 +36299c687c +362c5bc56e +3649228783 +365b0501ea +365f459863 +369893f3ad +369c9977e1 +369dde050a +36c7dac02f +36d5b1493b +36f5cc68fd +3735480d18 +374b479880 +375a49d38f +375a5c0e09 +376bda9651 +377db65f60 +37c19d1087 +37d4ae24fc +37ddce7f8b +37e10d33af +37e45c6247 +37fa0001e8 +3802d458c0 +382caa3cb4 +383bb93111 +388843df90 +38924f4a7f +38b00f93d7 +38c197c10e +38c9c3d801 +38eb2bf67f +38fe9b3ed1 +390352cced +390c51b987 +390ca6f1d6 +392bc0f8a1 +392ecb43bd +3935291688 +3935e63b41 +394454fa9c +394638fc8b +39545e20b7 +397abeae8f +3988074b88 +398f5d5f19 +39bc49a28c +39befd99fb +39c3c7bf55 +39d584b09f +39f6f6ffb1 +3a079fb484 +3a0d3a81b7 +3a1d55d22b +3a20a7583e +3a2c1f66e5 +3a33f4d225 +3a3bf84b13 +3a4565e5ec +3a4e32ed5e +3a7ad86ce0 +3a7bdde9b8 +3a98867cbe +3aa3f1c9e8 +3aa7fce8b6 +3aa876887d +3ab807ded6 +3ab9b1a85a +3adac8d7da +3ae1a4016f +3ae2deaec2 +3ae81609d6 +3af847e62f +3b23792b84 +3b3b0af2ee +3b512dad74 +3b6c7988f6 +3b6e983b5b +3b74a0fc20 +3b7a50b80d +3b96d3492f +3b9ad0c5a9 +3b9ba0894a +3bb4e10ed7 +3bd9a9b515 +3beef45388 +3c019c0a24 +3c090704aa +3c2784fc0d +3c47ab95f8 +3c4db32d74 +3c5ff93faf +3c700f073e +3c713cbf2f +3c8320669c +3c90d225ee +3cadbcc404 +3cb9be84a5 +3cc37fd487 +3cc6f90cb2 +3cd5e035ef +3cdf03531b +3cdf828f59 +3d254b0bca +3d5aeac5ba +3d690473e1 +3d69fed2fb +3d8997aeb6 +3db0d6b07e +3db1ddb8cf +3db907ac77 +3dcbc0635b +3dd48ed55f +3de4ac4ec4 +3decd63d88 +3e04a6be11 +3e108fb65a +3e1448b01c +3e16c19634 +3e2845307e +3e38336da5 +3e3a819865 +3e3e4be915 +3e680622d7 +3e7d2aeb07 +3e7d8f363d +3e91f10205 +3ea4c49bbe +3eb39d11ab +3ec273c8d5 +3ed3f91271 +3ee062a2fd +3eede9782c +3ef2fa99cb +3efc6e9892 +3f0b0dfddd +3f0c860359 +3f18728586 +3f3b15f083 +3f45a470ad +3f4f3bc803 +3fd96c5267 +3fea675fab +3fee8cbc9f +3fff16d112 +401888b36c +4019231330 +402316532d +402680df52 +404d02e0c0 +40709263a8 +4083cfbe15 +40a96c5cb1 +40b8e50f82 +40f4026bf5 +4100b57a3a +41059fdd0b +41124e36de +4122aba5f9 +413bab0f0d +4164faee0b +418035eec9 +4182d51532 +418bb97e10 +41a34c20e7 +41dab05200 +41ff6d5e2a +420caf0859 +42264230ba +425a0c96e0 +42da96b87c +42eb5a5b0f +42f17cd14d +42f5c61c49 +42ffdcdee9 +432f9884f9 +43326d9940 +4350f3ab60 +4399ffade3 +43a6c21f37 +43b5555faa +43d63b752a +4416bdd6ac +4444753edd +444aa274e7 +444d4e0596 +446b8b5f7a +4478f694bb +44b1da0d87 +44b4dad8c9 +44b5ece1b9 +44d239b24e +44eaf8f51e +44f4f57099 +44f7422af2 +450787ac97 +4523656564 +4536c882e5 +453b65daa4 +454f227427 +45636d806a +456fb9362e +457e717a14 +45a89f35e1 +45bf0e947d +45c36a9eab +45d9fc1357 +45f8128b97 +4607f6c03c +46146dfd39 +4620e66b1e +4625f3f2d3 +462b22f263 +4634736113 +463c0f4fdd +46565a75f8 +46630b55ae +466839cb37 +466ba4ae0c +4680236c9d +46bf4e8709 +46e18e42f1 +46f5093c59 +47269e0499 +472da1c484 +47354fab09 +4743bb84a7 +474a796272 +4783d2ab87 +479cad5da3 +479f5d7ef6 +47a05fbd1d +4804ee2767 +4810c3fbca +482fb439c2 +48375af288 +484ab44de4 +485f3944cd +4867b84887 +486a8ac57e +486e69c5bd +48812cf33e +4894b3b9ea +48bd66517d +48d83b48a4 +49058178b8 +4918d10ff0 +4932911f80 +49405b7900 +49972c2d14 +499bf07002 +49b16e9377 +49c104258e +49c879f82d +49e7326789 +49ec3e406a +49fbf0c98a +4a0255c865 +4a088fe99a +4a341402d0 +4a3471bdf5 +4a4b50571c +4a50f3d2e9 +4a6e3faaa1 +4a7191f08a +4a86fcfc30 +4a885fa3ef +4a8af115de +4aa2e0f865 +4aa9d6527f +4abb74bb52 +4ae13de1cd +4af8cb323f +4b02c272b3 +4b19c529fb +4b2974eff4 +4b3154c159 +4b54d2587f +4b556740ff +4b67aa9ef6 +4b97cc7b8d +4baa1ed4aa +4bc8c676bb +4beaea4dbe +4bf5763d24 +4bffa92b67 +4c25dfa8ec +4c397b6fd4 +4c51e75d66 +4c7710908f +4c9b5017be +4ca2ffc361 +4cad2e93bc +4cd427b535 +4cd9a4b1ef +4cdfe3c2b2 +4cef87b649 +4cf208e9b3 +4cf5bc3e60 +4cfdd73249 +4cff5c9e42 +4d26d41091 +4d5c23c554 +4d67c59727 +4d983cad9f +4da0d00b55 +4daa179861 +4dadd57153 +4db117e6c5 +4de4ce4dea +4dfaee19e5 +4dfdd7fab0 +4e3f346aa5 +4e49c2a9c7 +4e4e06a749 +4e70279712 +4e72856cc7 +4e752f8075 +4e7a28907f +4e824b9247 +4e82b1df57 +4e87a639bc +4ea77bfd15 +4eb6fc23a2 +4ec9da329e +4efb9a0720 +4f062fbc63 +4f35be0e0b +4f37e86797 +4f414dd6e7 +4f424abded +4f470cc3ae +4f601d255a +4f7386a1ab +4f824d3dcd +4f827b0751 +4f8db33a13 +4fa160f8a3 +4fa9c30a45 +4facd8f0e8 +4fca07ad01 +4fded94004 +4fdfef4dea +4feb3ac01f +4fffec8479 +500c835a86 +50168342bf +50243cffdc +5031d5a036 +504dd9c0fd +50568fbcfb +5069c7c5b3 +508189ac91 +50b6b3d4b7 +50c6f4fe3e +50cce40173 +50efbe152f +50f290b95d +5104aa1fea +5110dc72c0 +511e8ecd7f +513aada14e +5158d6e985 +5161e1fa57 +51794ddd58 +517d276725 +51a597ee04 +51b37b6d97 +51b5dc30a0 +51e85b347b +51eea1fdac +51eef778af +51f384721c +521cfadcb4 +52355da42f +5247d4b160 +524b470fd0 +524cee1534 +5252195e8a +5255c9ca97 +525928f46f +526df007a7 +529b12de78 +52c7a3d653 +52c8ec0373 +52d225ed52 +52ee406d9e +52ff1ccd4a +53143511e8 +5316d11eb7 +53253f2362 +534a560609 +5352c4a70e +536096501f +536b17bcea +5380eaabff +5390a43a54 +53af427bb2 +53bf5964ce +53c30110b5 +53cad8e44a +53d9c45013 +53e274f1b5 +53e32d21ea +540850e1c7 +540cb31cfe +541c4da30f +541d7935d7 +545468262b +5458647306 +54657855cd +547b3fb23b +5497dc3712 +549c56f1d4 +54a4260bb1 +54b98b8d5e +54e1054b0f +54e8867b83 +54ebe34f6e +5519b4ad13 +551acbffd5 +55341f42da +5566ab97e1 +556c79bbf2 +5589637cc4 +558aa072f0 +559824b6f6 +55c1764e90 +55eda6c77e +562d173565 +5665c024cb +566cef4959 +5675d78833 +5678a91bd8 +567a2b4bd0 +569c282890 +56cc449917 +56e71f3e07 +56f09b9d92 +56fc0e8cf9 +571ca79c71 +57243657cf +57246af7d1 +57427393e9 +574b682c19 +578f211b86 +5790ac295d +579393912d +57a344ab1a +57bd3bcda4 +57bfb7fa4c +57c010175e +57c457cc75 +57c7fc2183 +57d5289a01 +58045fde85 +58163c37cd +582d463e5c +5851739c15 +585dd0f208 +587250f3c3 +589e4cc1de +589f65f5d5 +58a07c17d5 +58adc6d8b6 +58b9bcf656 +58c374917e +58fc75fd42 +5914c30f05 +59323787d5 +5937b08d69 +594065ddd7 +595a0ceea6 +59623ec40b +597ff7ef78 +598935ef05 +598c2ad3b2 +59a6459751 +59b175e138 +59bf0a149f +59d53d1649 +59e3e6fae7 +59fe33e560 +5a13a73fe5 +5a25c22770 +5a4a785006 +5a50640995 +5a75f7a1cf +5a841e59ad +5a91c5ab6d +5ab49d9de0 +5aba1057fe +5abe46ba6d +5ac7c88d0c +5aeb95cc7d +5af15e4fc3 +5afe381ae4 +5b07b4229d +5b1001cc4f +5b1df237d2 +5b263013bf +5b27d19f0b +5b48ae16c5 +5b5babc719 +5baaebdf00 +5bab55cdbe +5bafef6e79 +5bd1f84545 +5bddc3ba25 +5bdf7c20d2 +5bf23bc9d3 +5c01f6171a +5c021681b7 +5c185cff1d +5c42aba280 +5c44bf8ab6 +5c4c574894 +5c52fa4662 +5c6ea7dac3 +5c74315dc2 +5c7668855e +5c83e96778 +5ca36173e4 +5cac477371 +5cb0cb1b2f +5cb0cfb98f +5cb49a19cf +5cbf7dc388 +5d0e07d126 +5d1e24b6e3 +5d663000ff +5da6b2dc5d +5de9b90f24 +5e08de0ed7 +5e1011df9a +5e1ce354fd +5e35512dd7 +5e418b25f9 +5e4849935a +5e4ee19663 +5e886ef78f +5e8d00b974 +5e8d59dc31 +5ed838bd5c +5edda6ee5a +5ede4d2f7a +5ede9767da +5eec4d9fe5 +5eecf07824 +5eef7ed4f4 +5ef5860ac6 +5ef6573a99 +5f1193e72b +5f29ced797 +5f32cf521e +5f51876986 +5f6ebe94a9 +5f6f14977c +5f808d0d2d +5fb8aded6a +5fba90767d +5fd1c7a3df +5fd3da9f68 +5fee2570ae +5ff66140d6 +5ff8b85b53 +600803c0f6 +600be7f53e +6024888af8 +603189a03c +6057307f6e +6061ddbb65 +606c86c455 +60c61cc2e5 +60e51ff1ae +610e38b751 +61344be2f6 +6135e27185 +614afe7975 +614e571886 +614e7078db +619812a1a7 +61b481a78b +61c7172650 +61cf7e40d2 +61d08ef5a1 +61da008958 +61ed178ecb +61f5d1282c +61fd977e49 +621584cffe +625817a927 +625892cf0b +625b89d28a +629995af95 +62a0840bb5 +62ad6e121c +62d6ece152 +62ede7b2da +62f025e1bc +6316faaebc +63281534dc +634058dda0 +6353f09384 +6363c87314 +636e4872e0 +637681cd6b +6376d49f31 +6377809ec2 +63936d7de5 +639bddef11 +63d37e9fd3 +63d90c2bae +63e544a5d6 +63ebbcf874 +63fff40b31 +6406c72e4d +64148128be +6419386729 +643092bc41 +644081b88d +64453cf61d +644bad9729 +6454f548fd +645913b63a +64750b825f +64a43876b7 +64dd6c83e3 +64e05bf46e +64f55f1478 +650b0165e4 +651066ed39 +652b67d960 +653821d680 +6538d00d73 +65866dce22 +6589565c8c +659832db64 +65ab7e1d98 +65b7dda462 +65bd5eb4f5 +65dcf115ab +65e9825801 +65f9afe51c +65ff12bcb5 +666b660284 +6671643f31 +668364b372 +66852243cb +6693a52081 +669b572898 +66e98e78f5 +670f12e88f +674c12c92d +675c27208a +675ed3e1ca +67741db50a +678a2357eb +67b0f4d562 +67cfbff9b1 +67e717d6bd +67ea169a3b +67ea809e0e +681249baa3 +683de643d9 +6846ac20df +6848e012ef +684bcd8812 +684dc1c40c +685a1fa9cf +686dafaac9 +68807d8601 +6893778c77 +6899d2dabe +68a2fad4ab +68cb45fda3 +68cc4a1970 +68dcb40675 +68ea4a8c3d +68f6e7fbf0 +68fa8300b4 +69023db81f +6908ccf557 +691a111e7c +6927723ba5 +692ca0e1a2 +692eb57b63 +69340faa52 +693cbf0c9d +6942f684ad +6944fc833b +69491c0ebf +695b61a2b0 +6979b4d83f +697d4fdb02 +69910460a4 +6997636670 +69a436750b +69aebf7669 +69b8c17047 +69c67f109f +69e0e7b868 +69ea9c09d1 +69f0af42a6 +6a078cdcc7 +6a37a91708 +6a42176f2e +6a48e4aea8 +6a5977be3a +6a5de0535f +6a80d2e2e5 +6a96c8815d +6a986084e2 +6aa8e50445 +6ab9dce449 +6abf0ba6b2 +6acc6049d9 +6adb31756c +6ade215eb0 +6afb7d50e4 +6afd692f1a +6b0b1044fe +6b17c67633 +6b1b6ef28b +6b1e04d00d +6b2261888d +6b25d6528a +6b3a24395c +6b685eb75b +6b79be238c +6b928b7ba6 +6b9c43c25a +6ba99cc41f +6bdab62bcd +6bf2e853b1 +6bf584200f +6bf95df2b9 +6c0949c51c +6c11a5f11f +6c23d89189 +6c4387daf5 +6c4ce479a4 +6c5123e4bc +6c54265f16 +6c56848429 +6c623fac5f +6c81b014e9 +6c99ea7c31 +6c9d29d509 +6c9e3b7d1a +6ca006e283 +6caeb928d6 +6cb2ee722a +6cbfd32c5e +6cc791250b +6cccc985e0 +6d12e30c48 +6d4bf200ad +6d6d2b8843 +6d6eea5682 +6d7a3d0c21 +6d7efa9b9e +6da21f5c91 +6da6adabc0 +6dd2827fbb +6dd36705b9 +6df3637557 +6dfe55e9e5 +6e1a21ba55 +6e2f834767 +6e36e4929a +6e4f460caf +6e618d26b6 +6ead4670f7 +6eaff19b9f +6eb2e1cd9e +6eb30b3b5a +6eca26c202 +6ecad29e52 +6ef0b44654 +6efcfe9275 +6f4789045c +6f49f522ef +6f67d7c4c4 +6f96e91d81 +6fc6fce380 +6fc9b44c00 +6fce7f3226 +6fdf1ca888 +702fd8b729 +70405185d2 +7053e4f41e +707bf4ce41 +7082544248 +708535b72a +7094ac0f60 +70a6b875fa +70c3e97e41 +7106b020ab +711dce6fe2 +7136a4453f +7143fb084f +714d902095 +7151c53b32 +715357be94 +7163b8085f +716df1aa59 +71caded286 +71d2665f35 +71d67b9e19 +71e06dda39 +720b398b9c +720e3fa04c +720e7a5f1e +721bb6f2cb +722803f4f2 +72552a07c9 +726243a205 +72690ef572 +728cda9b65 +728e81c319 +72a810a799 +72acb8cdf6 +72b01281f9 +72cac683e4 +72cadebbce +72cae058a5 +72d8dba870 +72e8d1c1ff +72edc08285 +72f04f1a38 +731b825695 +7320b49b13 +732626383b +732df1eb05 +73329902ab +733798921e +733824d431 +734ea0d7fb +735a7cf7b9 +7367a42892 +7368d5c053 +73c6ae7711 +73e1852735 +73e4e5cc74 +73eac9156b +73f8441a88 +7419e2ab3f +74267f68b9 +7435690c8c +747c44785c +747f1b1f2f +748b2d5c01 +74d4cee0a4 +74ec2b3073 +74ef677020 +750be4c4d8 +75172d4ac8 +75285a7eb1 +75504539c3 +7550949b1d +7551cbd537 +75595b453d +7559b4b0ec +755bd1fbeb +756f76f74d +7570ca7f3c +757a69746e +757cac96c6 +7584129dc3 +75a058dbcd +75b09ce005 +75cae39a8f +75cee6caf0 +75cf58fb2c +75d5c2f32a +75eaf5669d +75f7937438 +75f99bd3b3 +75fa586876 +7613df1f84 +762e1b3487 +76379a3e69 +764271f0f3 +764503c499 +7660005554 +7666351b84 +76693db153 +767856368b +768671f652 +768802b80d +76962c7ed2 +76a75f4eee +76b90809f7 +770a441457 +772a0fa402 +772f2ffc3e +774f6c2175 +77610860e0 +777e58ff3d +77920f1708 +7799df28e7 +779e847a9a +77ba4edc72 +77c834dc43 +77d8aa8691 +77e7f38f4d +77eea6845e +7806308f33 +78254660ea +7828af8bff +784398620a +784d201b12 +78613981ed +78896c6baf +78aff3ebc0 +78c7c03716 +78d3676361 +78e29dd4c3 +78f1a1a54f +79208585cd +792218456c +7923bad550 +794e6fc49f +796e6762ce +797cd21f71 +79921b21c2 +79a5778027 +79bc006280 +79bf95e624 +79d9e00c55 +79e20fc008 +79e9db913e +79f014085e +79fcbb433a +7a13a5dfaa +7a14bc9a36 +7a3c535f70 +7a446a51e9 +7a56e759c5 +7a5f46198d +7a626ec98d +7a802264c4 +7a8b5456ca +7abdff3086 +7aecf9f7ac +7b0fd09c28 +7b18b3db87 +7b39fe7371 +7b49e03d4c +7b5388c9f1 +7b5cf7837f +7b733d31d8 +7b74fd7b98 +7b918ccb8a +7ba3ce3485 +7bb0abc031 +7bb5bb25cd +7bb7dac673 +7bc7761b8c +7bf3820566 +7c03a18ec1 +7c078f211b +7c37d7991a +7c4ec17eff +7c649c2aaf +7c73340ab7 +7c78a2266d +7c88ce3c5b +7ca6843a72 +7cc9258dee +7cec7296ae +7d0ffa68a4 +7d11b4450f +7d1333fcbe +7d18074fef +7d18c8c716 +7d508fb027 +7d55f791f0 +7d74e3c2f6 +7d783f67a9 +7d83a5d854 +7dd409947e +7de45f75e5 +7e0cd25696 +7e1922575c +7e1e3bbcc1 +7e24023274 +7e2f212fd3 +7e6d1cc1f4 +7e7cdcb284 +7e9b6bef69 +7ea5b49283 +7eb2605d96 +7eb26b8485 +7ecd1f0c69 +7f02b3cfe2 +7f1723f0d5 +7f21063c3a +7f3658460e +7f54132e48 +7f559f9d4a +7f5faedf8b +7f838baf2b +7fa5f527e3 +7ff84d66dd +802b45c8c4 +804382b1ad +804c558adb +804f6338a4 +8056117b89 +806b6223ab +8088bda461 +80b790703b +80c4a94706 +80ce2e351b +80db581acd +80e12193df +80e41b608f +80f16b016d +81541b3725 +8175486e6a +8179095000 +8193671178 +81a58d2c6b +81aa1286fb +81dffd30fb +8200245704 +823e7a86e8 +824973babb +824ca5538f +827171a845 +8273a03530 +827cf4f886 +82b865c7dd +82c1517708 +82d15514d6 +82e117b900 +82fec06574 +832b5ef379 +83424c9fbf +8345358fb8 +834b50b31b +835e3b67d7 +836ea92b15 +837c618777 +838eb3bd89 +839381063f +839bc71489 +83a8151377 +83ae88d217 +83ca8bcad0 +83ce590d7f +83d3130ba0 +83d40bcba5 +83daba503a +83de906ec0 +84044f37f3 +84696b5a5e +84752191a3 +847eeeb2e0 +848e7835a0 +84a4b29286 +84a4bf147d +84be115c09 +84d95c4350 +84e0922cf7 +84f0cfc665 +8515f6db22 +851f2f32c1 +852a4d6067 +854c48b02a +857a387c86 +859633d56a +85a4f4a639 +85ab85510c +85b1eda0d9 +85dc1041c6 +85e081f3c7 +85f75187ad +8604bb2b75 +860745b042 +863b4049d7 +8643de22d0 +8647d06439 +864ffce4fe +8662d9441a +8666521b13 +868d6a0685 +869fa45998 +86a40b655d +86a8ae4223 +86b2180703 +86c85d27df +86d3755680 +86e61829a1 +871015806c +871e409c5c +8744b861ce +8749369ba0 +878a299541 +8792c193a0 +8799ab0118 +87d1f7d741 +882b9e4500 +885673ea17 +8859dedf41 +8873ab2806 +887a93b198 +8883e991a9 +8891aa6dfa +8899d8cbcd +88b8274d67 +88d3b80af6 +88ede83da2 +88f345941b +890976d6da +8909bde9ab +8929c7d5d9 +89363acf76 +89379487e0 +8939db6354 +893f658345 +8953138465 +895c96d671 +895cbf96f9 +895e8b29a7 +898fa256c8 +89986c60be +89b874547b +89bdb021d5 +89c802ff9c +89d6336c2b +89ebb27334 +8a27e2407c +8a31f7bca5 +8a4a2fc105 +8a5d6c619c +8a75ad7924 +8aa817e4ed +8aad0591eb +8aca214360 +8ae168c71b +8b0cfbab97 +8b3645d826 +8b3805dbd4 +8b473f0f5d +8b4f6d1186 +8b4fb018b7 +8b518ee936 +8b523bdfd6 +8b52fb5fba +8b91036e5c +8b99a77ac5 +8ba04b1e7b +8ba782192f +8bbeaad78b +8bd1b45776 +8bd7a2dda6 +8bdb091ccf +8be56f165d +8be950d00f +8bf84e7d45 +8bffc4374b +8bfff50747 +8c09867481 +8c0a3251c3 +8c3015cccb +8c469815cf +8c9ccfedc7 +8ca1af9f3c +8ca3f6e6c1 +8ca6a4f60f +8cac6900fe +8cba221a1e +8cbbe62ccd +8d064b29e2 +8d167e7c08 +8d4ab94e1c +8d81f6f899 +8d87897d66 +8dcccd2bd2 +8dcfb878a8 +8dd3ab71b9 +8dda6bf10f +8ddd51ca94 +8dea22c533 +8def5bd3bf +8e1848197c +8e3a83cf2d +8e478e73f3 +8e98ae3c84 +8ea6687ab0 +8eb0d315c1 +8ec10891f9 +8ec3065ec2 +8ecf51a971 +8eddbab9f7 +8ee198467a +8ee2368f40 +8ef595ce82 +8f0a653ad7 +8f1204a732 +8f1600f7f6 +8f16366707 +8f1ce0a411 +8f2e05e814 +8f320d0e09 +8f3b4a84ad +8f3fdad3da +8f5d3622d8 +8f62a2c633 +8f81c9405a +8f8c974d53 +8f918598b6 +8ff61619f6 +9002761b41 +90107941f3 +90118a42ee +902bc16b37 +903e87e0d6 +9041a0f489 +9047bf3222 +9057bfa502 +90617b0954 +9076f4b6db +9077e69b08 +909655b4a6 +909c2eca88 +909dbd1b76 +90bc4a319a +90c7a87887 +90cc785ddd +90d300f09b +9101ea9b1b +9108130458 +911ac9979b +9151cad9b5 +9153762797 +91634ee0c9 +916942666f +9198cfb4ea +919ac864d6 +91b67d58d4 +91bb8df281 +91be106477 +91c33b4290 +91ca7dd9f3 +91d095f869 +91f107082e +920329dd5e +920c959958 +92128fbf4b +9223dacb40 +923137bb7f +9268e1f88a +927647fe08 +9276f5ba47 +92a28cd233 +92b5c1fc6d +92c46be756 +92dabbe3a0 +92e3159361 +92ebab216a +934bdc2893 +9359174efc +935d97dd2f +935feaba1b +93901858ee +939378f6d6 +939bdf742e +93a22bee7e +93da9aeddf +93e2feacce +93e6f1fdf9 +93e811e393 +93e85d8fd3 +93f623d716 +93ff35e801 +94031f12f2 +94091a4873 +94125907e3 +9418653742 +941c870569 +94209c86f0 +9437c715eb +9445c3eca2 +9467c8617c +946d71fb5d +948f3ae6fb +9498baa359 +94a33abeab +94bf1af5e3 +94cf3a8025 +94db712ac8 +94e4b66cff +94e76cbaf6 +950be91db1 +952058e2d0 +952633c37f +952ec313fe +9533fc037c +9574b81269 +9579b73761 +957f7bc48b +958073d2b0 +9582e0eb33 +9584092d0b +95b58b8004 +95bd88da55 +95f74a9959 +962781c601 +962f045bf5 +964ad23b44 +967b90590e +967bffe201 +96825c4714 +968492136a +9684ef9d64 +968c41829e +96a856ef9a +96dfc49961 +96e1a5b4f8 +96e6ff0917 +96fb88e9d7 +96fbe5fc23 +96fc924050 +9715cc83dc +9720eff40f +972c187c0d +97476eb38d +97659ed431 +9773492949 +97756b264f +977bff0d10 +97ab569ff3 +97ba838008 +97d9d008c7 +97e59f09fa +97eb642e56 +98043e2d14 +981ff580cf +983e66cbfc +984f0f1c36 +98595f2bb4 +985c3be474 +9869a12362 +986b5a5e18 +9877af5063 +98911292da +9893a3cf77 +9893d9202d +98a8b06e7f +98ac6f93d9 +98b6974d12 +98ba3c9417 +98c7c00a19 +98d044f206 +98e909f9d1 +98fe7f0410 +990f2742c7 +992bd0779a +994b9b47ba +9955b76bf5 +9966f3adac +997117a654 +999d53d841 +99c04108d3 +99c4277aee +99c6b1acf2 +99dc8bb20b +99fcba71e5 +99fecd4efb +9a02c70ba2 +9a08e7a6f8 +9a2f2c0f86 +9a3254a76e +9a3570a020 +9a39112493 +9a4e9fd399 +9a50af4bfb +9a68631d24 +9a72318dbf +9a767493b7 +9a7fc1548b +9a84ccf6a7 +9a9c0e15b7 +9adf06d89b +9b22b54ee4 +9b473fc8fe +9b4f081782 +9b997664ba +9bc454e109 +9bccfd04de +9bce4583a2 +9bebf1b87f +9bfc50d261 +9c166c86ff +9c293ef4d7 +9c29c047b0 +9c3bc2e2a7 +9c3ce23bd1 +9c404cac0c +9c5180d23a +9c7feca6e4 +9caa49d3ff +9cb2f1b646 +9ce6f765c3 +9cfee34031 +9d01f08ec6 +9d04c280b8 +9d12ceaddc +9d15f8cb3c +9d2101e9bf +9d407c3aeb +9ddefc6165 +9df0b1e298 +9e16f115d8 +9e249b4982 +9e29b1982c +9e493e4773 +9e4c752cd0 +9e4de40671 +9e6319faeb +9e6ddbb52d +9eadcea74f +9ecec5f8ea +9efb47b595 +9f30bfe61e +9f3734c3a4 +9f5b858101 +9f66640cda +9f913803e9 +9f97bc74c8 +9fbad86e20 +9fc2bad316 +9fc5c3af78 +9fcb310255 +9fcc256871 +9fd2fd4d47 +a0071ae316 +a023141022 +a046399a74 +a066e739c1 +a06722ba82 +a07a15dd64 +a07b47f694 +a09c39472e +a0b208fe2e +a0b61c959e +a0bc6c611d +a0e6da5ba2 +a1193d6490 +a14ef483ff +a14f709908 +a15ccc5658 +a16062456f +a174e8d989 +a177c2733c +a17c62e764 +a18ad065fc +a1aaf63216 +a1bb65fb91 +a1bd8e5349 +a1dfdd0cac +a2052e4f6c +a20fd34693 +a21ffe4d81 +a22349e647 +a235d01ec1 +a24f63e8a2 +a2554c9f6d +a263ce8a87 +a29bfc29ec +a2a80072d4 +a2a800ab63 +a2bcd10a33 +a2bdaff3b0 +a2c146ab0d +a2c996e429 +a2dc51ebe8 +a2e6608bfa +a2f2a55f01 +a301869dea +a31fccd2cc +a34f440f33 +a35e0206da +a36bdc4cab +a36e8c79d8 +a378053b20 +a37db3a2b3 +a38950ebc2 +a39a0eb433 +a39c9bca52 +a3a945dc8c +a3b40a0c1e +a3b8588550 +a3c502bec3 +a3f2878017 +a3f4d58010 +a3f51855c3 +a402dc0dfe +a4065a7eda +a412bb2fef +a416b56b53 +a41ec95906 +a43299e362 +a4757bd7af +a48c53c454 +a49dcf9ad5 +a4a506521f +a4ba7753d9 +a4bac06849 +a4f05d681c +a50c10060f +a50eb5a0ea +a5122c6ec6 +a522b1aa79 +a590915345 +a5b5b59139 +a5b77abe43 +a5c2b2c3e1 +a5cd17bb11 +a5da03aef1 +a5dd11de0d +a5ea2b93b6 +a5eaeac80b +a5ec5b0265 +a5f350a87e +a5f472caf4 +a6027a53cf +a61715bb1b +a61cf4389d +a61d9bbd9b +a6470dbbf5 +a64a40f3eb +a653d5c23b +a65bd23cb5 +a66e0b7ad4 +a66fc5053c +a68259572b +a6a810a92c +a6bc36937f +a6c3a374e9 +a6d8a4228d +a6f4e0817f +a71e0481f5 +a7203deb2d +a7392d4438 +a73d3c3902 +a7491f1578 +a74b9ca19c +a77b7a91df +a78195a5f5 +a78758d4ce +a7e6d6c29a +a800d85e88 +a832fa8790 +a83d06410d +a8999af004 +a8f78125b9 +a907b18df1 +a919392446 +a965504e88 +a96b84b8d2 +a973f239cd +a977126596 +a9804f2a08 +a984e56893 +a99738f24c +a99bdd0079 +a9c9c1517e +a9cbf9c41b +a9e42e3c0c +aa07b7c1c0 +aa175e5ec7 +aa1a338630 +aa27d7b868 +aa45f1caaf +aa49e46432 +aa51934e1b +aa6287bb6c +aa6d999971 +aa85278334 +aab33f0e2a +aaba004362 +aade4cf385 +aae78feda4 +aaed233bf3 +aaff16c2db +ab199e8dfb +ab23b78715 +ab2e1b5577 +ab33a18ded +ab45078265 +ab56201494 +ab90f0d24b +abab2e6c20 +abb50c8697 +abbe2d15a0 +abbe73cd21 +abe61a11bb +abeae8ce21 +ac2b431d5f +ac2cb1b9eb +ac31fcd6d0 +ac3d3a126d +ac46bd8087 +ac783ef388 +acb73e4297 +acbf581760 +accafc3531 +acf2c4b745 +acf44293a2 +acf736a27b +acff336758 +ad1fe56886 +ad28f9b9d9 +ad2de9f80e +ad397527b2 +ad3d1cfbcb +ad3fada9d9 +ad4108ee8e +ad54468654 +ad573f7d31 +ad6255bc29 +ad65ebaa07 +ad97cc064a +adabbd1cc4 +adb0b5a270 +adc648f890 +add21ee467 +adfd15ceef +adfdd52eac +ae01cdab63 +ae0b50ff4f +ae13ee3d70 +ae1bcbd423 +ae20d09dea +ae2cecf5f6 +ae3bc4a0ef +ae499c7514 +ae628f2cd4 +ae8545d581 +ae93214fe6 +ae9cd16dbf +aeba9ac967 +aebb242b5c +aed4e0b4c4 +aedd71f125 +aef3e2cb0e +af0b54cee3 +af3de54c7a +af5fd24a36 +af8826d084 +af8ad72057 +afb71e22c5 +afcb331e1f +afe1a35c1e +b01080b5d3 +b05ad0d345 +b0623a6232 +b064dbd4b7 +b06ed37831 +b06f5888e6 +b08dcc490e +b0a68228dc +b0aece727f +b0b0731606 +b0c7f11f9f +b0cca8b830 +b0dd580a89 +b0de66ca08 +b0df7c5c5c +b0f5295608 +b11099eb09 +b132a53086 +b1399fac64 +b13abc0c69 +b1457e3b5e +b15bf4453b +b179c4a82d +b17ee70e8c +b190b1aa65 +b19b3e22c0 +b19c561fab +b1d1cd2e6e +b1d7c03927 +b1d7fe2753 +b1f540a4bd +b1fc9c64e1 +b1fcbb3ced +b220939e93 +b22099b419 +b241e95235 +b2432ae86d +b2456267df +b247940d01 +b24af1c35c +b24f600420 +b24fe36b2a +b258fb0b7d +b26b219919 +b26d9904de +b274456ce1 +b27b28d581 +b2a26bc912 +b2a9c51e1b +b2b0baf470 +b2b2756fe7 +b2ce7699e3 +b2edc76bd2 +b2f6b52100 +b30bf47bcd +b34105a4e9 +b372a82edf +b3779a1962 +b379ab4ff5 +b37a1d69e3 +b37c01396e +b382b09e25 +b3996e4ba5 +b3d9ca2aee +b3dde1e1e9 +b3eb7f05eb +b40b25055c +b41e0f1f19 +b44e32a42b +b4805ae9cd +b4807569a5 +b48efceb3e +b493c25c7f +b4b565aba1 +b4b715a15b +b4d0c90bf4 +b4d84bc371 +b4e5ad97aa +b4eaea9e6b +b50f4b90d5 +b53f675641 +b54278cd43 +b554843889 +b573c0677a +b58d853734 +b5943b18ab +b5a09a83f3 +b5aae1fe25 +b5b9da5364 +b5eb64d419 +b5ebb1d000 +b5f1c0c96a +b5f7fece90 +b6070de1bb +b60a76fe73 +b61f998772 +b62c943664 +b63094ba0c +b64fca8100 +b673e7dcfb +b678b7db00 +b68fc1b217 +b69926d9fa +b6a1df3764 +b6a4859528 +b6b4738b78 +b6b4f847b7 +b6b8d502d4 +b6bb00e366 +b6d65a9eef +b6d79a0845 +b6e9ec577f +b6ec609f7b +b6f92a308d +b70a2c0ab1 +b70a5a0d50 +b70c052f2f +b70d231781 +b72ac6e10b +b7302d8226 +b73867d769 +b751e767f2 +b76df6e059 +b77e5eddef +b7a2c2c83c +b7bcbe6466 +b7c2a469c4 +b7d69da8f0 +b7f31b7c36 +b7f675fb98 +b7fb871660 +b82e5ad1c9 +b841cfb932 +b84b8ae665 +b85b78ac2b +b86c17caa6 +b86e50d82d +b871db031a +b87d56925a +b8aaa59b75 +b8c03d1091 +b8c3210036 +b8e16df00b +b8f34cf72e +b8fb75864e +b9004db86c +b9166cbae9 +b920b256a6 +b938d79dff +b93963f214 +b941aef1a0 +b94d34d14e +b964c57da4 +b96a95bc7a +b96c57d2c7 +b9b6bdde0c +b9bcb3e0f2 +b9d3b92169 +b9dd4b306c +b9f43ef41e +ba1f03c811 +ba3a775d7b +ba3c7f2a31 +ba3fcd417d +ba5e1f4faa +ba795f3089 +ba8a291e6a +ba98512f97 +bac9db04f5 +baedae3442 +baff40d29d +bb04e28695 +bb1b0ee89f +bb1c770fe7 +bb1fc34f99 +bb2d220506 +bb334e5cdb +bb337f9830 +bb721eb9aa +bb87ff58bd +bb89a6b18a +bbaa9a036a +bbb4302dda +bbd31510cf +bbe0256a75 +bc141b9ad5 +bc17ab8a99 +bc318160de +bc3b9ee033 +bc4240b43c +bc4ce49105 +bc4f71372d +bc6b8d6371 +bcaad44ad7 +bcc241b081 +bcc5d8095e +bcd1d39afb +bd0d849da4 +bd0e9ed437 +bd2c94730f +bd321d2be6 +bd3ec46511 +bd5b2e2848 +bd7e02b139 +bd96f9943a +bda224cb25 +bda4a82837 +bdb74e333f +bdccd69dde +bddcc15521 +be116aab29 +be15e18f1e +be1a284edb +be2a367a7b +be376082d0 +be3e3cffbd +be5d1d89a0 +be8b72fe37 +be9b29e08e +bea1f6e62c +bea83281b5 +beb921a4c9 +bec5e9edcd +beeb8a3f92 +bf2232b58d +bf28751739 +bf443804e8 +bf461df850 +bf5374f122 +bf551a6f60 +bf8d0f5ada +bf961167a6 +bfab1ad8f9 +bfcb05d88d +bfd8f6e6c9 +bfd91d0742 +bfe262322f +c013f42ed7 +c01878083f +c01faff1ed +c046fd0edb +c053e35f97 +c079a6482d +c0847b521a +c0a1e06710 +c0e8d4635c +c0e973ad85 +c0f49c6579 +c0f5b222d7 +c10d07c90d +c1268d998c +c130c3fc0c +c14826ad5e +c15b922281 +c16f09cb63 +c18e19d922 +c1c830a735 +c1e8aeea45 +c20a5ccc99 +c20fd5e597 +c219d6f8dc +c2406ae462 +c26f7b5824 +c279e641ee +c27adaeac5 +c2a35c1cda +c2a9903b8b +c2b62567c1 +c2b974ec8c +c2baaff7bf +c2be6900f2 +c304dd44d5 +c307f33da2 +c30a7b62c9 +c3128733ee +c31fa6c598 +c325c8201e +c32d4aa5d1 +c33f28249a +c34365e2d7 +c3457af795 +c34d120a88 +c3509e728d +c35e4fa6c4 +c36240d96f +c3641dfc5a +c37b17a4a9 +c39559ddf6 +c3b0c6e180 +c3b3d82e6c +c3be369fdb +c3bf1e40c2 +c3c760b015 +c3dd38bf98 +c3e4274614 +c3edc48cbd +c41e6587f5 +c4272227b0 +c42917fe82 +c438858117 +c44676563f +c44beb7472 +c45411dacb +c4571bedc8 +c46deb2956 +c479ee052e +c47d551843 +c49f07d46d +c4cc40c1fc +c4f256f5d5 +c4f5b1ddcc +c4ff9b4885 +c52bce43db +c544da6854 +c55784c766 +c557b69fbf +c593a3f7ab +c598faa682 +c5ab1f09c8 +c5b6da8602 +c5b9128d94 +c5e845c6b7 +c5fba7b341 +c60897f093 +c61fe6ed7c +c62188c536 +c64035b2e2 +c69689f177 +c6a12c131f +c6bb6d2d5c +c6c18e860f +c6d9526e0d +c6e55c33f0 +c7030b28bd +c70682c7cc +c70f9be8c5 +c71f30d7b6 +c73c8e747f +c760eeb8b3 +c7637cab0a +c7a1a17308 +c7bf937af5 +c7c2860db3 +c7cef4aee2 +c7ebfc5d57 +c813dcf13c +c82235a49a +c82a7619a1 +c82ecb90cb +c844f03dc7 +c8557963f3 +c89147e6e8 +c8a46ff0c8 +c8ab107dd5 +c8b869a04a +c8c7b306a6 +c8c8b28781 +c8d79e3163 +c8edab0415 +c8f494f416 +c8f6cba9fd +c909ceea97 +c9188f4980 +c922365dd4 +c92c8c3c75 +c937eb0b83 +c94b31b5e5 +c95cd17749 +c96379c03c +c96465ee65 +c965afa713 +c9734b451f +c9862d82dc +c98b6fe013 +c9999b7c48 +c99e92aaf0 +c9b3a8fbda +c9bf64e965 +c9c3cb3797 +c9d1c60cd0 +c9de9c22c4 +ca1828fa54 +ca346f17eb +ca3787d3d3 +ca4b99cbac +ca91c69e3b +ca91e99105 +caa8e97f81 +caac5807f8 +cabba242c2 +cad5a656a9 +cad673e375 +cad8a85930 +cae7b0a02b +cae7ef3184 +caeb6b6cbb +caecf0a5db +cb15312003 +cb2e35d610 +cb35a87504 +cb3f22b0cf +cbb410da64 +cc8728052e +cc892997b8 +cce03c2a9b +cd47a23e31 +cd4dc03dc0 +cd5ae611da +cd603bb9d1 +cd8f49734c +cdc6b1c032 +cdcfe008ad +cdd57027c2 +ce1af99b4b +ce1bc5743a +ce25872021 +ce2776f78f +ce49b1f474 +ce4f0a266f +ce5641b195 +ce6866aa19 +ce712ed3c9 +ce7d1c8117 +ce7dbeaa88 +ce9b015a5e +cea7697b25 +cebbd826cf +cec3415361 +cec41ad4f4 +ced49d26df +ced7705ab2 +cef824a1e1 +cf13f5c95a +cf4376a52d +cf85ab28b5 +cfc2e50b9d +cfcd571fff +cfd9d4ae47 +cfda2dcce5 +cff035928b +cff8191891 +d01608c2a5 +d01a8f1f83 +d021d68bca +d04258ca14 +d0483573dc +d04a90aaff +d05279c0bd +d0696bd5fc +d072fda75b +d0a83bcd9f +d0ab39112e +d0acde820f +d0b4442c71 +d0c65e9e95 +d0fb600c73 +d107a1457c +d123d674c1 +d14d1e9289 +d154e3388e +d177e9878a +d1802f69f8 +d182c4483a +d195d31128 +d200838929 +d205e3cff5 +d247420c4c +d2484bff33 +d26f6ed9b0 +d280fcd1cb +d2857f0faa +d292a50c7f +d295ea2dc7 +d2a58b4fa6 +d2b026739a +d2ebe0890f +d2ede5d862 +d301ca58cc +d3069da8bb +d343d4a77d +d355e634ef +d367fb5253 +d36d16358e +d38bc77e2c +d38d1679e2 +d3932ad4bd +d3987b2930 +d39934abe3 +d3ae1c3f4c +d3b088e593 +d3e6e05e16 +d3eefae7c5 +d3f55f5ab8 +d3f5c309cc +d4034a7fdf +d4193011f3 +d429c67630 +d42c0ff975 +d44a764409 +d44e6acd1d +d45158c175 +d454e8444f +d45f62717e +d48ebdcf74 +d49ab52a25 +d4a607ad81 +d4b063c7db +d4da13e9ba +d4dd1a7d00 +d4f4f7c9c3 +d521aba02e +d535bb1b97 +d53b955f78 +d55cb7a205 +d55f247a45 +d5695544d8 +d5853d9b8b +d5b6c6d94a +d5cae12834 +d5df027f0c +d5ee40e5d0 +d600046f73 +d632fd3510 +d6476cad55 +d65a7bae86 +d664c89912 +d689658f06 +d6917db4be +d69967143e +d699d3d798 +d69f757a3f +d6ac0e065c +d6c02bfda5 +d6c1b5749e +d6e12ef6cc +d6eed152c4 +d6faaaf726 +d704766646 +d708e1350c +d7135cf104 +d7157a9f44 +d719cf9316 +d724134cfd +d73a60a244 +d7411662da +d74875ea7c +d756f5a694 +d7572b7d8a +d763bd6d96 +d7697c8b13 +d7797196b4 +d79c834768 +d7b34e5d73 +d7bb6b37a7 +d7c7e064a6 +d7fbf545b3 +d82a0aa15b +d847e24abd +d8596701b7 +d86101499c +d87069ba86 +d87160957b +d874654b52 +d88a403092 +d8aee40f3f +d8e77a222d +d8eb07c381 +d9010348a1 +d90e3cf281 +d92532c7b2 +d927fae122 +d95707bca8 +d973b31c00 +d991cb471d +d992c69d37 +d99d770820 +d9b63abc11 +d9db6f1983 +d9e52be2d2 +d9edc82650 +da01070697 +da070ea4b7 +da080507b9 +da0e944cc4 +da28d94ff4 +da5d78b9d1 +da6003fc72 +da690fee9f +da6c68708f +da7a816676 +dac361e828 +dac71659b8 +dad980385d +daebc12b77 +db0968cdd3 +db231a7100 +db59282ace +db7f267c3f +dba35b87fd +dbba735a50 +dbca076acd +dbd66dc3ac +dbdc3c292b +dbf4a5b32b +dbfc417d28 +dc1745e0a2 +dc32a44804 +dc34b35e30 +dc504a4f79 +dc704dd647 +dc71bc6918 +dc7771b3be +dcf8c93617 +dd0f4c9fb9 +dd415df125 +dd601f9a3f +dd61d903df +dd77583736 +dd8636bd8b +dd9fe6c6ac +ddb2da4c14 +ddcd450d47 +dde8e67fb4 +ddfc3f04d3 +de2ab79dfa +de2f35b2fd +de30990a51 +de36b216da +de37403340 +de46e4943b +de4ddbccb1 +de5e480f05 +de6a9382ca +de74a601d3 +de827c510d +ded6069f7b +defb71c741 +df01f277f1 +df05214b82 +df0638b0a0 +df11931ffe +df1b0e4620 +df20a8650d +df2bc56d7c +df365282c6 +df39a0d9df +df3c430c24 +df5536cfb9 +df59cfd91d +df5e2152b3 +df741313c9 +df7626172f +df8ad5deb9 +df96aa609a +df9705605c +df9c91c4da +dfc0d3d27a +dfdbf91a99 +e00baaae9b +e0a938c6e7 +e0b2ceee6f +e0bdb5dfae +e0be1f6e17 +e0c478f775 +e0de82caa7 +e0f217dd59 +e0f7208874 +e0fb58395e +e1194c2e9d +e11adcd05d +e128124b9d +e1495354e4 +e1561d6d4b +e158805399 +e16945b951 +e19edcd34b +e1a1544285 +e1ab7957f4 +e1d26d35be +e1e957085b +e1f14510fa +e214b160f4 +e2167379b8 +e21acb20ab +e221105579 +e22ddf8a1b +e22de45950 +e22ffc469b +e23cca5244 +e252f46f0b +e25fa6cf39 +e26e486026 +e275760245 +e27bbedbfe +e29e9868a8 +e2b37ff8af +e2b608d309 +e2bef4da9a +e2c87a6421 +e2ea25542c +e2fb1d6497 +e2fcc99117 +e33c18412a +e348377191 +e352cb59c8 +e36ac982f0 +e391bc981e +e39e3e0a06 +e3bf38265f +e3d5b2cd21 +e3d60e82d5 +e3e3245492 +e3e4134877 +e3f4635e03 +e4004ee048 +e402d1afa5 +e415093d27 +e41ceb5d81 +e424653b78 +e42b6d3dbb +e42d60f0d4 +e436d0ff1e +e43d7ae2c5 +e4428801bc +e44e0b4917 +e470345ede +e48e8b4263 +e4922e3726 +e4936852bb +e495f32c60 +e499228f26 +e4af66e163 +e4b2095f58 +e4d19c8283 +e4d4872dab +e4e2983570 +e4eaa63aab +e4ef0a3a34 +e4f8e5f46e +e4ffb6d0dd +e53e21aa02 +e57f4f668b +e588433c1e +e597442c99 +e5abc0e96b +e5be628030 +e5ce96a55d +e5d6b70a9f +e5fde1574c +e625e1d27b +e6261d2348 +e6267d46bc +e6295f223f +e63463d8c6 +e6387bd1e0 +e653883384 +e65f134e0b +e668ef5664 +e672ccd250 +e674510b20 +e676107765 +e699da0cdf +e6be243065 +e6deab5e0b +e6f065f2b9 +e71629e7b5 +e72a7d7b0b +e72f6104e1 +e75a466eea +e76c55933f +e7784ec8ad +e78922e5e6 +e78d450a9c +e7c6354e77 +e7c8de1fce +e7ea10db28 +e803918710 +e8073a140b +e828dd02db +e845994987 +e8485a2615 +e85c5118a7 +e88b6736e4 +e8962324e3 +e8b3018d36 +e8cee8bf0b +e8d97ebece +e8da49ea6a +e8ed1a3ccf +e8f7904326 +e8f8341dec +e8fa21eb13 +e90c10fc4c +e914b8cac8 +e92b6bfea4 +e92e1b7623 +e93f83e512 +e9422ad240 +e9460b55f9 +e9502628f6 +e950befd5f +e9582bdd1b +e95e5afe0f +e97cfac475 +e98d57d99c +e98eda8978 +e99706b555 +e9bc0760ba +e9d3c78bf3 +e9ec1b7ea8 +ea065cc205 +ea138b6617 +ea16d3fd48 +ea2545d64b +ea286a581c +ea320da917 +ea345f3627 +ea3b94a591 +ea444a37eb +ea4a01216b +ea5672ffa8 +eaa99191cb +eaab4d746c +eac7a59bc1 +ead5d3835a +eaec65cfa7 +eaed1a87be +eb2f821c6f +eb383cb82e +eb6992fe02 +eb6ac20a01 +eb6d7ab39e +eb7921facd +eb8fce51a6 +ebbb90e9f9 +ebbf5c9ee1 +ebc4ec32e6 +ebe56e5ef8 +ec1299aee4 +ec139ff675 +ec193e1a01 +ec28252938 +ec387be051 +ec3d4fac00 +ec4186ce12 +ec579c2f96 +ecae59b782 +ecb33a0448 +ece6bc9e92 +ecfedd4035 +ecfff22fd6 +ed3291c3d6 +ed3cd5308d +ed3e6fc1a5 +ed72ae8825 +ed7455da68 +ed844e879f +ed8f814b2b +ed911a1f63 +ed9ff4f649 +eda8ab984b +edb8878849 +edbfdfe1b4 +edd22c46a2 +edd663afa3 +ede3552eae +edeab61ee0 +ee07583fc0 +ee316eaed6 +ee3f509537 +ee40a1e491 +ee4bf100f1 +ee6f9b01f9 +ee947ed771 +ee9706ac7f +ee9a7840ae +eeb90cb569 +eebf45e5c5 +eeed0c7d73 +ef0061a309 +ef07f1a655 +ef0a8e8f35 +ef232a2aed +ef308ad2e9 +ef44945428 +ef45ce3035 +ef5dde449d +ef5e770988 +ef6359cea3 +ef65268834 +ef6cb5eae0 +ef78972bc2 +ef8cfcfc4f +ef96501dd0 +ef9a2e976b +efb24f950f +efce0c1868 +efe5ac6901 +efe828affa +efea4e0523 +f0268aa627 +f0483250c8 +f04cf99ee6 +f05b189097 +f08928c6d3 +f09d74856f +f0a7607d63 +f0ad38da27 +f0c34e1213 +f0c7f86c29 +f0dfa18ba7 +f0eb3179f7 +f119bab27d +f14409b6a3 +f1489baff4 +f14c18cf6a +f15c607b92 +f1af214222 +f1b77bd309 +f1ba9e1a3e +f1d99239eb +f1dc710cf4 +f1ec5c08fa +f22648fe12 +f22d21f1f1 +f233257395 +f23e95dbe5 +f2445b1572 +f253b3486d +f277c7a6a4 +f2ab2b84d6 +f2b7c9b1f3 +f2b83d5ce5 +f2c276018f +f2cfd94d64 +f2dd6e3add +f2e7653f16 +f2f333ad06 +f2f55d6713 +f2fdb6abec +f305a56d9f +f3085d6570 +f3325c3338 +f3400f1204 +f34497c932 +f34a56525e +f36483c824 +f3704d5663 +f3734c4913 +f38e5aa5b4 +f3986fba44 +f3a0ffc7d9 +f3b24a7d28 +f3e6c35ec3 +f3fc0ea80b +f40a683fbe +f4207ca554 +f4377499c2 +f46184f393 +f46c2d0a6d +f46c364dca +f46f7a0b63 +f46fe141b0 +f470b9aeb0 +f47eb7437f +f48b535719 +f49e4866ac +f4aa882cfd +f4daa3dbd5 +f4dd51ac35 +f507a1b9dc +f51c5ac84b +f52104164b +f54c67b9bb +f5966cadd2 +f5bddf5598 +f5d85cfd17 +f5e2e7d6a0 +f5f051e9b4 +f5f8a93a76 +f6283e8af5 +f635e9568b +f6474735be +f659251be2 +f66981af4e +f6708fa398 +f697fe8e8f +f6adb12c42 +f6c7906ca4 +f6cd0a8016 +f6d6f15ae7 +f6e501892c +f6f59d986f +f6fe8c90a5 +f714160545 +f74c3888d7 +f7782c430e +f7783ae5f2 +f77ab47923 +f788a98327 +f7961ac1f0 +f7a71e7574 +f7a8521432 +f7afbf4947 +f7b7cd5f44 +f7cf4b4a39 +f7d49799ad +f7e0c9bb83 +f7e5b84928 +f7e6bd58be +f7f2a38ac6 +f7f6cb2d6d +f83f19e796 +f85796a921 +f8603c26b2 +f8819b42ec +f891f8eaa1 +f89288d10c +f895ae8cc1 +f8b4ac12f1 +f8c3fb2b01 +f8c8de2764 +f8db369b40 +f8fcb6a78c +f94aafdeef +f95d217b70 +f9681d5103 +f9750192a4 +f9823a32c2 +f991ddb4c2 +f99d535567 +f9ae3d98b7 +f9b6217959 +f9bd1fabf5 +f9c68eaa64 +f9d3e04c4f +f9daf64494 +f9e4cc5a0a +f9ea6b7f31 +f9f3852526 +fa04c615cf +fa08e00a56 +fa4370d74d +fa67744af3 +fa88d48a92 +fa8b904cc9 +fa9526bdf1 +fa9b9d2426 +fad633fbe1 +faf5222dc3 +faff0e15f1 +fb08c64e8c +fb23455a7f +fb2e19fa6e +fb34dfbb77 +fb47fcea1e +fb49738155 +fb4cbc514b +fb4e6062f7 +fb5ba7ad6e +fb63cd1236 +fb81157a07 +fb92abdaeb +fba22a6848 +fbaca0c9df +fbc645f602 +fbd77444cd +fbe53dc8e8 +fbe541dd73 +fbe8488798 +fbfd25174f +fc28cb305e +fc33b1ffd6 +fc6186f0bb +fc918e3a40 +fc96cda9d8 +fc9832eea4 +fcb10d0f81 +fcd20a2509 +fcf637e3ab +fcfd81727f +fd31890379 +fd33551c28 +fd542da05e +fd6789b3fe +fd77828200 +fd7af75f4d +fdb28d0fbb +fdb3d1fb1e +fdb8b04124 +fdc6e3d581 +fdfce7e6fc +fe0f76d41b +fe24b0677d +fe3c02699d +fe58b48235 +fe6a5596b8 +fe6c244f63 +fe7afec086 +fe985d510a +fe9db35d15 +fea8ffcd36 +feb1080388 +fed208bfca +feda5ad1c2 +feec95b386 +ff15a5eff6 +ff204daf4b +ff25f55852 +ff2ada194f +ff2ce142e8 +ff49d36d20 +ff5a1ec4f3 +ff66152b25 +ff692fdc56 +ff773b1a1e +ff97129478 +ffb904207d +ffc43fc345 +fffe5f8df6